Skip to content

Commit

Permalink
implement cairo0 hints needed for assert_le_felt (#203)
Browse files Browse the repository at this point in the history
In turn, `assert_le_felt` is needed for some other
functions that use cairo0 hints, like `is_nn`.
  • Loading branch information
quasilyte committed Feb 14, 2024
1 parent 4c3b130 commit d9abd6d
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 31 deletions.
18 changes: 13 additions & 5 deletions integration_tests/cairozero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestCairoZeroFiles(t *testing.T) {
continue
}

pyTraceFile, pyMemoryFile, err := runPythonVm(compiledOutput)
pyTraceFile, pyMemoryFile, err := runPythonVm(dirEntry.Name(), compiledOutput)
if err != nil {
t.Error(err)
continue
Expand Down Expand Up @@ -114,20 +114,28 @@ func compileZeroCode(path string) (string, error) {

// given a path to a compiled cairo zero file, execute it using the
// python vm and returns the trace and memory files location
func runPythonVm(path string) (string, string, error) {
func runPythonVm(testFilename, path string) (string, string, error) {
traceOutput := swapExtenstion(path, pyTraceSuffix)
memoryOutput := swapExtenstion(path, pyMemorySuffix)

cmd := exec.Command(
"cairo-run",
args := []string{
"--program",
path,
"--proof_mode",
"--trace_file",
traceOutput,
"--memory_file",
memoryOutput,
)
}

// If any other layouts are needed, add the suffix checks here.
// The convention would be: ".$layout.cairo"
// A file without this suffix will use the default ("plain") layout.
if strings.HasSuffix(testFilename, ".small.cairo") {
args = append(args, "--layout", "small")
}

cmd := exec.Command("cairo-run", args...)

res, err := cmd.CombinedOutput()
if err != nil {
Expand Down
24 changes: 12 additions & 12 deletions pkg/hintrunner/core/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -1319,9 +1319,9 @@ func (hint *AllocConstantSize) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRu
}

type AssertLeFindSmallArc struct {
a hinter.ResOperander
b hinter.ResOperander
rangeCheckPtr hinter.ResOperander
A hinter.ResOperander
B hinter.ResOperander
RangeCheckPtr hinter.ResOperander
}

func (hint *AssertLeFindSmallArc) String() string {
Expand All @@ -1332,14 +1332,14 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
primeOver3High := uint256.Int{6148914691236517206, 192153584101141168, 0, 0}
primeOver2High := uint256.Int{9223372036854775809, 288230376151711752, 0, 0}

a, err := hint.a.Resolve(vm)
a, err := hint.A.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve a operand %s: %w", hint.a, err)
return fmt.Errorf("resolve a operand %s: %w", hint.A, err)
}

b, err := hint.b.Resolve(vm)
b, err := hint.B.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve b operand %s: %w", hint.b, err)
return fmt.Errorf("resolve b operand %s: %w", hint.B, err)
}

aFelt, err := a.FieldElement()
Expand Down Expand Up @@ -1374,7 +1374,7 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
// Exclude the largest arc after sorting
ctx.ExcludedArc = lengthsAndIndices[2].Position

rangeCheckPtrMemAddr, err := hinter.ResolveAsAddress(vm, hint.rangeCheckPtr)
rangeCheckPtrMemAddr, err := hinter.ResolveAsAddress(vm, hint.RangeCheckPtr)
if err != nil {
return fmt.Errorf("resolve range check pointer: %w", err)
}
Expand Down Expand Up @@ -1434,15 +1434,15 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
}

type AssertLeIsFirstArcExcluded struct {
skipExcludeAFlag hinter.CellRefer
SkipExcludeAFlag hinter.CellRefer
}

func (hint *AssertLeIsFirstArcExcluded) String() string {
return "AssertLeIsFirstArcExcluded"
}

func (hint *AssertLeIsFirstArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
addr, err := hint.skipExcludeAFlag.Get(vm)
addr, err := hint.SkipExcludeAFlag.Get(vm)
if err != nil {
return fmt.Errorf("get skipExcludeAFlag addr: %v", err)
}
Expand All @@ -1458,15 +1458,15 @@ func (hint *AssertLeIsFirstArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hint
}

type AssertLeIsSecondArcExcluded struct {
skipExcludeBMinusA hinter.CellRefer
SkipExcludeBMinusA hinter.CellRefer
}

func (hint *AssertLeIsSecondArcExcluded) String() string {
return "AssertLeIsSecondArcExcluded"
}

func (hint *AssertLeIsSecondArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
addr, err := hint.skipExcludeBMinusA.Get(vm)
addr, err := hint.SkipExcludeBMinusA.Get(vm)
if err != nil {
return fmt.Errorf("get skipExcludeBMinusA addr: %v", err)
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/hintrunner/core/hint_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func BenchmarkAssertLeIsFirstArcExcluded(b *testing.B) {
for i := 0; i < b.N; i++ {

hint := AssertLeIsFirstArcExcluded{
skipExcludeAFlag: skipExcludeAFlag,
SkipExcludeAFlag: skipExcludeAFlag,
}

err := hint.Execute(vm, &ctx)
Expand Down Expand Up @@ -356,7 +356,7 @@ func BenchmarkAssertLeIsSecondArcExcluded(b *testing.B) {
for i := 0; i < b.N; i++ {

hint := AssertLeIsSecondArcExcluded{
skipExcludeBMinusA: skipExcludeBMinusA,
SkipExcludeBMinusA: skipExcludeBMinusA,
}

err := hint.Execute(vm, &ctx)
Expand Down Expand Up @@ -392,9 +392,9 @@ func BenchmarkAssertLeFindSmallArc(b *testing.B) {
r1 := utils.RandomFeltElement(rand)
r2 := utils.RandomFeltElement(rand)
hint := AssertLeFindSmallArc{
a: hinter.Immediate(r1),
b: hinter.Immediate(r2),
rangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
A: hinter.Immediate(r1),
B: hinter.Immediate(r2),
RangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
}

if err := hint.Execute(vm, &ctx); err != nil &&
Expand Down
10 changes: 5 additions & 5 deletions pkg/hintrunner/core/hint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,9 @@ func TestAssertLeFindSmallArc(t *testing.T) {
utils.WriteTo(vm, VM.ExecutionSegment, vm.Context.Ap, mem.MemoryValueFromMemoryAddress(&addr))

hint := AssertLeFindSmallArc{
a: hinter.Immediate(tc.aFelt),
b: hinter.Immediate(tc.bFelt),
rangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
A: hinter.Immediate(tc.aFelt),
B: hinter.Immediate(tc.bFelt),
RangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
}

ctx := hinter.HintRunnerContext{
Expand Down Expand Up @@ -1027,7 +1027,7 @@ func TestAssertLeIsFirstArcExcluded(t *testing.T) {
var flag hinter.ApCellRef = 0

hint := AssertLeIsFirstArcExcluded{
skipExcludeAFlag: flag,
SkipExcludeAFlag: flag,
}

err := hint.Execute(vm, &ctx)
Expand All @@ -1053,7 +1053,7 @@ func TestAssertLeIsSecondArcExcluded(t *testing.T) {
var flag hinter.ApCellRef = 0

hint := AssertLeIsSecondArcExcluded{
skipExcludeBMinusA: flag,
SkipExcludeBMinusA: flag,
}

err := hint.Execute(vm, &ctx)
Expand Down
14 changes: 14 additions & 0 deletions pkg/hintrunner/zero/hint_reference_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ func (m *hintReferenceResolver) GetReference(name string) (hinter.Reference, err
return nil, fmt.Errorf("missing reference %s", name)
}

// GetResOperander returns the result of GetReference type-asserted to ResOperander.
// If reference is not found or it's not of ResOperander type, a non-nil error is returned.
func (m *hintReferenceResolver) GetResOperander(name string) (hinter.ResOperander, error) {
ref, err := m.GetReference(name)
if err != nil {
return nil, err
}
op, ok := ref.(hinter.ResOperander)
if !ok {
return nil, fmt.Errorf("expected %s to be ResOperander (got %T)", name, ref)
}
return op, nil
}

// shortSymbolName turns a full symbol name like "a.b.c" into just "c".
func shortSymbolName(name string) string {
i := strings.LastIndexByte(name, '.')
Expand Down
10 changes: 8 additions & 2 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package zero

const (
AllocSegmentCode string = "memory[ap] = segments.add()"
allocSegmentCode string = "memory[ap] = segments.add()"

// This is a very simple Cairo0 hint that allows us to test
// the identifier resolution code.
// Depending on the context, ids.a may be a complex reference.
TestAssignCode string = "memory[ap] = ids.a"
testAssignCode string = "memory[ap] = ids.a"

// assert_le_felt() hints.
assertLeFeltCode string = "import itertools\n\nfrom starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.a)\nassert_integer(ids.b)\na = ids.a % PRIME\nb = ids.b % PRIME\nassert a <= b, f'a = {a} is not less than or equal to b = {b}.'\n\n# Find an arc less than PRIME / 3, and another less than PRIME / 2.\nlengths_and_indices = [(a, 0), (b - a, 1), (PRIME - 1 - b, 2)]\nlengths_and_indices.sort()\nassert lengths_and_indices[0][0] <= PRIME // 3 and lengths_and_indices[1][0] <= PRIME // 2\nexcluded = lengths_and_indices[2][1]\n\nmemory[ids.range_check_ptr + 1], memory[ids.range_check_ptr + 0] = (\n divmod(lengths_and_indices[0][0], ids.PRIME_OVER_3_HIGH))\nmemory[ids.range_check_ptr + 3], memory[ids.range_check_ptr + 2] = (\n divmod(lengths_and_indices[1][0], ids.PRIME_OVER_2_HIGH))"
assertLeFeltExcluded0Code string = "memory[ap] = 1 if excluded != 0 else 0"
assertLeFeltExcluded1Code string = "memory[ap] = 1 if excluded != 1 else 0"
assertLeFeltExcluded2Code string = "assert excluded == 2"
)
58 changes: 56 additions & 2 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,18 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
}

switch rawHint.Code {
case AllocSegmentCode:
case allocSegmentCode:
return CreateAllocSegmentHinter(resolver)
case TestAssignCode:
case testAssignCode:
return createTestAssignHinter(resolver)
case assertLeFeltCode:
return createAssertLeFeltHinter(resolver)
case assertLeFeltExcluded0Code:
return createAssertLeFeltExcluded0Hinter(resolver)
case assertLeFeltExcluded1Code:
return createAssertLeFeltExcluded1Hinter(resolver)
case assertLeFeltExcluded2Code:
return createAssertLeFeltExcluded2Hinter(resolver)
default:
return nil, fmt.Errorf("Not identified hint")
}
Expand Down Expand Up @@ -90,6 +98,52 @@ func createTestAssignHinter(resolver hintReferenceResolver) (hinter.Hinter, erro
return h, nil
}

func createAssertLeFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
a, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}
b, err := resolver.GetResOperander("b")
if err != nil {
return nil, err
}
rangeCheckPtr, err := resolver.GetResOperander("range_check_ptr")
if err != nil {
return nil, err
}

h := &core.AssertLeFindSmallArc{
A: a,
B: b,
RangeCheckPtr: rangeCheckPtr,
}
return h, nil
}

func createAssertLeFeltExcluded0Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
return &core.AssertLeIsFirstArcExcluded{SkipExcludeAFlag: hinter.ApCellRef(0)}, nil
}

func createAssertLeFeltExcluded1Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
return &core.AssertLeIsSecondArcExcluded{SkipExcludeBMinusA: hinter.ApCellRef(0)}, nil
}

func createAssertLeFeltExcluded2Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
// This hint is Cairo0-specific.
// It only does a python-scoped variable named "excluded" assert.
// We store that variable inside a hinter context.
h := &GenericZeroHinter{
Name: "AssertLeFeltExcluded2",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
if ctx.ExcludedArc != 2 {
return fmt.Errorf("assertion `excluded == 2` failed")
}
return nil
},
}
return h, nil
}

func getParameters(zeroProgram *zero.ZeroProgram, hint zero.Hint, hintPC uint64) (hintReferenceResolver, error) {
resolver := NewReferenceResolver()

Expand Down

0 comments on commit d9abd6d

Please sign in to comment.