Skip to content

Commit

Permalink
implement cairo0 hints needed for assert_le_felt
Browse files Browse the repository at this point in the history
In turn, `assert_le_felt` is needed for some other
functions that use cairo0 hints, like `is_nn`.
  • Loading branch information
quasilyte committed Feb 13, 2024
1 parent 4c3b130 commit b046a42
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 22 deletions.
13 changes: 13 additions & 0 deletions integration_tests/cairo_files/math.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
%builtins range_check

from starkware.cairo.common.math import assert_le_felt

func main{range_check_ptr}() {
alloc_locals;
local v1 = 543;
local v2 = 657;

assert_le_felt(v1, v2);

ret;
}
24 changes: 12 additions & 12 deletions pkg/hintrunner/core/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -1319,9 +1319,9 @@ func (hint *AllocConstantSize) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRu
}

type AssertLeFindSmallArc struct {
a hinter.ResOperander
b hinter.ResOperander
rangeCheckPtr hinter.ResOperander
A hinter.ResOperander
B hinter.ResOperander
RangeCheckPtr hinter.ResOperander
}

func (hint *AssertLeFindSmallArc) String() string {
Expand All @@ -1332,14 +1332,14 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
primeOver3High := uint256.Int{6148914691236517206, 192153584101141168, 0, 0}
primeOver2High := uint256.Int{9223372036854775809, 288230376151711752, 0, 0}

a, err := hint.a.Resolve(vm)
a, err := hint.A.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve a operand %s: %w", hint.a, err)
return fmt.Errorf("resolve a operand %s: %w", hint.A, err)
}

b, err := hint.b.Resolve(vm)
b, err := hint.B.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve b operand %s: %w", hint.b, err)
return fmt.Errorf("resolve b operand %s: %w", hint.B, err)
}

aFelt, err := a.FieldElement()
Expand Down Expand Up @@ -1374,7 +1374,7 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
// Exclude the largest arc after sorting
ctx.ExcludedArc = lengthsAndIndices[2].Position

rangeCheckPtrMemAddr, err := hinter.ResolveAsAddress(vm, hint.rangeCheckPtr)
rangeCheckPtrMemAddr, err := hinter.ResolveAsAddress(vm, hint.RangeCheckPtr)
if err != nil {
return fmt.Errorf("resolve range check pointer: %w", err)
}
Expand Down Expand Up @@ -1434,15 +1434,15 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
}

type AssertLeIsFirstArcExcluded struct {
skipExcludeAFlag hinter.CellRefer
SkipExcludeAFlag hinter.CellRefer
}

func (hint *AssertLeIsFirstArcExcluded) String() string {
return "AssertLeIsFirstArcExcluded"
}

