From 722b84071ba7f743b30ccbb983f2456e47ce181e Mon Sep 17 00:00:00 2001 From: MaksymMalicki <81577596+MaksymMalicki@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:27:52 +0200 Subject: [PATCH] Implement SignedDivRemHint (#318) * Implement SignedDivRemHint * Implement hint method * Debug implementation, add testcases * Debug implementation * Fix merge problems * Resolved comments from the PR * Replace PrimeHigh calculation with const, modify the bound check condition for value * Add as_int method * Added test case for the as_int() functionality * Applied suggestions from the comments * Move AsInt to math_utils, make AsInt return BigInt pointer --- pkg/hintrunner/utils/math_utils.go | 19 ++++ pkg/hintrunner/zero/hintcode.go | 1 + pkg/hintrunner/zero/zerohint.go | 2 + pkg/hintrunner/zero/zerohint_math.go | 103 +++++++++++++++++++++- pkg/hintrunner/zero/zerohint_math_test.go | 103 +++++++++++++++++++++- pkg/utils/constant.go | 4 + 6 files changed, 229 insertions(+), 3 deletions(-) create mode 100644 pkg/hintrunner/utils/math_utils.go diff --git a/pkg/hintrunner/utils/math_utils.go b/pkg/hintrunner/utils/math_utils.go new file mode 100644 index 000000000..8b2f9ad62 --- /dev/null +++ b/pkg/hintrunner/utils/math_utils.go @@ -0,0 +1,19 @@ +package utils + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" +) + +func AsInt(valueFelt *fp.Element) *big.Int { + var valueBig big.Int + valueFelt.BigInt(&valueBig) + boundBig := new(big.Int).Div(fp.Modulus(), big.NewInt(2)) + + // val if val < prime // 2 else val - prime + if valueBig.Cmp(boundBig) == -1 { + return &valueBig + } + return new(big.Int).Sub(&valueBig, fp.Modulus()) +} diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 799da1a14..e31f8920d 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -33,6 +33,7 @@ const ( // split_int() hints. splitIntAssertRange string = "assert ids.value == 0, 'split_int(): value is out of range.'" splitIntCode string = "memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base\nassert res < ids.bound, f'split_int(): Limb {res} is out of range.'" + signedDivRemCode string = "from starkware.cairo.common.math_utils import as_int, assert_integer\nassert_integer(ids.div)\nassert 0 < ids.div <= PRIME // range_check_builtin.bound, f'div={hex(ids.div)} is out of the valid range.'\nassert_integer(ids.bound)\nassert ids.bound <= range_check_builtin.bound // 2, f'bound={hex(ids.bound)} is out of the valid range.'\nint_value = as_int(ids.value, PRIME)\nq, ids.r = divmod(int_value, ids.div)\nassert -ids.bound <= q < ids.bound, f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).'\nids.biased_q = q + ids.bound" // pow hints powCode string = "ids.locs.bit = (ids.prev_locs.exp % PRIME) & 1" diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index ea022596d..80c76fe20 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -82,6 +82,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64 return createSplitIntAssertRangeHinter(resolver) case splitIntCode: return createSplitIntHinter(resolver) + case signedDivRemCode: + return createSignedDivRemHinter(resolver) case powCode: return createPowHinter(resolver) case splitFeltCode: diff --git a/pkg/hintrunner/zero/zerohint_math.go b/pkg/hintrunner/zero/zerohint_math.go index bfc7a7661..e09d58691 100644 --- a/pkg/hintrunner/zero/zerohint_math.go +++ b/pkg/hintrunner/zero/zerohint_math.go @@ -8,6 +8,7 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" + math_utils "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/utils" "github.com/NethermindEth/cairo-vm-go/pkg/utils" VM "github.com/NethermindEth/cairo-vm-go/pkg/vm" "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" @@ -643,6 +644,105 @@ func createSplitFeltHinter(resolver hintReferenceResolver) (hinter.Hinter, error return newSplitFeltHint(low, high, value), nil } +func newSignedDivRemHint(value, div, bound, r, biased_q hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "SignedDivRem", + Op: func(vm *VM.VirtualMachine, _ *hinter.HintRunnerContext) error { + //> from starkware.cairo.common.math_utils import as_int, assert_integer + //> assert_integer(ids.div) + //> assert 0 < ids.div <= PRIME // range_check_builtin.bound, f'div={hex(ids.div)} is out of the valid range.' + //> assert_integer(ids.bound) + //> assert ids.bound <= range_check_builtin.bound // 2, f'bound={hex(ids.bound)} is out of the valid range.' + //> int_value = as_int(ids.value, PRIME) + //> q, ids.r = divmod(int_value, ids.div) + //> assert -ids.bound <= q < ids.bound, f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).' + //> ids.biased_q = q + ids.bound + + //> assert_integer(ids.div) + divFelt, err := hinter.ResolveAsFelt(vm, div) + if err != nil { + return err + } + //> assert 0 < ids.div <= PRIME // range_check_builtin.bound, f'div={hex(ids.div)} is out of the valid range.' + if divFelt.IsZero() || !utils.FeltLe(divFelt, &utils.PrimeHigh) { + return fmt.Errorf("div=%v is out of the valid range.", divFelt) + } + + //> assert_integer(ids.bound) + boundFelt, err := hinter.ResolveAsFelt(vm, bound) + if err != nil { + return err + } + //> assert ids.bound <= range_check_builtin.bound // 2, f'bound={hex(ids.bound)} is out of the valid range.' + if !utils.FeltLe(boundFelt, &utils.Felt127) { + return fmt.Errorf("bound=%v is out of the valid range.", boundFelt) + } + //> int_value = as_int(ids.value, PRIME) + valueFelt, err := hinter.ResolveAsFelt(vm, value) + if err != nil { + return err + } + intValueBig := math_utils.AsInt(valueFelt) + + //> q, ids.r = divmod(int_value, ids.div) + var divBig, boundBig big.Int + divFelt.BigInt(&divBig) + boundFelt.BigInt(&boundBig) + qBig, rBig := new(big.Int).DivMod(intValueBig, &divBig, new(big.Int)) + rFelt := new(fp.Element).SetBigInt(rBig) + rAddr, err := r.GetAddress(vm) + if err != nil { + return err + } + rValue := memory.MemoryValueFromFieldElement(rFelt) + err = vm.Memory.WriteToAddress(&rAddr, &rValue) + if err != nil { + return err + } + + //> assert -ids.bound <= q < ids.bound, f'{int_value} / {ids.div} = {q} is out of the range [{-ids.bound}, {ids.bound}).' + if !(qBig.Cmp(new(big.Int).Neg(&boundBig)) >= 0 && qBig.Cmp(&boundBig) == -1) { + return fmt.Errorf("%v / %v = %v is out of the range [-%v, %v].", valueFelt, divFelt, qBig, boundFelt, boundFelt) + } + + //> ids.biased_q = q + ids.bound + biasedQBig := new(big.Int).Add(qBig, &boundBig) + biasedQ := new(fp.Element).SetBigInt(biasedQBig) + biasedQAddr, err := biased_q.GetAddress(vm) + if err != nil { + return err + } + biasedQValue := memory.MemoryValueFromFieldElement(biasedQ) + return vm.Memory.WriteToAddress(&biasedQAddr, &biasedQValue) + }, + } +} + +func createSignedDivRemHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + value, err := resolver.GetResOperander("value") + if err != nil { + return nil, err + } + div, err := resolver.GetResOperander("div") + if err != nil { + return nil, err + } + bound, err := resolver.GetResOperander("bound") + if err != nil { + return nil, err + } + r, err := resolver.GetResOperander("r") + if err != nil { + return nil, err + } + biased_q, err := resolver.GetResOperander("biased_q") + if err != nil { + return nil, err + } + return newSignedDivRemHint(value, div, bound, r, biased_q), nil + +} + func newSqrtHint(root, value hinter.ResOperander) hinter.Hinter { return &GenericZeroHinter{ Name: "Sqrt", @@ -719,11 +819,10 @@ func newUnsignedDivRemHinter(value, div, q, r hinter.ResOperander) hinter.Hinter if err != nil { return err } - // (PRIME // range_check_builtin.bound) // 800000000000011000000000000000000000000000000000000000000000001 // 2**128 var divUpperBound big.Int - divUpperBound.SetString("8000000000000110000000000000000", 16) + utils.PrimeHigh.BigInt(&divUpperBound) var divBig big.Int div.BigInt(&divBig) diff --git a/pkg/hintrunner/zero/zerohint_math_test.go b/pkg/hintrunner/zero/zerohint_math_test.go index 317f9879d..5334f7ad4 100644 --- a/pkg/hintrunner/zero/zerohint_math_test.go +++ b/pkg/hintrunner/zero/zerohint_math_test.go @@ -1,6 +1,8 @@ package zero import ( + "fmt" + "math/big" "testing" "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter" @@ -10,6 +12,7 @@ import ( ) func TestZeroHintMath(t *testing.T) { + runHinterTests(t, map[string][]hintTestCase{ "IsLeFelt": { { @@ -619,7 +622,105 @@ func TestZeroHintMath(t *testing.T) { }), }, }, - + "SignedDivRem": { + { + operanders: []*hintOperander{ + {Name: "value", Kind: apRelative, Value: &utils.FeltZero}, + {Name: "div", Kind: apRelative, Value: &utils.FeltMax128}, + {Name: "bound", Kind: apRelative, Value: &utils.Felt127}, + {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) + }, + errCheck: errorTextContains(fmt.Sprintf("div=%v is out of the valid range.", &utils.FeltMax128)), + }, + { + operanders: []*hintOperander{ + {Name: "value", Kind: apRelative, Value: &utils.FeltZero}, + {Name: "div", Kind: apRelative, Value: &utils.FeltZero}, + {Name: "bound", Kind: apRelative, Value: &utils.Felt127}, + {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) + }, + errCheck: errorTextContains(fmt.Sprintf("div=%v is out of the valid range.", &utils.FeltZero)), + }, + { + operanders: []*hintOperander{ + {Name: "value", Kind: apRelative, Value: &utils.FeltZero}, + {Name: "div", Kind: apRelative, Value: &utils.FeltOne}, + {Name: "bound", Kind: apRelative, Value: new(fp.Element).SetBigInt(new(big.Int).Lsh(big.NewInt(1), 130))}, + {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) + }, + errCheck: errorTextContains(fmt.Sprintf("bound=%v is out of the valid range", new(fp.Element).SetBigInt(new(big.Int).Lsh(big.NewInt(1), 130)))), + }, + { + operanders: []*hintOperander{ + {Name: "value", Kind: apRelative, Value: feltInt64(-6)}, + {Name: "div", Kind: apRelative, Value: feltInt64(2)}, + {Name: "bound", Kind: apRelative, Value: feltInt64(2)}, + {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) + }, + errCheck: errorTextContains(fmt.Sprintf("%v / %v = %v is out of the range [-%v, %v]", feltInt64(-6), feltInt64(2), feltInt64(-3), feltInt64(2), feltInt64(2))), + }, + { + operanders: []*hintOperander{ + {Name: "value", Kind: apRelative, Value: feltInt64(6)}, + {Name: "div", Kind: apRelative, Value: feltInt64(2)}, + {Name: "bound", Kind: apRelative, Value: feltInt64(3)}, + {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) + }, + errCheck: errorTextContains(fmt.Sprintf("%v / %v = %v is out of the range [-%v, %v].", feltInt64(6), feltInt64(2), feltInt64(3), feltInt64(3), feltInt64(3))), + }, + { + operanders: []*hintOperander{ + {Name: "value", Kind: apRelative, Value: feltInt64(5)}, + {Name: "div", Kind: apRelative, Value: feltInt64(2)}, + {Name: "bound", Kind: apRelative, Value: &utils.Felt127}, + {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "r": &utils.FeltOne, + "biased_q": new(fp.Element).Add(feltInt64(2), &utils.Felt127), + }), + }, + { + operanders: []*hintOperander{ + {Name: "value", Kind: apRelative, Value: feltInt64(-3)}, + {Name: "div", Kind: apRelative, Value: feltInt64(2)}, + {Name: "bound", Kind: apRelative, Value: &utils.Felt127}, + {Name: "r", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 0)}, + {Name: "biased_q", Kind: reference, Value: addrBuiltin(starknet.RangeCheck, 1)}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newSignedDivRemHint(ctx.operanders["value"], ctx.operanders["div"], ctx.operanders["bound"], ctx.operanders["r"], ctx.operanders["biased_q"]) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "r": &utils.FeltOne, + "biased_q": new(fp.Element).Sub(&utils.Felt127, feltInt64(2)), + }), + }, + }, "SqrtHint": { { operanders: []*hintOperander{ diff --git a/pkg/utils/constant.go b/pkg/utils/constant.go index 0202a2969..994957e56 100644 --- a/pkg/utils/constant.go +++ b/pkg/utils/constant.go @@ -26,6 +26,10 @@ var FeltMax128 = fp.Element{18446744073700081665, 17407, 18446744073709551584, 5 // 2 ** 250 var FeltUpperBound = fp.Element{0xfffffff5cdf80011, 0x4cc3fff, 0xfffffffffffdbe00, 0x7ffff52ad780230} +// (PRIME // range_check_builtin.bound) +// 800000000000011000000000000000000000000000000000000000000000001 // 2**128 +var PrimeHigh = fp.Element{1, 0, 18446744073709551615, 576460752303423504} + // // Uint256 Constants //