forked from tetratelabs/wazero
/
module.go
477 lines (422 loc) · 13.2 KB
/
module.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
package wasm
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/aghosn/gasm/wasm/leb128"
)
var (
magic = []byte{0x00, 0x61, 0x73, 0x6D}
version = []byte{0x01, 0x00, 0x00, 0x00}
ErrInvalidMagicNumber = errors.New("invalid magic number")
ErrInvalidVersion = errors.New("invalid version header")
)
type (
Module struct {
SecTypes []*FunctionType
SecImports []*ImportSegment
SecFunctions []uint32
SecTables []*TableType
SecMemory []*MemoryType
SecGlobals []*GlobalSegment
SecExports map[string]*ExportSegment
SecStart []uint32
SecElements []*ElementSegment
SecCodes []*CodeSegment
SecData []*DataSegment
IndexSpace *ModuleIndexSpace
}
ModuleIndexSpace struct {
Function []VirtualMachineFunction
Globals []*Global
Table [][]*uint32
Memory [][]byte
}
// initialized global
Global struct {
Type *GlobalType
Val interface{}
}
)
// DecodeModule decodes a `raw` module from io.Reader whose index spaces are yet to be initialized
func DecodeModule(r io.Reader) (*Module, error) {
// magic number
buf := make([]byte, 4)
if n, err := io.ReadFull(r, buf); err != nil || n != 4 {
return nil, ErrInvalidMagicNumber
}
for i := 0; i < 4; i++ {
if buf[i] != magic[i] {
return nil, ErrInvalidMagicNumber
}
}
// version
if n, err := io.ReadFull(r, buf); err != nil || n != 4 {
panic(err)
}
for i := 0; i < 4; i++ {
if buf[i] != version[i] {
return nil, ErrInvalidVersion
}
}
ret := &Module{}
if err := ret.readSections(r); err != nil {
return nil, fmt.Errorf("readSections failed: %w", err)
}
return ret, nil
}
// buildIndexSpaces build index spaces of the module with the given external modules
func (m *Module) buildIndexSpaces(externModules map[string]*Module) error {
m.IndexSpace = new(ModuleIndexSpace)
// resolve imports
if err := m.resolveImports(externModules); err != nil {
return fmt.Errorf("resolve imports: %w", err)
}
// fill in the gap between the definition and imported ones in index spaces
// note: MVP restricts the size of memory index spaces to 1
if diff := len(m.SecTables) - len(m.IndexSpace.Table); diff > 0 {
for i := 0; i < diff; i++ {
m.IndexSpace.Table = append(m.IndexSpace.Table, []*uint32{})
}
}
// fill in the gap between the definition and imported ones in index spaces
// note: MVP restricts the size of memory index spaces to 1
if diff := len(m.SecMemory) - len(m.IndexSpace.Memory); diff > 0 {
for i := 0; i < diff; i++ {
m.IndexSpace.Memory = append(m.IndexSpace.Memory, []byte{})
}
}
if err := m.buildGlobalIndexSpace(); err != nil {
return fmt.Errorf("build global index space: %w", err)
}
if err := m.buildFunctionIndexSpace(); err != nil {
return fmt.Errorf("build function index space: %w", err)
}
if err := m.buildTableIndexSpace(); err != nil {
return fmt.Errorf("build table index space: %w", err)
}
if err := m.buildMemoryIndexSpace(); err != nil {
return fmt.Errorf("build memory index space: %w", err)
}
return nil
}
func (m *Module) resolveImports(externModules map[string]*Module) error {
for _, is := range m.SecImports {
if err := m.resolveImport(is, externModules); err != nil {
return fmt.Errorf("%s: %w", is.Name, err)
}
}
return nil
}
func (m *Module) resolveImport(is *ImportSegment, externModules map[string]*Module) error {
em, ok := externModules[is.Module]
if !ok {
return fmt.Errorf("failed to resolve import of module name %s", is.Module)
}
es, ok := em.SecExports[is.Name]
if !ok {
return fmt.Errorf("not exported in module %s", is.Module)
}
if is.Desc.Kind != es.Desc.Kind {
return fmt.Errorf("type mismatch on export: got %#x but want %#x", es.Desc.Kind, is.Desc.Kind)
}
switch is.Desc.Kind {
case 0x00: // function
if err := m.applyFunctionImport(is, em, es); err != nil {
return fmt.Errorf("applyFunctionImport failed: %w", err)
}
case 0x01: // table
if err := m.applyTableImport(em, es); err != nil {
return fmt.Errorf("applyTableImport failed: %w", err)
}
case 0x02: // mem
if err := m.applyMemoryImport(em, es); err != nil {
return fmt.Errorf("applyMemoryImport: %w", err)
}
case 0x03: // global
if err := m.applyGlobalImport(em, es); err != nil {
return fmt.Errorf("applyGlobalImport: %w", err)
}
default:
return fmt.Errorf("invalid kind of import: %#x", is.Desc.Kind)
}
return nil
}
func (m *Module) applyFunctionImport(is *ImportSegment, em *Module, es *ExportSegment) error {
if es.Desc.Index >= uint32(len(em.IndexSpace.Function)) {
return fmt.Errorf("exported index out of range")
}
if is.Desc.TypeIndexPtr == nil {
return fmt.Errorf("is.Desc.TypeIndexPtr is nill")
}
iSig := m.SecTypes[*is.Desc.TypeIndexPtr]
f := em.IndexSpace.Function[es.Desc.Index]
if !hasSameSignature(iSig.ReturnTypes, f.FunctionType().ReturnTypes) {
return fmt.Errorf("return signature mimatch: %#x != %#x", iSig.ReturnTypes, f.FunctionType().ReturnTypes)
} else if !hasSameSignature(iSig.InputTypes, f.FunctionType().InputTypes) {
return fmt.Errorf("input signature mimatch: %#x != %#x", iSig.InputTypes, f.FunctionType().InputTypes)
}
m.IndexSpace.Function = append(m.IndexSpace.Function, f)
return nil
}
func (m *Module) applyTableImport(em *Module, es *ExportSegment) error {
if es.Desc.Index >= uint32(len(em.IndexSpace.Table)) {
return fmt.Errorf("exported index out of range")
}
// note: MVP restricts the size of table index spaces to 1
m.IndexSpace.Table = append(m.IndexSpace.Table, em.IndexSpace.Table[es.Desc.Index])
return nil
}
func (m *Module) applyMemoryImport(em *Module, es *ExportSegment) error {
if es.Desc.Index >= uint32(len(em.IndexSpace.Memory)) {
return fmt.Errorf("exported index out of range")
}
// note: MVP restricts the size of memory index spaces to 1
m.IndexSpace.Memory = append(m.IndexSpace.Memory, em.IndexSpace.Memory[es.Desc.Index])
return nil
}
func (m *Module) applyGlobalImport(em *Module, es *ExportSegment) error {
if es.Desc.Index >= uint32(len(em.IndexSpace.Globals)) {
return fmt.Errorf("exported index out of range")
}
gb := em.IndexSpace.Globals[es.Desc.Index]
if gb.Type.Mutable {
return fmt.Errorf("cannot import mutable global")
}
m.IndexSpace.Globals = append(em.IndexSpace.Globals, gb)
return nil
}
func (m *Module) buildGlobalIndexSpace() error {
for _, gs := range m.SecGlobals {
v, err := m.executeConstExpression(gs.Init)
if err != nil {
return fmt.Errorf("execution failed: %w", err)
}
m.IndexSpace.Globals = append(m.IndexSpace.Globals, &Global{
Type: gs.Type,
Val: v,
})
}
return nil
}
func (m *Module) buildFunctionIndexSpace() error {
for codeIndex, typeIndex := range m.SecFunctions {
if typeIndex >= uint32(len(m.SecTypes)) {
return fmt.Errorf("function type index out of range")
} else if codeIndex >= len(m.SecCodes) {
return fmt.Errorf("code index out of range")
}
f := &NativeFunction{
Signature: m.SecTypes[typeIndex],
Body: m.SecCodes[codeIndex].Body,
NumLocal: m.SecCodes[codeIndex].NumLocals,
}
brs, err := m.parseBlocks(f.Body)
if err != nil {
return fmt.Errorf("parse blocks: %w", err)
}
f.Blocks = brs
m.IndexSpace.Function = append(m.IndexSpace.Function, f)
}
return nil
}
func (m *Module) buildMemoryIndexSpace() error {
for _, d := range m.SecData {
// note: MVP restricts the size of memory index spaces to 1
if d.MemoryIndex >= uint32(len(m.IndexSpace.Memory)) {
return fmt.Errorf("index out of range of index space")
} else if d.MemoryIndex >= uint32(len(m.SecMemory)) {
return fmt.Errorf("index out of range of memory section")
}
rawOffset, err := m.executeConstExpression(d.OffsetExpression)
if err != nil {
return fmt.Errorf("calculate offset: %w", err)
}
offset, ok := rawOffset.(int32)
if !ok {
return fmt.Errorf("type assertion failed")
}
size := int(offset) + len(d.Init)
if m.SecMemory[d.MemoryIndex].Max != nil && uint32(size) > *(m.SecMemory[d.MemoryIndex].Max)*vmPageSize {
return fmt.Errorf("memory size out of limit %d * 64Ki", int(*(m.SecMemory[d.MemoryIndex].Max)))
}
memory := m.IndexSpace.Memory[d.MemoryIndex]
if size > len(memory) {
next := make([]byte, size)
copy(next, memory)
copy(next[offset:], d.Init)
m.IndexSpace.Memory[d.MemoryIndex] = next
} else {
copy(memory[offset:], d.Init)
}
}
return nil
}
func (m *Module) buildTableIndexSpace() error {
for _, elem := range m.SecElements {
// note: MVP restricts the size of memory index spaces to 1
if elem.TableIndex >= uint32(len(m.IndexSpace.Table)) {
return fmt.Errorf("index out of range of index space")
} else if elem.TableIndex >= uint32(len(m.SecTables)) {
// this is just in case since we could assume len(SecTables) == len(IndexSpace.Table)
return fmt.Errorf("index out of range of table section")
}
rawOffset, err := m.executeConstExpression(elem.OffsetExpr)
if err != nil {
return fmt.Errorf("calculate offset: %w", err)
}
offset32, ok := rawOffset.(int32)
if !ok {
return fmt.Errorf("type assertion failed")
}
offset := int(offset32)
size := offset + len(elem.Init)
if m.SecTables[elem.TableIndex].Limit.Max != nil &&
size > int(*(m.SecTables[elem.TableIndex].Limit.Max)) {
return fmt.Errorf("table size out of limit of %d", int(*(m.SecTables[elem.TableIndex].Limit.Max)))
}
table := m.IndexSpace.Table[elem.TableIndex]
if size > len(table) {
next := make([]*uint32, size)
copy(next, table)
for i := range elem.Init {
next[i+offset] = &elem.Init[i]
}
m.IndexSpace.Table[elem.TableIndex] = next
} else {
for i := range elem.Init {
table[i+offset] = &elem.Init[i]
}
}
}
return nil
}
type BlockType = FunctionType
func (m *Module) readBlockType(r io.Reader) (*BlockType, uint64, error) {
raw, num, err := leb128.DecodeInt33AsInt64(r)
if err != nil {
return nil, 0, fmt.Errorf("decode int33: %w", err)
}
var ret *BlockType
switch raw {
case -64: // 0x40 in original byte = nil
ret = &BlockType{}
case -1: // 0x7f in original byte = i32
ret = &BlockType{ReturnTypes: []ValueType{ValueTypeI32}}
case -2: // 0x7e in original byte = i64
ret = &BlockType{ReturnTypes: []ValueType{ValueTypeI64}}
case -3: // 0x7d in original byte = f32
ret = &BlockType{ReturnTypes: []ValueType{ValueTypeF32}}
case -4: // 0x7c in original byte = f64
ret = &BlockType{ReturnTypes: []ValueType{ValueTypeF64}}
default:
if raw < 0 || (raw >= int64(len(m.SecTypes))) {
return nil, 0, fmt.Errorf("invalid block type: %d", raw)
}
ret = m.SecTypes[raw]
}
return ret, num, nil
}
func (m *Module) parseBlocks(body []byte) (map[uint64]*NativeFunctionBlock, error) {
ret := map[uint64]*NativeFunctionBlock{}
stack := make([]*NativeFunctionBlock, 0)
for pc := uint64(0); pc < uint64(len(body)); pc++ {
rawOc := body[pc]
if 0x28 <= rawOc && rawOc <= 0x3e { // memory load,store
pc++
// align
_, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:]))
if err != nil {
return nil, fmt.Errorf("read memory align: %w", err)
}
pc += num
// offset
_, num, err = leb128.DecodeUint32(bytes.NewBuffer(body[pc:]))
if err != nil {
return nil, fmt.Errorf("read memory offset: %w", err)
}
pc += num - 1
continue
} else if 0x41 <= rawOc && rawOc <= 0x44 { // const instructions
pc++
switch OptCode(rawOc) {
case OptCodeI32Const:
_, num, err := leb128.DecodeInt32(bytes.NewBuffer(body[pc:]))
if err != nil {
return nil, fmt.Errorf("read immediate: %w", err)
}
pc += num - 1
case OptCodeI64Const:
_, num, err := leb128.DecodeInt64(bytes.NewBuffer(body[pc:]))
if err != nil {
return nil, fmt.Errorf("read immediate: %w", err)
}
pc += num - 1
case OptCodeF32Const:
pc += 3
case OptCodeF64Const:
pc += 7
}
continue
} else if (0x3f <= rawOc && rawOc <= 0x40) || // memory grow,size
(0x20 <= rawOc && rawOc <= 0x24) || // variable instructions
(0x0c <= rawOc && rawOc <= 0x0d) || // br,br_if instructions
(0x10 <= rawOc && rawOc <= 0x11) { // call,call_indirect
pc++
_, num, err := leb128.DecodeUint32(bytes.NewBuffer(body[pc:]))
if err != nil {
return nil, fmt.Errorf("read immediate: %w", err)
}
pc += num - 1
if rawOc == 0x11 { // if call_indirect
pc++
}
continue
} else if rawOc == 0x0e { // br_table
pc++
r := bytes.NewBuffer(body[pc:])
nl, num, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("read immediate: %w", err)
}
for i := uint32(0); i < nl; i++ {
_, n, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("read immediate: %w", err)
}
num += n
}
_, n, err := leb128.DecodeUint32(r)
if err != nil {
return nil, fmt.Errorf("read immediate: %w", err)
}
pc += n + num - 1
continue
}
switch OptCode(rawOc) {
case OptCodeBlock, OptCodeIf, OptCodeLoop:
bt, num, err := m.readBlockType(bytes.NewBuffer(body[pc+1:]))
if err != nil {
return nil, fmt.Errorf("read block: %w", err)
}
stack = append(stack, &NativeFunctionBlock{
StartAt: pc,
BlockType: bt,
BlockTypeBytes: num,
})
pc += num
case OptCodeElse:
stack[len(stack)-1].ElseAt = pc
case OptCodeEnd:
bl := stack[len(stack)-1]
stack = stack[:len(stack)-1]
bl.EndAt = pc
ret[bl.StartAt] = bl
}
}
if len(stack) > 0 {
return nil, fmt.Errorf("ill-nested block exists")
}
return ret, nil
}