Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce allocations on VM instruction loop #79

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions pkg/hintrunner/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ func (hint AllocSegment) Execute(vm *VM.VirtualMachine) error {
segmentIndex := vm.MemoryManager.Memory.AllocateEmptySegment()
memAddress := memory.MemoryValueFromSegmentAndOffset(segmentIndex, 0)

cell, err := hint.dst.Get(vm)
regAddr, err := hint.dst.Get(vm)
if err != nil {
return fmt.Errorf("get destination cell: %v", err)
return fmt.Errorf("get register %s: %w", hint.dst, err)
}

err = cell.Write(memAddress)
err = vm.MemoryManager.Memory.WriteToAddress(&regAddr, &memAddress)
if err != nil {
return fmt.Errorf("write cell: %v", err)
return fmt.Errorf("write to address %s: %v", regAddr, err)
}

return nil
Expand All @@ -52,12 +52,12 @@ func (hint TestLessThan) String() string {
func (hint TestLessThan) Execute(vm *VM.VirtualMachine) error {
lhsVal, err := hint.lhs.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve lhs operand %s: %v", hint.lhs, err)
return fmt.Errorf("resolve lhs operand %s: %w", hint.lhs, err)
}

rhsVal, err := hint.rhs.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve rhs operand %s: %v", hint.rhs, err)
return fmt.Errorf("resolve rhs operand %s: %w", hint.rhs, err)
}

lhsFelt, err := lhsVal.ToFieldElement()
Expand All @@ -75,14 +75,15 @@ func (hint TestLessThan) Execute(vm *VM.VirtualMachine) error {
resFelt.SetOne()
}

dstCell, err := hint.dst.Get(vm)
dstAddr, err := hint.dst.Get(vm)
if err != nil {
return fmt.Errorf("get destination cell: %v", err)
return fmt.Errorf("get dst address %s: %w", dstAddr, err)
}

err = dstCell.Write(memory.MemoryValueFromFieldElement(&resFelt))
mv := memory.MemoryValueFromFieldElement(&resFelt)
err = vm.MemoryManager.Memory.WriteToAddress(&dstAddr, &mv)
if err != nil {
return fmt.Errorf("write cell: %v", err)
return fmt.Errorf("write to dst address %s: %w", dstAddr, err)
}

return nil
Expand Down
10 changes: 8 additions & 2 deletions pkg/hintrunner/hintrunner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ func TestExistingHint(t *testing.T) {
10: allocHint,
})

vm.Context.Pc = memory.NewMemoryAddress(0, 10)
vm.Context.Pc = memory.MemoryAddress{
SegmentIndex: 0,
Offset: 10,
}
err := hr.RunHint(vm)
require.Nil(t, err)
require.Equal(
Expand All @@ -40,7 +43,10 @@ func TestNoHint(t *testing.T) {
10: allocHint,
})

vm.Context.Pc = memory.NewMemoryAddress(0, 100)
vm.Context.Pc = memory.MemoryAddress{
SegmentIndex: 0,
Offset: 100,
}
err := hr.RunHint(vm)
require.Nil(t, err)
require.Equal(t, 2, len(vm.MemoryManager.Memory.Segments))
Expand Down
79 changes: 46 additions & 33 deletions pkg/hintrunner/operand.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,35 @@ import (
type CellRefer interface {
fmt.Stringer

Get(vm *VM.VirtualMachine) (*memory.Cell, error)
Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error)
}

type ApCellRef int16

func (ap ApCellRef) String() string {
return "ApCellRef"
return fmt.Sprintf("ApCellRef(%d)", ap)
}

func (ap ApCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) {
func (ap ApCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) {
res, overflow := safemath.SafeOffset(vm.Context.Ap, int16(ap))
if overflow {
return nil, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap))
return memory.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap))
}
return vm.MemoryManager.Memory.Peek(VM.ExecutionSegment, res)
return memory.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil
}

type FpCellRef int16

func (fp FpCellRef) String() string {
return "FpCellRef"
return fmt.Sprintf("FpCellRef(%d)", fp)
}

func (fp FpCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) {
func (fp FpCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) {
res, overflow := safemath.SafeOffset(vm.Context.Fp, int16(fp))
if overflow {
return nil, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp))
return memory.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp))
}
return vm.MemoryManager.Memory.Peek(VM.ExecutionSegment, res)
return memory.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil
}

//
Expand All @@ -53,7 +53,7 @@ func (fp FpCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) {
type ResOperander interface {
fmt.Stringer

Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error)
Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error)
}

type Deref struct {
Expand All @@ -64,41 +64,47 @@ func (deref Deref) String() string {
return "Deref"
}

func (deref Deref) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) {
cell, err := deref.deref.Get(vm)
func (deref Deref) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) {
address, err := deref.deref.Get(vm)
if err != nil {
return nil, fmt.Errorf("get cell: %v", err)
return memory.MemoryValue{}, fmt.Errorf("get cell: %w", err)
}
return cell.Read(), nil
return vm.MemoryManager.Memory.ReadFromAddress(&address)
}

type DoubleDeref struct {
deref CellRefer
offset int16
}

