Skip to content

Commit

Permalink
Reduce allocations on VM instruction loop (#79)
Browse files Browse the repository at this point in the history
* Less pointers more value types in memory package

* Make vm operate on memory addresses instead of cells

* remove hint runner for the moment

* make vm.Pc a value

* UknownValue constant

* Restore hintrunner

* update hintrunner to use optimized mem

* Restore hintrunner to zero runner

* Linting on tests

---------

Co-authored-by: Rodrigo <rodrodpino@gmail.com>
  • Loading branch information
omerfirmak and rodrigo-pino committed Oct 2, 2023
1 parent 2301ad6 commit b13df19
Show file tree
Hide file tree
Showing 15 changed files with 597 additions and 680 deletions.
21 changes: 11 additions & 10 deletions pkg/hintrunner/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ func (hint AllocSegment) Execute(vm *VM.VirtualMachine) error {
segmentIndex := vm.MemoryManager.Memory.AllocateEmptySegment()
memAddress := memory.MemoryValueFromSegmentAndOffset(segmentIndex, 0)

cell, err := hint.dst.Get(vm)
regAddr, err := hint.dst.Get(vm)
if err != nil {
return fmt.Errorf("get destination cell: %v", err)
return fmt.Errorf("get register %s: %w", hint.dst, err)
}

err = cell.Write(memAddress)
err = vm.MemoryManager.Memory.WriteToAddress(&regAddr, &memAddress)
if err != nil {
return fmt.Errorf("write cell: %v", err)
return fmt.Errorf("write to address %s: %v", regAddr, err)
}

return nil
Expand All @@ -52,12 +52,12 @@ func (hint TestLessThan) String() string {
func (hint TestLessThan) Execute(vm *VM.VirtualMachine) error {
lhsVal, err := hint.lhs.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve lhs operand %s: %v", hint.lhs, err)
return fmt.Errorf("resolve lhs operand %s: %w", hint.lhs, err)
}

rhsVal, err := hint.rhs.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve rhs operand %s: %v", hint.rhs, err)
return fmt.Errorf("resolve rhs operand %s: %w", hint.rhs, err)
}

lhsFelt, err := lhsVal.ToFieldElement()
Expand All @@ -75,14 +75,15 @@ func (hint TestLessThan) Execute(vm *VM.VirtualMachine) error {
resFelt.SetOne()
}

dstCell, err := hint.dst.Get(vm)
dstAddr, err := hint.dst.Get(vm)
if err != nil {
return fmt.Errorf("get destination cell: %v", err)
return fmt.Errorf("get dst address %s: %w", dstAddr, err)
}

err = dstCell.Write(memory.MemoryValueFromFieldElement(&resFelt))
mv := memory.MemoryValueFromFieldElement(&resFelt)
err = vm.MemoryManager.Memory.WriteToAddress(&dstAddr, &mv)
if err != nil {
return fmt.Errorf("write cell: %v", err)
return fmt.Errorf("write to dst address %s: %w", dstAddr, err)
}

return nil
Expand Down
10 changes: 8 additions & 2 deletions pkg/hintrunner/hintrunner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ func TestExistingHint(t *testing.T) {
10: allocHint,
})

vm.Context.Pc = memory.NewMemoryAddress(0, 10)
vm.Context.Pc = memory.MemoryAddress{
SegmentIndex: 0,
Offset: 10,
}
err := hr.RunHint(vm)
require.Nil(t, err)
require.Equal(
Expand All @@ -40,7 +43,10 @@ func TestNoHint(t *testing.T) {
10: allocHint,
})

vm.Context.Pc = memory.NewMemoryAddress(0, 100)
vm.Context.Pc = memory.MemoryAddress{
SegmentIndex: 0,
Offset: 100,
}
err := hr.RunHint(vm)
require.Nil(t, err)
require.Equal(t, 2, len(vm.MemoryManager.Memory.Segments))
Expand Down
79 changes: 46 additions & 33 deletions pkg/hintrunner/operand.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,35 @@ import (
type CellRefer interface {
fmt.Stringer

Get(vm *VM.VirtualMachine) (*memory.Cell, error)
Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error)
}

type ApCellRef int16

func (ap ApCellRef) String() string {
return "ApCellRef"
return fmt.Sprintf("ApCellRef(%d)", ap)
}

func (ap ApCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) {
func (ap ApCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) {
res, overflow := safemath.SafeOffset(vm.Context.Ap, int16(ap))
if overflow {
return nil, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap))
return memory.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(ap))
}
return vm.MemoryManager.Memory.Peek(VM.ExecutionSegment, res)
return memory.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil
}

