Skip to content

Commit

Permalink
Implement SignedDivRemHint (#318)
Browse files Browse the repository at this point in the history
* Implement SignedDivRemHint

* Implement hint method

* Debug implementation, add testcases

* Debug implementation

* Fix merge problems

* Resolved comments from the PR

* Replace PrimeHigh calculation with const, modify the bound check condition for value

* Add as_int method

* Added test case for the as_int() functionality

* Applied suggestions from the comments

* Move AsInt to math_utils, make AsInt return BigInt pointer
  • Loading branch information
MaksymMalicki committed Apr 2, 2024
1 parent 73a2a57 commit 722b840
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 3 deletions.
19 changes: 19 additions & 0 deletions pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package utils

import (
"math/big"

"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

func AsInt(valueFelt *fp.Element) *big.Int {
var valueBig big.Int
valueFelt.BigInt(&valueBig)
boundBig := new(big.Int).Div(fp.Modulus(), big.NewInt(2))

// val if val < prime // 2 else val - prime
if valueBig.Cmp(boundBig) == -1 {
return &valueBig
}
return new(big.Int).Sub(&valueBig, fp.Modulus())
}
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
// split_int() hints.
splitIntAssertRange string = "assert ids.value == 0, 'split_int(): value is out of range.'"
splitIntCode string = "memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base\nassert res < ids.bound, f'split_int(): Limb {res} is out of range.'"
signedDivRemCode string = "from starkware.cairo.common.math_utils import as_int, assert_integer\nassert_integer(ids.div)\nassert 0 < ids.div <= PRIME // range_check_builtin.bound, f'div={hex(ids.div)} is out of the valid range.'\nassert_integer(ids.bound)\nassert ids.bound <= range_check_builtin.bound // 2, f'bound={hex(ids.bound)} is out of the valid range.'\nint_value = as_int(ids.value, PRIME)\nq, ids.r = divmod(int_value, ids.div)\nassert -ids.bound <= q < ids.bound, f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).'\nids.biased_q = q + ids.bound"

// pow hints
powCode string = "ids.locs.bit = (ids.prev_locs.exp % PRIME) & 1"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createSplitIntAssertRangeHinter(resolver)
case splitIntCode:
return createSplitIntHinter(resolver)
case signedDivRemCode:
return createSignedDivRemHinter(resolver)
case powCode:
return createPowHinter(resolver)
case splitFeltCode:
Expand Down
103 changes: 101 additions & 2 deletions pkg/hintrunner/zero/zerohint_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
math_utils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils"
"github.com/NethermindEth/cairo-vm-go/pkg/utils"
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
Expand Down Expand Up @@ -643,6 +644,105 @@ func createSplitFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error
return newSplitFeltHint(low, high, value), nil
}

func newSignedDivRemHint(value, div, bound, r, biased_q hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "SignedDivRem",
Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error {
//> from starkware.cairo.common.math_utils import as_int, assert_integer
//> assert_integer(ids.div)
//> assert 0 < ids.div <= PRIME // range_check_builtin.bound, f'div={hex(ids.div)} is out of the valid range.'
//> assert_integer(ids.bound)
//> assert ids.bound <= range_check_builtin.bound // 2, f'bound={hex(ids.bound)} is out of the valid range.'
//> int_value = as_int(ids.value, PRIME)
//> q, ids.r = divmod(int_value, ids.div)
//> assert -ids.bound <= q < ids.bound, f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).'
//> ids.biased_q = q + ids.bound

//> assert_integer(ids.div)
divFelt, err := hinter.ResolveAsFelt(vm, div)
if err != nil {
return err
}
//> assert 0 < ids.div <= PRIME // range_check_builtin.bound, f'div={hex(ids.div)} is out of the valid range.'
if divFelt.IsZero() || !utils.FeltLe(divFelt, &utils.PrimeHigh) {
return fmt.Errorf("div=%v is out of the valid range.", divFelt)
}

//> assert_integer(ids.bound)
boundFelt, err := hinter.ResolveAsFelt(vm, bound)
if err != nil {
return err
}
//> assert ids.bound <= range_check_builtin.bound // 2, f'bound={hex(ids.bound)} is out of the valid range.'
if !utils.FeltLe(boundFelt, &utils.Felt127) {
return fmt.Errorf("bound=%v is out of the valid range.", boundFelt)
}
//> int_value = as_int(ids.value, PRIME)
valueFelt, err := hinter.ResolveAsFelt(vm, value)
if err != nil {
return err
}
intValueBig := math_utils.AsInt(valueFelt)

//> q, ids.r = divmod(int_value, ids.div)
var divBig, boundBig big.Int
divFelt.BigInt(&divBig)
boundFelt.BigInt(&boundBig)
qBig, rBig := new(big.Int).DivMod(intValueBig, &divBig, new(big.Int))
rFelt := new(fp.Element).SetBigInt(rBig)
rAddr, err := r.GetAddress(vm)
if err != nil {
return err
}
rValue := memory.MemoryValueFromFieldElement(rFelt)
err = vm.Memory.WriteToAddress(&rAddr, &rValue)
if err != nil {
return err
}

//> assert -ids.bound <= q < ids.bound, f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).'
if !(qBig.Cmp(new(big.Int).Neg(&boundBig)) >= 0 && qBig.Cmp(&boundBig) == -1) {
return fmt.Errorf("%v / %v = %v is out of the range [-%v, %v].", valueFelt, divFelt, qBig, boundFelt, boundFelt)
}

//> ids.biased_q = q + ids.bound
biasedQBig := new(big.Int).Add(qBig, &boundBig)
biasedQ := new(fp.Element).SetBigInt(biasedQBig)
biasedQAddr, err := biased_q.GetAddress(vm)
if err != nil {
return err
}
biasedQValue := memory.MemoryValueFromFieldElement(biasedQ)
return vm.Memory.WriteToAddress(&biasedQAddr, &biasedQValue)
},
}
}

func createSignedDivRemHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
value, err := resolver.GetResOperander("value")
if err != nil {
return nil, err
}
div, err := resolver.GetResOperander("div")
if err != nil {
return nil, err
}
bound, err := resolver.GetResOperander("bound")
if err != nil {
return nil, err
}
r, err := resolver.GetResOperander("r")
if err != nil {
return nil, err
}
biased_q, err := resolver.GetResOperander("biased_q")
if err != nil {
return nil, err
}
return newSignedDivRemHint(value, div, bound, r, biased_q), nil

}

func newSqrtHint(root, value hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "Sqrt",
Expand Down Expand Up @@ -719,11 +819,10 @@ func newUnsignedDivRemHinter(value, div, q, r hinter.ResOperander) hinter.Hinter
if err != nil {
return err
}

// (PRIME // range_check_builtin.bound)
// 800000000000011000000000000000000000000000000000000000000000001 // 2**128
var divUpperBound big.Int
divUpperBound.SetString("8000000000000110000000000000000", 16)
utils.PrimeHigh.BigInt(&divUpperBound)

var divBig big.Int
div.BigInt(&divBig)
Expand Down
103 changes: 102 additions & 1 deletion pkg/hintrunner/zero/zerohint_math_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package zero

import (
"fmt"
"math/big"
"testing"

"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
Expand All @@ -10,6 +12,7 @@ import (
)

func TestZeroHintMath(t *testing.T) {

runHinterTests(t, map[string][]hintTestCase{
"IsLeFelt": {
{
Expand Down Expand Up @@ -619,7 +622,105 @@ func TestZeroHintMath(t *testing.T) {
}),
},
},

"SignedDivRem": {
{
operanders: []*hintOperander{
{Name: "value", Kind: apRelative, Value: &utils.FeltZero},
{Name: "div", Kind: apRelative, Value: &utils.FeltMax128},
{Name: "bound", Kind: apRelative, Value: &utils.Felt127},
{Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"])
},
errCheck: errorTextContains(fmt.Sprintf("div=%v is out of the valid range.", &utils.FeltMax128)),
},
{
operanders: []*hintOperander{
{Name: "value", Kind: apRelative, Value: &utils.FeltZero},
{Name: "div", Kind: apRelative, Value: &utils.FeltZero},
{Name: "bound", Kind: apRelative, Value: &utils.Felt127},
{Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"])
},
errCheck: errorTextContains(fmt.Sprintf("div=%v is out of the valid range.", &utils.FeltZero)),
},
{
operanders: []*hintOperander{
{Name: "value", Kind: apRelative, Value: &utils.FeltZero},
{Name: "div", Kind: apRelative, Value: &utils.FeltOne},
{Name: "bound", Kind: apRelative, Value: new(fp.Element).SetBigInt(new(big.Int).Lsh(big.NewInt(1), 130))},
{Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"])
},
errCheck: errorTextContains(fmt.Sprintf("bound=%v is out of the valid range", new(fp.Element).SetBigInt(new(big.Int).Lsh(big.NewInt(1), 130)))),
},
{
operanders: []*hintOperander{
{Name: "value", Kind: apRelative, Value: feltInt64(-6)},
{Name: "div", Kind: apRelative, Value: feltInt64(2)},
{Name: "bound", Kind: apRelative, Value: feltInt64(2)},
{Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"])
},
errCheck: errorTextContains(fmt.Sprintf("%v / %v = %v is out of the range [-%v, %v]", feltInt64(-6), feltInt64(2), feltInt64(-3), feltInt64(2), feltInt64(2))),
},
{
operanders: []*hintOperander{
{Name: "value", Kind: apRelative, Value: feltInt64(6)},
{Name: "div", Kind: apRelative, Value: feltInt64(2)},
{Name: "bound", Kind: apRelative, Value: feltInt64(3)},
{Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"])
},
errCheck: errorTextContains(fmt.Sprintf("%v / %v = %v is out of the range [-%v, %v].", feltInt64(6), feltInt64(2), feltInt64(3), feltInt64(3), feltInt64(3))),
},
{
operanders: []*hintOperander{
{Name: "value", Kind: apRelative, Value: feltInt64(5)},
{Name: "div", Kind: apRelative, Value: feltInt64(2)},
{Name: "bound", Kind: apRelative, Value: &utils.Felt127},
{Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"])
},
check: allVarValueEquals(map[string]*fp.Element{
"r": &utils.FeltOne,
"biased_q": new(fp.Element).Add(feltInt64(2), &utils.Felt127),
}),
},
{
operanders: []*hintOperander{
{Name: "value", Kind: apRelative, Value: feltInt64(-3)},
{Name: "div", Kind: apRelative, Value: feltInt64(2)},
{Name: "bound", Kind: apRelative, Value: &utils.Felt127},
{Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)},
{Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"])
},
check: allVarValueEquals(map[string]*fp.Element{
"r": &utils.FeltOne,
"biased_q": new(fp.Element).Sub(&utils.Felt127, feltInt64(2)),
}),
},
},
"SqrtHint": {
{
operanders: []*hintOperander{
Expand Down
4 changes: 4 additions & 0 deletions pkg/utils/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ var FeltMax128 = fp.Element{18446744073700081665, 17407, 18446744073709551584, 5
// 2 ** 250
var FeltUpperBound = fp.Element{0xfffffff5cdf80011, 0x4cc3fff, 0xfffffffffffdbe00, 0x7ffff52ad780230}

// (PRIME // range_check_builtin.bound)
// 800000000000011000000000000000000000000000000000000000000000001 // 2**128
var PrimeHigh = fp.Element{1, 0, 18446744073709551615, 576460752303423504}

//
// Uint256 Constants
//
Expand Down

0 comments on commit 722b840

Please sign in to comment.