diff --git a/cluster/cluster_test.go b/cluster/cluster_test.go index e4996deb2..8eb2c1ccf 100644 --- a/cluster/cluster_test.go +++ b/cluster/cluster_test.go @@ -83,7 +83,6 @@ func TestEncode(t *testing.T) { require.NoError(t, err) require.Equal(t, b1, b2) - require.Equal(t, definition, definition2) lock := cluster.Lock{ Definition: definition, @@ -127,7 +126,6 @@ func TestEncode(t *testing.T) { require.NoError(t, err) require.Equal(t, b1, b2) - require.Equal(t, lock, lock2) } } diff --git a/cluster/definition.go b/cluster/definition.go index 824151c2d..a72b8e1cc 100644 --- a/cluster/definition.go +++ b/cluster/definition.go @@ -91,6 +91,12 @@ type Definition struct { // Operators define the charon nodes in the cluster and their operators. Operators []Operator + + // unmarshalledConfigHash defines config hash from config_hash field from JSON string after json.Unmarshal. + unmarshalledConfigHash []byte + + // unmarshalledDefinitionHash defines definition hash from definition_hash field from JSON string after json.Unmarshal. + unmarshalledDefinitionHash []byte } // NodeIdx returns the node index for the peer. @@ -169,6 +175,30 @@ func (d Definition) Verify() error { return nil } +func (d Definition) VerifyHashes() error { + // Verify config_hash + confHash, err := d.ConfigHash() + if err != nil { + return errors.Wrap(err, "config hash") + } + + if !bytes.Equal(d.unmarshalledConfigHash, confHash[:]) { + return errors.New("invalid config hash") + } + + // Verify definition_hash + defHash, err := d.HashTreeRoot() + if err != nil { + return errors.Wrap(err, "definition hash") + } + + if !bytes.Equal(d.unmarshalledDefinitionHash, defHash[:]) { + return errors.New("invalid definition hash") + } + + return nil +} + // ConfigHash returns the config hash of the definition object. func (d Definition) ConfigHash() ([32]byte, error) { return configHash(d) @@ -316,19 +346,17 @@ func (d *Definition) UnmarshalJSON(data []byte) error { } var ( - def Definition - configHashJSON []byte - defHashJSON []byte - err error + def Definition + err error ) switch { case isJSONv1x1(version.Version): - def, configHashJSON, defHashJSON, err = unmarshalDefinitionV1x1(data) + def, err = unmarshalDefinitionV1x1(data) if err != nil { return err } case isJSONv1x2(version.Version): - def, configHashJSON, defHashJSON, err = unmarshalDefinitionV1x2(data) + def, err = unmarshalDefinitionV1x2(data) if err != nil { return err } @@ -336,26 +364,6 @@ func (d *Definition) UnmarshalJSON(data []byte) error { return errors.New("unsupported version") } - // Verify config_hash - configHash, err := def.ConfigHash() - if err != nil { - return errors.Wrap(err, "config hash") - } - - if !bytes.Equal(configHashJSON, configHash[:]) { - return errors.New("invalid config hash") - } - - // Verify definition_hash - defHash, err := def.HashTreeRoot() - if err != nil { - return errors.Wrap(err, "definition hash") - } - - if !bytes.Equal(defHashJSON, defHash[:]) { - return errors.New("invalid definition hash") - } - *d = def return nil @@ -407,55 +415,55 @@ func marshalDefinitionV1x2(def Definition, configHash, defHash [32]byte) ([]byte return resp, nil } -func unmarshalDefinitionV1x1(data []byte) (def Definition, configHashJSON, defHashJSON []byte, err error) { +func unmarshalDefinitionV1x1(data []byte) (Definition, error) { var defJSON definitionJSONv1x1 if err := json.Unmarshal(data, &defJSON); err != nil { - return Definition{}, nil, nil, errors.Wrap(err, "unmarshal definition v1_1") + return Definition{}, errors.Wrap(err, "unmarshal definition v1_1") } operators, err := operatorsFromV1x1(defJSON.Operators) if err != nil { - return Definition{}, nil, nil, err + return Definition{}, err } - def = Definition{ - Name: defJSON.Name, - UUID: defJSON.UUID, - Version: defJSON.Version, - Timestamp: defJSON.Timestamp, - NumValidators: defJSON.NumValidators, - Threshold: defJSON.Threshold, - FeeRecipientAddress: defJSON.FeeRecipientAddress, - WithdrawalAddress: defJSON.WithdrawalAddress, - DKGAlgorithm: defJSON.DKGAlgorithm, - ForkVersion: defJSON.ForkVersion, - Operators: operators, - } - - return def, defJSON.ConfigHash, defJSON.DefinitionHash, nil + return Definition{ + Name: defJSON.Name, + UUID: defJSON.UUID, + Version: defJSON.Version, + Timestamp: defJSON.Timestamp, + NumValidators: defJSON.NumValidators, + Threshold: defJSON.Threshold, + FeeRecipientAddress: defJSON.FeeRecipientAddress, + WithdrawalAddress: defJSON.WithdrawalAddress, + DKGAlgorithm: defJSON.DKGAlgorithm, + ForkVersion: defJSON.ForkVersion, + Operators: operators, + unmarshalledConfigHash: defJSON.ConfigHash, + unmarshalledDefinitionHash: defJSON.DefinitionHash, + }, nil } -func unmarshalDefinitionV1x2(data []byte) (def Definition, configHashJSON, defHashJSON []byte, err error) { +func unmarshalDefinitionV1x2(data []byte) (Definition, error) { var defJSON definitionJSONv1x2 if err := json.Unmarshal(data, &defJSON); err != nil { - return Definition{}, nil, nil, errors.Wrap(err, "unmarshal definition v1v2") + return Definition{}, errors.Wrap(err, "unmarshal definition v1v2") } - def = Definition{ - Name: defJSON.Name, - UUID: defJSON.UUID, - Version: defJSON.Version, - Timestamp: defJSON.Timestamp, - NumValidators: defJSON.NumValidators, - Threshold: defJSON.Threshold, - FeeRecipientAddress: defJSON.FeeRecipientAddress, - WithdrawalAddress: defJSON.WithdrawalAddress, - DKGAlgorithm: defJSON.DKGAlgorithm, - ForkVersion: defJSON.ForkVersion, - Operators: operatorsFromV1x2(defJSON.Operators), - } - - return def, defJSON.ConfigHash, defJSON.DefinitionHash, nil + return Definition{ + Name: defJSON.Name, + UUID: defJSON.UUID, + Version: defJSON.Version, + Timestamp: defJSON.Timestamp, + NumValidators: defJSON.NumValidators, + Threshold: defJSON.Threshold, + FeeRecipientAddress: defJSON.FeeRecipientAddress, + WithdrawalAddress: defJSON.WithdrawalAddress, + DKGAlgorithm: defJSON.DKGAlgorithm, + ForkVersion: defJSON.ForkVersion, + Operators: operatorsFromV1x2(defJSON.Operators), + unmarshalledConfigHash: defJSON.ConfigHash, + unmarshalledDefinitionHash: defJSON.DefinitionHash, + }, nil } // definitionJSONv1x1 is the json formatter of Definition for versions v1.0.0 and v1.1.1. diff --git a/dkg/disk.go b/dkg/disk.go index bb98408bf..7751afcb9 100644 --- a/dkg/disk.go +++ b/dkg/disk.go @@ -30,6 +30,7 @@ import ( "github.com/coinbase/kryptology/pkg/signatures/bls/bls_sig" "github.com/obolnetwork/charon/app/errors" + "github.com/obolnetwork/charon/app/log" "github.com/obolnetwork/charon/app/z" "github.com/obolnetwork/charon/cluster" "github.com/obolnetwork/charon/core" @@ -59,6 +60,13 @@ func loadDefinition(ctx context.Context, conf Config) (cluster.Definition, error return cluster.Definition{}, errors.Wrap(err, "unmarshal definition") } + // Verify config hash and definition hash from json string and resultant cluster.Definition. + if err := res.VerifyHashes(); err != nil && !conf.NoVerify { + return cluster.Definition{}, errors.Wrap(err, "cluster definition hash verification failed. Run with --no-verify to bypass verification at own risk") + } else if err != nil && conf.NoVerify { + log.Warn(ctx, "Ignoring failed cluster definition hash verification due to --no-verify flag", err) + } + return res, nil } diff --git a/dkg/disk_internal_test.go b/dkg/disk_internal_test.go index f7d4ee1d0..b055b04f7 100644 --- a/dkg/disk_internal_test.go +++ b/dkg/disk_internal_test.go @@ -16,6 +16,7 @@ package dkg import ( + "bytes" "context" "encoding/json" "fmt" @@ -74,7 +75,7 @@ func TestFetchDefinition(t *testing.T) { return } require.NoError(t, err) - require.Equal(t, tt.want, got) + require.True(t, compareDefinitions(t, got, tt.want)) }) } } @@ -91,49 +92,88 @@ func TestLoadDefinition(t *testing.T) { // Invalid definition invalidDef := cluster.Definition{} - invalidFile := "invalid-cluster-definition.json" - err = os.WriteFile(invalidFile, []byte{1, 2, 3}, 0o666) + invalidFile1 := "invalid-cluster-definition.json" + err = os.WriteFile(invalidFile1, []byte{1, 2, 3}, 0o666) require.NoError(t, err) + // Invalid definition without definition_hash and config_hash + invalidFile2 := "invalid-cluster-definition2.json" + defJSON := []byte(`{"name":"test cluster","operators":[{"address":"0x9A4C8145c7457b0BDC84Ba46729c3c9e15b56106","enr":"enr:-Iu4QFTSWOu_OplK-CYUv29EqIoMGQGtuHjTxLohMxOEMSxYFqraJdtWfMiwzS9wiGH-gB32IrYdyXSH-i5nJbLTm4yAgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQJ9hgLS0tOo-w3eLfoHVnENCJuN5QXDgtCX_cHQo5FpDoN0Y3CC52eDdWRwgudo","config_signature":"0x84b5e48201deebc59ac09c2aaec57870cad357d5a0b65a1954ee301b2760597d479a8ee2b6037963554b5323fb7944a389b9e86948ad89da08d6fffcc4ba5c5a01","enr_signature":"0x1530ab4bf5267f88c76850f9750e328698e0206f70a29bf2a6136cec3ffc365e620bddcd46ecfe667718dd6770b38c7c4c7bdf1e6c8ecaf0bc098d3959d9ae0c00"},{"address":"0x79AB788F445d5A689C34AD6e54e247865DE41E99","enr":"enr:-Iu4QDKAQ7dsqHud5m1T2FsjYcahgYRrzMiCZjjx9sRTOjnWH67n8ZEepVZ4WHp-XNn0c0CtFIB-KSBHeiKe8oDLztiAgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQPFe0POIvtf3fkTXUjkj7yLDJ2APptRF6CK8_9C8g0NvYN0Y3CC52mDdWRwgudq","config_signature":"0xc8a306ec102a83dce35be9a552d0b30a09e74b1c1b2316596245c0e8f2ffa5d063051373993ecb694d82e2d65a398ecd5ef06599661903470385998d19a702f001","enr_signature":"0x61c44ee84dfc3991fe9de0d25f2bc3cad7feb0d3341208c894f455a233bcb1d2647b4fd29c376ceb2e5d2990fb2af7b6eae10bfe5b55963404ea2571222b350c01"},{"address":"0xfDdd1CF7733Fd8a638020e963792f9Fcc0182Bf4","enr":"enr:-Iu4QNgEtRy6wbpdPCXrj52_rF4Ur7OQf6mOg1xfRpmzPgRYW-QSA-oUslOTmIPL8etUIg95quQoRJg9FIILIa6990uAgmlkgnY0gmlwhH8AAAGJc2VjcDI1NmsxoQLM5474DZHwbqwSbQFLrAO8PNh2AdZOXTYGy1ZItyDJaoN0Y3CC52uDdWRwguds","config_signature":"0xddb059b67e0603c0073536225e3a5ae5a7eeae178ec4c31dda521ed9b0209dc90b7e3295ee5ad4f35fbe0b8d62adf4aa4cfbddd3106f7563d234e0b134d9e93701","enr_signature":"0x3c556e60c4b44a4ea015f860759fa01e2e8664d105455b71e55845f451d7ec5e048ee56999c9f5333024bfcc087471fcb7d8e8a2685cb59d63266b4df092104501"}],"uuid":"04513690-AA41-CE01-6281-7901E9FB6D87","version":"v1.2.0","timestamp":"2022-09-07T18:46:30+05:30","num_validators":1,"threshold":2,"fee_recipient_address":"0xd805a5fcea20d3d27d3eee59d5dd5749e3271617","withdrawal_address":"0x75e896f172869cf3ade31c97f681cc1a4015ceed","dkg_algorithm":"default","fork_version":"0x00000000"}`) + err = os.WriteFile(invalidFile2, defJSON, 0o666) + require.NoError(t, err) + var invalidDef2 cluster.Definition + require.NoError(t, invalidDef2.UnmarshalJSON(defJSON)) + defer func() { require.NoError(t, os.Remove(validFile)) - require.NoError(t, os.Remove(invalidFile)) + require.NoError(t, os.Remove(invalidFile1)) + require.NoError(t, os.Remove(invalidFile2)) }() tests := []struct { - name string - defFile string - want cluster.Definition - wantErr bool + name string + defFile string + want cluster.Definition + noVerify bool + wantErr bool }{ { - name: "Load valid definition", - defFile: validFile, - want: validDef, - wantErr: false, + name: "Load valid definition", + defFile: validFile, + want: validDef, + noVerify: false, + wantErr: false, }, { - name: "Definition file doesn't exist", - defFile: "", - want: invalidDef, - wantErr: true, + name: "Definition file doesn't exist", + defFile: "", + want: invalidDef, + noVerify: false, + wantErr: true, }, { - name: "Load invalid definition", - defFile: invalidFile, - want: invalidDef, - wantErr: true, + name: "Load invalid definition", + defFile: invalidFile1, + want: invalidDef, + noVerify: false, + wantErr: true, + }, + { + name: "Load invalid definition with no verify", + defFile: invalidFile2, + want: invalidDef2, + noVerify: true, + wantErr: false, + }, + { + name: "Load invalid definition without no verify", + defFile: invalidFile2, + want: invalidDef2, + noVerify: false, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := loadDefinition(context.Background(), Config{DefFile: tt.defFile}) + got, err := loadDefinition(context.Background(), Config{DefFile: tt.defFile, NoVerify: tt.noVerify}) if tt.wantErr { require.Error(t, err) return } require.NoError(t, err) - require.Equal(t, tt.want, got) + require.True(t, compareDefinitions(t, got, tt.want)) }) } } + +func compareDefinitions(t *testing.T, a, b cluster.Definition) bool { + t.Helper() + + b1, err := json.Marshal(a) + require.NoError(t, err) + + b2, err := json.Marshal(b) + require.NoError(t, err) + + return bytes.Equal(b1, b2) +}