Skip to content

Commit

Permalink
Fix non-empty arg for protobuf endpoints.
Browse files Browse the repository at this point in the history
  • Loading branch information
q-uint committed May 30, 2024
1 parent 1e0635b commit 1eeaf4b
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 30 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/test-registry.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
on:
push:
paths:
- registry/**
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.22.1'
- run: make test-registry
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21.0'
go-version: '1.22.1'
- uses: aviate-labs/setup-dfx@v0.3.2
with:
dfx-version: 0.18.0
Expand Down
28 changes: 16 additions & 12 deletions registry/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Client struct {
func New() (*Client, error) {
dp, err := NewDataProvider()
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create data provider: %w", err)
}
return &Client{
dp: dp,
Expand All @@ -26,15 +26,19 @@ func New() (*Client, error) {
func (c *Client) GetNNSSubnetID() (*principal.Principal, error) {
v, _, err := c.dp.GetValueUpdate([]byte("nns_subnet_id"), nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get NNS subnet ID: %w", err)
}
var nnsSubnetID v1.SubnetId
if err := proto.Unmarshal(v, &nnsSubnetID); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal NNS subnet ID: %w", err)
}
return &principal.Principal{Raw: nnsSubnetID.PrincipalId.Raw}, nil
}

func (c *Client) GetLatestVersion() (uint64, error) {
return c.dp.GetLatestVersion()
}

func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
nnsSubnetID, err := c.GetNNSSubnetID()
if err != nil {
Expand All @@ -56,7 +60,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
for {
records, _, err := c.dp.GetCertifiedChangesSince(currentVersion, nnsPublicKey)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get certified changes: %w", err)
}
currentVersion = records[len(records)-1].Version
for _, record := range records {
Expand All @@ -66,7 +70,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
} else {
var nodeRecord v1.NodeRecord
if err := proto.Unmarshal(record.Value, &nodeRecord); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal node record: %w", err)
}
nodeMap[strings.TrimPrefix(record.Key, "node_record_")] = &nodeRecord
}
Expand All @@ -76,7 +80,7 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
} else {
var nodeOperatorRecord v1.NodeOperatorRecord
if err := proto.Unmarshal(record.Value, &nodeOperatorRecord); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal node operator record: %w", err)
}
nodeOperatorMap[strings.TrimPrefix(record.Key, "node_operator_record_")] = &nodeOperatorRecord
}
Expand Down Expand Up @@ -124,23 +128,23 @@ func (c *Client) GetNodeListSince(version uint64) (NodeMap, error) {
func (c *Client) GetSubnetDetails(subnetID principal.Principal) (*v1.SubnetRecord, error) {
v, _, err := c.dp.GetValueUpdate([]byte(fmt.Sprintf("subnet_record_%s", subnetID)), nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get subnet details: %w", err)
}
var record v1.SubnetRecord
if err := proto.Unmarshal(v, &record); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal subnet details: %w", err)
}
return &record, nil
}

func (c *Client) GetSubnetIDs() ([]principal.Principal, error) {
v, _, err := c.dp.GetValueUpdate([]byte("subnet_list"), nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get subnet IDs: %w", err)
}
var list v1.SubnetListRecord
if err := proto.Unmarshal(v, &list); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal subnet IDs: %w", err)
}
var subnets []principal.Principal
for _, subnet := range list.Subnets {
Expand All @@ -152,11 +156,11 @@ func (c *Client) GetSubnetIDs() ([]principal.Principal, error) {
func (c *Client) GetSubnetPublicKey(subnetID principal.Principal) ([]byte, error) {
v, _, err := c.dp.GetValueUpdate([]byte(fmt.Sprintf("crypto_threshold_signing_public_key_%s", subnetID)), nil)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get subnet public key: %w", err)
}
var publicKey v1.PublicKey
if err := proto.Unmarshal(v, &publicKey); err != nil {
return nil, err
return nil, fmt.Errorf("failed to unmarshal subnet public key: %w", err)
}
if publicKey.Algorithm != v1.AlgorithmId_ALGORITHM_ID_THRES_BLS12_381 {
return nil, fmt.Errorf("unsupported public key algorithm")
Expand Down
14 changes: 11 additions & 3 deletions registry/client_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
package registry
package registry_test

import (
"github.com/aviate-labs/agent-go/registry"
"os"
"testing"
)

func TestClient_GetNodeListSince(t *testing.T) {
checkEnabled(t)
c, err := New()

c, err := registry.New()
if err != nil {
t.Fatal(err)
}

latestVersion, err := c.GetLatestVersion()
if err != nil {
t.Fatal(err)
}
if _, err := c.GetNodeListSince(0); err != nil {

if _, err := c.GetNodeListSince(latestVersion - 100); err != nil {
t.Fatal(err)
}
}
Expand Down
26 changes: 13 additions & 13 deletions registry/dataprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type DataProvider struct {
func NewDataProvider() (*DataProvider, error) {
a, err := agent.New(agent.DefaultConfig)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create agent: %w", err)
}
return &DataProvider{a: a}, nil
}
Expand All @@ -37,36 +37,36 @@ func (d DataProvider) GetCertifiedChangesSince(version uint64, publicKey []byte)
},
&resp,
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get certified changes: %w", err)
}
ht, err := NewHashTree(resp.HashTree)
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to create hash tree: %w", err)
}
rawCurrentVersion, err := ht.Lookup(hashtree.Label("current_version"))
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to lookup current version: %w", err)
}
currentVersion, err := leb128.DecodeUnsigned(bytes.NewReader(rawCurrentVersion))
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to decode current version: %w", err)
}

deltaNodes, err := ht.LookupSubTree(hashtree.Label("delta"))
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to lookup delta nodes: %w", err)
}
rawDeltas, err := hashtree.AllChildren(deltaNodes)
if err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get all children: %w", err)
}

var deltas []VersionedRecord
lastVersion := version
for _, delta := range rawDeltas {
req := new(v1.RegistryAtomicMutateRequest)
if err := proto.Unmarshal(delta.Value, req); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to unmarshal atomic mutate request: %w", err)
}

v := binary.BigEndian.Uint64(delta.Path[0])
Expand Down Expand Up @@ -99,7 +99,7 @@ func (d DataProvider) GetCertifiedChangesSince(version uint64, publicKey []byte)
publicKey,
digest[:],
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to verify certified data: %w", err)
}

return deltas, currentVersion.Uint64(), nil
Expand All @@ -116,7 +116,7 @@ func (d DataProvider) GetChangesSince(version uint64) ([]*v1.RegistryDelta, uint
},
&resp,
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get changes since: %w", err)
}
if resp.Error != nil {
return nil, 0, fmt.Errorf("error: %s", resp.Error.String())
Expand All @@ -132,7 +132,7 @@ func (d DataProvider) GetLatestVersion() (uint64, error) {
nil,
&resp,
); err != nil {
return 0, err
return 0, fmt.Errorf("failed to get latest version: %w", err)
}
return resp.Version, nil
}
Expand All @@ -154,7 +154,7 @@ func (d DataProvider) GetValue(key []byte, version *uint64) ([]byte, uint64, err
},
&resp,
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get value: %w", err)
}
if resp.Error != nil {
return nil, 0, fmt.Errorf("error: %s", resp.Error.String())
Expand All @@ -178,7 +178,7 @@ func (d DataProvider) GetValueUpdate(key []byte, version *uint64) ([]byte, uint6
},
&resp,
); err != nil {
return nil, 0, err
return nil, 0, fmt.Errorf("failed to get value: %w", err)
}
if resp.Error != nil {
return nil, 0, fmt.Errorf("error: %s", resp.Error.String())
Expand Down
18 changes: 18 additions & 0 deletions registry/dataprovider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package registry_test

import (
"github.com/aviate-labs/agent-go/registry"
"testing"
)

func TestDataProvider_GetLatestVersion(t *testing.T) {
checkEnabled(t)

dp, err := registry.NewDataProvider()
if err != nil {
t.Fatal(err)
}
if _, err := dp.GetLatestVersion(); err != nil {
t.Error(err)
}
}
4 changes: 3 additions & 1 deletion request.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ func (r *Request) MarshalCBOR() ([]byte, error) {
if len(r.MethodName) != 0 {
m["method_name"] = r.MethodName
}
if len(r.Arguments) != 0 {
if r.Arguments != nil {
// Some endpoints require the argument to be an empty array, not null.
// This is the case with the protobuf endpoints on the registry.
m["arg"] = r.Arguments
}
if len(r.Sender.Raw) != 0 {
Expand Down

0 comments on commit 1eeaf4b

Please sign in to comment.