diff --git a/pkg/vm/memory/memory_value.go b/pkg/vm/memory/memory_value.go index f2fa9babd..7ef2c3a82 100644 --- a/pkg/vm/memory/memory_value.go +++ b/pkg/vm/memory/memory_value.go @@ -25,17 +25,6 @@ func (address *MemoryAddress) Equal(other *MemoryAddress) bool { return address.SegmentIndex == other.SegmentIndex && address.Offset == other.Offset } -// Adds a memory address and a field element -func (address *MemoryAddress) Add(lhs *MemoryAddress, rhs *f.Element) (*MemoryAddress, error) { - if !rhs.IsUint64() { - return nil, fmt.Errorf("field element does not fit in uint64: %s", rhs.String()) - } - - address.SegmentIndex = lhs.SegmentIndex - address.Offset = lhs.Offset + rhs.Uint64() - return address, nil -} - // Subs from a memory address a felt or another memory address in the same segment func (address *MemoryAddress) Sub(lhs *MemoryAddress, rhs any) (*MemoryAddress, error) { // First match segment index @@ -194,19 +183,65 @@ func (mv *MemoryValue) Equal(other *MemoryValue) bool { return false } -// Adds two memory values is the second one is a Felt +// // Adds two memory values is the second one is a Felt +// func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (*MemoryValue, error) { +// var err error +// if lhs.IsAddress() { +// if !rhs.IsFelt() { +// return nil, errors.New("rhs is not a felt") +// } +// // lhs : Address, rhs : Address +// mv.address, err = mv.address.Add(lhs.address, rhs.felt) +// } else { +// if rhs.IsAddress() { +// mv.address, err = mv.address.Add(rhs.address, lhs.felt) +// } else { +// mv.felt = mv.felt.Add(lhs.felt, rhs.felt) +// } +// } + +// if err != nil { +// return nil, err +// } +// return mv, nil +// } + +// Adds a memory address and a field element +func (address *MemoryAddress) Add(lhs *MemoryAddress, rhs *f.Element) (*MemoryAddress, error) { + if !rhs.IsUint64() { + return nil, fmt.Errorf("field element does not fit in uint64: %s", rhs.String()) + } + + address.SegmentIndex = lhs.SegmentIndex + address.Offset = lhs.Offset + rhs.Uint64() + return address, nil +} + func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (*MemoryValue, error) { var err error if lhs.IsAddress() { - if !rhs.IsFelt() { - return nil, errors.New("rhs is not a felt") + if rhs.IsAddress() { + // lhs : Address, rhs : Address + if lhs.address.SegmentIndex != rhs.address.SegmentIndex { + return nil, errors.New("cannot add addresses from different segments") + } + mv.address.SegmentIndex = lhs.address.SegmentIndex + mv.address.Offset = lhs.address.Offset + rhs.address.Offset + } else if rhs.IsFelt() { + // lhs : Address, rhs : felt + mv.address, err = mv.address.Add(lhs.address, rhs.felt) + } else { + return nil, errors.New("invalid rhs type") } - mv.address, err = mv.address.Add(lhs.address, rhs.felt) } else { if rhs.IsAddress() { - mv.address, err = mv.address.Add(rhs.address, lhs.felt) - } else { + // lhs : felt, rhs : Address + return nil, errors.New("invalid operation: cannot add integer to memory address") + } else if rhs.IsFelt() { + // lhs : felt, rhs : felt mv.felt = mv.felt.Add(lhs.felt, rhs.felt) + } else { + return nil, errors.New("invalid rhs type") } } diff --git a/pkg/vm/vm_test.go b/pkg/vm/vm_test.go index be175264c..0d670a10e 100644 --- a/pkg/vm/vm_test.go +++ b/pkg/vm/vm_test.go @@ -301,6 +301,18 @@ func TestInferOperandSub(t *testing.T) { assert.Equal(t, expectedOp0Cell, op0Cell) } +/* +CaseA: Adding a MemoryAddress to an integer (already provided). +CaseB: Adding an integer to a MemoryAddress. +CaseC: Adding two MemoryAddress values from the same segment. +CaseD: Adding two MemoryAddress values from different segments (this should result in an error). +CaseE: Adding a MemoryAddress to a FieldElement (if supported). +CaseF: Adding a FieldElement to a MemoryAddress (if supported). +CaseG: Adding two FieldElement values. +CaseH: Adding a MemoryAddress to a value that exceeds the uint64 limit (should result in an error). +CaseI: Adding a FieldElement to a value that exceeds the uint64 limit (should result in an error). +CaseJ: Adding two values where one of them is neither a MemoryAddress nor a FieldElement (should result in an error). +*/ func TestComputeAddRes(t *testing.T) { vm := defaultVirtualMachine() @@ -330,6 +342,76 @@ func TestComputeAddRes(t *testing.T) { assert.Equal(t, expected, res) } +func TestComputeAddRes_Jake(t *testing.T) { + // Test cases + testCases := []struct { + name string + op0 *mem.MemoryValue + op1 *mem.MemoryValue + expected *mem.MemoryValue + shouldFail bool + }{ + { + name: "Type A : MemoryAddressPlusInteger", + op0: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)), + op1: mem.MemoryValueFromInt(15), + expected: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 25)), + }, + { + name: "Type B : MemoryAddressPlusMemoryAddress_SameSegment", + op0: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)), + op1: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 15)), + expected: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 25)), + }, + { + name: "Type C : MemoryAddressPlusMemoryAddress_DifferentSegment", + op0: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)), + op1: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(3, 15)), + shouldFail: true, + }, + { + name: "Type D : IntegerPlusInteger", + op0: mem.MemoryValueFromInt(10), + op1: mem.MemoryValueFromInt(15), + expected: mem.MemoryValueFromInt(25), + }, + { + name: "Type E : IntegerPlusMemoryAddress", + op0: mem.MemoryValueFromInt(15), + op1: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)), + shouldFail: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + vm := defaultVirtualMachine() + + instruction := Instruction{ + Res: AddOperands, + } + + cellOp0 := &mem.Cell{ + Accessed: true, + Value: tc.op0, + } + + cellOp1 := &mem.Cell{ + Accessed: true, + Value: tc.op1, + } + + res, err := vm.computeRes(&instruction, cellOp0, cellOp1) + if tc.shouldFail { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.expected, res) + }) + } +} + func TestOpcodeAssertionAssertEq(t *testing.T) { vm := defaultVirtualMachine()