Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement cairo0 hints needed for assert_le_felt #203

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions integration_tests/cairozero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestCairoZeroFiles(t *testing.T) {
continue
}

pyTraceFile, pyMemoryFile, err := runPythonVm(compiledOutput)
pyTraceFile, pyMemoryFile, err := runPythonVm(dirEntry.Name(), compiledOutput)
if err != nil {
t.Error(err)
continue
Expand Down Expand Up @@ -114,20 +114,28 @@ func compileZeroCode(path string) (string, error) {

// given a path to a compiled cairo zero file, execute it using the
// python vm and returns the trace and memory files location
func runPythonVm(path string) (string, string, error) {
func runPythonVm(testFilename, path string) (string, string, error) {
traceOutput := swapExtenstion(path, pyTraceSuffix)
memoryOutput := swapExtenstion(path, pyMemorySuffix)

cmd := exec.Command(
"cairo-run",
args := []string{
"--program",
path,
"--proof_mode",
"--trace_file",
traceOutput,
"--memory_file",
memoryOutput,
)
}

// If any other layouts are needed, add the suffix checks here.
// The convention would be: ".$layout.cairo"
// A file without this suffix will use the default ("plain") layout.
if strings.HasSuffix(testFilename, ".small.cairo") {
args = append(args, "--layout", "small")
}

cmd := exec.Command("cairo-run", args...)

res, err := cmd.CombinedOutput()
if err != nil {
Expand Down
24 changes: 12 additions & 12 deletions pkg/hintrunner/core/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -1319,9 +1319,9 @@ func (hint *AllocConstantSize) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRu
}

type AssertLeFindSmallArc struct {
a hinter.ResOperander
b hinter.ResOperander
rangeCheckPtr hinter.ResOperander
A hinter.ResOperander
B hinter.ResOperander
RangeCheckPtr hinter.ResOperander
}

func (hint *AssertLeFindSmallArc) String() string {
Expand All @@ -1332,14 +1332,14 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
primeOver3High := uint256.Int{6148914691236517206, 192153584101141168, 0, 0}
primeOver2High := uint256.Int{9223372036854775809, 288230376151711752, 0, 0}

a, err := hint.a.Resolve(vm)
a, err := hint.A.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve a operand %s: %w", hint.a, err)
return fmt.Errorf("resolve a operand %s: %w", hint.A, err)
}

b, err := hint.b.Resolve(vm)
b, err := hint.B.Resolve(vm)
if err != nil {
return fmt.Errorf("resolve b operand %s: %w", hint.b, err)
return fmt.Errorf("resolve b operand %s: %w", hint.B, err)
}

aFelt, err := a.FieldElement()
Expand Down Expand Up @@ -1374,7 +1374,7 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
// Exclude the largest arc after sorting
ctx.ExcludedArc = lengthsAndIndices[2].Position

rangeCheckPtrMemAddr, err := hinter.ResolveAsAddress(vm, hint.rangeCheckPtr)
rangeCheckPtrMemAddr, err := hinter.ResolveAsAddress(vm, hint.RangeCheckPtr)
if err != nil {
return fmt.Errorf("resolve range check pointer: %w", err)
}
Expand Down Expand Up @@ -1434,15 +1434,15 @@ func (hint *AssertLeFindSmallArc) Execute(vm *VM.VirtualMachine, ctx *hinter.Hin
}

type AssertLeIsFirstArcExcluded struct {
skipExcludeAFlag hinter.CellRefer
SkipExcludeAFlag hinter.CellRefer
}

func (hint *AssertLeIsFirstArcExcluded) String() string {
return "AssertLeIsFirstArcExcluded"
}

func (hint *AssertLeIsFirstArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
addr, err := hint.skipExcludeAFlag.Get(vm)
addr, err := hint.SkipExcludeAFlag.Get(vm)
if err != nil {
return fmt.Errorf("get skipExcludeAFlag addr: %v", err)
}
Expand All @@ -1458,15 +1458,15 @@ func (hint *AssertLeIsFirstArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hint
}

type AssertLeIsSecondArcExcluded struct {
skipExcludeBMinusA hinter.CellRefer
SkipExcludeBMinusA hinter.CellRefer
}

func (hint *AssertLeIsSecondArcExcluded) String() string {
return "AssertLeIsSecondArcExcluded"
}

func (hint *AssertLeIsSecondArcExcluded) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
addr, err := hint.skipExcludeBMinusA.Get(vm)
addr, err := hint.SkipExcludeBMinusA.Get(vm)
if err != nil {
return fmt.Errorf("get skipExcludeBMinusA addr: %v", err)
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/hintrunner/core/hint_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ func BenchmarkAssertLeIsFirstArcExcluded(b *testing.B) {
for i := 0; i < b.N; i++ {

hint := AssertLeIsFirstArcExcluded{
skipExcludeAFlag: skipExcludeAFlag,
SkipExcludeAFlag: skipExcludeAFlag,
}

err := hint.Execute(vm, &ctx)
Expand Down Expand Up @@ -356,7 +356,7 @@ func BenchmarkAssertLeIsSecondArcExcluded(b *testing.B) {
for i := 0; i < b.N; i++ {

hint := AssertLeIsSecondArcExcluded{
skipExcludeBMinusA: skipExcludeBMinusA,
SkipExcludeBMinusA: skipExcludeBMinusA,
}

err := hint.Execute(vm, &ctx)
Expand Down Expand Up @@ -392,9 +392,9 @@ func BenchmarkAssertLeFindSmallArc(b *testing.B) {
r1 := utils.RandomFeltElement(rand)
r2 := utils.RandomFeltElement(rand)
hint := AssertLeFindSmallArc{
a: hinter.Immediate(r1),
b: hinter.Immediate(r2),
rangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
A: hinter.Immediate(r1),
B: hinter.Immediate(r2),
RangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
}

if err := hint.Execute(vm, &ctx); err != nil &&
Expand Down
10 changes: 5 additions & 5 deletions pkg/hintrunner/core/hint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,9 +987,9 @@ func TestAssertLeFindSmallArc(t *testing.T) {
utils.WriteTo(vm, VM.ExecutionSegment, vm.Context.Ap, mem.MemoryValueFromMemoryAddress(&addr))

hint := AssertLeFindSmallArc{
a: hinter.Immediate(tc.aFelt),
b: hinter.Immediate(tc.bFelt),
rangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
A: hinter.Immediate(tc.aFelt),
B: hinter.Immediate(tc.bFelt),
RangeCheckPtr: hinter.Deref{Deref: hinter.ApCellRef(0)},
}

ctx := hinter.HintRunnerContext{
Expand Down Expand Up @@ -1027,7 +1027,7 @@ func TestAssertLeIsFirstArcExcluded(t *testing.T) {
var flag hinter.ApCellRef = 0

hint := AssertLeIsFirstArcExcluded{
skipExcludeAFlag: flag,
SkipExcludeAFlag: flag,
}

err := hint.Execute(vm, &ctx)
Expand All @@ -1053,7 +1053,7 @@ func TestAssertLeIsSecondArcExcluded(t *testing.T) {
var flag hinter.ApCellRef = 0

hint := AssertLeIsSecondArcExcluded{
skipExcludeBMinusA: flag,
SkipExcludeBMinusA: flag,
}

err := hint.Execute(vm, &ctx)
Expand Down
14 changes: 14 additions & 0 deletions pkg/hintrunner/zero/hint_reference_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ func (m *hintReferenceResolver) GetReference(name string) (hinter.Reference, err
return nil, fmt.Errorf("missing reference %s", name)
}

// GetResOperander returns the result of GetReference type-asserted to ResOperander.
// If reference is not found or it's not of ResOperander type, a non-nil error is returned.
func (m *hintReferenceResolver) GetResOperander(name string) (hinter.ResOperander, error) {
cicr99 marked this conversation as resolved.
Show resolved Hide resolved
ref, err := m.GetReference(name)
if err != nil {
return nil, err
}
op, ok := ref.(hinter.ResOperander)
if !ok {
return nil, fmt.Errorf("expected %s to be ResOperander (got %T)", name, ref)
}
return op, nil
}

// shortSymbolName turns a full symbol name like "a.b.c" into just "c".
func shortSymbolName(name string) string {
i := strings.LastIndexByte(name, '.')
Expand Down
10 changes: 8 additions & 2 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package zero

const (
AllocSegmentCode string = "memory[ap] = segments.add()"
allocSegmentCode string = "memory[ap] = segments.add()"

// This is a very simple Cairo0 hint that allows us to test
// the identifier resolution code.
// Depending on the context, ids.a may be a complex reference.
TestAssignCode string = "memory[ap] = ids.a"
testAssignCode string = "memory[ap] = ids.a"

// assert_le_felt() hints.
assertLeFeltCode string = "import itertools\n\nfrom starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.a)\nassert_integer(ids.b)\na = ids.a % PRIME\nb = ids.b % PRIME\nassert a <= b, f'a = {a} is not less than or equal to b = {b}.'\n\n# Find an arc less than PRIME / 3, and another less than PRIME / 2.\nlengths_and_indices = [(a, 0), (b - a, 1), (PRIME - 1 - b, 2)]\nlengths_and_indices.sort()\nassert lengths_and_indices[0][0] <= PRIME // 3 and lengths_and_indices[1][0] <= PRIME // 2\nexcluded = lengths_and_indices[2][1]\n\nmemory[ids.range_check_ptr + 1], memory[ids.range_check_ptr + 0] = (\n divmod(lengths_and_indices[0][0], ids.PRIME_OVER_3_HIGH))\nmemory[ids.range_check_ptr + 3], memory[ids.range_check_ptr + 2] = (\n divmod(lengths_and_indices[1][0], ids.PRIME_OVER_2_HIGH))"
assertLeFeltExcluded0Code string = "memory[ap] = 1 if excluded != 0 else 0"
assertLeFeltExcluded1Code string = "memory[ap] = 1 if excluded != 1 else 0"
assertLeFeltExcluded2Code string = "assert excluded == 2"
)
58 changes: 56 additions & 2 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,18 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
}

switch rawHint.Code {
case AllocSegmentCode:
case allocSegmentCode:
return CreateAllocSegmentHinter(resolver)
case TestAssignCode:
case testAssignCode:
return createTestAssignHinter(resolver)
case assertLeFeltCode:
return createAssertLeFeltHinter(resolver)
case assertLeFeltExcluded0Code:
return createAssertLeFeltExcluded0Hinter(resolver)
case assertLeFeltExcluded1Code:
return createAssertLeFeltExcluded1Hinter(resolver)
case assertLeFeltExcluded2Code:
return createAssertLeFeltExcluded2Hinter(resolver)
default:
return nil, fmt.Errorf("Not identified hint")
}
Expand Down Expand Up @@ -90,6 +98,52 @@ func createTestAssignHinter(resolver hintReferenceResolver) (hinter.Hinter, erro
return h, nil
}

func createAssertLeFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
a, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}
b, err := resolver.GetResOperander("b")
if err != nil {
return nil, err
}
rangeCheckPtr, err := resolver.GetResOperander("range_check_ptr")
if err != nil {
return nil, err
}

h := &core.AssertLeFindSmallArc{
A: a,
B: b,
RangeCheckPtr: rangeCheckPtr,
}
return h, nil
}

func createAssertLeFeltExcluded0Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
return &core.AssertLeIsFirstArcExcluded{SkipExcludeAFlag: hinter.ApCellRef(0)}, nil
}

func createAssertLeFeltExcluded1Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
return &core.AssertLeIsSecondArcExcluded{SkipExcludeBMinusA: hinter.ApCellRef(0)}, nil
}

func createAssertLeFeltExcluded2Hinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
// This hint is Cairo0-specific.
// It only does a python-scoped variable named "excluded" assert.
// We store that variable inside a hinter context.
h := &GenericZeroHinter{
Name: "AssertLeFeltExcluded2",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
if ctx.ExcludedArc != 2 {
return fmt.Errorf("assertion `excluded == 2` failed")
}
return nil
},
}
return h, nil
}

func getParameters(zeroProgram *zero.ZeroProgram, hint zero.Hint, hintPC uint64) (hintReferenceResolver, error) {
resolver := NewReferenceResolver()

Expand Down
Loading