/
generate.go
95 lines (79 loc) · 3.01 KB
/
generate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package lib
import (
"fmt"
"os"
. "github.com/dave/jennifer/jen"
)
const sdkPath = "github.com/0xys/stylus-go/sdk"
const sdkAlias = "sdk"
const contractAlias = "contract"
var (
FailureCode = Lit(0)
SuccessCode = Lit(0)
)
func GenContract(modID string, cont *Contract, out *os.File) error {
file := NewFile("main")
file.HeaderComment(fmt.Sprintf("Code generated by %s. DO NOT EDIT.", sdkPath))
file.ImportAlias(sdkPath, sdkAlias)
contractPath := fmt.Sprintf("%s/contract", modID)
file.ImportAlias(contractPath, contractAlias)
cases := make([]Code, len(cont.Functions))
for _, f := range cont.Functions {
sel := f.FuncMetadata.Selector.UInt32()
var callFuncCodes []Code
if !f.IsPayable() {
callFuncCodes = append(callFuncCodes, If(Op("!").Qual(sdkPath, "MsgValue()").Dot("IsZero").Call()).Block(Return(FailureCode)))
}
if f.IsPure() {
callFuncCodes = append(callFuncCodes, Qual(sdkPath, "SetPure").Call())
}
if f.IsView() {
callFuncCodes = append(callFuncCodes, Qual(sdkPath, "SetView").Call())
}
params := make([]Code, len(f.Params))
for j, pin := range f.Params {
id := Id(fmt.Sprintf("param%d", j))
callFuncCodes = append(callFuncCodes, List(id, Err()).Op(":=").Qual(sdkPath, fmt.Sprintf("Decode%s", pin.Type)).Call(Id("cd").Index(Lit(4+j*32), Empty())))
callFuncCodes = append(callFuncCodes, If(Err().Op("!=").Nil()).Block(Qual(sdkPath, "SetReturnString").Call(Id("err").Dot("Error").Call()), Return(FailureCode)))
params = append(params, id)
}
if len(f.Returns) > 1 {
callFuncCodes = append(callFuncCodes, List(Id("ret"), Err()).Op(":=").Id("cont").Dot(f.Name).Call(params...))
callFuncCodes = append(callFuncCodes, If(Err().Op("!=").Nil()).Block(Qual(sdkPath, "SetReturnString").Call(Id("err").Dot("Error").Call()), Return(FailureCode)))
callFuncCodes = append(callFuncCodes, Qual(sdkPath, "SetReturnBytes").Call(Id("ret").Dot("Bytes").Call()))
} else {
op := "="
if len(params) == 0 {
op = ":="
}
callFuncCodes = append(callFuncCodes, Err().Op(op).Id("cont").Dot(f.Name).Call(params...))
callFuncCodes = append(callFuncCodes, If(Err().Op("!=").Nil()).Block(Qual(sdkPath, "SetReturnString").Call(Id("err").Dot("Error").Call()), Return(FailureCode)))
}
c := Case(Lit(sel)).Block(
callFuncCodes...,
)
cases = append(cases, c)
}
file.Comment("export user_entrypoint")
file.Func().Id("user_entrypoint").Params(Id("args_len").Uint32()).Uint32().Block(
Qual(sdkPath, "Init").Call(Id("args_len")),
Defer().Qual(sdkPath, "Flush").Call(),
Line(),
Id("cont").Op(":=").Op("&").Qual(contractPath, cont.Name).Block(),
Id("cd").Op(":=").Qual(sdkPath, "GetCalldata").Call(),
If(Len(Id("cd")).Op("<").Lit(4)).Block(
Return(FailureCode),
),
Id("sel").Op(":=").Qual(sdkPath, "ToSelector").Call(Id("cd").Index(Empty(), Lit(4))),
Switch(Id("sel")).Block(
cases...,
),
Return(SuccessCode),
)
file.Comment("dummy main is needed")
file.Func().Id("main").Params().Block(
Id("user_entrypoint").Call(Lit(0)),
)
fmt.Fprintf(out, "%#v", file)
return nil
}