Skip to content

Commit

Permalink
Merge branch 'main' into Keccak_integration_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TAdev0 committed Jul 4, 2024
2 parents 4f280f1 + 6cab0ef commit db69fb7
Show file tree
Hide file tree
Showing 9 changed files with 1,128 additions and 31 deletions.
39 changes: 37 additions & 2 deletions pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,45 @@ func sign(n *big.Int) (int, big.Int) {

func SafeDiv(x, y *big.Int) (big.Int, error) {
if y.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), fmt.Errorf("Division by zero.")
return *big.NewInt(0), fmt.Errorf("division by zero")
}
if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 {
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v.", x, y)
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y)
}
return *new(big.Int).Div(x, y), nil
}

func IsQuadResidue(x *fp.Element) bool {
// Implementation adapted from sympy implementation which can be found here :
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689
// We have omitted the prime as it will be CAIRO_PRIME

return x.IsZero() || x.IsOne() || x.Legendre() == 1
}

func YSquaredFromX(x, beta, fieldPrime *big.Int) *big.Int {
// Computes y^2 using the curve equation:
// y^2 = x^3 + alpha * x + beta (mod field_prime)
// We ignore alpha as it is a constant with a value of 1

ySquaredBigInt := new(big.Int).Set(x)
ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Add(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Add(ySquaredBigInt, beta).Mod(ySquaredBigInt, fieldPrime)

return ySquaredBigInt
}

func Sqrt(x, p *big.Int) *big.Int {
// Finds the minimum non-negative integer m such that (m*m) % p == x.

halfPrimeBigInt := new(big.Int).Rsh(p, 1)
m := new(big.Int).ModSqrt(x, p)

if m.Cmp(halfPrimeBigInt) > 0 {
m.Sub(p, m)
}

return m
}
7 changes: 7 additions & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const (

// is_quad_residue() hint
isQuadResidueCode string = "from starkware.crypto.signature.signature import FIELD_PRIME\nfrom starkware.python.math_utils import div_mod, is_quad_residue, sqrt\n\nx = ids.x\nif is_quad_residue(x, FIELD_PRIME):\n ids.y = sqrt(x, FIELD_PRIME)\nelse:\n ids.y = sqrt(div_mod(x, 3, FIELD_PRIME), FIELD_PRIME)"

// ------ Uint256 hints related code ------
uint256AddCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0"
split64Code string = "ids.low = ids.a & ((1<<64) - 1)\nids.high = ids.a >> 64"
Expand Down Expand Up @@ -97,6 +98,7 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
isZeroNondetCode string = "memory[ap] = to_felt_or_relocatable(x == 0)"
isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P"
isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)"
recoverYCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import recover_y\nids.p.x = ids.x\n# This raises an exception if `x` is not on the curve.\nids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)"

// ------ Signature hints related code ------
verifyECDSASignatureCode string = "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))"
Expand All @@ -120,6 +122,11 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
blockPermutationCode string = "from starkware.cairo.common.keccak_utils.keccak_utils import keccak_func\n_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)\nassert 0 <= _keccak_state_size_felts < 100\n\noutput_values = keccak_func(memory.get_range(\n ids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts))\nsegments.write_arg(ids.keccak_ptr, output_values)"
compareBytesInWordCode string = "memory[ap] = to_felt_or_relocatable(ids.n_bytes < ids.BYTES_IN_WORD)"
compareKeccakFullRateInBytesCode string = "memory[ap] = to_felt_or_relocatable(ids.n_bytes >= ids.KECCAK_FULL_RATE_IN_BYTES)"
splitInput3Code string = "ids.high3, ids.low3 = divmod(memory[ids.inputs + 3], 256)"
splitInput6Code string = "ids.high6, ids.low6 = divmod(memory[ids.inputs + 6], 256 ** 2)"
splitInput9Code string = "ids.high9, ids.low9 = divmod(memory[ids.inputs + 9], 256 ** 3)"
splitInput12Code string = "ids.high12, ids.low12 = divmod(memory[ids.inputs + 12], 256 ** 4)"
splitInput15Code string = "ids.high15, ids.low15 = divmod(memory[ids.inputs + 15], 256 ** 5)"
splitOutputMidLowHighCode string = "tmp, ids.output1_low = divmod(ids.output1, 256 ** 7)\nids.output1_high, ids.output1_mid = divmod(tmp, 2 ** 128)"
SplitNBytesCode string = "ids.n_words_to_copy, ids.n_bytes_left = divmod(ids.n_bytes, ids.BYTES_IN_WORD)"

Expand Down
12 changes: 12 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createIsZeroPackHinter(resolver)
case isZeroDivModCode:
return createIsZeroDivModHinter()
case recoverYCode:
return createRecoverYHinter(resolver)
// Blake hints
case blake2sAddUint256BigendCode:
return createBlake2sAddUint256Hinter(resolver, true)
Expand All @@ -173,6 +175,16 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createBlockPermutationHinter(resolver)
case compareBytesInWordCode:
return createCompareBytesInWordNondetHinter(resolver)
case splitInput3Code:
return createSplitInput3Hinter(resolver)
case splitInput6Code:
return createSplitInput6Hinter(resolver)
case splitInput9Code:
return createSplitInput9Hinter(resolver)
case splitInput12Code:
return createSplitInput12Hinter(resolver)
case splitInput15Code:
return createSplitInput15Hinter(resolver)
case splitOutputMidLowHighCode:
return createSplitOutputMidLowHighHinter(resolver)
case SplitNBytesCode:
Expand Down
83 changes: 83 additions & 0 deletions pkg/hintrunner/zero/zerohint_ec.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,3 +862,86 @@ func newIsZeroDivModHint() hinter.Hinter {
func createIsZeroDivModHinter() (hinter.Hinter, error) {
return newIsZeroDivModHint(), nil
}

// RecoverY hint Recovers the y coordinate of a point on the elliptic curve
// y^2 = x^3 + alpha * x + beta (mod field_prime) of a given x coordinate.
//
// `newRecoverYHint` takes 2 operanders as arguments
// - `x` is the x coordinate of an elliptic curve point
// - `p` is one of the two EC points with the given x coordinate (x, y)
func newRecoverYHint(x, p hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "RecoverY",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME
//> from starkware.python.math_utils import recover_y
//> ids.p.x = ids.x
//> # This raises an exception if `x` is not on the curve.
//> ids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)

pXAddr, err := p.GetAddress(vm)
if err != nil {
return err
}

pYAddr, err := pXAddr.AddOffset(1)
if err != nil {
return err
}

xFelt, err := hinter.ResolveAsFelt(vm, x)
if err != nil {
return err
}

valueX := mem.MemoryValueFromFieldElement(xFelt)

err = vm.Memory.WriteToAddress(&pXAddr, &valueX)
if err != nil {
return err
}

const betaString = "3141592653589793238462643383279502884197169399375105820974944592307816406665"
betaBigInt, ok := new(big.Int).SetString(betaString, 10)
if !ok {
panic("failed to convert BETA string to big.Int")
}

const fieldPrimeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
fieldPrimeBigInt, ok := new(big.Int).SetString(fieldPrimeString, 10)
if !ok {
panic("failed to convert FIELD_PRIME string to big.Int")
}

xBigInt := new(big.Int)
xFelt.BigInt(xBigInt)

// y^2 = x^3 + alpha * x + beta (mod field_prime)
ySquaredBigInt := secp_utils.YSquaredFromX(xBigInt, betaBigInt, fieldPrimeBigInt)
ySquaredFelt := new(fp.Element).SetBigInt(ySquaredBigInt)

if secp_utils.IsQuadResidue(ySquaredFelt) {
result := new(fp.Element).SetBigInt(secp_utils.Sqrt(ySquaredBigInt, fieldPrimeBigInt))
value := mem.MemoryValueFromFieldElement(result)
return vm.Memory.WriteToAddress(&pYAddr, &value)
} else {
ySquaredString := ySquaredBigInt.String()
return fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquaredString)
}
},
}
}

func createRecoverYHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
x, err := resolver.GetResOperander("x")
if err != nil {
return nil, err
}

p, err := resolver.GetResOperander("p")
if err != nil {
return nil, err
}

return newRecoverYHint(x, p), nil
}
80 changes: 80 additions & 0 deletions pkg/hintrunner/zero/zerohint_ec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,86 @@ func TestZeroHintEc(t *testing.T) {
check: varValueInScopeEquals("value", bigIntString("4", 10)),
},
},
"RecoverY": {
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("2497468900767850684421727063357792717599762502387246235265616708902555305129")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("2497468900767850684421727063357792717599762502387246235265616708902555305129"),
"p.y": feltString("205857351767627712295703269674687767888261140702556021834663354704341414042"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("205857351767627712295703269674687767888261140702556021834663354704341414042")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020"),
"p.y": feltString("386236054595386575795345623791920124827519018828430310912260655089307618738"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("138597138396302485058562442936200017709939129389766076747102238692717075504")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("138597138396302485058562442936200017709939129389766076747102238692717075504"),
"p.y": feltString("1116947097676727397390632683964789044871379304271794004325353078455954290524"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("71635783675677659163985681365816684268526846280467284682674852685628658265882465826464572245")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("71635783675677659163985681365816684268526846280467284682674852685628658265882465826464572245"),
"p.y": feltString("903372048565605391120071143811887302063650776015287438589675702929494830362"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("42424242424242424242")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"),
},
},
},
)
}
Loading

0 comments on commit db69fb7

Please sign in to comment.