diff --git a/pkg/hintrunner/hint.go b/pkg/hintrunner/hint.go index cef41c3b3..d72c6e28b 100644 --- a/pkg/hintrunner/hint.go +++ b/pkg/hintrunner/hint.go @@ -26,14 +26,14 @@ func (hint AllocSegment) Execute(vm *VM.VirtualMachine) error { segmentIndex := vm.MemoryManager.Memory.AllocateEmptySegment() memAddress := memory.MemoryValueFromSegmentAndOffset(segmentIndex, 0) - cell, err := hint.dst.Get(vm) + regAddr, err := hint.dst.Get(vm) if err != nil { - return fmt.Errorf("get destination cell: %v", err) + return fmt.Errorf("get register %s: %w", hint.dst, err) } - err = cell.Write(memAddress) + err = vm.MemoryManager.Memory.WriteToAddress(®Addr, &memAddress) if err != nil { - return fmt.Errorf("write cell: %v", err) + return fmt.Errorf("write to address %s: %v", regAddr, err) } return nil @@ -52,12 +52,12 @@ func (hint TestLessThan) String() string { func (hint TestLessThan) Execute(vm *VM.VirtualMachine) error { lhsVal, err := hint.lhs.Resolve(vm) if err != nil { - return fmt.Errorf("resolve lhs operand %s: %v", hint.lhs, err) + return fmt.Errorf("resolve lhs operand %s: %w", hint.lhs, err) } rhsVal, err := hint.rhs.Resolve(vm) if err != nil { - return fmt.Errorf("resolve rhs operand %s: %v", hint.rhs, err) + return fmt.Errorf("resolve rhs operand %s: %w", hint.rhs, err) } lhsFelt, err := lhsVal.ToFieldElement() @@ -75,14 +75,15 @@ func (hint TestLessThan) Execute(vm *VM.VirtualMachine) error { resFelt.SetOne() } - dstCell, err := hint.dst.Get(vm) + dstAddr, err := hint.dst.Get(vm) if err != nil { - return fmt.Errorf("get destination cell: %v", err) + return fmt.Errorf("get dst address %s: %w", dstAddr, err) } - err = dstCell.Write(memory.MemoryValueFromFieldElement(&resFelt)) + mv := memory.MemoryValueFromFieldElement(&resFelt) + err = vm.MemoryManager.Memory.WriteToAddress(&dstAddr, &mv) if err != nil { - return fmt.Errorf("write cell: %v", err) + return fmt.Errorf("write to dst address %s: %w", dstAddr, err) } return nil diff --git a/pkg/hintrunner/hintrunner_test.go b/pkg/hintrunner/hintrunner_test.go index e88f7a09d..4692dd4e1 100644 --- a/pkg/hintrunner/hintrunner_test.go +++ b/pkg/hintrunner/hintrunner_test.go @@ -19,7 +19,10 @@ func TestExistingHint(t *testing.T) { 10: allocHint, }) - vm.Context.Pc = memory.NewMemoryAddress(0, 10) + vm.Context.Pc = memory.MemoryAddress{ + SegmentIndex: 0, + Offset: 10, + } err := hr.RunHint(vm) require.Nil(t, err) require.Equal( @@ -40,7 +43,10 @@ func TestNoHint(t *testing.T) { 10: allocHint, }) - vm.Context.Pc = memory.NewMemoryAddress(0, 100) + vm.Context.Pc = memory.MemoryAddress{ + SegmentIndex: 0, + Offset: 100, + } err := hr.RunHint(vm) require.Nil(t, err) require.Equal(t, 2, len(vm.MemoryManager.Memory.Segments)) diff --git a/pkg/hintrunner/operand.go b/pkg/hintrunner/operand.go index 24ea73ce7..6926db2c9 100644 --- a/pkg/hintrunner/operand.go +++ b/pkg/hintrunner/operand.go @@ -16,35 +16,35 @@ import ( type CellRefer interface { fmt.Stringer - Get(vm *VM.VirtualMachine) (*memory.Cell, error) + Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) } type ApCellRef int16 func (ap ApCellRef) String() string { - return "ApCellRef" + return fmt.Sprintf("ApCellRef(%d)", ap) } -func (ap ApCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) { +func (ap ApCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) { res, overflow := safemath.SafeOffset(vm.Context.Ap, int16(ap)) if overflow { - return nil, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap)) + return memory.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap)) } - return vm.MemoryManager.Memory.Peek(VM.ExecutionSegment, res) + return memory.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil } type FpCellRef int16 func (fp FpCellRef) String() string { - return "FpCellRef" + return fmt.Sprintf("FpCellRef(%d)", fp) } -func (fp FpCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) { +func (fp FpCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) { res, overflow := safemath.SafeOffset(vm.Context.Fp, int16(fp)) if overflow { - return nil, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp)) + return memory.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp)) } - return vm.MemoryManager.Memory.Peek(VM.ExecutionSegment, res) + return memory.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil } // @@ -53,7 +53,7 @@ func (fp FpCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) { type ResOperander interface { fmt.Stringer - Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) + Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) } type Deref struct { @@ -64,12 +64,12 @@ func (deref Deref) String() string { return "Deref" } -func (deref Deref) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) { - cell, err := deref.deref.Get(vm) +func (deref Deref) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) { + address, err := deref.deref.Get(vm) if err != nil { - return nil, fmt.Errorf("get cell: %v", err) + return memory.MemoryValue{}, fmt.Errorf("get cell: %w", err) } - return cell.Read(), nil + return vm.MemoryManager.Memory.ReadFromAddress(&address) } type DoubleDeref struct { @@ -77,28 +77,34 @@ type DoubleDeref struct { offset int16 } -func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) { - cell, err := dderef.deref.Get(vm) +func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) { + lhsAddr, err := dderef.deref.Get(vm) if err != nil { - return nil, fmt.Errorf("get cell: %v", err) + return memory.MemoryValue{}, fmt.Errorf("get lhs address %s: %w", lhsAddr, err) + } + lhs, err := vm.MemoryManager.Memory.ReadFromAddress(&lhsAddr) + if err != nil { + return memory.MemoryValue{}, fmt.Errorf("read lhs address %s: %w", lhsAddr, err) } - lhs := cell.Read() - // Double deref implies the first value read must be an address + // Double deref implies the left hand side read must be an address address, err := lhs.ToMemoryAddress() if err != nil { - return nil, err + return memory.MemoryValue{}, err } newOffset, overflow := safemath.SafeOffset(address.Offset, dderef.offset) if overflow { - return nil, safemath.NewSafeOffsetError(address.Offset, dderef.offset) + return memory.MemoryValue{}, safemath.NewSafeOffsetError(address.Offset, dderef.offset) + } + resAddr := memory.MemoryAddress{ + SegmentIndex: address.SegmentIndex, + Offset: newOffset, } - resAddr := memory.NewMemoryAddress(address.SegmentIndex, newOffset) - value, err := vm.MemoryManager.Memory.ReadFromAddress(resAddr) + value, err := vm.MemoryManager.Memory.ReadFromAddress(&resAddr) if err != nil { - return nil, fmt.Errorf("read cell: %v", err) + return memory.MemoryValue{}, fmt.Errorf("read result at %s: %w", resAddr, err) } return value, nil @@ -112,7 +118,7 @@ func (imm Immediate) String() string { // todo(rodro): Specs from Starkware stablish this can be uint256 and not a felt. // Should we respect that, or go straight to felt? -func (imm Immediate) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) { +func (imm Immediate) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) { felt := &f.Element{} bigInt := (big.Int)(imm) // todo(rodro): do we require to check that big int is lesser than P, or do we @@ -139,24 +145,31 @@ func (bop BinaryOp) String() string { return "BinaryOperator" } -func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) { - cell, err := bop.lhs.Get(vm) +func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) { + lhsAddr, err := bop.lhs.Get(vm) + if err != nil { + return memory.MemoryValue{}, fmt.Errorf("get lhs address %s: %w", bop.lhs, err) + } + lhs, err := vm.MemoryManager.Memory.ReadFromAddress(&lhsAddr) if err != nil { - return nil, fmt.Errorf("get lhs operand %s: %v", bop.lhs, err) + return memory.MemoryValue{}, fmt.Errorf("read lhs address %s: %v", lhsAddr, err) } - lhs := cell.Read() rhs, err := bop.rhs.Resolve(vm) if err != nil { - return nil, fmt.Errorf("resolve rhs operand %s: %v", rhs, err) + return memory.MemoryValue{}, fmt.Errorf("resolve rhs operand %s: %v", rhs, err) } switch bop.operator { case Add: - return memory.EmptyMemoryValueAs(lhs.IsAddress()).Add(lhs, rhs) + mv := memory.EmptyMemoryValueAs(lhs.IsAddress() || rhs.IsAddress()) + err := mv.Add(&lhs, &rhs) + return mv, err case Mul: - return memory.EmptyMemoryValueAsFelt().Mul(lhs, rhs) + mv := memory.EmptyMemoryValueAsFelt() + err := mv.Mul(&lhs, &rhs) + return mv, err default: - return nil, fmt.Errorf("unknown binary operator: %d", bop.operator) + return memory.MemoryValue{}, fmt.Errorf("unknown binary operator: %d", bop.operator) } } diff --git a/pkg/hintrunner/operand_test.go b/pkg/hintrunner/operand_test.go index 4e8359461..950e9e2e5 100644 --- a/pkg/hintrunner/operand_test.go +++ b/pkg/hintrunner/operand_test.go @@ -14,12 +14,14 @@ func TestGetAp(t *testing.T) { vm.Context.Ap = 5 writeTo(vm, VM.ExecutionSegment, vm.Context.Ap+7, memory.MemoryValueFromInt(11)) - var apCell ApCellRef = 7 - cell, err := apCell.Get(vm) + var apReg ApCellRef = 7 + apAddr, err := apReg.Get(vm) + + require.NoError(t, err) + value, err := vm.MemoryManager.Memory.ReadFromAddress(&apAddr) require.NoError(t, err) - value := cell.Read() require.Equal(t, memory.MemoryValueFromInt(11), value) } @@ -28,12 +30,13 @@ func TestGetFp(t *testing.T) { vm.Context.Fp = 15 writeTo(vm, VM.ExecutionSegment, vm.Context.Fp-7, memory.MemoryValueFromInt(11)) - var fpCell FpCellRef = -7 - cell, err := fpCell.Get(vm) + var fpReg FpCellRef = -7 + fpAddr, err := fpReg.Get(vm) + require.NoError(t, err) + value, err := vm.MemoryManager.Memory.ReadFromAddress(&fpAddr) require.NoError(t, err) - value := cell.Read() require.Equal(t, memory.MemoryValueFromInt(11), value) } @@ -46,8 +49,8 @@ func TestResolveDeref(t *testing.T) { deref := Deref{apCell} value, err := deref.Resolve(vm) - require.NoError(t, err) + require.Equal(t, memory.MemoryValueFromInt(11), value) } diff --git a/pkg/hintrunner/testutils.go b/pkg/hintrunner/testutils.go index 83a77dcf9..b94e9b85d 100644 --- a/pkg/hintrunner/testutils.go +++ b/pkg/hintrunner/testutils.go @@ -11,11 +11,11 @@ func defaultVirtualMachine() *VM.VirtualMachine { return vm } -func writeTo(vm *VM.VirtualMachine, segment uint64, offset uint64, val *memory.MemoryValue) { - _ = vm.MemoryManager.Memory.Write(segment, offset, val) +func writeTo(vm *VM.VirtualMachine, segment uint64, offset uint64, val memory.MemoryValue) { + _ = vm.MemoryManager.Memory.Write(segment, offset, &val) } -func readFrom(vm *VM.VirtualMachine, segment uint64, offset uint64) *memory.MemoryValue { +func readFrom(vm *VM.VirtualMachine, segment uint64, offset uint64) memory.MemoryValue { val, _ := vm.MemoryManager.Memory.Read(segment, offset) return val } diff --git a/pkg/runners/zero/zero.go b/pkg/runners/zero/zero.go index 6dfa3f5ed..a5dc8ccad 100644 --- a/pkg/runners/zero/zero.go +++ b/pkg/runners/zero/zero.go @@ -33,7 +33,6 @@ func NewRunner(program *Program, proofmode bool, maxsteps uint64) (*ZeroRunner, return nil, fmt.Errorf("runner error: %w", err) } - // intialize hintrunner // todo(rodro): given the program get the appropiate hints hintrunner := hintrunner.NewHintRunner(make(map[uint64]hintrunner.Hinter)) @@ -57,7 +56,7 @@ func (runner *ZeroRunner) Run() error { return fmt.Errorf("initializing main entry point: %w", err) } - err = runner.RunUntilPc(end) + err = runner.RunUntilPc(&end) if err != nil { return err } @@ -77,78 +76,83 @@ func (runner *ZeroRunner) Run() error { return nil } -func (runner *ZeroRunner) InitializeMainEntrypoint() (*memory.MemoryAddress, error) { +func (runner *ZeroRunner) InitializeMainEntrypoint() (memory.MemoryAddress, error) { if runner.proofmode { startPc, ok := runner.program.Labels["__start__"] if !ok { - return nil, errors.New("start label not found. Try compiling with `--proof_mode`") + return memory.UnknownValue, errors.New("start label not found. Try compiling with `--proof_mode`") } endPc, ok := runner.program.Labels["__end__"] if !ok { - return nil, errors.New("end label not found. Try compiling with `--proof_mode`") + return memory.UnknownValue, errors.New("end label not found. Try compiling with `--proof_mode`") } offset := runner.segments()[VM.ExecutionSegment].Len() + dummyFPValue := memory.MemoryValueFromSegmentAndOffset( + VM.ProgramSegment, + runner.segments()[VM.ProgramSegment].Len()+offset+2, + ) // set dummy fp value err := runner.memory().Write( VM.ExecutionSegment, offset, - memory.MemoryValueFromSegmentAndOffset( - VM.ProgramSegment, - runner.segments()[VM.ProgramSegment].Len()+offset+2, - ), + &dummyFPValue, ) if err != nil { - return nil, err + return memory.UnknownValue, err } + + dummyPCValue := memory.MemoryValueFromUint[uint64](0) // set dummy pc value - err = runner.memory().Write(VM.ExecutionSegment, offset+1, memory.MemoryValueFromUint[uint64](0)) + err = runner.memory().Write(VM.ExecutionSegment, offset+1, &dummyPCValue) if err != nil { - return nil, err + return memory.UnknownValue, err } - runner.vm.Context.Pc = memory.NewMemoryAddress(VM.ProgramSegment, startPc) + runner.vm.Context.Pc = memory.MemoryAddress{SegmentIndex: VM.ProgramSegment, Offset: startPc} runner.vm.Context.Ap = offset + 2 runner.vm.Context.Fp = runner.vm.Context.Ap - return memory.NewMemoryAddress(VM.ProgramSegment, endPc), nil + return memory.MemoryAddress{SegmentIndex: VM.ProgramSegment, Offset: endPc}, nil } returnFp := memory.MemoryValueFromSegmentAndOffset( runner.memory().AllocateEmptySegment(), 0, ) - return runner.InitializeEntrypoint("main", nil, returnFp) + return runner.InitializeEntrypoint("main", nil, &returnFp) } func (runner *ZeroRunner) InitializeEntrypoint( funcName string, arguments []*f.Element, returnFp *memory.MemoryValue, -) (*memory.MemoryAddress, error) { +) (memory.MemoryAddress, error) { segmentIndex := runner.memory().AllocateEmptySegment() - end := memory.NewMemoryAddress(uint64(segmentIndex), 0) + end := memory.MemoryAddress{SegmentIndex: uint64(segmentIndex), Offset: 0} // write arguments for i := range arguments { - err := runner.memory().Write(VM.ExecutionSegment, uint64(i), memory.MemoryValueFromFieldElement(arguments[i])) + v := memory.MemoryValueFromFieldElement(arguments[i]) + err := runner.memory().Write(VM.ExecutionSegment, uint64(i), &v) if err != nil { - return nil, err + return memory.UnknownValue, err } } offset := runner.segments()[VM.ExecutionSegment].Len() err := runner.memory().Write(VM.ExecutionSegment, offset, returnFp) if err != nil { - return nil, err + return memory.UnknownValue, err } - err = runner.memory().Write(VM.ExecutionSegment, offset+1, memory.MemoryValueFromMemoryAddress(end)) + endMV := memory.MemoryValueFromMemoryAddress(&end) + err = runner.memory().Write(VM.ExecutionSegment, offset+1, &endMV) if err != nil { - return nil, err + return memory.UnknownValue, err } pc, ok := runner.program.Entrypoints[funcName] if !ok { - return nil, fmt.Errorf("unknwon entrypoint: %s", funcName) + return memory.UnknownValue, fmt.Errorf("unknwon entrypoint: %s", funcName) } - runner.vm.Context.Pc = memory.NewMemoryAddress(VM.ProgramSegment, pc) + runner.vm.Context.Pc = memory.MemoryAddress{SegmentIndex: VM.ProgramSegment, Offset: pc} runner.vm.Context.Ap = offset + 2 runner.vm.Context.Fp = runner.vm.Context.Ap @@ -166,7 +170,7 @@ func (runner *ZeroRunner) RunUntilPc(pc *memory.MemoryAddress) error { ) } - err := runner.vm.RunStep(runner.hintrunner) + err := runner.vm.RunStep(nil) if err != nil { return fmt.Errorf("pc %s step %d: %w", runner.pc(), runner.steps(), err) } @@ -185,7 +189,7 @@ func (runner *ZeroRunner) RunFor(steps uint64) error { ) } - err := runner.vm.RunStep(runner.hintrunner) + err := runner.vm.RunStep(nil) if err != nil { return fmt.Errorf( "pc %s step %d: %w", @@ -218,7 +222,7 @@ func (runner *ZeroRunner) segments() []*memory.Segment { return runner.vm.MemoryManager.Memory.Segments } -func (runner *ZeroRunner) pc() *memory.MemoryAddress { +func (runner *ZeroRunner) pc() memory.MemoryAddress { return runner.vm.Context.Pc } diff --git a/pkg/runners/zero/zero_test.go b/pkg/runners/zero/zero_test.go index b154407fc..a5eb76d1c 100644 --- a/pkg/runners/zero/zero_test.go +++ b/pkg/runners/zero/zero_test.go @@ -30,11 +30,11 @@ func TestSimpleProgram(t *testing.T) { endPc, err := runner.InitializeMainEntrypoint() require.NoError(t, err) - expectedPc := memory.NewMemoryAddress(3, 0) + expectedPc := memory.MemoryAddress{SegmentIndex: 3, Offset: 0} require.Equal(t, expectedPc, endPc) - err = runner.RunUntilPc(endPc) + err = runner.RunUntilPc(&endPc) require.NoError(t, err) executionSegment := runner.segments()[VM.ExecutionSegment] @@ -43,9 +43,9 @@ func TestSimpleProgram(t *testing.T) { t, createSegment( // return fp - memory.NewMemoryAddress(2, 0), + &memory.MemoryAddress{SegmentIndex: 2, Offset: 0}, // next pc - expectedPc, + &expectedPc, 2, 3, 4, @@ -76,10 +76,10 @@ func TestStepLimitExceeded(t *testing.T) { endPc, err := runner.InitializeMainEntrypoint() require.NoError(t, err) - expectedPc := memory.NewMemoryAddress(3, 0) + expectedPc := memory.MemoryAddress{SegmentIndex: 3, Offset: 0} require.Equal(t, expectedPc, endPc) - err = runner.RunUntilPc(endPc) + err = runner.RunUntilPc(&endPc) require.ErrorContains(t, err, "step limit exceeded") executionSegment := runner.segments()[VM.ExecutionSegment] @@ -88,9 +88,9 @@ func TestStepLimitExceeded(t *testing.T) { t, createSegment( // return fp - memory.NewMemoryAddress(2, 0), + &memory.MemoryAddress{SegmentIndex: 2, Offset: 0}, // next pc - expectedPc, + &expectedPc, 2, 3, 5, @@ -103,7 +103,7 @@ func TestStepLimitExceeded(t *testing.T) { assert.Equal(t, uint64(2), runner.vm.Context.Ap) assert.Equal(t, uint64(2), runner.vm.Context.Fp) // the fourth instruction starts at 0:6 because all previous one have size 2 - assert.Equal(t, memory.NewMemoryAddress(0, 6), runner.vm.Context.Pc) + assert.Equal(t, memory.MemoryAddress{SegmentIndex: 0, Offset: 6}, runner.vm.Context.Pc) // step limit exceeded assert.Equal(t, uint64(3), runner.steps()) } @@ -140,10 +140,7 @@ func TestStepLimitExceededProofMode(t *testing.T) { t, createSegment( // return fp - memory.NewMemoryAddress( - 0, - uint64(len(program.Bytecode)+2), - ), + &memory.MemoryAddress{SegmentIndex: 0, Offset: uint64(len(program.Bytecode) + 2)}, // next pc 0, 2, @@ -161,7 +158,7 @@ func TestStepLimitExceededProofMode(t *testing.T) { assert.Equal(t, uint64(2), runner.vm.Context.Ap) assert.Equal(t, uint64(2), runner.vm.Context.Fp) // it repeats the last instruction at 0:12 - assert.Equal(t, memory.NewMemoryAddress(0, 12), runner.vm.Context.Pc) + assert.Equal(t, memory.MemoryAddress{SegmentIndex: 0, Offset: 12}, runner.vm.Context.Pc) // step limit exceeded assert.Equal(t, uint64(maxstep), runner.steps()) } @@ -482,14 +479,14 @@ func BenchmarkRunnerWithFibonacci(b *testing.B) { } func createSegment(values ...any) *memory.Segment { - data := make([]*memory.Cell, len(values)) + data := make([]memory.MemoryValue, len(values)) for i := range values { if values[i] != nil { - memVal, err := memory.MemoryValueFromAny(values[i]) + var err error + data[i], err = memory.MemoryValueFromAny(values[i]) if err != nil { panic(err) } - data[i] = &memory.Cell{Value: memVal, Accessed: true} } } return &memory.Segment{ diff --git a/pkg/vm/memory/memory.go b/pkg/vm/memory/memory.go index 33d84d154..2949b33f2 100644 --- a/pkg/vm/memory/memory.go +++ b/pkg/vm/memory/memory.go @@ -7,43 +7,8 @@ import ( f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) -// Represents a write-once Memory Cell -type Cell struct { - Value *MemoryValue - Accessed bool -} - -func (cell *Cell) Write(value *MemoryValue) error { - if cell.Accessed && cell.Value != nil && !cell.Value.Equal(value) { - return fmt.Errorf( - "rewriting cell: old value: %s, new value: %s", - cell.Value.String(), - value.String(), - ) - } - - cell.Accessed = true - cell.Value = value - return nil -} - -func (cell *Cell) Read() *MemoryValue { - cell.Accessed = true - if cell.Value == nil { - cell.Value = EmptyMemoryValueAsFelt() - } - return cell.Value -} - -func (cell *Cell) String() string { - if !cell.Accessed { - return "-" - } - return cell.Value.String() -} - type Segment struct { - Data []*Cell + Data []MemoryValue // the max index where a value was written LastIndex int } @@ -51,21 +16,21 @@ type Segment struct { func EmptySegment() *Segment { // empty segments have capacity 100 as a default return &Segment{ - Data: make([]*Cell, 0, 100), + Data: make([]MemoryValue, 0, 100), LastIndex: -1, } } func EmptySegmentWithCapacity(capacity int) *Segment { return &Segment{ - Data: make([]*Cell, 0, capacity), + Data: make([]MemoryValue, 0, capacity), LastIndex: -1, } } func EmptySegmentWithLength(length int) *Segment { return &Segment{ - Data: make([]*Cell, length), + Data: make([]MemoryValue, length), LastIndex: length - 1, } } @@ -89,43 +54,42 @@ func (segment *Segment) Write(offset uint64, value *MemoryValue) error { if offset >= segment.Len() { segment.LastIndex = int(offset) } - if segment.Data[offset] == nil { - segment.Data[offset] = &Cell{} - } - err := segment.Data[offset].Write(value) - if err != nil { - return fmt.Errorf("write cell at segment offset %d: %v", offset, err) + cell := &segment.Data[offset] + if cell.Known() && !cell.Equal(value) { + return fmt.Errorf( + "rewriting cell: old value: %s, new value: %s", + cell.String(), + value.String(), + ) } + segment.Data[offset] = *value return nil } // Reads a memory value from a specified offset at the segment -func (segment *Segment) Read(offset uint64) *MemoryValue { +func (segment *Segment) Read(offset uint64) MemoryValue { if offset >= segment.RealLen() { segment.IncreaseSegmentSize(offset + 1) } if offset > segment.Len() { segment.LastIndex = int(offset) } - if segment.Data[offset] == nil { - segment.Data[offset] = &Cell{} - } - return segment.Data[offset].Read() + cell := &segment.Data[offset] + if !cell.Known() { + *cell = EmptyMemoryValueAsFelt() + } + return *cell } -func (segment *Segment) Peek(offset uint64) *Cell { +func (segment *Segment) Peek(offset uint64) MemoryValue { if offset >= segment.RealLen() { segment.IncreaseSegmentSize(offset + 1) } if offset >= segment.Len() { segment.LastIndex = int(offset) } - - if segment.Data[offset] == nil { - segment.Data[offset] = &Cell{} - } return segment.Data[offset] } @@ -140,11 +104,11 @@ func (segment *Segment) IncreaseSegmentSize(newSize uint64) { )) } - var newSegmentData []*Cell + var newSegmentData []MemoryValue if cap(segmentData) > int(newSize) { newSegmentData = segmentData[:cap(segmentData)] } else { - newSegmentData = make([]*Cell, safemath.Max(newSize, uint64(len(segmentData)*2))) + newSegmentData = make([]MemoryValue, safemath.Max(newSize, uint64(len(segmentData)*2))) copy(newSegmentData, segmentData) } segment.Data = newSegmentData @@ -176,7 +140,7 @@ func (segment *Segment) String() string { if i < int(segment.Len())-5 { continue } - if segment.Data[i].Accessed { + if segment.Data[i].Known() { header += fmt.Sprintf("[%d]-> %s\n", i, segment.Data[i].String()) } } @@ -202,7 +166,7 @@ func (memory *Memory) AllocateSegment(data []*f.Element) (int, error) { newSegment := EmptySegmentWithLength(len(data)) for i := range data { memVal := MemoryValueFromFieldElement(data[i]) - err := newSegment.Write(uint64(i), memVal) + err := newSegment.Write(uint64(i), &memVal) if err != nil { return 0, err } @@ -233,9 +197,9 @@ func (memory *Memory) WriteToAddress(address *MemoryAddress, value *MemoryValue) // Reads a memory value given the segment index and offset. Errors if reading from // an unallocated space. If reading a cell which hasn't been accesed before, it is // initalized with its default zero value -func (memory *Memory) Read(segmentIndex uint64, offset uint64) (*MemoryValue, error) { +func (memory *Memory) Read(segmentIndex uint64, offset uint64) (MemoryValue, error) { if segmentIndex > uint64(len(memory.Segments)) { - return nil, fmt.Errorf("unallocated segment at index %d", segmentIndex) + return MemoryValue{}, fmt.Errorf("unallocated segment at index %d", segmentIndex) } return memory.Segments[segmentIndex].Read(offset), nil } @@ -243,19 +207,19 @@ func (memory *Memory) Read(segmentIndex uint64, offset uint64) (*MemoryValue, er // Reads a memory value from a memory address. Errors if reading from an unallocated // space. If reading a cell which hasn't been accesed before, it is initalized with // its default zero value -func (memory *Memory) ReadFromAddress(address *MemoryAddress) (*MemoryValue, error) { +func (memory *Memory) ReadFromAddress(address *MemoryAddress) (MemoryValue, error) { return memory.Read(address.SegmentIndex, address.Offset) } // Given a segment index and offset returns a pointer to the Memory Cell -func (memory *Memory) Peek(segmentIndex uint64, offset uint64) (*Cell, error) { +func (memory *Memory) Peek(segmentIndex uint64, offset uint64) (MemoryValue, error) { if segmentIndex > uint64(len(memory.Segments)) { - return nil, fmt.Errorf("unallocated segment at index %d", segmentIndex) + return MemoryValue{}, fmt.Errorf("unallocated segment at index %d", segmentIndex) } return memory.Segments[segmentIndex].Peek(offset), nil } // Given a Memory Address returns a pointer to the Memory Cell -func (memory *Memory) PeekFromAddress(address *MemoryAddress) (*Cell, error) { +func (memory *Memory) PeekFromAddress(address *MemoryAddress) (MemoryValue, error) { return memory.Peek(address.SegmentIndex, address.Offset) } diff --git a/pkg/vm/memory/memory_manager.go b/pkg/vm/memory/memory_manager.go index 6f93d2e15..9922d18bb 100644 --- a/pkg/vm/memory/memory_manager.go +++ b/pkg/vm/memory/memory_manager.go @@ -43,15 +43,15 @@ func (mm *MemoryManager) RelocateMemory() []*f.Element { // fmt.Printf("s: %s", segment) for j := uint64(0); j < segment.Len(); j++ { cell := segment.Data[j] - if cell == nil || !cell.Accessed { + if !cell.Known() { continue } var felt *f.Element - if cell.Value.IsAddress() { - felt = cell.Value.address.Relocate(segmentsOffsets) + if cell.IsAddress() { + felt = cell.address.Relocate(segmentsOffsets) } else { - felt = cell.Value.felt + felt = &cell.felt } relocatedMemory[segmentsOffsets[i]+j] = felt diff --git a/pkg/vm/memory/memory_manager_test.go b/pkg/vm/memory/memory_manager_test.go index ed0c444a1..cb06383b0 100644 --- a/pkg/vm/memory/memory_manager_test.go +++ b/pkg/vm/memory/memory_manager_test.go @@ -70,20 +70,20 @@ func TestMemoryRelocationWithAddress(t *testing.T) { []memoryWrite{ // segment zero {0, 1, uint64(1)}, - {0, 3, NewMemoryAddress(1, 5)}, + {0, 3, &MemoryAddress{1, 5}}, // segment one {1, 0, uint64(1)}, - {1, 1, NewMemoryAddress(4, 3)}, + {1, 1, &MemoryAddress{4, 3}}, {1, 2, uint64(7)}, {1, 5, uint64(13)}, // segment two - {2, 0, NewMemoryAddress(0, 1)}, + {2, 0, &MemoryAddress{0, 1}}, // segment three - {3, 0, NewMemoryAddress(2, 0)}, + {3, 0, &MemoryAddress{2, 0}}, // segment four - {4, 0, NewMemoryAddress(0, 0)}, - {4, 1, NewMemoryAddress(1, 1)}, - {4, 2, NewMemoryAddress(1, 5)}, + {4, 0, &MemoryAddress{0, 0}}, + {4, 1, &MemoryAddress{1, 1}}, + {4, 2, &MemoryAddress{1, 5}}, {4, 3, uint64(15)}, }, ) @@ -141,7 +141,7 @@ func updateMemoryWithValues(memory *Memory, valuesToWrite []memoryWrite) { } // write the memory val - err = memory.Write(toWrite.SegmentIndex, toWrite.Offset, val) + err = memory.Write(toWrite.SegmentIndex, toWrite.Offset, &val) if err != nil { panic(err) } diff --git a/pkg/vm/memory/memory_test.go b/pkg/vm/memory/memory_test.go index 9369a8bea..18681dd45 100644 --- a/pkg/vm/memory/memory_test.go +++ b/pkg/vm/memory/memory_test.go @@ -7,55 +7,20 @@ import ( "github.com/stretchr/testify/require" ) -func TestCellWrite(t *testing.T) { - cell := Cell{} - - err := cell.Write(MemoryValueFromInt(1)) // Write 1 to a new cell - - assert.NoError(t, err) - assert.True(t, cell.Accessed) - assert.Equal(t, cell.Value, MemoryValueFromInt(1)) - - //Attemp to write again to the same cell - err = cell.Write(MemoryValueFromInt(51)) - assert.Error(t, err) - assert.True(t, cell.Accessed) - assert.Equal(t, cell.Value, MemoryValueFromInt(1)) //check that the value didn't change -} - -func TestCellRead(t *testing.T) { - cell := Cell{Accessed: false, Value: nil} - assert.Equal(t, cell.Read(), EmptyMemoryValueAsFelt()) //Read from empty cell - assert.True(t, cell.Accessed) - - cell = Cell{Accessed: false, Value: MemoryValueFromInt(51)} - assert.Equal(t, cell.Read(), MemoryValueFromInt(51)) - assert.True(t, cell.Accessed) -} - -func TestCellWriteAndRead(t *testing.T) { - cell := Cell{} - - err := cell.Write(MemoryValueFromInt(82)) - - assert.NoError(t, err) - assert.True(t, cell.Accessed) - assert.Equal(t, cell.Read(), MemoryValueFromInt(82)) -} - func TestSegmentRead(t *testing.T) { - segment := Segment{Data: []*Cell{ - {Accessed: false, Value: MemoryValueFromInt(3)}, - {Accessed: false, Value: MemoryValueFromInt(5)}, - {Accessed: true, Value: MemoryValueFromInt(9)}, + segment := Segment{Data: []MemoryValue{ + MemoryValueFromInt(3), + MemoryValueFromInt(5), + {}, }} assert.Equal(t, segment.Read(0), MemoryValueFromInt(3)) assert.Equal(t, segment.Read(1), MemoryValueFromInt(5)) - assert.Equal(t, segment.Read(2), MemoryValueFromInt(9)) - assert.True(t, segment.Data[0].Accessed) //Segment read should mark cell as accessed - assert.True(t, segment.Data[1].Accessed) - assert.True(t, segment.Data[2].Accessed) + assert.False(t, segment.Data[2].Known()) + assert.Equal(t, segment.Read(2), EmptyMemoryValueAsFelt()) + assert.True(t, segment.Data[0].Known()) //Segment read should mark cell as accessed + assert.True(t, segment.Data[1].Known()) + assert.True(t, segment.Data[2].Known()) assert.Equal(t, len(segment.Data), 3) //Check if we can read offsets higher than segment len @@ -64,61 +29,56 @@ func TestSegmentRead(t *testing.T) { } func TestSegmentPeek(t *testing.T) { - segment := Segment{Data: []*Cell{ - {Accessed: false, Value: MemoryValueFromInt(2)}, - {Accessed: true, Value: MemoryValueFromInt(4)}, + segment := Segment{Data: []MemoryValue{ + MemoryValueFromInt(2), + {}, }} - assert.Equal(t, segment.Peek(0).Value, MemoryValueFromInt(2)) - assert.Equal(t, segment.Peek(1).Value, MemoryValueFromInt(4)) - assert.False(t, segment.Data[0].Accessed) //Peek should not mark the cell as accessed - assert.True(t, segment.Data[1].Accessed) //Cell that was already accessed should stay accessed + assert.Equal(t, segment.Peek(0), MemoryValueFromInt(2)) + assert.Equal(t, segment.Peek(1), MemoryValue{}) + assert.True(t, segment.Data[0].Known()) //Cell that was already accessed should stay accessed + assert.False(t, segment.Data[1].Known()) //Peek should not mark the cell as accessed assert.Equal(t, len(segment.Data), 2) //Check if we can peek offsets higher than segment len - assert.Equal(t, segment.Peek(30).Read(), EmptyMemoryValueAsFelt()) + assert.Equal(t, segment.Peek(30), MemoryValue{}) assert.Equal(t, len(segment.Data), 31) //Verify that segment len was increased } func TestSegmentWrite(t *testing.T) { - segment := Segment{Data: []*Cell{ - {Accessed: false, Value: nil}, - {Accessed: false, Value: nil}, - }} + segment := Segment{Data: make([]MemoryValue, 2)} - err := segment.Write(0, MemoryValueFromInt(100)) + err := segment.Write(0, UseInTestOnlyMemoryValuePointerFromInt(100)) assert.NoError(t, err) - assert.Equal(t, segment.Data[0].Value, MemoryValueFromInt(100)) - assert.True(t, segment.Data[0].Accessed) - assert.False(t, segment.Data[1].Accessed) //Check that the other cell wasn't marked as accessed + assert.Equal(t, segment.Data[0], MemoryValueFromInt(100)) + assert.True(t, segment.Data[0].Known()) + assert.False(t, segment.Data[1].Known()) //Check that the other cell wasn't marked as accessed - err = segment.Write(1, MemoryValueFromInt(15)) + err = segment.Write(1, UseInTestOnlyMemoryValuePointerFromInt(15)) assert.NoError(t, err) - assert.Equal(t, segment.Data[1].Value, MemoryValueFromInt(15)) - assert.True(t, segment.Data[1].Accessed) + assert.Equal(t, segment.Data[1], MemoryValueFromInt(15)) + assert.True(t, segment.Data[1].Known()) //Atempt to write twice - err = segment.Write(0, MemoryValueFromInt(590)) + err = segment.Write(0, UseInTestOnlyMemoryValuePointerFromInt(590)) assert.Error(t, err) //Check that memory wasn't modified assert.Equal(t, segment.Read(0), MemoryValueFromInt(100)) - assert.True(t, segment.Peek(0).Accessed) + assert.True(t, segment.Data[0].Known()) } func TestSegmentReadAndWrite(t *testing.T) { - segment := Segment{Data: []*Cell{ - {Accessed: false, Value: nil}, - }} - err := segment.Write(0, MemoryValueFromInt(48)) + segment := Segment{Data: make([]MemoryValue, 1)} + err := segment.Write(0, UseInTestOnlyMemoryValuePointerFromInt(48)) assert.NoError(t, err) assert.Equal(t, segment.Read(0), MemoryValueFromInt(48)) - assert.True(t, segment.Peek(0).Accessed) + assert.True(t, segment.Data[0].Known()) } func TestIncreaseSegmentSizeSmallerSize(t *testing.T) { - segment := Segment{Data: []*Cell{ - {Accessed: true, Value: MemoryValueFromInt(1)}, - {Accessed: true, Value: MemoryValueFromInt(2)}, + segment := Segment{Data: []MemoryValue{ + MemoryValueFromInt(1), + MemoryValueFromInt(2), }} // Panic if we decrase the size require.Panics(t, func() { segment.IncreaseSegmentSize(0) }) @@ -127,10 +87,10 @@ func TestIncreaseSegmentSizeSmallerSize(t *testing.T) { } func TestIncreaseSegmentSizeMaxNewSize(t *testing.T) { - segment := Segment{Data: []*Cell{ - {Accessed: true, Value: MemoryValueFromInt(1)}, - {Accessed: true, Value: MemoryValueFromInt(2)}, - {Accessed: true, Value: MemoryValueFromInt(3)}, + segment := Segment{Data: []MemoryValue{ + MemoryValueFromInt(1), + MemoryValueFromInt(2), + MemoryValueFromInt(3), }} segment.IncreaseSegmentSize(1000) @@ -144,9 +104,9 @@ func TestIncreaseSegmentSizeMaxNewSize(t *testing.T) { } func TestIncreaseSegmentSizeDouble(t *testing.T) { - segment := Segment{Data: []*Cell{ - {Accessed: true, Value: MemoryValueFromInt(1)}, - {Accessed: true, Value: MemoryValueFromInt(2)}, + segment := Segment{Data: []MemoryValue{ + MemoryValueFromInt(1), + MemoryValueFromInt(2), }} segment.IncreaseSegmentSize(3) @@ -157,49 +117,22 @@ func TestIncreaseSegmentSizeDouble(t *testing.T) { assert.Equal(t, segment.Read(0), MemoryValueFromInt(1)) assert.Equal(t, segment.Read(1), MemoryValueFromInt(2)) } - -func TestIncreaseSegmentKeepReference(t *testing.T) { - segment := Segment{Data: []*Cell{ - {Accessed: true, Value: MemoryValueFromInt(1)}, - {Accessed: true, Value: MemoryValueFromInt(2)}, - {Accessed: true, Value: MemoryValueFromInt(3)}, - }} - segment.IncreaseSegmentSize(4) - require.Equal(t, len(segment.Data), 6) - require.Equal(t, cap(segment.Data), 6) - - fourthCell := segment.Peek(5) - - segment.IncreaseSegmentSize(8) - require.Equal(t, len(segment.Data), 12) - require.Equal(t, cap(segment.Data), 12) - - err := fourthCell.Write(MemoryValueFromInt(5)) - require.NoError(t, err) - - //Make sure no data was lost after incrase - assert.Equal(t, MemoryValueFromInt(1), segment.Read(0)) - assert.Equal(t, MemoryValueFromInt(2), segment.Read(1)) - assert.Equal(t, MemoryValueFromInt(3), segment.Read(2)) - assert.Equal(t, MemoryValueFromInt(5), segment.Read(5)) -} - func TestMemoryWriteAndRead(t *testing.T) { memory := InitializeEmptyMemory() memory.AllocateEmptySegment() - err := memory.Write(0, 0, MemoryValueFromInt(123)) + err := memory.Write(0, 0, UseInTestOnlyMemoryValuePointerFromInt(123)) assert.NoError(t, err) val, err := memory.Read(0, 0) assert.NoError(t, err) assert.Equal(t, val, MemoryValueFromInt(123)) //Attempt to write twice segment and offset - err = memory.Write(0, 0, MemoryValueFromInt(321)) + err = memory.Write(0, 0, UseInTestOnlyMemoryValuePointerFromInt(321)) assert.Error(t, err) //Attempt to write twice using address - err = memory.WriteToAddress(&MemoryAddress{0, 0}, MemoryValueFromInt(542)) + err = memory.WriteToAddress(&MemoryAddress{0, 0}, UseInTestOnlyMemoryValuePointerFromInt(542)) assert.Error(t, err) //Verify data wasn't modified @@ -208,7 +141,7 @@ func TestMemoryWriteAndRead(t *testing.T) { assert.Equal(t, val, MemoryValueFromInt(123)) addr := MemoryAddress{0, 6} - err = memory.WriteToAddress(&addr, MemoryValueFromInt(31)) + err = memory.WriteToAddress(&addr, UseInTestOnlyMemoryValuePointerFromInt(31)) assert.NoError(t, err) val, err = memory.Read(0, 6) assert.NoError(t, err) @@ -228,14 +161,14 @@ func TestMemoryReadOutOfRange(t *testing.T) { func TestMemoryPeek(t *testing.T) { memory := InitializeEmptyMemory() memory.AllocateEmptySegment() - err := memory.Write(0, 1, MemoryValueFromInt(412)) + err := memory.Write(0, 1, UseInTestOnlyMemoryValuePointerFromInt(412)) assert.NoError(t, err) cell, err := memory.Peek(0, 1) assert.NoError(t, err) - assert.Equal(t, cell.Value, MemoryValueFromInt(412)) + assert.Equal(t, cell, MemoryValueFromInt(412)) cell, err = memory.PeekFromAddress(&MemoryAddress{0, 1}) assert.NoError(t, err) - assert.Equal(t, cell.Value, MemoryValueFromInt(412)) + assert.Equal(t, cell, MemoryValueFromInt(412)) } diff --git a/pkg/vm/memory/memory_value.go b/pkg/vm/memory/memory_value.go index 248734074..f50af4be7 100644 --- a/pkg/vm/memory/memory_value.go +++ b/pkg/vm/memory/memory_value.go @@ -16,31 +16,28 @@ type MemoryAddress struct { Offset uint64 } -// Creates a new memory address -func NewMemoryAddress(segment uint64, offset uint64) *MemoryAddress { - return &MemoryAddress{SegmentIndex: segment, Offset: offset} -} +var UnknownValue = MemoryAddress{} func (address *MemoryAddress) Equal(other *MemoryAddress) bool { return address.SegmentIndex == other.SegmentIndex && address.Offset == other.Offset } // Adds a memory address and a field element -func (address *MemoryAddress) Add(lhs *MemoryAddress, rhs *f.Element) (*MemoryAddress, error) { +func (address *MemoryAddress) Add(lhs *MemoryAddress, rhs *f.Element) error { lhsOffset := new(f.Element).SetUint64(lhs.Offset) newOffset := new(f.Element).Add(lhsOffset, rhs) if !newOffset.IsUint64() { - return nil, fmt.Errorf("new offset bigger than uint64: %s", rhs.Text(10)) + return fmt.Errorf("new offset bigger than uint64: %s", rhs.Text(10)) } address.SegmentIndex = lhs.SegmentIndex address.Offset = newOffset.Uint64() - return address, nil + return nil } // Subs from a memory address a felt or another memory address in the same segment -func (address *MemoryAddress) Sub(lhs *MemoryAddress, rhs any) (*MemoryAddress, error) { +func (address *MemoryAddress) Sub(lhs *MemoryAddress, rhs any) error { // First match segment index address.SegmentIndex = lhs.SegmentIndex @@ -48,32 +45,32 @@ func (address *MemoryAddress) Sub(lhs *MemoryAddress, rhs any) (*MemoryAddress, switch rhs := rhs.(type) { case uint64: if rhs > lhs.Offset { - return nil, errors.New("rhs is greater than lhs offset") + return errors.New("rhs is greater than lhs offset") } address.Offset = lhs.Offset - rhs - return address, nil + return nil case *f.Element: if !rhs.IsUint64() { - return nil, fmt.Errorf("rhs field element does not fit in uint64: %s", rhs) + return fmt.Errorf("rhs field element does not fit in uint64: %s", rhs) } feltRhs64 := rhs.Uint64() if feltRhs64 > lhs.Offset { - return nil, fmt.Errorf("rhs %d is greater than lhs offset %d", feltRhs64, lhs.Offset) + return fmt.Errorf("rhs %d is greater than lhs offset %d", feltRhs64, lhs.Offset) } address.Offset = lhs.Offset - feltRhs64 - return address, nil + return nil case *MemoryAddress: if lhs.SegmentIndex != rhs.SegmentIndex { - return nil, fmt.Errorf("addresses are in different segments: rhs is in %d, lhs is in %d", + return fmt.Errorf("addresses are in different segments: rhs is in %d, lhs is in %d", rhs.SegmentIndex, lhs.SegmentIndex) } if rhs.Offset > lhs.Offset { - return nil, fmt.Errorf("rhs offset %d is greater than lhs offset %d", rhs.Offset, lhs.Offset) + return fmt.Errorf("rhs offset %d is greater than lhs offset %d", rhs.Offset, lhs.Offset) } address.Offset = lhs.Offset - rhs.Offset - return address, nil + return nil default: - return nil, fmt.Errorf("unknown rhs type: %T", rhs) + return fmt.Errorf("unknown rhs type: %T", rhs) } } @@ -96,50 +93,53 @@ func (address MemoryAddress) String() string { // // - either a Felt value (an `f.Element`), // - or a pointer to another Memory Cell (a `MemoryAddress`) -// -// Both members cannot be non-nil at the same time type MemoryValue struct { - felt *f.Element - address *MemoryAddress + felt f.Element + address MemoryAddress + isFelt bool + isAddress bool } -func MemoryValueFromMemoryAddress(address *MemoryAddress) *MemoryValue { - return &MemoryValue{ - address: address, +func MemoryValueFromMemoryAddress(address *MemoryAddress) MemoryValue { + return MemoryValue{ + address: *address, + isAddress: true, } } -func MemoryValueFromFieldElement(felt *f.Element) *MemoryValue { - return &MemoryValue{ - felt: felt, +func MemoryValueFromFieldElement(felt *f.Element) MemoryValue { + return MemoryValue{ + felt: *felt, + isFelt: true, } } -func MemoryValueFromInt[T constraints.Integer](v T) *MemoryValue { +func MemoryValueFromInt[T constraints.Integer](v T) MemoryValue { if v >= 0 { return MemoryValueFromUint(uint64(v)) } - lhs := &f.Element{} - rhs := new(f.Element).SetUint64(uint64(-v)) - return &MemoryValue{ - felt: new(f.Element).Sub(lhs, rhs), - } + + value := MemoryValue{isFelt: true} + rhs := f.NewElement(uint64(-v)) + value.felt.Sub(&value.felt, &rhs) + return value } -func MemoryValueFromUint[T constraints.Unsigned](v T) *MemoryValue { - newElement := f.NewElement(uint64(v)) - return &MemoryValue{ - felt: &newElement, +func MemoryValueFromUint[T constraints.Unsigned](v T) MemoryValue { + return MemoryValue{ + felt: f.NewElement(uint64(v)), + isFelt: true, } } -func MemoryValueFromSegmentAndOffset[T constraints.Integer](segmentIndex, offset T) *MemoryValue { - return &MemoryValue{ - address: &MemoryAddress{SegmentIndex: uint64(segmentIndex), Offset: uint64(offset)}, +func MemoryValueFromSegmentAndOffset[T constraints.Integer](segmentIndex, offset T) MemoryValue { + return MemoryValue{ + address: MemoryAddress{SegmentIndex: uint64(segmentIndex), Offset: uint64(offset)}, + isAddress: true, } } -func MemoryValueFromAny(anyType any) (*MemoryValue, error) { +func MemoryValueFromAny(anyType any) (MemoryValue, error) { switch anyType := anyType.(type) { case int: return MemoryValueFromInt(anyType), nil @@ -150,125 +150,116 @@ func MemoryValueFromAny(anyType any) (*MemoryValue, error) { case *MemoryAddress: return MemoryValueFromMemoryAddress(anyType), nil default: - return nil, fmt.Errorf("invalid type to convert to a MemoryValue: %T", anyType) + return MemoryValue{}, fmt.Errorf("invalid type to convert to a MemoryValue: %T", anyType) } } -func EmptyMemoryValueAsFelt() *MemoryValue { - return &MemoryValue{ - felt: new(f.Element), +func EmptyMemoryValueAsFelt() MemoryValue { + return MemoryValue{ + isFelt: true, } } -func EmptyMemoryValueAsAddress() *MemoryValue { - return &MemoryValue{ - address: new(MemoryAddress), +func EmptyMemoryValueAsAddress() MemoryValue { + return MemoryValue{ + isAddress: true, } } -func EmptyMemoryValueAs(address bool) *MemoryValue { - if address { - return EmptyMemoryValueAsAddress() +func EmptyMemoryValueAs(address bool) MemoryValue { + return MemoryValue{ + isAddress: address, + isFelt: !address, } - return EmptyMemoryValueAsFelt() } func (mv *MemoryValue) ToMemoryAddress() (*MemoryAddress, error) { - if mv.address == nil { + if !mv.isAddress { return nil, errors.New("memory value is not an address") } - return mv.address, nil + return &mv.address, nil } func (mv *MemoryValue) ToFieldElement() (*f.Element, error) { - if mv.felt == nil { + if !mv.isFelt { return nil, fmt.Errorf("memory value is not a field element") } - return mv.felt, nil + return &mv.felt, nil } func (mv *MemoryValue) ToAny() any { - if mv.felt != nil { - return mv.felt + if mv.isAddress { + return &mv.address } - return mv.address + return &mv.felt } func (mv *MemoryValue) IsAddress() bool { - return mv.address != nil + return mv.isAddress } func (mv *MemoryValue) IsFelt() bool { - return mv.felt != nil + return mv.isFelt +} + +func (mv *MemoryValue) Known() bool { + return mv.isAddress || mv.isFelt } func (mv *MemoryValue) Equal(other *MemoryValue) bool { if mv.IsAddress() && other.IsAddress() { - return mv.address.Equal(other.address) + return mv.address.Equal(&other.address) } if mv.IsFelt() && other.IsFelt() { - return mv.felt.Equal(other.felt) + return mv.felt.Equal(&other.felt) } return false } // Adds two memory values is the second one is a Felt -func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (*MemoryValue, error) { - var err error +func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) error { if lhs.IsAddress() { if !rhs.IsFelt() { - return nil, errors.New("rhs is not a felt") - } - mv.address, err = mv.address.Add(lhs.address, rhs.felt) - } else { - if rhs.IsAddress() { - mv.address, err = mv.address.Add(rhs.address, lhs.felt) - } else { - mv.felt = mv.felt.Add(lhs.felt, rhs.felt) + return errors.New("rhs is not a felt") } + return mv.address.Add(&lhs.address, &rhs.felt) } - if err != nil { - return nil, err + if rhs.IsAddress() { + return mv.address.Add(&rhs.address, &lhs.felt) } - return mv, nil + mv.felt.Add(&lhs.felt, &rhs.felt) + return nil } // Subs two memory values if they're in the same segment or the rhs is a Felt. -func (mv *MemoryValue) Sub(lhs, rhs *MemoryValue) (*MemoryValue, error) { - var err error +func (mv *MemoryValue) Sub(lhs, rhs *MemoryValue) error { if lhs.IsAddress() { - mv.address, err = mv.address.Sub(lhs.address, rhs.ToAny()) - } else { - if rhs.IsAddress() { - return nil, errors.New("cannot substract an address from a felt") - } else { - mv.felt = mv.felt.Sub(lhs.felt, rhs.felt) - } + return mv.address.Sub(&lhs.address, rhs.ToAny()) } - if err != nil { - return nil, err + if rhs.IsAddress() { + return errors.New("cannot substract an address from a felt") } - return mv, nil + mv.felt.Sub(&lhs.felt, &rhs.felt) + return nil } -func (mv *MemoryValue) Mul(lhs, rhs *MemoryValue) (*MemoryValue, error) { +func (mv *MemoryValue) Mul(lhs, rhs *MemoryValue) error { if lhs.IsAddress() || rhs.IsAddress() { - return nil, errors.New("cannot multiply memory addresses") + return errors.New("cannot multiply memory addresses") } - mv.felt.Mul(lhs.felt, rhs.felt) - return mv, nil + mv.felt.Mul(&lhs.felt, &rhs.felt) + return nil } -func (mv *MemoryValue) Div(lhs, rhs *MemoryValue) (*MemoryValue, error) { +func (mv *MemoryValue) Div(lhs, rhs *MemoryValue) error { if lhs.IsAddress() || rhs.IsAddress() { - return nil, errors.New("cannot divide memory addresses") + return errors.New("cannot divide memory addresses") } - - mv.felt.Div(lhs.felt, rhs.felt) - return mv, nil + mv.felt.Div(&lhs.felt, &rhs.felt) + return nil } func (mv MemoryValue) String() string { diff --git a/pkg/vm/memory/memory_value_test.go b/pkg/vm/memory/memory_value_test.go index fe1140323..b078976f7 100644 --- a/pkg/vm/memory/memory_value_test.go +++ b/pkg/vm/memory/memory_value_test.go @@ -4,11 +4,17 @@ import ( "testing" "github.com/stretchr/testify/require" + "golang.org/x/exp/constraints" f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" "github.com/stretchr/testify/assert" ) +func UseInTestOnlyMemoryValuePointerFromInt[T constraints.Integer](v T) *MemoryValue { + mv := MemoryValueFromInt(v) + return &mv +} + func TestFeltPlusFelt(t *testing.T) { memVal := EmptyMemoryValueAsFelt() lhs := MemoryValueFromFieldElement(new(f.Element).SetUint64(3)) @@ -16,11 +22,9 @@ func TestFeltPlusFelt(t *testing.T) { expected := MemoryValueFromInt(10) - res, err := memVal.Add(lhs, rhs) + err := memVal.Add(&lhs, &rhs) require.NoError(t, err) - - assert.Equal(t, memVal, res) - assert.Equal(t, *expected, *res) + assert.Equal(t, expected, memVal) } func TestMemoryAddressPlusFelt(t *testing.T) { @@ -36,11 +40,9 @@ func TestMemoryAddressPlusFelt(t *testing.T) { Offset: 12, }) - res, err := memVal.Add(lhs, rhs) + err := memVal.Add(&lhs, &rhs) require.NoError(t, err) - - assert.Equal(t, memVal, res) - assert.Equal(t, *expected, *res) + assert.Equal(t, expected, memVal) } func TestFeltPlusMemoryAddress(t *testing.T) { @@ -56,11 +58,9 @@ func TestFeltPlusMemoryAddress(t *testing.T) { Offset: 12, }) - res, err := memVal.Add(lhs, rhs) + err := memVal.Add(&lhs, &rhs) require.NoError(t, err) - - assert.Equal(t, memVal, res) - assert.Equal(t, *expected, *res) + assert.Equal(t, expected, memVal) } func TestMemoryAddressPlusMemoryAddress(t *testing.T) { @@ -73,9 +73,7 @@ func TestMemoryAddressPlusMemoryAddress(t *testing.T) { SegmentIndex: 2, Offset: 2, }) - memVal, err := memVal.Add(lhs, rhs) - - assert.Nil(t, memVal) + err := memVal.Add(&lhs, &rhs) assert.Error(t, err) } @@ -86,11 +84,9 @@ func TestFeltSubFelt(t *testing.T) { expected := MemoryValueFromInt(1) - res, err := memVal.Sub(lhs, rhs) + err := memVal.Sub(&lhs, &rhs) require.NoError(t, err) - - assert.Equal(t, memVal, res) - assert.Equal(t, *expected, *res) + assert.Equal(t, expected, memVal) } func TestMemoryAddressSubFelt(t *testing.T) { @@ -106,11 +102,9 @@ func TestMemoryAddressSubFelt(t *testing.T) { Offset: 8, }) - res, err := memVal.Sub(lhs, rhs) - + err := memVal.Sub(&lhs, &rhs) require.NoError(t, err) - assert.Equal(t, memVal, res) - assert.Equal(t, *expected, *res) + assert.Equal(t, expected, memVal) } func TestFeltSubMemoryAddress(t *testing.T) { @@ -121,9 +115,7 @@ func TestFeltSubMemoryAddress(t *testing.T) { Offset: 10, }) - memVal, err := memVal.Sub(lhs, rhs) - - assert.Nil(t, memVal) + err := memVal.Sub(&lhs, &rhs) assert.Error(t, err) } @@ -142,11 +134,9 @@ func TestMemoryAddressSubMemoryAddressSameSegment(t *testing.T) { Offset: 8, }) - res, err := memVal.Sub(lhs, rhs) + err := memVal.Sub(&lhs, &rhs) require.NoError(t, err) - - assert.Equal(t, memVal, res) - assert.Equal(t, *expected, *res) + assert.Equal(t, expected, memVal) } func TestMemoryAddressSubMemoryAddressDiffSegment(t *testing.T) { @@ -160,9 +150,7 @@ func TestMemoryAddressSubMemoryAddressDiffSegment(t *testing.T) { Offset: 2, }) - memVal, err := memVal.Sub(lhs, rhs) - - assert.Nil(t, memVal) + err := memVal.Sub(&lhs, &rhs) assert.Error(t, err) } diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 3966fd9ea..d398f965c 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -24,7 +24,7 @@ type HintRunner interface { // Represents the current execution context of the vm type Context struct { - Pc *mem.MemoryAddress + Pc mem.MemoryAddress Fp uint64 Ap uint64 } @@ -39,16 +39,16 @@ func (ctx *Context) String() string { ) } -func (ctx *Context) AddressAp() *mem.MemoryAddress { - return &mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: ctx.Ap} +func (ctx *Context) AddressAp() mem.MemoryAddress { + return mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: ctx.Ap} } -func (ctx *Context) AddressFp() *mem.MemoryAddress { - return &mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: ctx.Fp} +func (ctx *Context) AddressFp() mem.MemoryAddress { + return mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: ctx.Fp} } -func (ctx *Context) AddressPc() *mem.MemoryAddress { - return &mem.MemoryAddress{SegmentIndex: ctx.Pc.SegmentIndex, Offset: ctx.Pc.Offset} +func (ctx *Context) AddressPc() mem.MemoryAddress { + return mem.MemoryAddress{SegmentIndex: ctx.Pc.SegmentIndex, Offset: ctx.Pc.Offset} } // relocates pc, ap and fp to be their real address value @@ -108,7 +108,7 @@ func NewVirtualMachine(programBytecode []*f.Element, config VirtualMachineConfig Context: Context{ Fp: 0, Ap: 0, - Pc: &mem.MemoryAddress{ + Pc: mem.MemoryAddress{ SegmentIndex: ProgramSegment, Offset: 0, }, @@ -124,16 +124,10 @@ func NewVirtualMachine(programBytecode []*f.Element, config VirtualMachineConfig // todo(rodro): add a cache mechanism for not decoding the same instruction twice func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error { - // Run hint - err := hintRunner.RunHint(vm) - if err != nil { - return fmt.Errorf("hint runner: %w", err) - } - // if instruction is not in cache, redecode and store it instruction, ok := vm.instructions[vm.Context.Pc.Offset] if !ok { - memoryValue, err := vm.MemoryManager.Memory.ReadFromAddress(vm.Context.Pc) + memoryValue, err := vm.MemoryManager.Memory.ReadFromAddress(&vm.Context.Pc) if err != nil { return fmt.Errorf("reading instruction: %w", err) } @@ -155,7 +149,7 @@ func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error { vm.Trace = append(vm.Trace, vm.Context) } - err = vm.RunInstruction(instruction) + err := vm.RunInstruction(instruction) if err != nil { return fmt.Errorf("running instruction: %w", err) } @@ -165,48 +159,48 @@ func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error { } func (vm *VirtualMachine) RunInstruction(instruction *Instruction) error { - dstCell, err := vm.getCellDst(instruction) + dstAddr, err := vm.getDstAddr(instruction) if err != nil { return fmt.Errorf("dst cell: %w", err) } - op0Cell, err := vm.getCellOp0(instruction) + op0Addr, err := vm.getOp0Addr(instruction) if err != nil { return fmt.Errorf("op0 cell: %w", err) } - op1Cell, err := vm.getCellOp1(instruction, op0Cell) + op1Addr, err := vm.getOp1Addr(instruction, &op0Addr) if err != nil { return fmt.Errorf("op1 cell: %w", err) } - res, err := vm.inferOperand(instruction, dstCell, op0Cell, op1Cell) + res, err := vm.inferOperand(instruction, &dstAddr, &op0Addr, &op1Addr) if err != nil { return fmt.Errorf("res infer: %w", err) } - if res == nil { - res, err = vm.computeRes(instruction, op0Cell, op1Cell) + if !res.Known() { + res, err = vm.computeRes(instruction, &op0Addr, &op1Addr) if err != nil { return fmt.Errorf("compute res: %w", err) } } - err = vm.opcodeAssertions(instruction, dstCell, op0Cell, res) + err = vm.opcodeAssertions(instruction, &dstAddr, &op0Addr, &res) if err != nil { return fmt.Errorf("opcode assertions: %w", err) } - nextPc, err := vm.updatePc(instruction, dstCell, op1Cell, res) + nextPc, err := vm.updatePc(instruction, &dstAddr, &op1Addr, &res) if err != nil { return fmt.Errorf("pc update: %w", err) } - nextAp, err := vm.updateAp(instruction, res) + nextAp, err := vm.updateAp(instruction, &res) if err != nil { return fmt.Errorf("ap update: %w", err) } - nextFp, err := vm.updateFp(instruction, dstCell) + nextFp, err := vm.updateFp(instruction, &dstAddr) if err != nil { return fmt.Errorf("fp update: %w", err) } @@ -231,7 +225,7 @@ func (vm *VirtualMachine) Proof() ([]Trace, []*f.Element, error) { return relocatedTrace, relocatedMemory, nil } -func (vm *VirtualMachine) getCellDst(instruction *Instruction) (*mem.Cell, error) { +func (vm *VirtualMachine) getDstAddr(instruction *Instruction) (mem.MemoryAddress, error) { var dstRegister uint64 if instruction.DstRegister == Ap { dstRegister = vm.Context.Ap @@ -241,12 +235,12 @@ func (vm *VirtualMachine) getCellDst(instruction *Instruction) (*mem.Cell, error addr, isOverflow := safemath.SafeOffset(dstRegister, instruction.OffDest) if isOverflow { - return nil, fmt.Errorf("offset overflow: %d + %d", dstRegister, instruction.OffDest) + return mem.UnknownValue, fmt.Errorf("offset overflow: %d + %d", dstRegister, instruction.OffDest) } - return vm.MemoryManager.Memory.Peek(ExecutionSegment, addr) + return mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: addr}, nil } -func (vm *VirtualMachine) getCellOp0(instruction *Instruction) (*mem.Cell, error) { +func (vm *VirtualMachine) getOp0Addr(instruction *Instruction) (mem.MemoryAddress, error) { var op0Register uint64 if instruction.Op0Register == Ap { op0Register = vm.Context.Ap @@ -256,21 +250,26 @@ func (vm *VirtualMachine) getCellOp0(instruction *Instruction) (*mem.Cell, error addr, isOverflow := safemath.SafeOffset(op0Register, instruction.OffOp0) if isOverflow { - return nil, fmt.Errorf("offset overflow: %d + %d", op0Register, instruction.OffOp0) + return mem.UnknownValue, fmt.Errorf("offset overflow: %d + %d", op0Register, instruction.OffOp0) } - return vm.MemoryManager.Memory.Peek(ExecutionSegment, addr) + return mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: addr}, nil } -func (vm *VirtualMachine) getCellOp1(instruction *Instruction, op0Cell *mem.Cell) (*mem.Cell, error) { - var op1Address *mem.MemoryAddress +func (vm *VirtualMachine) getOp1Addr(instruction *Instruction, op0Addr *mem.MemoryAddress) (mem.MemoryAddress, error) { + var op1Address mem.MemoryAddress switch instruction.Op1Source { case Op0: // in this case Op0 is being used as an address, and must be of unwrapped as it - op0Address, err := op0Cell.Read().ToMemoryAddress() + op0Value, err := vm.MemoryManager.Memory.ReadFromAddress(op0Addr) if err != nil { - return nil, fmt.Errorf("op0 is not an address: %w", err) + return mem.UnknownValue, fmt.Errorf("cannot read op0: %w", err) } - op1Address = mem.NewMemoryAddress(op0Address.SegmentIndex, op0Address.Offset) + + op0Address, err := op0Value.ToMemoryAddress() + if err != nil { + return mem.UnknownValue, fmt.Errorf("op0 is not an address: %w", err) + } + op1Address = mem.MemoryAddress{SegmentIndex: op0Address.SegmentIndex, Offset: op0Address.Offset} case Imm: op1Address = vm.Context.AddressPc() case FpPlusOffOp1: @@ -281,11 +280,10 @@ func (vm *VirtualMachine) getCellOp1(instruction *Instruction, op0Cell *mem.Cell addr, isOverflow := safemath.SafeOffset(op1Address.Offset, instruction.OffOp1) if isOverflow { - return nil, fmt.Errorf("offset overflow: %d + %d", op1Address.Offset, instruction.OffOp1) + return mem.UnknownValue, fmt.Errorf("offset overflow: %d + %d", op1Address.Offset, instruction.OffOp1) } op1Address.Offset = addr - - return vm.MemoryManager.Memory.PeekFromAddress(op1Address) + return op1Address, nil } // when there is an assertion with a substraction or division like : x = y - z @@ -293,104 +291,119 @@ func (vm *VirtualMachine) getCellOp1(instruction *Instruction, op0Cell *mem.Cell // dstCell value and either op0Cell xor op1Cell. This function infers the // unknow operand as well as the `res` auxiliar value func (vm *VirtualMachine) inferOperand( - instruction *Instruction, dstCell *mem.Cell, op0Cell *mem.Cell, op1Cell *mem.Cell, -) (*mem.MemoryValue, error) { + instruction *Instruction, dstAddr *mem.MemoryAddress, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, +) (mem.MemoryValue, error) { if instruction.Opcode != AssertEq || - (instruction.Res != AddOperands && instruction.Res != MulOperands) || - (op0Cell.Accessed && op1Cell.Accessed) { - return nil, nil + (instruction.Res != AddOperands && instruction.Res != MulOperands) { + return mem.MemoryValue{}, nil + } + + op0Value, err := vm.MemoryManager.Memory.PeekFromAddress(op0Addr) + if err != nil { + return mem.MemoryValue{}, fmt.Errorf("cannot read op0: %w", err) + } + op1Value, err := vm.MemoryManager.Memory.PeekFromAddress(op1Addr) + if err != nil { + return mem.MemoryValue{}, fmt.Errorf("cannot read op1: %w", err) + } + + if op0Value.Known() && op1Value.Known() { + return mem.MemoryValue{}, nil } - if !dstCell.Accessed { - return nil, fmt.Errorf("dst cell is unknown") + + dstValue, err := vm.MemoryManager.Memory.PeekFromAddress(dstAddr) + if err != nil { + return mem.MemoryValue{}, fmt.Errorf("cannot read dst: %w", err) + } + + if !dstValue.Known() { + return mem.MemoryValue{}, fmt.Errorf("dst cell is unknown") } - var knownOpCell *mem.Cell - var unknownOpCell *mem.Cell - if op0Cell.Accessed { - knownOpCell = op0Cell - unknownOpCell = op1Cell + var knownOpValue mem.MemoryValue + var unknownOpAddr *mem.MemoryAddress + if op0Value.Known() { + knownOpValue = op0Value + unknownOpAddr = op1Addr } else { - knownOpCell = op1Cell - unknownOpCell = op0Cell + knownOpValue = op1Value + unknownOpAddr = op0Addr } - var missingVal *mem.MemoryValue - var err error - dst := dstCell.Read() + var missingVal mem.MemoryValue if instruction.Res == AddOperands { - missingVal, err = mem.EmptyMemoryValueAs(dst.IsAddress()).Sub(dst, knownOpCell.Read()) + missingVal = mem.EmptyMemoryValueAs(dstValue.IsAddress()) + err = missingVal.Sub(&dstValue, &knownOpValue) } else { - missingVal, err = mem.EmptyMemoryValueAsFelt().Div(dst, knownOpCell.Read()) + missingVal = mem.EmptyMemoryValueAsFelt() + err = missingVal.Div(&dstValue, &knownOpValue) } if err != nil { - return nil, err + return mem.MemoryValue{}, err } - err = unknownOpCell.Write(missingVal) - if err != nil { - return nil, err + if err = vm.MemoryManager.Memory.WriteToAddress(unknownOpAddr, &missingVal); err != nil { + return mem.MemoryValue{}, err } - - return dst, nil + return dstValue, nil } func (vm *VirtualMachine) computeRes( - instruction *Instruction, op0Cell *mem.Cell, op1Cell *mem.Cell, -) (*mem.MemoryValue, error) { + instruction *Instruction, op0Addr *mem.MemoryAddress, op1Addr *mem.MemoryAddress, +) (mem.MemoryValue, error) { switch instruction.Res { case Unconstrained: - return nil, nil + return mem.MemoryValue{}, nil case Op1: - return op1Cell.Read(), nil - case AddOperands: - op0 := op0Cell.Read() - op1 := op1Cell.Read() - return mem.EmptyMemoryValueAs(op0.IsAddress()).Add(op0, op1) - case MulOperands: - op0 := op0Cell.Read() - op1 := op1Cell.Read() - return mem.EmptyMemoryValueAsFelt().Mul(op0, op1) - } - - return nil, fmt.Errorf("invalid res flag value: %d", instruction.Res) + return vm.MemoryManager.Memory.ReadFromAddress(op1Addr) + default: + op0, err := vm.MemoryManager.Memory.ReadFromAddress(op0Addr) + if err != nil { + return mem.MemoryValue{}, fmt.Errorf("cannot read op0: %w", err) + } + + op1, err := vm.MemoryManager.Memory.ReadFromAddress(op1Addr) + if err != nil { + return mem.MemoryValue{}, fmt.Errorf("cannot read op1: %w", err) + } + + if instruction.Res == AddOperands { + err = op0.Add(&op0, &op1) + } else if instruction.Res == MulOperands { + err = op0.Mul(&op0, &op1) + } else { + return mem.MemoryValue{}, fmt.Errorf("invalid res flag value: %d", instruction.Res) + } + return op0, err + } } func (vm *VirtualMachine) opcodeAssertions( instruction *Instruction, - dstCell *mem.Cell, - op0Cell *mem.Cell, + dstAddr *mem.MemoryAddress, + op0Addr *mem.MemoryAddress, res *mem.MemoryValue, ) error { switch instruction.Opcode { case Call: + fpAddr := vm.Context.AddressFp() + fpMv := mem.MemoryValueFromMemoryAddress(&fpAddr) // Store at [ap] the current fp - err := dstCell.Write(mem.MemoryValueFromMemoryAddress(vm.Context.AddressFp())) - if err != nil { - return err - } - err = dstCell.Write(mem.MemoryValueFromMemoryAddress(vm.Context.AddressFp())) - if err != nil { + if err := vm.MemoryManager.Memory.WriteToAddress(dstAddr, &fpMv); err != nil { return err } - // Write in [ap + 1] the next instruction to execute - err = op0Cell.Write( - mem.MemoryValueFromSegmentAndOffset( - vm.Context.Pc.SegmentIndex, - vm.Context.Pc.Offset+uint64(instruction.Size()), - ), + apMv := mem.MemoryValueFromSegmentAndOffset( + vm.Context.Pc.SegmentIndex, + vm.Context.Pc.Offset+uint64(instruction.Size()), ) - if err != nil { - return err - } - err = dstCell.Write(mem.MemoryValueFromMemoryAddress(vm.Context.AddressFp())) - if err != nil { + // Write in [ap + 1] the next instruction to execute + if err := vm.MemoryManager.Memory.WriteToAddress(op0Addr, &apMv); err != nil { return err } case AssertEq: // assert that the calculated res is stored in dst - err := dstCell.Write(res) - if err != nil { + if err := vm.MemoryManager.Memory.WriteToAddress(dstAddr, res); err != nil { return err } } @@ -399,45 +412,64 @@ func (vm *VirtualMachine) opcodeAssertions( func (vm *VirtualMachine) updatePc( instruction *Instruction, - dstCell *mem.Cell, - op1Cell *mem.Cell, + dstAddr *mem.MemoryAddress, + op1Addr *mem.MemoryAddress, res *mem.MemoryValue, -) (*mem.MemoryAddress, error) { +) (mem.MemoryAddress, error) { switch instruction.PcUpdate { case NextInstr: - return mem.NewMemoryAddress( - vm.Context.Pc.SegmentIndex, - vm.Context.Pc.Offset+uint64(instruction.Size()), - ), nil + return mem.MemoryAddress{ + SegmentIndex: vm.Context.Pc.SegmentIndex, + Offset: vm.Context.Pc.Offset + uint64(instruction.Size()), + }, nil case Jump: addr, err := res.ToMemoryAddress() if err != nil { - return nil, fmt.Errorf("absolute jump: %w", err) + return mem.UnknownValue, fmt.Errorf("absolute jump: %w", err) } - return addr, nil + return *addr, nil case JumpRel: val, err := res.ToFieldElement() if err != nil { - return nil, fmt.Errorf("relative jump: %w", err) + return mem.UnknownValue, fmt.Errorf("relative jump: %w", err) } - return new(mem.MemoryAddress).Add(vm.Context.Pc, val) + newPc := vm.Context.Pc + err = newPc.Add(&newPc, val) + return newPc, err case Jnz: - dest, err := dstCell.Read().ToFieldElement() + destMv, err := vm.MemoryManager.Memory.ReadFromAddress(dstAddr) if err != nil { - return nil, err + return mem.UnknownValue, err + } + + dest, err := destMv.ToFieldElement() + if err != nil { + return mem.UnknownValue, err } if dest.IsZero() { - return mem.NewMemoryAddress(vm.Context.Pc.SegmentIndex, vm.Context.Pc.Offset+uint64(instruction.Size())), nil + return mem.MemoryAddress{ + SegmentIndex: vm.Context.Pc.SegmentIndex, + Offset: vm.Context.Pc.Offset + uint64(instruction.Size()), + }, nil + } + + op1Mv, err := vm.MemoryManager.Memory.ReadFromAddress(op1Addr) + if err != nil { + return mem.UnknownValue, err } - val, err := op1Cell.Read().ToFieldElement() + + val, err := op1Mv.ToFieldElement() if err != nil { - return nil, err + return mem.UnknownValue, err } - return new(mem.MemoryAddress).Add(vm.Context.Pc, val) + + newPc := vm.Context.Pc + err = newPc.Add(&newPc, val) + return newPc, err } - return nil, fmt.Errorf("unkwon pc update value: %d", instruction.PcUpdate) + return mem.UnknownValue, fmt.Errorf("unkwon pc update value: %d", instruction.PcUpdate) } func (vm *VirtualMachine) updateAp(instruction *Instruction, res *mem.MemoryValue) (uint64, error) { @@ -458,14 +490,19 @@ func (vm *VirtualMachine) updateAp(instruction *Instruction, res *mem.MemoryValu return 0, fmt.Errorf("cannot update ap, unknown ApUpdate flag: %d", instruction.ApUpdate) } -func (vm *VirtualMachine) updateFp(instruction *Instruction, dstCell *mem.Cell) (uint64, error) { +func (vm *VirtualMachine) updateFp(instruction *Instruction, dstAddr *mem.MemoryAddress) (uint64, error) { switch instruction.Opcode { case Call: // [ap] and [ap + 1] are written to memory return vm.Context.Ap + 2, nil case Ret: // [dst] should be a memory address of the form (executionSegment, fp - 2) - dst, err := dstCell.Read().ToMemoryAddress() + destMv, err := vm.MemoryManager.Memory.ReadFromAddress(dstAddr) + if err != nil { + return 0, err + } + + dst, err := destMv.ToMemoryAddress() if err != nil { return 0, fmt.Errorf("ret: %w", err) } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index 47152266f..9dd2e9368 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -48,11 +48,13 @@ func TestGetCellApDst(t *testing.T) { DstRegister: Ap, } - cell, err := vm.getCellDst(&instruction) + addr, err := vm.getDstAddr(&instruction) require.NoError(t, err) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(200), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(200), mv) } func TestGetCellFpDst(t *testing.T) { @@ -73,11 +75,13 @@ func TestGetCellFpDst(t *testing.T) { DstRegister: Fp, } - cell, err := vm.getCellDst(&instruction) + addr, err := vm.getDstAddr(&instruction) require.NoError(t, err) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(123), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(123), mv) } func TestGetCellDstApNegativeOffset(t *testing.T) { @@ -96,11 +100,13 @@ func TestGetCellDstApNegativeOffset(t *testing.T) { DstRegister: Ap, } - cell, err := vm.getCellDst(&instruction) + addr, err := vm.getDstAddr(&instruction) + require.NoError(t, err) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) require.NoError(t, err) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(100), cell.Read()) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(100), mv) } func TestGetCellDstFpNegativeOffset(t *testing.T) { @@ -119,10 +125,13 @@ func TestGetCellDstFpNegativeOffset(t *testing.T) { DstRegister: Fp, } - cell, err := vm.getCellDst(&instruction) + addr, err := vm.getDstAddr(&instruction) + require.NoError(t, err) + + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) require.NoError(t, err) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(100), cell.Read()) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(100), mv) } func TestGetApCellOp0(t *testing.T) { @@ -141,11 +150,13 @@ func TestGetApCellOp0(t *testing.T) { Op0Register: Ap, } - cell, err := vm.getCellOp0(&instruction) + addr, err := vm.getOp0Addr(&instruction) require.NoError(t, err) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(123), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(123), mv) } func TestGetImmCellOp1(t *testing.T) { @@ -158,20 +169,21 @@ func TestGetImmCellOp1(t *testing.T) { ) // Prepare vm with dummy values - const offOp1 = 1 // target imm - vm.Context.Pc = mem.NewMemoryAddress(0, 1) // "current instruction" + const offOp1 = 1 // target imm + vm.Context.Pc = mem.MemoryAddress{SegmentIndex: 0, Offset: 1} // "current instruction" instruction := Instruction{ OffOp1: offOp1, Op1Source: Imm, } - cell, err := vm.getCellOp1(&instruction, nil) + addr, err := vm.getOp1Addr(&instruction, nil) require.NoError(t, err) - assert.NotNil(t, cell) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(1234), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(1234), mv) } func TestGetOp0PosCellOp1(t *testing.T) { @@ -183,25 +195,23 @@ func TestGetOp0PosCellOp1(t *testing.T) { newElementPtr(333), // op0+offset }, ) + writeToDataSegment(vm, 0, mem.MemoryValueFromSegmentAndOffset(0, 2)) + op0Addr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 0} // Prepare vm with dummy values const offOp1 = 1 // target relative to op0 offset - op0Cell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromSegmentAndOffset(0, 2), - } - instruction := Instruction{ OffOp1: offOp1, Op1Source: Op0, } - cell, err := vm.getCellOp1(&instruction, op0Cell) + addr, err := vm.getOp1Addr(&instruction, &op0Addr) require.NoError(t, err) - assert.NotNil(t, cell) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(333), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(333), mv) } func TestGetOp0NegCellOp1(t *testing.T) { @@ -213,25 +223,23 @@ func TestGetOp0NegCellOp1(t *testing.T) { newElementPtr(444), // op0 - offset }, ) + writeToDataSegment(vm, 0, mem.MemoryValueFromSegmentAndOffset(0, 4)) + op0Addr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 0} // Prepare vm with dummy values const offOp1 = -1 // target relative to op0 offset - op0Cell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromSegmentAndOffset(0, 4), - } - instruction := Instruction{ OffOp1: offOp1, Op1Source: Op0, } - cell, err := vm.getCellOp1(&instruction, op0Cell) + addr, err := vm.getOp1Addr(&instruction, &op0Addr) require.NoError(t, err) - assert.NotNil(t, cell) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(444), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(444), mv) } func TestGetFpPosCellOp1(t *testing.T) { @@ -247,12 +255,13 @@ func TestGetFpPosCellOp1(t *testing.T) { writeToDataSegment(vm, vm.Context.Fp+2, mem.MemoryValueFromInt(321)) //Write to Execution Segment at Fp+2 - cell, err := vm.getCellOp1(&instruction, nil) + addr, err := vm.getOp1Addr(&instruction, nil) require.NoError(t, err) - assert.NotNil(t, cell) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(321), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(321), mv) } func TestGetFpNegCellOp1(t *testing.T) { @@ -268,12 +277,13 @@ func TestGetFpNegCellOp1(t *testing.T) { writeToDataSegment(vm, vm.Context.Fp-2, mem.MemoryValueFromInt(123)) //Write to Execution Segment at Fp-2 - cell, err := vm.getCellOp1(&instruction, nil) + addr, err := vm.getOp1Addr(&instruction, nil) require.NoError(t, err) - assert.NotNil(t, cell) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(123), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(123), mv) } func TestGetApPosCellOp1(t *testing.T) { @@ -288,12 +298,13 @@ func TestGetApPosCellOp1(t *testing.T) { } writeToDataSegment(vm, vm.Context.Ap+2, mem.MemoryValueFromInt(41)) //Write to Execution Segment at Ap+2 - cell, err := vm.getCellOp1(&instruction, nil) + addr, err := vm.getOp1Addr(&instruction, nil) require.NoError(t, err) - assert.NotNil(t, cell) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(41), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(41), mv) } func TestGetApNegCellOp1(t *testing.T) { @@ -308,12 +319,13 @@ func TestGetApNegCellOp1(t *testing.T) { } writeToDataSegment(vm, vm.Context.Ap-2, mem.MemoryValueFromInt(57)) //Write to Execution Segment at Ap-2 - cell, err := vm.getCellOp1(&instruction, nil) + addr, err := vm.getOp1Addr(&instruction, nil) require.NoError(t, err) - assert.NotNil(t, cell) - assert.True(t, cell.Accessed) - assert.Equal(t, mem.MemoryValueFromInt(57), cell.Read()) + mv, err := vm.MemoryManager.Memory.ReadFromAddress(&addr) + require.NoError(t, err) + assert.True(t, mv.Known()) + assert.Equal(t, mem.MemoryValueFromInt(57), mv) } func TestInferOperandSub(t *testing.T) { @@ -322,84 +334,61 @@ func TestInferOperandSub(t *testing.T) { Opcode: AssertEq, Res: AddOperands, } + writeToDataSegment(vm, 0, mem.MemoryValueFromSegmentAndOffset(3, 15)) //destCell + writeToDataSegment(vm, 1, mem.MemoryValueFromSegmentAndOffset(3, 7)) //op1Cell + dstAddr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 0} + op1Addr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 1} + op0Addr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 2} - dstCell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromSegmentAndOffset(3, 15), - } - op1Cell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromSegmentAndOffset(3, 7), - } - - // unknown cell to infer - op0Cell := &mem.Cell{} - expectedOp0Cell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromSegmentAndOffset(3, 8), - } - - inferedRes, err := vm.inferOperand(&instruction, dstCell, op0Cell, op1Cell) + expectedOp0Vaue := mem.MemoryValueFromSegmentAndOffset(3, 8) + inferedRes, err := vm.inferOperand(&instruction, &dstAddr, &op0Addr, &op1Addr) require.NoError(t, err) + assert.Equal(t, mem.MemoryValueFromSegmentAndOffset(3, 15), inferedRes) - assert.Equal(t, dstCell.Value, inferedRes) - assert.Equal(t, expectedOp0Cell, op0Cell) + op0Value, err := vm.MemoryManager.Memory.PeekFromAddress(&op0Addr) + require.NoError(t, err) + assert.Equal(t, expectedOp0Vaue, op0Value) } func TestComputeAddRes(t *testing.T) { vm := defaultVirtualMachine() + writeToDataSegment(vm, 0, mem.MemoryValueFromSegmentAndOffset(2, 10)) //op0Cell + writeToDataSegment(vm, 1, mem.MemoryValueFromInt(15)) //op1Cell + op0Addr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 0} + op1Addr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 1} instruction := Instruction{ Res: AddOperands, } - cellOp0 := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromMemoryAddress( - mem.NewMemoryAddress(2, 10), - ), - } - - cellOp1 := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromInt(15), - } - - res, err := vm.computeRes(&instruction, cellOp0, cellOp1) + res, err := vm.computeRes(&instruction, &op0Addr, &op1Addr) require.NoError(t, err) - expected := mem.MemoryValueFromMemoryAddress( - mem.NewMemoryAddress(2, 25), - ) - + expected := mem.MemoryValueFromSegmentAndOffset(2, 25) assert.Equal(t, expected, res) } func TestOpcodeAssertionAssertEq(t *testing.T) { vm := defaultVirtualMachine() + dstAddr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 0} instruction := Instruction{ Opcode: AssertEq, } - dstCell := mem.Cell{} - res := mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)) + res := mem.MemoryValueFromSegmentAndOffset(2, 10) + err := vm.opcodeAssertions(&instruction, &dstAddr, nil, &res) + require.NoError(t, err) - err := vm.opcodeAssertions(&instruction, &dstCell, nil, res) + op0Value, err := vm.MemoryManager.Memory.PeekFromAddress(&dstAddr) require.NoError(t, err) - assert.Equal( - t, - mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10))}, - dstCell, - ) + assert.Equal(t, res, op0Value) } func TestUpdatePcNextInstr(t *testing.T) { vm := defaultVirtualMachine() - vm.Context.Pc = mem.NewMemoryAddress(0, 3) + vm.Context.Pc = mem.MemoryAddress{SegmentIndex: 0, Offset: 3} instruction := Instruction{ PcUpdate: NextInstr, Op1Source: Op0, // anything but imm @@ -407,13 +396,13 @@ func TestUpdatePcNextInstr(t *testing.T) { nextPc, err := vm.updatePc(&instruction, nil, nil, nil) require.NoError(t, err) - assert.Equal(t, mem.NewMemoryAddress(0, 4), nextPc) + assert.Equal(t, mem.MemoryAddress{SegmentIndex: 0, Offset: 4}, nextPc) } func TestUpdatePcNextInstrImm(t *testing.T) { vm := defaultVirtualMachine() - vm.Context.Pc = mem.NewMemoryAddress(0, 3) + vm.Context.Pc = mem.MemoryAddress{SegmentIndex: 0, Offset: 3} instruction := Instruction{ PcUpdate: NextInstr, Op1Source: Imm, @@ -421,102 +410,93 @@ func TestUpdatePcNextInstrImm(t *testing.T) { nextPc, err := vm.updatePc(&instruction, nil, nil, nil) require.NoError(t, err) - assert.Equal(t, mem.NewMemoryAddress(0, 5), nextPc) + assert.Equal(t, mem.MemoryAddress{SegmentIndex: 0, Offset: 5}, nextPc) } func TestUpdatePcJump(t *testing.T) { vm := defaultVirtualMachine() - vm.Context.Pc = mem.NewMemoryAddress(0, 3) + vm.Context.Pc = mem.MemoryAddress{SegmentIndex: 0, Offset: 3} jumpAddr := uint64(10) res := mem.MemoryValueFromSegmentAndOffset(0, jumpAddr) instruction := Instruction{ PcUpdate: Jump, } - nextPc, err := vm.updatePc(&instruction, nil, nil, res) + nextPc, err := vm.updatePc(&instruction, nil, nil, &res) require.NoError(t, err) - assert.Equal(t, mem.NewMemoryAddress(0, jumpAddr), nextPc) + assert.Equal(t, mem.MemoryAddress{SegmentIndex: 0, Offset: jumpAddr}, nextPc) } func TestUpdatePcJumpRel(t *testing.T) { vm := defaultVirtualMachine() - vm.Context.Pc = mem.NewMemoryAddress(0, 3) + vm.Context.Pc = mem.MemoryAddress{SegmentIndex: 0, Offset: 3} relAddr := uint64(10) res := mem.MemoryValueFromInt(relAddr) instruction := Instruction{ PcUpdate: JumpRel, } - nextPc, err := vm.updatePc(&instruction, nil, nil, res) + nextPc, err := vm.updatePc(&instruction, nil, nil, &res) require.NoError(t, err) - assert.Equal(t, mem.NewMemoryAddress(0, 3+relAddr), nextPc) + assert.Equal(t, mem.MemoryAddress{SegmentIndex: 0, Offset: 3 + relAddr}, nextPc) } func TestUpdatePcJnz(t *testing.T) { vm := defaultVirtualMachine() - - vm.Context.Pc = mem.NewMemoryAddress(0, 11) relAddr := uint64(10) + writeToDataSegment(vm, 0, mem.MemoryValueFromInt(10)) //dstCell + writeToDataSegment(vm, 1, mem.MemoryValueFromInt(relAddr)) //op1Cell + dstAddr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 0} + op1Addr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 1} + vm.Context.Pc = mem.MemoryAddress{SegmentIndex: 0, Offset: 11} res := mem.MemoryValueFromInt(10) - dstCell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromInt(10), - } - op1Cell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromInt(relAddr), - } instruction := Instruction{ PcUpdate: Jnz, Op1Source: Op0, } - nextPc, err := vm.updatePc(&instruction, dstCell, op1Cell, res) + nextPc, err := vm.updatePc(&instruction, &dstAddr, &op1Addr, &res) require.NoError(t, err) - assert.Equal(t, mem.NewMemoryAddress(0, 11+relAddr), nextPc) + assert.Equal(t, mem.MemoryAddress{SegmentIndex: 0, Offset: 11 + relAddr}, nextPc) } func TestUpdatePcJnzDstZero(t *testing.T) { vm := defaultVirtualMachine() + writeToDataSegment(vm, 0, mem.MemoryValueFromInt(0)) //dstCell + dstAddr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 0} - vm.Context.Pc = mem.NewMemoryAddress(0, 11) + vm.Context.Pc = mem.MemoryAddress{SegmentIndex: 0, Offset: 11} - dstCell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromInt(0), - } instruction := Instruction{ PcUpdate: Jnz, Op1Source: Op0, } - nextPc, err := vm.updatePc(&instruction, dstCell, nil, nil) + nextPc, err := vm.updatePc(&instruction, &dstAddr, nil, nil) require.NoError(t, err) - assert.Equal(t, mem.NewMemoryAddress(0, 11+1), nextPc) + assert.Equal(t, mem.MemoryAddress{SegmentIndex: 0, Offset: 11 + 1}, nextPc) } func TestUpdatePcJnzDstZeroImm(t *testing.T) { vm := defaultVirtualMachine() + writeToDataSegment(vm, 0, mem.MemoryValueFromInt(0)) //dstCell + dstAddr := mem.MemoryAddress{SegmentIndex: ExecutionSegment, Offset: 0} - vm.Context.Pc = mem.NewMemoryAddress(0, 9) + vm.Context.Pc = mem.MemoryAddress{SegmentIndex: 0, Offset: 9} - dstCell := &mem.Cell{ - Accessed: true, - Value: mem.MemoryValueFromInt(0), - } instruction := Instruction{ PcUpdate: Jnz, Op1Source: Imm, } - nextPc, err := vm.updatePc(&instruction, dstCell, nil, nil) + nextPc, err := vm.updatePc(&instruction, &dstAddr, nil, nil) require.NoError(t, err) - assert.Equal(t, mem.NewMemoryAddress(0, 9+2), nextPc) + assert.Equal(t, mem.MemoryAddress{SegmentIndex: 0, Offset: 9 + 2}, nextPc) } func TestUpdateApAddOne(t *testing.T) { @@ -546,8 +526,8 @@ func TestUpdateFp(t *testing.T) { assert.Equal(t, vm.Context.Fp, nextFp) } -func writeToDataSegment(vm *VirtualMachine, index uint64, value *mem.MemoryValue) { - err := vm.MemoryManager.Memory.Write(ExecutionSegment, index, value) +func writeToDataSegment(vm *VirtualMachine, index uint64, value mem.MemoryValue) { + err := vm.MemoryManager.Memory.Write(ExecutionSegment, index, &value) if err != nil { panic("error in test util: writeToDataSegment") }