Skip to content

Commit

Permalink
chore: monadic operations for memory values
Browse files Browse the repository at this point in the history
  • Loading branch information
ElijahVlasov committed Sep 8, 2023
1 parent fb1d625 commit 2b57126
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 67 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
require (
github.com/consensys/gnark-crypto v0.11.1
github.com/go-playground/validator/v10 v10.4.1
github.com/samber/mo v1.8.0
github.com/stretchr/testify v1.8.4
golang.org/x/exp v0.0.0-20230811145659-89c5cff77bcb
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/samber/mo v1.8.0 h1:vYjHTfg14JF9tD2NLhpoUsRi9bjyRoYwa4+do0nvbVw=
github.com/samber/mo v1.8.0/go.mod h1:BfkrCPuYzVG3ZljnZB783WIJIGk1mcZr9c9CPf8tAxs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
Expand Down
148 changes: 81 additions & 67 deletions pkg/vm/memory/memory_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
mo "github.com/samber/mo"
"golang.org/x/exp/constraints"
)

Expand Down Expand Up @@ -99,40 +100,32 @@ func (address MemoryAddress) String() string {
)
}

// This is an abreviation for simplicity
type memoryValue = mo.Either[*f.Element, *MemoryAddress]

// Stores all posible types that can be stored in a Memory cell,
//
// - either a Felt value (an `f.Element`),
// - or a pointer to another Memory Cell (a `MemoryAddress`)
//
// Both members cannot be non-nil at the same time
type MemoryValue struct {
felt *f.Element
address *MemoryAddress
}
type MemoryValue memoryValue

func MemoryValueFromMemoryAddress(address *MemoryAddress) *MemoryValue {
return &MemoryValue{
address: address,
}
var mv MemoryValue = MemoryValue(mo.Right[*f.Element, *MemoryAddress](address))
return &mv
}

func MemoryValueFromFieldElement(felt *f.Element) *MemoryValue {
return &MemoryValue{
felt: felt,
}
mv := MemoryValue(mo.Left[*f.Element, *MemoryAddress](felt))
return &mv
}

func MemoryValueFromInt[T constraints.Integer](v T) *MemoryValue {
newElement := f.NewElement(uint64(v))
return &MemoryValue{
felt: &newElement,
}
return MemoryValueFromFieldElement(&newElement)
}

func MemoryValueFromSegmentAndOffset[T constraints.Integer](segmentIndex, offset T) *MemoryValue {
return &MemoryValue{
address: &MemoryAddress{SegmentIndex: uint64(segmentIndex), Offset: uint64(offset)},
}
return MemoryValueFromMemoryAddress(&MemoryAddress{SegmentIndex: uint64(segmentIndex), Offset: uint64(offset)})
}

