Skip to content

Commit

Permalink
Fixes #56 : new add functions and test for the changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jkktom committed Sep 22, 2023
1 parent dead298 commit 610753b
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 17 deletions.
69 changes: 52 additions & 17 deletions pkg/vm/memory/memory_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@ func (address *MemoryAddress) Equal(other *MemoryAddress) bool {
return address.SegmentIndex == other.SegmentIndex && address.Offset == other.Offset
}

// Adds a memory address and a field element
func (address *MemoryAddress) Add(lhs *MemoryAddress, rhs *f.Element) (*MemoryAddress, error) {
if !rhs.IsUint64() {
return nil, fmt.Errorf("field element does not fit in uint64: %s", rhs.String())
}

address.SegmentIndex = lhs.SegmentIndex
address.Offset = lhs.Offset + rhs.Uint64()
return address, nil
}

// Subs from a memory address a felt or another memory address in the same segment
func (address *MemoryAddress) Sub(lhs *MemoryAddress, rhs any) (*MemoryAddress, error) {
// First match segment index
Expand Down Expand Up @@ -194,19 +183,65 @@ func (mv *MemoryValue) Equal(other *MemoryValue) bool {
return false
}

// Adds two memory values is the second one is a Felt
// // Adds two memory values is the second one is a Felt
// func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (*MemoryValue, error) {
// var err error
// if lhs.IsAddress() {
// if !rhs.IsFelt() {
// return nil, errors.New("rhs is not a felt")
// }
// // lhs : Address, rhs : Address
// mv.address, err = mv.address.Add(lhs.address, rhs.felt)
// } else {
// if rhs.IsAddress() {
// mv.address, err = mv.address.Add(rhs.address, lhs.felt)
// } else {
// mv.felt = mv.felt.Add(lhs.felt, rhs.felt)
// }
// }

// if err != nil {
// return nil, err
// }
// return mv, nil
// }

// Adds a memory address and a field element
func (address *MemoryAddress) Add(lhs *MemoryAddress, rhs *f.Element) (*MemoryAddress, error) {
if !rhs.IsUint64() {
return nil, fmt.Errorf("field element does not fit in uint64: %s", rhs.String())
}

address.SegmentIndex = lhs.SegmentIndex
address.Offset = lhs.Offset + rhs.Uint64()
return address, nil
}

func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (*MemoryValue, error) {
var err error
if lhs.IsAddress() {
if !rhs.IsFelt() {
return nil, errors.New("rhs is not a felt")
if rhs.IsAddress() {
// lhs : Address, rhs : Address
if lhs.address.SegmentIndex != rhs.address.SegmentIndex {
return nil, errors.New("cannot add addresses from different segments")
}
mv.address.SegmentIndex = lhs.address.SegmentIndex
mv.address.Offset = lhs.address.Offset + rhs.address.Offset
} else if rhs.IsFelt() {
// lhs : Address, rhs : felt
mv.address, err = mv.address.Add(lhs.address, rhs.felt)
} else {
return nil, errors.New("invalid rhs type")
}
mv.address, err = mv.address.Add(lhs.address, rhs.felt)
} else {
if rhs.IsAddress() {
mv.address, err = mv.address.Add(rhs.address, lhs.felt)
} else {
// lhs : felt, rhs : Address
return nil, errors.New("invalid operation: cannot add integer to memory address")
} else if rhs.IsFelt() {
// lhs : felt, rhs : felt
mv.felt = mv.felt.Add(lhs.felt, rhs.felt)
} else {
return nil, errors.New("invalid rhs type")
}
}

Expand Down
82 changes: 82 additions & 0 deletions pkg/vm/vm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,18 @@ func TestInferOperandSub(t *testing.T) {
assert.Equal(t, expectedOp0Cell, op0Cell)
}

/*
CaseA: Adding a MemoryAddress to an integer (already provided).
CaseB: Adding an integer to a MemoryAddress.
CaseC: Adding two MemoryAddress values from the same segment.
CaseD: Adding two MemoryAddress values from different segments (this should result in an error).
CaseE: Adding a MemoryAddress to a FieldElement (if supported).
CaseF: Adding a FieldElement to a MemoryAddress (if supported).
CaseG: Adding two FieldElement values.
CaseH: Adding a MemoryAddress to a value that exceeds the uint64 limit (should result in an error).
CaseI: Adding a FieldElement to a value that exceeds the uint64 limit (should result in an error).
CaseJ: Adding two values where one of them is neither a MemoryAddress nor a FieldElement (should result in an error).
*/
func TestComputeAddRes(t *testing.T) {
vm := defaultVirtualMachine()

Expand Down Expand Up @@ -330,6 +342,76 @@ func TestComputeAddRes(t *testing.T) {
assert.Equal(t, expected, res)
}

func TestComputeAddRes_Jake(t *testing.T) {
// Test cases
testCases := []struct {
name string
op0 *mem.MemoryValue
op1 *mem.MemoryValue
expected *mem.MemoryValue
shouldFail bool
}{
{
name: "Type A : MemoryAddressPlusInteger",
op0: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)),
op1: mem.MemoryValueFromInt(15),
expected: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 25)),
},
{
name: "Type B : MemoryAddressPlusMemoryAddress_SameSegment",
op0: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)),
op1: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 15)),
expected: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 25)),
},
{
name: "Type C : MemoryAddressPlusMemoryAddress_DifferentSegment",
op0: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)),
op1: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(3, 15)),
shouldFail: true,
},
{
name: "Type D : IntegerPlusInteger",
op0: mem.MemoryValueFromInt(10),
op1: mem.MemoryValueFromInt(15),
expected: mem.MemoryValueFromInt(25),
},
{
name: "Type E : IntegerPlusMemoryAddress",
op0: mem.MemoryValueFromInt(15),
op1: mem.MemoryValueFromMemoryAddress(mem.NewMemoryAddress(2, 10)),
shouldFail: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
vm := defaultVirtualMachine()

instruction := Instruction{
Res: AddOperands,
}

cellOp0 := &mem.Cell{
Accessed: true,
Value: tc.op0,
}

cellOp1 := &mem.Cell{
Accessed: true,
Value: tc.op1,
}

res, err := vm.computeRes(&instruction, cellOp0, cellOp1)
if tc.shouldFail {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tc.expected, res)
})
}
}

func TestOpcodeAssertionAssertEq(t *testing.T) {
vm := defaultVirtualMachine()

Expand Down

0 comments on commit 610753b

Please sign in to comment.