func (hint *AssertLeIsFirstArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
addr, err := hint.skipExcludeAFlag.Get(vm)
addr, err := hint.SkipExcludeAFlag.Get(vm)
if err != nil {
return fmt.Errorf("get skipExcludeAFlag addr: %v", err)
}
Expand All @@ -1458,15 +1458,15 @@ func (hint *AssertLeIsFirstArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hint
}

type AssertLeIsSecondArcExcluded struct {
skipExcludeBMinusA hinter.CellRefer
SkipExcludeBMinusA hinter.CellRefer
}

func (hint *AssertLeIsSecondArcExcluded) String() string {
return "AssertLeIsSecondArcExcluded"
}

func (hint *AssertLeIsSecondArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
addr, err := hint.skipExcludeBMinusA.Get(vm)
addr, err := hint.SkipExcludeBMinusA.Get(vm)
if err != nil {
return fmt.Errorf("get skipExcludeBMinusA addr: %v", err)
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/hintrunner/core/hint_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func BenchmarkAssertLeIsFirstArcExcluded(b *testing.B) {
for i := 0; i < b.N; i++ {

hint := AssertLeIsFirstArcExcluded{
skipExcludeAFlag: skipExcludeAFlag,
SkipExcludeAFlag: skipExcludeAFlag,
}

err := hint.Execute(vm, &ctx)
Expand Down Expand Up @@ -356,7 +356,7 @@ func BenchmarkAssertLeIsSecondArcExcluded(b *testing.B) {
for i := 0; i < b.N; i++ {

hint := AssertLeIsSecondArcExcluded{
skipExcludeBMinusA: skipExcludeBMinusA,
SkipExcludeBMinusA: skipExcludeBMinusA,
}

err := hint.Execute(vm, &ctx)
Expand Down Expand Up @@ -392,9 +392,9 @@ func BenchmarkAssertLeFindSmallArc(b *testing.B) {
r1 := utils.RandomFeltElement(rand)
r2 := utils.RandomFeltElement(rand)
hint := AssertLeFindSmallArc{
a: hinter.Immediate(r1),
b: hinter.Immediate(r2),
rangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
A: hinter.Immediate(r1),
B: hinter.Immediate(r2),
RangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
}

if err := hint.Execute(vm, &ctx); err != nil &&
Expand Down
10 changes: 5 additions & 5 deletions pkg/hintrunner/core/hint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,9 @@ func TestAssertLeFindSmallArc(t *testing.T) {
utils.WriteTo(vm, VM.ExecutionSegment, vm.Context.Ap, mem.MemoryValueFromMemoryAddress(&addr))

hint := AssertLeFindSmallArc{
a: hinter.Immediate(tc.aFelt),
b: hinter.Immediate(tc.bFelt),
rangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
A: hinter.Immediate(tc.aFelt),
B: hinter.Immediate(tc.bFelt),
RangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
}

ctx := hinter.HintRunnerContext{
Expand Down Expand Up @@ -1027,7 +1027,7 @@ func TestAssertLeIsFirstArcExcluded(t *testing.T) {
var flag hinter.ApCellRef = 0

hint := AssertLeIsFirstArcExcluded{
skipExcludeAFlag: flag,
SkipExcludeAFlag: flag,
}

err := hint.Execute(vm, &ctx)
Expand All @@ -1053,7 +1053,7 @@ func TestAssertLeIsSecondArcExcluded(t *testing.T) {
var flag hinter.ApCellRef = 0

hint := AssertLeIsSecondArcExcluded{
skipExcludeBMinusA: flag,
SkipExcludeBMinusA: flag,
}

err := hint.Execute(vm, &ctx)
Expand Down
14 changes: 14 additions & 0 deletions pkg/hintrunner/zero/hint_reference_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ func (m *hintReferenceResolver) GetReference(name string) (hinter.Reference, err
return nil, fmt.Errorf("missing reference %s", name)
}

// GetResOperander returns the result of GetReference type-asserted to ResOperander.
// If reference is not found or it's not of ResOperander type, a non-nil error is returned.
func (m *hintReferenceResolver) GetResOperander(name string) (hinter.ResOperander, error) {
ref, err := m.GetReference(name)
if err != nil {
return nil, err
}
op, ok := ref.(hinter.ResOperander)
if !ok {
return nil, fmt.Errorf("expected %s to be ResOperander (got %T)", name, ref)
}
return op, nil
}

// shortSymbolName turns a full symbol name like "a.b.c" into just "c".
func shortSymbolName(name string) string {
i := strings.LastIndexByte(name, '.')
Expand Down
6 changes: 6 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,10 @@ const (
// the identifier resolution code.
// Depending on the context, ids.a may be a complex reference.
TestAssignCode string = "memory[ap] = ids.a"

// assert_le_felt() hints.
assertLeFelt string = "import itertools\n\nfrom starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.a)\nassert_integer(ids.b)\na = ids.a % PRIME\nb = ids.b % PRIME\nassert a <= b, f'a = {a} is not less than or equal to b = {b}.'\n\n# Find an arc less than PRIME / 3, and another less than PRIME / 2.\nlengths_and_indices = [(a, 0), (b - a, 1), (PRIME - 1 - b, 2)]\nlengths_and_indices.sort()\nassert lengths_and_indices[0][0] <= PRIME // 3 and lengths_and_indices[1][0] <= PRIME // 2\nexcluded = lengths_and_indices[2][1]\n\nmemory[ids.range_check_ptr + 1], memory[ids.range_check_ptr + 0] = (\n divmod(lengths_and_indices[0][0], ids.PRIME_OVER_3_HIGH))\nmemory[ids.range_check_ptr + 3], memory[ids.range_check_ptr + 2] = (\n divmod(lengths_and_indices[1][0], ids.PRIME_OVER_2_HIGH))"
assertLeFeltExcluded0 string = "memory[ap] = 1 if excluded != 0 else 0"
assertLeFeltExcluded1 string = "memory[ap] = 1 if excluded != 1 else 0"
assertLeFeltExcluded2 string = "assert excluded == 2"
)
54 changes: 54 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return CreateAllocSegmentHinter(resolver)
case TestAssignCode:
return createTestAssignHinter(resolver)
case assertLeFelt:
return createAssertLeFeltHinter(resolver)
case assertLeFeltExcluded0:
return createAssertLeFeltExcluded0Hinter(resolver)
case assertLeFeltExcluded1:
return createAssertLeFeltExcluded1Hinter(resolver)
case assertLeFeltExcluded2:
return createAssertLeFeltExcluded2Hinter(resolver)
default:
return nil, fmt.Errorf("Not identified hint")
}
Expand Down Expand Up @@ -90,6 +98,52 @@ func createTestAssignHinter(resolver hintReferenceResolver) (hinter.Hinter, erro
return h, nil
}

func createAssertLeFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
a, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}
b, err := resolver.GetResOperander("b")
if err != nil {
return nil, err
}
rangeCheckPtr, err := resolver.GetResOperander("range_check_ptr")
if err != nil {
return nil, err
}

h := &core.AssertLeFindSmallArc{
A: a,
B: b,
RangeCheckPtr: rangeCheckPtr,
}
return h, nil
}

func createAssertLeFeltExcluded0Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
return &core.AssertLeIsFirstArcExcluded{SkipExcludeAFlag: hinter.ApCellRef(0)}, nil
}

func createAssertLeFeltExcluded1Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
return &core.AssertLeIsSecondArcExcluded{SkipExcludeBMinusA: hinter.ApCellRef(0)}, nil
}

func createAssertLeFeltExcluded2Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
// This hint is Cairo0-specific.
// It only does a python-scoped variable named "excluded" assert.
// We store that variable inside a hinter context.
h := &GenericZeroHinter{
Name: "AssertLeFeltExcluded2",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
if ctx.ExcludedArc != 2 {
return fmt.Errorf("assertion `excluded == 2` failed")
}
return nil
},
}
return h, nil
}

func getParameters(zeroProgram *zero.ZeroProgram, hint zero.Hint, hintPC uint64) (hintReferenceResolver, error) {
resolver := NewReferenceResolver()

Expand Down

0 comments on commit b046a42

Please sign in to comment.