Skip to content

Commit

Permalink
Fix variant structs.
Browse files Browse the repository at this point in the history
  • Loading branch information
q-uint committed Dec 6, 2023
1 parent 47612e8 commit 6e17358
Show file tree
Hide file tree
Showing 18 changed files with 875 additions and 179 deletions.
25 changes: 17 additions & 8 deletions candid/idl/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,6 @@ func EmptyOf(t Type) (any, error) {
return nil, UnknownTypeError{Type: t}
}

func IsType(v any, t Type) bool {
typ, err := TypeOf(v)
if err != nil {
return false
}
return typ.String() == t.String()
}

func TypeOf(v any) (Type, error) {
switch v := v.(type) {
case Null:
Expand Down Expand Up @@ -183,6 +175,23 @@ func TypeOf(v any) (Type, error) {
if err != nil {
return nil, err
}
if isVariantType(v) {
fields := make(map[string]Type)
for k, v := range m {
typ, err := TypeOf(v)
if err != nil {
return nil, err
}
switch t := typ.(type) {
case *OptionalType:
typ = t.Type
default:
return nil, UnknownValueTypeError{Value: v}
}
fields[k] = typ
}
return NewVariantType(fields), nil
}
return TypeOf(m)
case reflect.Ptr:
indirect := reflect.Indirect(reflect.ValueOf(v))
Expand Down
33 changes: 33 additions & 0 deletions candid/idl/encode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package idl

import (
"testing"
)

func TestEncode_issue7(t *testing.T) {
type ConsumerPermissionEnum = struct {
ReadOnly *Null `ic:"ReadOnly,variant"`
ReadAndWrite *Null `ic:"ReadAndWrite,variant"`
}

type SecretConsumer = struct {
Name string `ic:"name"`
PermissionType ConsumerPermissionEnum `ic:"permission_type"`
}

raw, err := Marshal([]any{
[]SecretConsumer{
{
Name: "test",
PermissionType: ConsumerPermissionEnum{ReadAndWrite: new(Null)},
},
},
})
if err != nil {
t.Fatal(err)
}
var v []SecretConsumer
if err := Unmarshal(raw, []any{&v}); err != nil {
t.Fatal(err)
}
}
2 changes: 1 addition & 1 deletion candid/idl/null.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func (NullType) EncodeType(_ *TypeDefinitionTable) ([]byte, error) {
}

func (NullType) EncodeValue(v any) ([]byte, error) {
if v != nil {
if _, ok := v.(Null); !ok && v != nil {
return nil, NewEncodeValueError(v, nullType)
}
return []byte{}, nil
Expand Down
4 changes: 4 additions & 0 deletions candid/idl/optional.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ import (
"reflect"
)

func Ptr[a any](v a) *a {
return &v
}

// OptionalType is the type of an optional value.
type OptionalType struct {
Type Type
Expand Down
48 changes: 47 additions & 1 deletion candid/idl/variant.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ import (
"github.com/aviate-labs/leb128"
)

func isVariantType(value any) bool {
v := reflect.ValueOf(value)
switch v.Kind() {
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
if !field.IsExported() {
continue
}

tag := ParseTags(field)
if tag.VariantType {
return true
}
}
}
return false
}

type Variant struct {
Name string
Value any
Expand Down Expand Up @@ -98,7 +117,11 @@ func (variant VariantType) EncodeType(tdt *TypeDefinitionTable) ([]byte, error)
func (variant VariantType) EncodeValue(value any) ([]byte, error) {
fs, ok := value.(Variant)
if !ok {
return nil, NewEncodeValueError(variant, varType)
v, err := variant.structToVariant(value)
if err != nil {
return nil, err
}
return variant.EncodeValue(*v)
}
for i, f := range variant.Fields {
if f.Name == fs.Name {
Expand Down Expand Up @@ -176,6 +199,29 @@ func (variant VariantType) UnmarshalGo(raw any, _v any) error {
return variant.unmarshalStruct(name, value, v)
}

func (variant VariantType) structToVariant(value any) (*Variant, error) {
v := reflect.ValueOf(value)
switch v.Kind() {
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
if !field.IsExported() {
continue
}

tag := ParseTags(field)
if !v.Field(i).IsNil() {
return &Variant{
Name: tag.Name,
Value: v.Field(i).Elem().Interface(),
}, nil
}
}
}
fmt.Printf("%#v\n", value)
return nil, fmt.Errorf("invalid value kind: %s", v.Kind())
}

func (variant VariantType) unmarshalMap(name string, value any, _v *map[string]any) error {
for _, f := range variant.Fields {
if f.Name != name {
Expand Down
157 changes: 144 additions & 13 deletions gen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,21 @@ func (g *Generator) GenerateActorTypes() ([]byte, error) {
}

func (g *Generator) GenerateMock() ([]byte, error) {
definitions := make(map[string]did.Data)
for _, definition := range g.ServiceDescription.Definitions {
switch definition := definition.(type) {
case did.Type:
definitions[definition.Id] = definition.Data
}
}
var methods []agentArgsMethod
for _, service := range g.ServiceDescription.Services {
for _, method := range service.Methods {
name := rawName(method.Name)
f := method.Func

var argumentTypes []agentArgsMethodArgument
var filledArgumentTypes []agentArgsMethodArgument
for i, t := range f.ArgTypes {
var n string
if (t.Name != nil) && (*t.Name != "") {
Expand All @@ -268,11 +276,15 @@ func (g *Generator) GenerateMock() ([]byte, error) {
Name: n,
Type: g.dataToString(g.PackageName, t.Data),
})
filledArgumentTypes = append(filledArgumentTypes, agentArgsMethodArgument{
Name: n,
Type: g.dataToGoReturnValue(definitions, g.PackageName, t.Data),
})
}

var returnTypes []string
for _, t := range f.ResTypes {
returnTypes = append(returnTypes, g.dataToString(g.PackageName, t.Data))
returnTypes = append(returnTypes, g.dataToGoReturnValue(definitions, g.PackageName, t.Data))
}

typ := "Call"
Expand All @@ -281,11 +293,12 @@ func (g *Generator) GenerateMock() ([]byte, error) {
}

methods = append(methods, agentArgsMethod{
RawName: name,
Name: funcName("", name),
Type: typ,
ArgumentTypes: argumentTypes,
ReturnTypes: returnTypes,
RawName: name,
Name: funcName("", name),
Type: typ,
ArgumentTypes: argumentTypes,
FilledArgumentTypes: filledArgumentTypes,
ReturnTypes: returnTypes,
})
}
}
Expand All @@ -307,6 +320,121 @@ func (g *Generator) GenerateMock() ([]byte, error) {
return io.ReadAll(&tmpl)
}

func (g *Generator) dataToGoReturnValue(definitions map[string]did.Data, prefix string, data did.Data) string {
switch t := data.(type) {
case did.Primitive:
switch t {
case "nat":
g.usedIDL = true
return "idl.NewNat(uint(0))"
case "int":
g.usedIDL = true
return "idl.NewInt(0)"
default:
return fmt.Sprintf("*new(%s)", g.dataToString(prefix, data))
}
case did.DataId:
switch data := definitions[t.String()].(type) {
case did.Record:
var fields []string
for _, f := range data {
var data did.Data
if f.Data != nil {
data = *f.Data
} else {
data = did.DataId(*f.NameData)
}
fields = append(fields, g.dataToGoReturnValue(definitions, prefix, data))
}
if len(fields) == 0 {
return fmt.Sprintf("%s{}", g.dataToString(prefix, t))
}
return fmt.Sprintf("%s{\n%s,\n}", g.dataToString(prefix, t), strings.Join(fields, ",\n"))
case did.Variant:
f := data[0]
var d did.Data
if f.Data != nil {
d = *f.Data
} else {
d = did.DataId(*f.NameData)
}
field := g.dataToGoReturnValue(definitions, prefix, d)
if !strings.HasPrefix(field, "*") {
g.usedIDL = true
field = fmt.Sprintf("idl.Ptr(%s)", field)
} else {
field = strings.TrimPrefix(field, "*")
}
var name string
if f.Name != nil {
name = *f.Name
} else {
name = *f.NameData
}
return fmt.Sprintf("%s{\n%s: %s,\n}", g.dataToString(prefix, t), name, field)
default:
switch data := data.(type) {
case did.Primitive:
switch data {
case "nat":
g.usedIDL = true
return "idl.NewNat(uint(0))"
case "int":
g.usedIDL = true
return "idl.NewInt(0)"
}
}
if data != nil {
return fmt.Sprintf("*new(%s)", g.dataToString(prefix, data))
}
return "*new(idl.Null)"
}
case did.Record:
var fields []string
for _, f := range t {
var data did.Data
if f.Data != nil {
data = *f.Data
} else {
data = did.DataId(*f.NameData)
}
fields = append(fields, g.dataToGoReturnValue(definitions, prefix, data))
}
if len(fields) == 0 {
return fmt.Sprintf("%s{}", g.dataToString(prefix, data))
}
return fmt.Sprintf("%s{\n%s,\n}", g.dataToString(prefix, data), strings.Join(fields, ",\n"))
case did.Variant:
f := t[0]
var name string
var d did.Data
if f.Data != nil {
name = *f.Name
d = *f.Data
} else {
name = *f.NameData
d = did.DataId(*f.NameData)
}
field := g.dataToGoReturnValue(definitions, prefix, d)
if !strings.HasPrefix(field, "*") {
g.usedIDL = true
field = fmt.Sprintf("idl.Ptr(%s)", field)
} else {
field = strings.TrimPrefix(field, "*")
}
return fmt.Sprintf("%s{\n%s: %s,\n}", g.dataToString(prefix, data), name, field)
case did.Vector:
switch t.Data.(type) {
case did.DataId:
return fmt.Sprintf("[]%s{%s}", funcName(prefix, t.Data.String()), g.dataToGoReturnValue(definitions, prefix, t.Data))
default:
return fmt.Sprintf("[]%s{%s}", g.dataToString(prefix, t.Data), g.dataToGoReturnValue(definitions, prefix, t.Data))
}
default:
return fmt.Sprintf("*new(%s)", g.dataToString(prefix, data))
}
}

func (g *Generator) dataToMotokoReturnValue(s rand.Source, definitions map[string]did.Data, data did.Data) string {
r := rand.New(s)
switch t := data.(type) {
Expand Down Expand Up @@ -496,7 +624,8 @@ func (g *Generator) dataToString(prefix string, data did.Data) string {
case "text":
return "string"
case "null":
return "struct{}"
g.usedIDL = true
return "idl.Null"
default:
panic(fmt.Sprintf("unknown primitive: %s", t))
}
Expand Down Expand Up @@ -572,11 +701,12 @@ func (g *Generator) dataToString(prefix string, data did.Data) string {
if 8 > sizeType {
sizeType = 8
}
g.usedIDL = true
records = append(records, struct {
originalName string
name string
typ string
}{originalName: name, name: name, typ: "struct{}"})
}{originalName: name, name: name, typ: "idl.Null"})
} else {
name := funcName("", *field.Name)
if l := len(name); l > sizeName {
Expand Down Expand Up @@ -649,11 +779,12 @@ type agentArgsDefinition struct {
}

type agentArgsMethod struct {
RawName string
Name string
Type string
ArgumentTypes []agentArgsMethodArgument
ReturnTypes []string
RawName string
Name string
Type string
ArgumentTypes []agentArgsMethodArgument
FilledArgumentTypes []agentArgsMethodArgument
ReturnTypes []string
}

type agentArgsMethodArgument struct {
Expand Down
Loading

0 comments on commit 6e17358

Please sign in to comment.