-
Notifications
You must be signed in to change notification settings - Fork 157
/
vm.go
353 lines (303 loc) · 10.2 KB
/
vm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
package vm
/*
#include <stdint.h>
#include <stdlib.h>
#include <stddef.h>
#define FELT_SIZE 32
typedef struct CallInfo {
unsigned char contract_address[FELT_SIZE];
unsigned char class_hash[FELT_SIZE];
unsigned char entry_point_selector[FELT_SIZE];
unsigned char** calldata;
size_t len_calldata;
} CallInfo;
typedef struct BlockInfo {
unsigned long long block_number;
unsigned long long block_timestamp;
unsigned char sequencer_address[FELT_SIZE];
unsigned char gas_price_wei[FELT_SIZE];
unsigned char gas_price_fri[FELT_SIZE];
char* version;
unsigned char block_hash_to_be_revealed[FELT_SIZE];
unsigned char data_gas_price_wei[FELT_SIZE];
unsigned char data_gas_price_fri[FELT_SIZE];
unsigned char use_blob_data;
} BlockInfo;
extern void cairoVMCall(CallInfo* call_info_ptr, BlockInfo* block_info_ptr, uintptr_t readerHandle, char* chain_id,
unsigned long long max_steps);
extern void cairoVMExecute(char* txns_json, char* classes_json, char* paid_fees_on_l1_json,
BlockInfo* block_info_ptr, uintptr_t readerHandle, char* chain_id,
unsigned char skip_charge_fee, unsigned char skip_validate, unsigned char err_on_revert);
#cgo vm_debug LDFLAGS: -L./rust/target/debug -ljuno_starknet_rs -ldl -lm
#cgo !vm_debug LDFLAGS: -L./rust/target/release -ljuno_starknet_rs -ldl -lm
*/
import "C"
import (
"encoding/json"
"errors"
"fmt"
"runtime"
"runtime/cgo"
"unsafe"
"github.com/NethermindEth/juno/core"
"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/utils"
)
//go:generate mockgen -destination=../mocks/mock_vm.go -package=mocks github.com/NethermindEth/juno/vm VM
type VM interface {
Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateReader, network *utils.Network, maxSteps uint64, useBlobData bool) ([]*felt.Felt, error) //nolint:lll
Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt, blockInfo *BlockInfo,
state core.StateReader, network *utils.Network, skipChargeFee, skipValidate, errOnRevert, useBlobData bool,
) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, error)
}
type vm struct {
log utils.SimpleLogger
}
func New(log utils.SimpleLogger) VM {
return &vm{
log: log,
}
}
// callContext manages the context that a Call instance executes on
type callContext struct {
// state that the call is running on
state core.StateReader
log utils.SimpleLogger
// err field to be possibly populated in case of an error in execution
err string
// index of the transaction that generated err
errTxnIndex int64
// response from the executed Cairo function
response []*felt.Felt
// fee amount taken per transaction during VM execution
actualFees []*felt.Felt
traces []json.RawMessage
dataGasConsumed []*felt.Felt
}
func unwrapContext(readerHandle C.uintptr_t) *callContext {
context, ok := cgo.Handle(readerHandle).Value().(*callContext)
if !ok {
panic("cannot cast reader")
}
return context
}
//export JunoReportError
func JunoReportError(readerHandle C.uintptr_t, txnIndex C.long, str *C.char) {
context := unwrapContext(readerHandle)
context.errTxnIndex = int64(txnIndex)
context.err = C.GoString(str)
}
//export JunoAppendTrace
func JunoAppendTrace(readerHandle C.uintptr_t, jsonBytes *C.void, bytesLen C.size_t) {
context := unwrapContext(readerHandle)
byteSlice := C.GoBytes(unsafe.Pointer(jsonBytes), C.int(bytesLen))
context.traces = append(context.traces, json.RawMessage(byteSlice))
}
//export JunoAppendResponse
func JunoAppendResponse(readerHandle C.uintptr_t, ptr unsafe.Pointer) {
context := unwrapContext(readerHandle)
context.response = append(context.response, makeFeltFromPtr(ptr))
}
//export JunoAppendActualFee
func JunoAppendActualFee(readerHandle C.uintptr_t, ptr unsafe.Pointer) {
context := unwrapContext(readerHandle)
context.actualFees = append(context.actualFees, makeFeltFromPtr(ptr))
}
//export JunoAppendDataGasConsumed
func JunoAppendDataGasConsumed(readerHandle C.uintptr_t, ptr unsafe.Pointer) {
context := unwrapContext(readerHandle)
context.dataGasConsumed = append(context.dataGasConsumed, makeFeltFromPtr(ptr))
}
func makeFeltFromPtr(ptr unsafe.Pointer) *felt.Felt {
return new(felt.Felt).SetBytes(C.GoBytes(ptr, felt.Bytes))
}
func makePtrFromFelt(val *felt.Felt) unsafe.Pointer {
feltBytes := val.Bytes()
//nolint:gocritic
return C.CBytes(feltBytes[:])
}
type CallInfo struct {
ContractAddress *felt.Felt
ClassHash *felt.Felt
Selector *felt.Felt
Calldata []felt.Felt
}
type BlockInfo struct {
Header *core.Header
BlockHashToBeRevealed *felt.Felt
}
func copyFeltIntoCArray(f *felt.Felt, cArrPtr *C.uchar) {
if f == nil {
return
}
feltBytes := f.Bytes()
cArr := unsafe.Slice(cArrPtr, len(feltBytes))
for index := range feltBytes {
cArr[index] = C.uchar(feltBytes[index])
}
}
func makeCCallInfo(callInfo *CallInfo) (C.CallInfo, runtime.Pinner) {
var cCallInfo C.CallInfo
var pinner runtime.Pinner
copyFeltIntoCArray(callInfo.ContractAddress, &cCallInfo.contract_address[0])
copyFeltIntoCArray(callInfo.ClassHash, &cCallInfo.class_hash[0])
copyFeltIntoCArray(callInfo.Selector, &cCallInfo.entry_point_selector[0])
if len(callInfo.Calldata) > 0 {
// prepare calldata in Go heap.
cCallInfo.len_calldata = C.ulong(len(callInfo.Calldata))
calldataPtrs := make([]*C.uchar, 0, len(callInfo.Calldata))
for _, data := range callInfo.Calldata {
cArr := make([]C.uchar, felt.Bytes)
copyFeltIntoCArray(&data, &cArr[0])
pinner.Pin(&cArr[0])
calldataPtrs = append(calldataPtrs, &cArr[0])
}
pinner.Pin(&calldataPtrs[0])
cCallInfo.calldata = &calldataPtrs[0]
}
return cCallInfo, pinner
}
func makeCBlockInfo(blockInfo *BlockInfo, useBlobData bool) C.BlockInfo {
var cBlockInfo C.BlockInfo
cBlockInfo.block_number = C.ulonglong(blockInfo.Header.Number)
cBlockInfo.block_timestamp = C.ulonglong(blockInfo.Header.Timestamp)
copyFeltIntoCArray(blockInfo.Header.SequencerAddress, &cBlockInfo.sequencer_address[0])
copyFeltIntoCArray(blockInfo.Header.GasPrice, &cBlockInfo.gas_price_wei[0])
copyFeltIntoCArray(blockInfo.Header.GasPriceSTRK, &cBlockInfo.gas_price_fri[0])
cBlockInfo.version = cstring([]byte(blockInfo.Header.ProtocolVersion))
copyFeltIntoCArray(blockInfo.BlockHashToBeRevealed, &cBlockInfo.block_hash_to_be_revealed[0])
if blockInfo.Header.L1DAMode == core.Blob {
copyFeltIntoCArray(blockInfo.Header.L1DataGasPrice.PriceInWei, &cBlockInfo.data_gas_price_wei[0])
copyFeltIntoCArray(blockInfo.Header.L1DataGasPrice.PriceInFri, &cBlockInfo.data_gas_price_fri[0])
if useBlobData {
cBlockInfo.use_blob_data = 1
} else {
cBlockInfo.use_blob_data = 0
}
}
return cBlockInfo
}
func (v *vm) Call(callInfo *CallInfo, blockInfo *BlockInfo, state core.StateReader,
network *utils.Network, maxSteps uint64, useBlobData bool,
) ([]*felt.Felt, error) {
context := &callContext{
state: state,
response: []*felt.Felt{},
log: v.log,
}
handle := cgo.NewHandle(context)
defer handle.Delete()
cCallInfo, callInfoPinner := makeCCallInfo(callInfo)
cBlockInfo := makeCBlockInfo(blockInfo, useBlobData)
chainID := C.CString(network.L2ChainID)
C.cairoVMCall(
&cCallInfo,
&cBlockInfo,
C.uintptr_t(handle),
chainID,
C.ulonglong(maxSteps), //nolint:gocritic
)
callInfoPinner.Unpin()
C.free(unsafe.Pointer(chainID))
C.free(unsafe.Pointer(cBlockInfo.version))
if context.err != "" {
return nil, errors.New(context.err)
}
return context.response, nil
}
// Execute executes a given transaction set and returns the gas spent per transaction
func (v *vm) Execute(txns []core.Transaction, declaredClasses []core.Class, paidFeesOnL1 []*felt.Felt,
blockInfo *BlockInfo, state core.StateReader, network *utils.Network,
skipChargeFee, skipValidate, errOnRevert, useBlobData bool,
) ([]*felt.Felt, []*felt.Felt, []TransactionTrace, error) {
context := &callContext{
state: state,
log: v.log,
}
handle := cgo.NewHandle(context)
defer handle.Delete()
txnsJSON, classesJSON, err := marshalTxnsAndDeclaredClasses(txns, declaredClasses)
if err != nil {
return nil, nil, nil, err
}
paidFeesOnL1Bytes, err := json.Marshal(paidFeesOnL1)
if err != nil {
return nil, nil, nil, err
}
paidFeesOnL1CStr := cstring(paidFeesOnL1Bytes)
txnsJSONCstr := cstring(txnsJSON)
classesJSONCStr := cstring(classesJSON)
var skipChargeFeeByte byte
if skipChargeFee {
skipChargeFeeByte = 1
}
var skipValidateByte byte
if skipValidate {
skipValidateByte = 1
}
var errOnRevertByte byte
if errOnRevert {
errOnRevertByte = 1
}
cBlockInfo := makeCBlockInfo(blockInfo, useBlobData)
chainID := C.CString(network.L2ChainID)
C.cairoVMExecute(txnsJSONCstr,
classesJSONCStr,
paidFeesOnL1CStr,
&cBlockInfo,
C.uintptr_t(handle),
chainID,
C.uchar(skipChargeFeeByte),
C.uchar(skipValidateByte),
C.uchar(errOnRevertByte), //nolint:gocritic
)
C.free(unsafe.Pointer(classesJSONCStr))
C.free(unsafe.Pointer(paidFeesOnL1CStr))
C.free(unsafe.Pointer(txnsJSONCstr))
C.free(unsafe.Pointer(chainID))
C.free(unsafe.Pointer(cBlockInfo.version))
if context.err != "" {
if context.errTxnIndex >= 0 {
return nil, nil, nil, TransactionExecutionError{
Index: uint64(context.errTxnIndex),
Cause: errors.New(context.err),
}
}
return nil, nil, nil, errors.New(context.err)
}
traces := make([]TransactionTrace, len(context.traces))
for index, traceJSON := range context.traces {
if err := json.Unmarshal(traceJSON, &traces[index]); err != nil {
return nil, nil, nil, fmt.Errorf("unmarshal trace: %v", err)
}
//
}
return context.actualFees, context.dataGasConsumed, traces, nil
}
func marshalTxnsAndDeclaredClasses(txns []core.Transaction, declaredClasses []core.Class) (json.RawMessage, json.RawMessage, error) { //nolint:lll
txnJSONs := []json.RawMessage{}
for _, txn := range txns {
txnJSON, err := marshalTxn(txn)
if err != nil {
return nil, nil, err
}
txnJSONs = append(txnJSONs, txnJSON)
}
classJSONs := []json.RawMessage{}
for _, declaredClass := range declaredClasses {
declaredClassJSON, cErr := marshalClassInfo(declaredClass)
if cErr != nil {
return nil, nil, cErr
}
classJSONs = append(classJSONs, declaredClassJSON)
}
txnsJSON, err := json.Marshal(txnJSONs)
if err != nil {
return nil, nil, err
}
classesJSON, err := json.Marshal(classJSONs)
if err != nil {
return nil, nil, err
}
return txnsJSON, classesJSON, nil
}