From 6602f735ab3d6089bd0e0fc5d3be25408b44631d Mon Sep 17 00:00:00 2001 From: Quint Daenen Date: Wed, 15 May 2024 13:15:42 +0200 Subject: [PATCH] Add marshal interface to principal/account ids. --- principal/accountid.go | 20 ++++++++++++++++++++ principal/accountid_test.go | 16 ++++++++++++++++ principal/principal.go | 20 ++++++++++++++++++++ principal/principal_test.go | 17 +++++++++++++++++ 4 files changed, 73 insertions(+) diff --git a/principal/accountid.go b/principal/accountid.go index e618a85..0d0a0cd 100644 --- a/principal/accountid.go +++ b/principal/accountid.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/binary" "encoding/hex" + "encoding/json" "fmt" "hash/crc32" ) @@ -63,7 +64,26 @@ func (id AccountIdentifier) Encode() string { return hex.EncodeToString(id.Bytes()) } +// MarshalJSON encodes the account identifier into JSON bytes as a string. +func (id AccountIdentifier) MarshalJSON() ([]byte, error) { + return json.Marshal(id.String()) +} + // String returns the hexadecimal representation of the account identifier. func (id AccountIdentifier) String() string { return id.Encode() } + +// UnmarshalJSON decodes the given JSON bytes into an account identifier from a string. +func (id *AccountIdentifier) UnmarshalJSON(bytes []byte) error { + var accountID string + if err := json.Unmarshal(bytes, &accountID); err != nil { + return err + } + decoded, err := DecodeAccountID(accountID) + if err != nil { + return err + } + *id = decoded + return nil +} diff --git a/principal/accountid_test.go b/principal/accountid_test.go index 68d5fb5..c3e1792 100644 --- a/principal/accountid_test.go +++ b/principal/accountid_test.go @@ -1,6 +1,7 @@ package principal_test import ( + "encoding/json" "fmt" "github.com/aviate-labs/agent-go/principal" "testing" @@ -24,3 +25,18 @@ func TestAccountIdentifier(t *testing.T) { } } } + +func TestAccountIdentifier_MarshalJSON(t *testing.T) { + original := principal.NewAccountID(principal.AnonymousID, principal.DefaultSubAccount) + raw, err := json.Marshal(original) + if err != nil { + t.Error(err) + } + var decoded principal.AccountIdentifier + if err := json.Unmarshal(raw, &decoded); err != nil { + t.Error(err) + } + if original != decoded { + t.Errorf("expected %v, got %v", original, decoded) + } +} diff --git a/principal/principal.go b/principal/principal.go index 964b99b..4ebde2f 100644 --- a/principal/principal.go +++ b/principal/principal.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/base32" "encoding/binary" + "encoding/json" "fmt" "github.com/fxamacker/cbor/v2" "hash/crc32" @@ -130,6 +131,11 @@ func (p Principal) MarshalCBOR() ([]byte, error) { return cbor.Marshal(p.Raw) } +// MarshalJSON converts the principal to its JSON representation as a string. +func (p Principal) MarshalJSON() ([]byte, error) { + return json.Marshal(p.String()) +} + // String implements the Stringer interface. func (p Principal) String() string { return p.Encode() @@ -139,3 +145,17 @@ func (p Principal) String() string { func (p *Principal) UnmarshalCBOR(bytes []byte) error { return cbor.Unmarshal(bytes, &p.Raw) } + +// UnmarshalJSON converts the JSON bytes into a principal from a string. +func (p *Principal) UnmarshalJSON(bytes []byte) error { + var principal string + if err := json.Unmarshal(bytes, &principal); err != nil { + return err + } + decoded, err := Decode(principal) + if err != nil { + return err + } + *p = decoded + return nil +} diff --git a/principal/principal_test.go b/principal/principal_test.go index a74566d..0719309 100644 --- a/principal/principal_test.go +++ b/principal/principal_test.go @@ -1,7 +1,9 @@ package principal_test import ( + "bytes" "encoding/hex" + "encoding/json" "fmt" "github.com/aviate-labs/agent-go/ic" "testing" @@ -38,3 +40,18 @@ func TestPrincipal(t *testing.T) { t.Fatal("expected reserved principal") } } + +func TestPrincipal_MarshalJSON(t *testing.T) { + original := principal.MustDecode("em77e-bvlzu-aq") + raw, err := original.MarshalJSON() + if err != nil { + t.Error(err) + } + var decoded principal.Principal + if err := json.Unmarshal(raw, &decoded); err != nil { + t.Error(err) + } + if !bytes.Equal(original.Raw, decoded.Raw) { + t.Errorf("expected %v, got %v", original, decoded) + } +}