Skip to content

Commit

Permalink
implement is_nn Cairo0 hint (#207)
Browse files Browse the repository at this point in the history
implement several Cairo0 hint

is_nn hint uses the assert_felt_le beneath it, but
it was implemented beforehand.

This PR has no tests included since #204 is not solved yet.
I used a couple of Cairo0 scripts to test this functionality
with a set of different arguments to cover both hints
that are a part of `is_nn` function.
(One of them handles negatives while another is for the non-negatives.)

Refs #164
  • Loading branch information
quasilyte committed Feb 20, 2024
1 parent d9abd6d commit acba30e
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package zero

const (
// This is a block for hint code strings where there is a single
// hint per function it belongs to (with some exceptions like testAssignCode).
allocSegmentCode string = "memory[ap] = segments.add()"
isLeFeltCode string = "memory[ap] = 0 if (ids.a % PRIME) <= (ids.b % PRIME) else 1"
assertLtFeltCode string = "from starkware.cairo.common.math_utils import assert_integer\nassert_integer(ids.a)\nassert_integer(ids.b)\nassert (ids.a % PRIME) < (ids.b % PRIME), \\\n f'a = {ids.a % PRIME} is not less than b = {ids.b % PRIME}.'"

// This is a very simple Cairo0 hint that allows us to test
// the identifier resolution code.
Expand All @@ -13,4 +17,8 @@ const (
assertLeFeltExcluded0Code string = "memory[ap] = 1 if excluded != 0 else 0"
assertLeFeltExcluded1Code string = "memory[ap] = 1 if excluded != 1 else 0"
assertLeFeltExcluded2Code string = "assert excluded == 2"

// is_nn() hints.
isNNCode string = "memory[ap] = 0 if 0 <= (ids.a % PRIME) < range_check_builtin.bound else 1"
isNNOutOfRangeCode string = "memory[ap] = 0 if 0 <= ((-ids.a - 1) % PRIME) < range_check_builtin.bound else 1"
)
170 changes: 170 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ import (
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
zero "github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

// GenericZeroHinter wraps an adhoc Cairo0 inline (pythonic) hint implementation.
Expand Down Expand Up @@ -54,6 +57,10 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
switch rawHint.Code {
case allocSegmentCode:
return CreateAllocSegmentHinter(resolver)
case isLeFeltCode:
return createIsLeFeltHinter(resolver)
case assertLtFeltCode:
return createAssertLtFeltHinter(resolver)
case testAssignCode:
return createTestAssignHinter(resolver)
case assertLeFeltCode:
Expand All @@ -64,6 +71,10 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createAssertLeFeltExcluded1Hinter(resolver)
case assertLeFeltExcluded2Code:
return createAssertLeFeltExcluded2Hinter(resolver)
case isNNCode:
return createIsNNHinter(resolver)
case isNNOutOfRangeCode:
return createIsNNOutOfRangeHinter(resolver)
default:
return nil, fmt.Errorf("Not identified hint")
}
Expand All @@ -73,6 +84,95 @@ func CreateAllocSegmentHinter(resolver hintReferenceResolver) (hinter.Hinter, er
return &core.AllocSegment{Dst: hinter.ApCellRef(0)}, nil
}

func createIsLeFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
argA, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}
argB, err := resolver.GetResOperander("b")
if err != nil {
return nil, err
}

h := &GenericZeroHinter{
Name: "IsLeFelt",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
//> memory[ap] = 0 if (ids.a % PRIME) <= (ids.b % PRIME) else 1
apAddr := vm.Context.AddressAp()

a, err := argA.Resolve(vm)
if err != nil {
return err
}
aFelt, err := a.FieldElement()
if err != nil {
return err
}
b, err := argB.Resolve(vm)
if err != nil {
return err
}
bFelt, err := b.FieldElement()
if err != nil {
return err
}

var v memory.MemoryValue
if utils.FeltLe(aFelt, bFelt) {
v = memory.MemoryValueFromFieldElement(&utils.FeltZero)
} else {
v = memory.MemoryValueFromFieldElement(&utils.FeltOne)
}
return vm.Memory.WriteToAddress(&apAddr, &v)
},
}
return h, nil
}

func createAssertLtFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
argA, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}
argB, err := resolver.GetResOperander("b")
if err != nil {
return nil, err
}

h := &GenericZeroHinter{
Name: "AssertLtFelt",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
//> from starkware.cairo.common.math_utils import assert_integer
//> assert_integer(ids.a)
//> assert_integer(ids.b)
//> assert (ids.a % PRIME) < (ids.b % PRIME),
//> f'a = {ids.a % PRIME} is not less than b = {ids.b % PRIME}.'
a, err := argA.Resolve(vm)
if err != nil {
return err
}
aFelt, err := a.FieldElement()
if err != nil {
return err
}
b, err := argB.Resolve(vm)
if err != nil {
return err
}
bFelt, err := b.FieldElement()
if err != nil {
return err
}

if !utils.FeltLt(aFelt, bFelt) {
return fmt.Errorf("a = %v is not less than b = %v", aFelt, bFelt)
}
return nil
},
}
return h, nil
}

func createTestAssignHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
arg, err := resolver.GetReference("a")
if err != nil {
Expand Down Expand Up @@ -144,6 +244,76 @@ func createAssertLeFeltExcluded2Hinter(resolver hintReferenceResolver) (hinter.H
return h, nil
}

func createIsNNHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
argA, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}

h := &GenericZeroHinter{
Name: "IsNN",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
apAddr := vm.Context.AddressAp()
//> memory[ap] = 0 if 0 <= (ids.a % PRIME) < range_check_builtin.bound else 1
a, err := argA.Resolve(vm)
if err != nil {
return err
}
// aFelt is already modulo PRIME, no need to adjust it.
aFelt, err := a.FieldElement()
if err != nil {
return err
}
// range_check_builtin.bound is utils.FeltMax128 (1 << 128).
var v memory.MemoryValue
if utils.FeltLt(aFelt, &utils.FeltMax128) {
v = memory.MemoryValueFromFieldElement(&utils.FeltZero)
} else {
v = memory.MemoryValueFromFieldElement(&utils.FeltOne)
}
return vm.Memory.WriteToAddress(&apAddr, &v)
},
}
return h, nil
}

func createIsNNOutOfRangeHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
// This hint is executed for the negative values.
// If the value was non-negative, it's usually handled by the IsNN hint.

argA, err := resolver.GetResOperander("a")
if err != nil {
return nil, err
}

h := &GenericZeroHinter{
Name: "IsNNOutOfRange",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
apAddr := vm.Context.AddressAp()
//> memory[ap] = 0 if 0 <= ((-ids.a - 1) % PRIME) < range_check_builtin.bound else 1
a, err := argA.Resolve(vm)
if err != nil {
return err
}
aFelt, err := a.FieldElement()
if err != nil {
return err
}
var lhs fp.Element
lhs.Sub(&utils.FeltZero, aFelt) //> -ids.a
lhs.Sub(&lhs, &utils.FeltOne)
var v memory.MemoryValue
if utils.FeltLt(aFelt, &utils.FeltMax128) {
v = memory.MemoryValueFromFieldElement(&utils.FeltZero)
} else {
v = memory.MemoryValueFromFieldElement(&utils.FeltOne)
}
return vm.Memory.WriteToAddress(&apAddr, &v)
},
}
return h, nil
}

func getParameters(zeroProgram *zero.ZeroProgram, hint zero.Hint, hintPC uint64) (hintReferenceResolver, error) {
resolver := NewReferenceResolver()

Expand Down
13 changes: 13 additions & 0 deletions pkg/utils/math.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"math/bits"

"golang.org/x/exp/constraints"

"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

// Takes a uint64 and an int16 and outputs their addition as well
Expand Down Expand Up @@ -59,3 +61,14 @@ func Max[T constraints.Integer](a, b T) T {
}
return b
}

// FeltLt implements `a < b` felt comparison.
func FeltLt(a, b *fp.Element) bool {
return a.Cmp(b) == -1
}

// FeltLe implements `a <= b` felt comparison.
func FeltLe(a, b *fp.Element) bool {
// a is less or equal than b if it's not greater than b.
return a.Cmp(b) != 1
}

0 comments on commit acba30e

Please sign in to comment.