/
contract_mock_receiver.go
127 lines (107 loc) · 3.65 KB
/
contract_mock_receiver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package cltest
import (
"errors"
"reflect"
"testing"
"github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts/abi"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
evmclimocks "github.com/O1MaGnUmO1/erinaceus-vrf/core/chains/evm/client/mocks"
)
// funcSigLength is the length of the function signature (including the 0x)
// ex: 0x1234ABCD
const funcSigLength = 10
func NewContractMockReceiver(t *testing.T, ethMock *evmclimocks.Client, abi abi.ABI, address common.Address) contractMockReceiver {
return contractMockReceiver{
t: t,
ethMock: ethMock,
abi: abi,
address: address,
}
}
type contractMockReceiver struct {
t *testing.T
ethMock *evmclimocks.Client
abi abi.ABI
address common.Address
}
func (receiver contractMockReceiver) MockResponse(funcName string, responseArgs ...interface{}) *mock.Call {
funcSig := hexutil.Encode(receiver.abi.Methods[funcName].ID)
if len(funcSig) != funcSigLength {
receiver.t.Fatalf("Unable to find Registry contract function with name %s", funcName)
}
encoded := receiver.mustEncodeResponse(funcName, responseArgs...)
return receiver.ethMock.
On(
"CallContract",
mock.Anything,
mock.MatchedBy(func(callArgs ethereum.CallMsg) bool {
return *callArgs.To == receiver.address &&
hexutil.Encode(callArgs.Data)[0:funcSigLength] == funcSig
}),
mock.Anything).
Return(encoded, nil)
}
func (receiver contractMockReceiver) MockMatchedResponse(funcName string, matcher func(callArgs ethereum.CallMsg) bool, responseArgs ...interface{}) *mock.Call {
funcSig := hexutil.Encode(receiver.abi.Methods[funcName].ID)
if len(funcSig) != funcSigLength {
receiver.t.Fatalf("Unable to find Registry contract function with name %s", funcName)
}
encoded := receiver.mustEncodeResponse(funcName, responseArgs...)
return receiver.ethMock.
On(
"CallContract",
mock.Anything,
mock.MatchedBy(func(callArgs ethereum.CallMsg) bool {
return *callArgs.To == receiver.address &&
hexutil.Encode(callArgs.Data)[0:funcSigLength] == funcSig &&
matcher(callArgs)
}),
mock.Anything).
Return(encoded, nil)
}
func (receiver contractMockReceiver) MockRevertResponse(funcName string) *mock.Call {
funcSig := hexutil.Encode(receiver.abi.Methods[funcName].ID)
if len(funcSig) != funcSigLength {
receiver.t.Fatalf("Unable to find Registry contract function with name %s", funcName)
}
return receiver.ethMock.
On(
"CallContract",
mock.Anything,
mock.MatchedBy(func(callArgs ethereum.CallMsg) bool {
return *callArgs.To == receiver.address &&
hexutil.Encode(callArgs.Data)[0:funcSigLength] == funcSig
}),
mock.Anything).
Return(nil, errors.New("revert"))
}
func (receiver contractMockReceiver) mustEncodeResponse(funcName string, responseArgs ...interface{}) []byte {
if len(responseArgs) == 0 {
return []byte{}
}
var outputList []interface{}
firstArg := responseArgs[0]
isStruct := reflect.TypeOf(firstArg).Kind() == reflect.Struct
if isStruct && len(responseArgs) > 1 {
receiver.t.Fatal("cannot encode response with struct and multiple return values")
} else if isStruct {
outputList = structToInterfaceSlice(firstArg)
} else {
outputList = responseArgs
}
encoded, err := receiver.abi.Methods[funcName].Outputs.PackValues(outputList)
require.NoError(receiver.t, err)
return encoded
}
func structToInterfaceSlice(structArg interface{}) []interface{} {
v := reflect.ValueOf(structArg)
values := make([]interface{}, v.NumField())
for i := 0; i < v.NumField(); i++ {
values[i] = v.Field(i).Interface()
}
return values
}