func MemoryValueFromAny(anyType any) (*MemoryValue, error) {
Expand All @@ -147,14 +140,10 @@ func MemoryValueFromAny(anyType any) (*MemoryValue, error) {
}

func EmptyMemoryValueAsFelt() *MemoryValue {
return &MemoryValue{
felt: new(f.Element),
}
return MemoryValueFromFieldElement(new(f.Element))
}
func EmptyMemoryValueAsAddress() *MemoryValue {
return &MemoryValue{
address: new(MemoryAddress),
}
return MemoryValueFromMemoryAddress(new(MemoryAddress))
}
func EmptyMemoryValueAs(address bool) *MemoryValue {
if address {
Expand All @@ -163,92 +152,114 @@ func EmptyMemoryValueAs(address bool) *MemoryValue {
return EmptyMemoryValueAsFelt()
}

func (mv *MemoryValue) toMemoryAddress() *MemoryAddress {
return memoryValue(*mv).RightOrEmpty()
}

func (mv *MemoryValue) ToMemoryAddress() (*MemoryAddress, error) {
if mv.address == nil {
address, isAddress := memoryValue(*mv).Right()
if !isAddress {
return nil, fmt.Errorf("error trying to read a memory value as an address")
}
return mv.address, nil
return address, nil
}

func (mv *MemoryValue) toFieldElement() *f.Element {
return memoryValue(*mv).LeftOrEmpty()
}

func (mv *MemoryValue) ToFieldElement() (*f.Element, error) {
if mv.felt == nil {
felt, isFelt := memoryValue(*mv).Left()
if !isFelt {
return nil, fmt.Errorf("error trying to read a memory value as a field element")
}
return mv.felt, nil
return felt, nil
}

func (mv *MemoryValue) ToAny() any {
if mv.felt != nil {
return mv.felt
felt, isFelt := memoryValue(*mv).Left()

if isFelt {
return felt
}
return mv.address

return memoryValue(*mv).RightOrEmpty()
}

func (mv *MemoryValue) IsAddress() bool {
return mv.address != nil
return memoryValue(*mv).IsRight()
}

func (mv *MemoryValue) IsFelt() bool {
return mv.felt != nil
return memoryValue(*mv).IsLeft()
}

func (mv *MemoryValue) Equal(other *MemoryValue) bool {
func (mv *MemoryValue) Equal(other *MemoryValue) (isEqual bool) {
if mv.IsAddress() && other.IsAddress() {
return mv.address.Equal(other.address)
}
if mv.IsFelt() && other.IsFelt() {
return mv.felt.Equal(other.felt)
return mv.toMemoryAddress().Equal(other.toMemoryAddress())
} else if mv.IsFelt() && other.IsFelt() {
return mv.toFieldElement().Equal(other.toFieldElement())
}
return false
}

// Adds two memory values is the second one is a Felt
func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (*MemoryValue, error) {
var err error
if lhs.IsAddress() {
if !rhs.IsFelt() {
return nil, fmt.Errorf("memory value addition requires a felt in the rhs")
}
mv.address, err = mv.address.Add(lhs.address, rhs.felt)
} else {
func (mv *MemoryValue) Add(lhs, rhs *MemoryValue) (res *MemoryValue, err error) {
memoryValue(*lhs).MapLeft(func(e *f.Element) mo.Either[*f.Element, *MemoryAddress] {
if rhs.IsAddress() {
mv.address, err = mv.address.Add(rhs.address, lhs.felt)
_, err = mv.toMemoryAddress().Add(rhs.toMemoryAddress(), e)
} else {
mv.felt = mv.felt.Add(lhs.felt, rhs.felt)
mv.toFieldElement().Add(e, rhs.toFieldElement())
}
}
return mo.Left[*f.Element, *MemoryAddress](nil)
}).MapRight(func(ma *MemoryAddress) mo.Either[*f.Element, *MemoryAddress] {
if !rhs.IsFelt() {
err = fmt.Errorf("memory value addition requires a felt in the rhs")
} else {
_, err = mv.toMemoryAddress().Add(ma, rhs.toFieldElement())
}

return memoryValue(*mv)
})

res = mv
if err != nil {
return nil, fmt.Errorf("error adding two memory values: %w", err)
res = nil
err = fmt.Errorf("error adding two memory values: %w", err)
}
return mv, nil

return
}

// Subs two memory values if they're in the same segment or the rhs is a Felt.
func (mv *MemoryValue) Sub(lhs, rhs *MemoryValue) (*MemoryValue, error) {
var err error
if lhs.IsAddress() {
mv.address, err = mv.address.Sub(lhs.address, rhs.ToAny())
} else {
func (mv *MemoryValue) Sub(lhs, rhs *MemoryValue) (res *MemoryValue, err error) {
memoryValue(*lhs).MapLeft(func(e *f.Element) mo.Either[*f.Element, *MemoryAddress] {
if rhs.IsAddress() {
return nil, fmt.Errorf("cannot substract a an address from a felt")
err = fmt.Errorf("cannot substract a an address from a felt")
} else {
mv.felt = mv.felt.Sub(lhs.felt, rhs.felt)
mv.toFieldElement().Sub(e, rhs.toFieldElement())
}
}
return mo.Left[*f.Element, *MemoryAddress](nil)
}).MapRight(func(ma *MemoryAddress) mo.Either[*f.Element, *MemoryAddress] {
_, err = mv.toMemoryAddress().Sub(ma, rhs.ToAny())
return memoryValue(*mv)
})

res = mv

if err != nil {
return nil, fmt.Errorf("error substracting two memory values: %w", err)
res = nil
err = fmt.Errorf("error substracting two memory values: %w", err)
}

return mv, nil
return
}

func (mv *MemoryValue) Mul(lhs, rhs *MemoryValue) (*MemoryValue, error) {
if lhs.IsAddress() || rhs.IsAddress() {
return nil, fmt.Errorf("cannot multiply memory addresses")
}
mv.felt.Mul(lhs.felt, rhs.felt)
mv.toFieldElement().Mul(lhs.toFieldElement(), rhs.toFieldElement())
return mv, nil
}

Expand All @@ -257,27 +268,30 @@ func (mv *MemoryValue) Div(lhs, rhs *MemoryValue) (*MemoryValue, error) {
return nil, fmt.Errorf("cannot divide memory addresses")
}

mv.felt.Div(lhs.felt, rhs.felt)
mv.toFieldElement().Div(lhs.toFieldElement(), rhs.toFieldElement())
return mv, nil
}

func (mv MemoryValue) String() string {
if mv.IsAddress() {
return mv.address.String()
return mv.toMemoryAddress().String()
}
return mv.felt.String()
return mv.toFieldElement().String()
}

// Retuns a MemoryValue holding a felt as uint if it fits
func (mv *MemoryValue) Uint64() (uint64, error) {
if mv.IsAddress() {
return 0, fmt.Errorf("cannot convert a memory address '%s' into uint64", *mv)
}
if !mv.felt.IsUint64() {

felt := mv.toFieldElement()

if !felt.IsUint64() {
return 0, fmt.Errorf("cannot convert a field element '%s' into uint64", *mv)
}

return mv.felt.Uint64(), nil
return felt.Uint64(), nil
}

// Note: Commenting this function since relocation is possibly going to look
Expand Down

0 comments on commit 2b57126

Please sign in to comment.