func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) {
cell, err := dderef.deref.Get(vm)
func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) {
lhsAddr, err := dderef.deref.Get(vm)
if err != nil {
return nil, fmt.Errorf("get cell: %v", err)
return memory.MemoryValue{}, fmt.Errorf("get lhs address %s: %w", lhsAddr, err)
}
lhs, err := vm.MemoryManager.Memory.ReadFromAddress(&lhsAddr)
if err != nil {
return memory.MemoryValue{}, fmt.Errorf("read lhs address %s: %w", lhsAddr, err)
}
lhs := cell.Read()

// Double deref implies the first value read must be an address
// Double deref implies the left hand side read must be an address
address, err := lhs.ToMemoryAddress()
if err != nil {
return nil, err
return memory.MemoryValue{}, err
}

newOffset, overflow := safemath.SafeOffset(address.Offset, dderef.offset)
if overflow {
return nil, safemath.NewSafeOffsetError(address.Offset, dderef.offset)
return memory.MemoryValue{}, safemath.NewSafeOffsetError(address.Offset, dderef.offset)
}
resAddr := memory.MemoryAddress{
SegmentIndex: address.SegmentIndex,
Offset: newOffset,
}
resAddr := memory.NewMemoryAddress(address.SegmentIndex, newOffset)

value, err := vm.MemoryManager.Memory.ReadFromAddress(resAddr)
value, err := vm.MemoryManager.Memory.ReadFromAddress(&resAddr)
if err != nil {
return nil, fmt.Errorf("read cell: %v", err)
return memory.MemoryValue{}, fmt.Errorf("read result at %s: %w", resAddr, err)
}

return value, nil
Expand All @@ -112,7 +118,7 @@ func (imm Immediate) String() string {

// todo(rodro): Specs from Starkware stablish this can be uint256 and not a felt.
// Should we respect that, or go straight to felt?
func (imm Immediate) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) {
func (imm Immediate) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) {
felt := &f.Element{}
bigInt := (big.Int)(imm)
// todo(rodro): do we require to check that big int is lesser than P, or do we
Expand All @@ -139,24 +145,31 @@ func (bop BinaryOp) String() string {
return "BinaryOperator"
}

func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) {
cell, err := bop.lhs.Get(vm)
func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) {
lhsAddr, err := bop.lhs.Get(vm)
if err != nil {
return memory.MemoryValue{}, fmt.Errorf("get lhs address %s: %w", bop.lhs, err)
}
lhs, err := vm.MemoryManager.Memory.ReadFromAddress(&lhsAddr)
if err != nil {
return nil, fmt.Errorf("get lhs operand %s: %v", bop.lhs, err)
return memory.MemoryValue{}, fmt.Errorf("read lhs address %s: %v", lhsAddr, err)
}
lhs := cell.Read()

rhs, err := bop.rhs.Resolve(vm)
if err != nil {
return nil, fmt.Errorf("resolve rhs operand %s: %v", rhs, err)
return memory.MemoryValue{}, fmt.Errorf("resolve rhs operand %s: %v", rhs, err)
}

switch bop.operator {
case Add:
return memory.EmptyMemoryValueAs(lhs.IsAddress()).Add(lhs, rhs)
mv := memory.EmptyMemoryValueAs(lhs.IsAddress() || rhs.IsAddress())
err := mv.Add(&lhs, &rhs)
return mv, err
case Mul:
return memory.EmptyMemoryValueAsFelt().Mul(lhs, rhs)
mv := memory.EmptyMemoryValueAsFelt()
err := mv.Mul(&lhs, &rhs)
return mv, err
default:
return nil, fmt.Errorf("unknown binary operator: %d", bop.operator)
return memory.MemoryValue{}, fmt.Errorf("unknown binary operator: %d", bop.operator)
}
}
17 changes: 10 additions & 7 deletions pkg/hintrunner/operand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ func TestGetAp(t *testing.T) {
vm.Context.Ap = 5
writeTo(vm, VM.ExecutionSegment, vm.Context.Ap+7, memory.MemoryValueFromInt(11))

var apCell ApCellRef = 7
cell, err := apCell.Get(vm)
var apReg ApCellRef = 7
apAddr, err := apReg.Get(vm)

require.NoError(t, err)

value, err := vm.MemoryManager.Memory.ReadFromAddress(&apAddr)
require.NoError(t, err)

value := cell.Read()
require.Equal(t, memory.MemoryValueFromInt(11), value)
}

Expand All @@ -28,12 +30,13 @@ func TestGetFp(t *testing.T) {
vm.Context.Fp = 15
writeTo(vm, VM.ExecutionSegment, vm.Context.Fp-7, memory.MemoryValueFromInt(11))

var fpCell FpCellRef = -7
cell, err := fpCell.Get(vm)
var fpReg FpCellRef = -7
fpAddr, err := fpReg.Get(vm)
require.NoError(t, err)

value, err := vm.MemoryManager.Memory.ReadFromAddress(&fpAddr)
require.NoError(t, err)

value := cell.Read()
require.Equal(t, memory.MemoryValueFromInt(11), value)
}

Expand All @@ -46,8 +49,8 @@ func TestResolveDeref(t *testing.T) {
deref := Deref{apCell}

value, err := deref.Resolve(vm)

require.NoError(t, err)

require.Equal(t, memory.MemoryValueFromInt(11), value)
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/hintrunner/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ func defaultVirtualMachine() *VM.VirtualMachine {
return vm
}

func writeTo(vm *VM.VirtualMachine, segment uint64, offset uint64, val *memory.MemoryValue) {
_ = vm.MemoryManager.Memory.Write(segment, offset, val)
func writeTo(vm *VM.VirtualMachine, segment uint64, offset uint64, val memory.MemoryValue) {
_ = vm.MemoryManager.Memory.Write(segment, offset, &val)
}

func readFrom(vm *VM.VirtualMachine, segment uint64, offset uint64) *memory.MemoryValue {
func readFrom(vm *VM.VirtualMachine, segment uint64, offset uint64) memory.MemoryValue {
val, _ := vm.MemoryManager.Memory.Read(segment, offset)
return val
}
Loading