Skip to content

Commit

Permalink
Simplify error handling in memory package (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshklop committed Sep 12, 2023
1 parent 4342727 commit 1f15c48
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 59 deletions.
14 changes: 6 additions & 8 deletions pkg/vm/memory/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ type Cell struct {

func (cell *Cell) Write(value *MemoryValue) error {
if cell.Accessed && cell.Value != nil && !cell.Value.Equal(value) {
return fmt.Errorf(
"rewriting cell old value: %d new value: %d", &cell.Value, &value,
)
return fmt.Errorf("rewriting cell: old value \"%d\", new value \"%d\"", &cell.Value, &value)
}

cell.Accessed = true
Expand Down Expand Up @@ -77,7 +75,7 @@ func (segment *Segment) Write(offset uint64, value *MemoryValue) error {

err := segment.Data[offset].Write(value)
if err != nil {
return fmt.Errorf("error at index %d: %w", offset, err)
return fmt.Errorf("write cell at segment offset %d: %v", offset, err)
}
return nil
}
Expand Down Expand Up @@ -149,7 +147,7 @@ func (memory *Memory) AllocateSegment(data []*f.Element) (int, error) {
memVal := MemoryValueFromFieldElement(data[i])
err := newSegment.Write(uint64(i), memVal)
if err != nil {
return 0, fmt.Errorf("cannot allocate new segment: %w", err)
return 0, err
}
}
memory.Segments = append(memory.Segments, newSegment)
Expand All @@ -166,7 +164,7 @@ func (memory *Memory) AllocateEmptySegment() int {
// space or if rewriting a specific cell
func (memory *Memory) Write(segmentIndex uint64, offset uint64, value *MemoryValue) error {
if segmentIndex > uint64(len(memory.Segments)) {
return fmt.Errorf("writing to unallocated segment %d", segmentIndex)
return fmt.Errorf("unallocated segment at index %d", segmentIndex)
}

return memory.Segments[segmentIndex].Write(offset, value)
Expand All @@ -181,7 +179,7 @@ func (memory *Memory) WriteToAddress(address *MemoryAddress, value *MemoryValue)
// initalized with its default zero value
func (memory *Memory) Read(segmentIndex uint64, offset uint64) (*MemoryValue, error) {
if segmentIndex > uint64(len(memory.Segments)) {
return nil, fmt.Errorf("reading from unallocated segment %d", segmentIndex)
return nil, fmt.Errorf("unallocated segment at index %d", segmentIndex)
}
return memory.Segments[segmentIndex].Read(offset), nil
}
Expand All @@ -196,7 +194,7 @@ func (memory *Memory) ReadFromAddress(address *MemoryAddress) (*MemoryValue, err
// Given a segment index and offset returns a pointer to the Memory Cell
func (memory *Memory) Peek(segmentIndex uint64, offset uint64) (*Cell, error) {
if segmentIndex > uint64(len(memory.Segments)) {
return nil, fmt.Errorf("peeking from unallocated segment %d", segmentIndex)
return nil, fmt.Errorf("unallocated segment at index %d", segmentIndex)
}
return memory.Segments[segmentIndex].Peek(offset), nil
}
Expand Down
85 changes: 34 additions & 51 deletions pkg/vm/memory/memory_value.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package memory

import (
"errors"
"fmt"

f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
Expand All @@ -27,11 +28,7 @@ func (address *MemoryAddress) Equal(other *MemoryAddress) bool {
// Adds a memory address and a field element
func (address *MemoryAddress) Add(lhs *MemoryAddress, rhs *f.Element) (*MemoryAddress, error) {
if !rhs.IsUint64() {
return nil, fmt.Errorf(
"adding to %s a field element %s greater than uint64",
lhs.String(),
rhs.String(),
)
return nil, fmt.Errorf("field element does not fit in uint64: %s", rhs.String())
}

address.SegmentIndex = lhs.SegmentIndex
Expand All @@ -45,51 +42,35 @@ func (address *MemoryAddress) Sub(lhs *MemoryAddress, rhs any) (*MemoryAddress,
address.SegmentIndex = lhs.SegmentIndex

// Then update offset accordingly
switch t := rhs.(type) {
switch rhs := rhs.(type) {
case uint64:
rhs64 := rhs.(uint64)
if rhs64 > lhs.Offset {
return nil, fmt.Errorf("rhs offset greater than lhs offset")
if rhs > lhs.Offset {
return nil, errors.New("rhs is greater than lhs offset")
}
address.Offset = lhs.Offset - rhs64
address.Offset = lhs.Offset - rhs
return address, nil
case *f.Element:
feltRhs := rhs.(*f.Element)
if !feltRhs.IsUint64() {
return nil, fmt.Errorf(
"substracting from %s a field element %s greater than uint64",
lhs.String(),
feltRhs.String(),
)
if !rhs.IsUint64() {
return nil, fmt.Errorf("rhs field element does not fit in uint64: %s", rhs)
}
feltRhs64 := feltRhs.Uint64()
feltRhs64 := rhs.Uint64()
if feltRhs64 > lhs.Offset {
return nil, fmt.Errorf("rhs offset greater than lhs offset")
return nil, fmt.Errorf("rhs %d is greater than lhs offset %d", feltRhs64, lhs.Offset)
}
address.Offset = lhs.Offset - feltRhs64
return address, nil
case *MemoryAddress:
addressRhs := rhs.(*MemoryAddress)
if lhs.SegmentIndex != addressRhs.SegmentIndex {
return nil, fmt.Errorf(
"cannot substract %s from %s due to different segment location",
addressRhs.String(),
lhs.String(),
)
if lhs.SegmentIndex != rhs.SegmentIndex {
return nil, fmt.Errorf("addresses are in different segments: rhs is in %d, lhs is in %d",
rhs.SegmentIndex, lhs.SegmentIndex)
}
if addressRhs.Offset > lhs.Offset {
return nil, fmt.Errorf("rhs offset greater than lhs offset")
if rhs.Offset > lhs.Offset {
return nil, fmt.Errorf("rhs offset %d is greater than lhs offset %d", rhs.Offset, lhs.Offset)
}
address.Offset = lhs.Offset - addressRhs.Offset
address.Offset = lhs.Offset - rhs.Offset
return address, nil
default:
return nil,
fmt.Errorf(
"cannot substract from %s, invalid rhs type: %v. Expected a felt or another memory address",
address.String(),
t,
)

return nil, fmt.Errorf("unknown rhs type: %T", rhs)
}
}

Expand Down Expand Up @@ -143,15 +124,15 @@ func MemoryValueFromSegmentAndOffset[T constraints.Integer](segmentIndex, offset
}

func MemoryValueFromAny(anyType any) (*MemoryValue, error) {
switch t := anyType.(type) {
switch anyType := anyType.(type) {
case uint64:
return MemoryValueFromInt(anyType.(uint64)), nil
return MemoryValueFromInt(anyType), nil
case *f.Element:
return MemoryValueFromFieldElement(anyType.(*f.Element)), nil
return MemoryValueFromFieldElement(anyType), nil
case *MemoryAddress:
return MemoryValueFromMemoryAddress(anyType.(*MemoryAddress)), nil
return MemoryValueFromMemoryAddress(anyType), nil
default:
return nil, fmt.Errorf("invalid type to convert a memory value: %v", t)
return nil, fmt.Errorf("invalid type to convert to a MemoryValue: %T", anyType)
}
}

Expand All @@ -160,11 +141,13 @@ func EmptyMemoryValueAsFelt() *MemoryValue {
felt: new(f.Element),
}
}

func EmptyMemoryValueAsAddress() *MemoryValue {
return &MemoryValue{
address: new(MemoryAddress),
}
}

func EmptyMemoryValueAs(address bool) *MemoryValue {
if address {
return EmptyMemoryValueAsAddress()
Expand All @@ -174,14 +157,14 @@ func EmptyMemoryValueAs(address bool) *MemoryValue {

func (mv *MemoryValue) ToMemoryAddress() (*MemoryAddress, error) {
if mv.address == nil {
return nil, fmt.Errorf("error trying to read a memory value as an address")
return nil, errors.New("memory value is not an address")
}
return mv.address, nil
}

func (mv *MemoryValue) ToFieldElement() (*f.Element, error) {
if mv.felt == nil {
return nil, fmt.Errorf("error trying to read a memory value as a field element")
return nil, fmt.Errorf("memory value is not a field element")
}
return mv.felt, nil
}
Expand Down Expand Up @@ -216,7 +199,7 @@ func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (*MemoryValue, error) {
var err error
if lhs.IsAddress() {
if !rhs.IsFelt() {
return nil, fmt.Errorf("memory value addition requires a felt in the rhs")
return nil, errors.New("rhs is not a felt")
}
mv.address, err = mv.address.Add(lhs.address, rhs.felt)
} else {
Expand All @@ -228,7 +211,7 @@ func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (*MemoryValue, error) {
}

if err != nil {
return nil, fmt.Errorf("error adding two memory values: %w", err)
return nil, err
}
return mv, nil
}
Expand All @@ -240,30 +223,30 @@ func (mv *MemoryValue) Sub(lhs, rhs *MemoryValue) (*MemoryValue, error) {
mv.address, err = mv.address.Sub(lhs.address, rhs.ToAny())
} else {
if rhs.IsAddress() {
return nil, fmt.Errorf("cannot substract a an address from a felt")
return nil, errors.New("cannot substract an address from a felt")
} else {
mv.felt = mv.felt.Sub(lhs.felt, rhs.felt)
}
}

if err != nil {
return nil, fmt.Errorf("error substracting two memory values: %w", err)
return nil, err
}

return mv, nil
}

func (mv *MemoryValue) Mul(lhs, rhs *MemoryValue) (*MemoryValue, error) {
if lhs.IsAddress() || rhs.IsAddress() {
return nil, fmt.Errorf("cannot multiply memory addresses")
return nil, errors.New("cannot multiply memory addresses")
}
mv.felt.Mul(lhs.felt, rhs.felt)
return mv, nil
}

func (mv *MemoryValue) Div(lhs, rhs *MemoryValue) (*MemoryValue, error) {
if lhs.IsAddress() || rhs.IsAddress() {
return nil, fmt.Errorf("cannot divide memory addresses")
return nil, errors.New("cannot divide memory addresses")
}

mv.felt.Div(lhs.felt, rhs.felt)
Expand All @@ -280,10 +263,10 @@ func (mv MemoryValue) String() string {
// Retuns a MemoryValue holding a felt as uint if it fits
func (mv *MemoryValue) Uint64() (uint64, error) {
if mv.IsAddress() {
return 0, fmt.Errorf("cannot convert a memory address '%s' into uint64", *mv)
return 0, fmt.Errorf("cannot convert a memory address into uint64: %s", *mv)
}
if !mv.felt.IsUint64() {
return 0, fmt.Errorf("cannot convert a field element '%s' into uint64", *mv)
return 0, fmt.Errorf("field element does not fit in uint64: %s", mv.String())
}

return mv.felt.Uint64(), nil
Expand Down

0 comments on commit 1f15c48

Please sign in to comment.