type FpCellRef int16

func (fp FpCellRef) String() string {
return "FpCellRef"
return fmt.Sprintf("FpCellRef(%d)", fp)
}

func (fp FpCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) {
func (fp FpCellRef) Get(vm *VM.VirtualMachine) (memory.MemoryAddress, error) {
res, overflow := safemath.SafeOffset(vm.Context.Fp, int16(fp))
if overflow {
return nil, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp))
return memory.MemoryAddress{}, safemath.NewSafeOffsetError(vm.Context.Ap, int16(fp))
}
return vm.MemoryManager.Memory.Peek(VM.ExecutionSegment, res)
return memory.MemoryAddress{SegmentIndex: VM.ExecutionSegment, Offset: res}, nil
}

//
Expand All @@ -53,7 +53,7 @@ func (fp FpCellRef) Get(vm *VM.VirtualMachine) (*memory.Cell, error) {
type ResOperander interface {
fmt.Stringer

Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error)
Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error)
}

type Deref struct {
Expand All @@ -64,41 +64,47 @@ func (deref Deref) String() string {
return "Deref"
}

func (deref Deref) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) {
cell, err := deref.deref.Get(vm)
func (deref Deref) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) {
address, err := deref.deref.Get(vm)
if err != nil {
return nil, fmt.Errorf("get cell: %v", err)
return memory.MemoryValue{}, fmt.Errorf("get cell: %w", err)
}
return cell.Read(), nil
return vm.MemoryManager.Memory.ReadFromAddress(&address)
}

type DoubleDeref struct {
deref CellRefer
offset int16
}

func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) {
cell, err := dderef.deref.Get(vm)
func (dderef DoubleDeref) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) {
lhsAddr, err := dderef.deref.Get(vm)
if err != nil {
return nil, fmt.Errorf("get cell: %v", err)
return memory.MemoryValue{}, fmt.Errorf("get lhs address %s: %w", lhsAddr, err)
}
lhs, err := vm.MemoryManager.Memory.ReadFromAddress(&lhsAddr)
if err != nil {
return memory.MemoryValue{}, fmt.Errorf("read lhs address %s: %w", lhsAddr, err)
}
lhs := cell.Read()

// Double deref implies the first value read must be an address
// Double deref implies the left hand side read must be an address
address, err := lhs.ToMemoryAddress()
if err != nil {
return nil, err
return memory.MemoryValue{}, err
}

newOffset, overflow := safemath.SafeOffset(address.Offset, dderef.offset)
if overflow {
return nil, safemath.NewSafeOffsetError(address.Offset, dderef.offset)
return memory.MemoryValue{}, safemath.NewSafeOffsetError(address.Offset, dderef.offset)
}
resAddr := memory.MemoryAddress{
SegmentIndex: address.SegmentIndex,
Offset: newOffset,
}
resAddr := memory.NewMemoryAddress(address.SegmentIndex, newOffset)

value, err := vm.MemoryManager.Memory.ReadFromAddress(resAddr)
value, err := vm.MemoryManager.Memory.ReadFromAddress(&resAddr)
if err != nil {
return nil, fmt.Errorf("read cell: %v", err)
return memory.MemoryValue{}, fmt.Errorf("read result at %s: %w", resAddr, err)
}

return value, nil
Expand All @@ -112,7 +118,7 @@ func (imm Immediate) String() string {

// todo(rodro): Specs from Starkware stablish this can be uint256 and not a felt.
// Should we respect that, or go straight to felt?
func (imm Immediate) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) {
func (imm Immediate) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) {
felt := &f.Element{}
bigInt := (big.Int)(imm)
// todo(rodro): do we require to check that big int is lesser than P, or do we
Expand All @@ -139,24 +145,31 @@ func (bop BinaryOp) String() string {
return "BinaryOperator"
}

func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (*memory.MemoryValue, error) {
cell, err := bop.lhs.Get(vm)
func (bop BinaryOp) Resolve(vm *VM.VirtualMachine) (memory.MemoryValue, error) {
lhsAddr, err := bop.lhs.Get(vm)
if err != nil {
return memory.MemoryValue{}, fmt.Errorf("get lhs address %s: %w", bop.lhs, err)
}
lhs, err := vm.MemoryManager.Memory.ReadFromAddress(&lhsAddr)
if err != nil {
return nil, fmt.Errorf("get lhs operand %s: %v", bop.lhs, err)
return memory.MemoryValue{}, fmt.Errorf("read lhs address %s: %v", lhsAddr, err)
}
lhs := cell.Read()

rhs, err := bop.rhs.Resolve(vm)
if err != nil {
return nil, fmt.Errorf("resolve rhs operand %s: %v", rhs, err)
return memory.MemoryValue{}, fmt.Errorf("resolve rhs operand %s: %v", rhs, err)
}

switch bop.operator {
case Add:
return memory.EmptyMemoryValueAs(lhs.IsAddress()).Add(lhs, rhs)
mv := memory.EmptyMemoryValueAs(lhs.IsAddress() || rhs.IsAddress())
err := mv.Add(&lhs, &rhs)
return mv, err
case Mul:
return memory.EmptyMemoryValueAsFelt().Mul(lhs, rhs)
mv := memory.EmptyMemoryValueAsFelt()
err := mv.Mul(&lhs, &rhs)
return mv, err
default:
return nil, fmt.Errorf("unknown binary operator: %d", bop.operator)
return memory.MemoryValue{}, fmt.Errorf("unknown binary operator: %d", bop.operator)
}
}
17 changes: 10 additions & 7 deletions pkg/hintrunner/operand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ func TestGetAp(t *testing.T) {
vm.Context.Ap = 5
writeTo(vm, VM.ExecutionSegment, vm.Context.Ap+7, memory.MemoryValueFromInt(11))

var apCell ApCellRef = 7
cell, err := apCell.Get(vm)
var apReg ApCellRef = 7
apAddr, err := apReg.Get(vm)

require.NoError(t, err)

value, err := vm.MemoryManager.Memory.ReadFromAddress(&apAddr)
require.NoError(t, err)

value := cell.Read()
require.Equal(t, memory.MemoryValueFromInt(11), value)
}

Expand All @@ -28,12 +30,13 @@ func TestGetFp(t *testing.T) {
vm.Context.Fp = 15
writeTo(vm, VM.ExecutionSegment, vm.Context.Fp-7, memory.MemoryValueFromInt(11))

var fpCell FpCellRef = -7
cell, err := fpCell.Get(vm)
var fpReg FpCellRef = -7
fpAddr, err := fpReg.Get(vm)
require.NoError(t, err)

value, err := vm.MemoryManager.Memory.ReadFromAddress(&fpAddr)
require.NoError(t, err)

value := cell.Read()
require.Equal(t, memory.MemoryValueFromInt(11), value)
}

Expand All @@ -46,8 +49,8 @@ func TestResolveDeref(t *testing.T) {
deref := Deref{apCell}

value, err := deref.Resolve(vm)

require.NoError(t, err)

require.Equal(t, memory.MemoryValueFromInt(11), value)
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/hintrunner/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ func defaultVirtualMachine() *VM.VirtualMachine {
return vm
}

func writeTo(vm *VM.VirtualMachine, segment uint64, offset uint64, val *memory.MemoryValue) {
_ = vm.MemoryManager.Memory.Write(segment, offset, val)
func writeTo(vm *VM.VirtualMachine, segment uint64, offset uint64, val memory.MemoryValue) {
_ = vm.MemoryManager.Memory.Write(segment, offset, &val)
}

func readFrom(vm *VM.VirtualMachine, segment uint64, offset uint64) *memory.MemoryValue {
func readFrom(vm *VM.VirtualMachine, segment uint64, offset uint64) memory.MemoryValue {
val, _ := vm.MemoryManager.Memory.Read(segment, offset)
return val
}
Loading

0 comments on commit b13df19

Please sign in to comment.