From a86cb4bafa028db736461ab83129fbedb8e0ca65 Mon Sep 17 00:00:00 2001 From: Shourya Goel Date: Wed, 3 Jul 2024 19:54:56 +0530 Subject: [PATCH] Implement `RecoverY` hint (#506) * bug fix * Implemented the hint * Added test for recoverY * Added more tests * Fix * Update hintcode.go * Update zerohint.go * nit * Refactored code * nit * Cleaned the code * Modified IsQuadResidue to use helper function * Update math_utils.go --- pkg/hintrunner/utils/math_utils.go | 39 +++++++++++- pkg/hintrunner/zero/hintcode.go | 1 + pkg/hintrunner/zero/zerohint.go | 2 + pkg/hintrunner/zero/zerohint_ec.go | 83 +++++++++++++++++++++++++ pkg/hintrunner/zero/zerohint_ec_test.go | 80 ++++++++++++++++++++++++ pkg/hintrunner/zero/zerohint_math.go | 42 +++++-------- pkg/vm/builtins/ecdsa.go | 2 +- 7 files changed, 218 insertions(+), 31 deletions(-) diff --git a/pkg/hintrunner/utils/math_utils.go b/pkg/hintrunner/utils/math_utils.go index 6a112587..9b789b6f 100644 --- a/pkg/hintrunner/utils/math_utils.go +++ b/pkg/hintrunner/utils/math_utils.go @@ -129,10 +129,45 @@ func sign(n *big.Int) (int, big.Int) { func SafeDiv(x, y *big.Int) (big.Int, error) { if y.Cmp(big.NewInt(0)) == 0 { - return *big.NewInt(0), fmt.Errorf("Division by zero.") + return *big.NewInt(0), fmt.Errorf("division by zero") } if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 { - return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v.", x, y) + return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y) } return *new(big.Int).Div(x, y), nil } + +func IsQuadResidue(x *fp.Element) bool { + // Implementation adapted from sympy implementation which can be found here : + // https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689 + // We have omitted the prime as it will be CAIRO_PRIME + + return x.IsZero() || x.IsOne() || x.Legendre() == 1 +} + +func YSquaredFromX(x, beta, fieldPrime *big.Int) *big.Int { + // Computes y^2 using the curve equation: + // y^2 = x^3 + alpha * x + beta (mod field_prime) + // We ignore alpha as it is a constant with a value of 1 + + ySquaredBigInt := new(big.Int).Set(x) + ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime) + ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime) + ySquaredBigInt.Add(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime) + ySquaredBigInt.Add(ySquaredBigInt, beta).Mod(ySquaredBigInt, fieldPrime) + + return ySquaredBigInt +} + +func Sqrt(x, p *big.Int) *big.Int { + // Finds the minimum non-negative integer m such that (m*m) % p == x. + + halfPrimeBigInt := new(big.Int).Rsh(p, 1) + m := new(big.Int).ModSqrt(x, p) + + if m.Cmp(halfPrimeBigInt) > 0 { + m.Sub(p, m) + } + + return m +} diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 34808f0c..c3d7e2c0 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -97,6 +97,7 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])` isZeroNondetCode string = "memory[ap] = to_felt_or_relocatable(x == 0)" isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P" isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)" + recoverYCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import recover_y\nids.p.x = ids.x\n# This raises an exception if `x` is not on the curve.\nids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)" // ------ Signature hints related code ------ verifyECDSASignatureCode string = "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))" diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index f4fdc214..51c08dd8 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -149,6 +149,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64 return createIsZeroPackHinter(resolver) case isZeroDivModCode: return createIsZeroDivModHinter() + case recoverYCode: + return createRecoverYHinter(resolver) // Blake hints case blake2sAddUint256BigendCode: return createBlake2sAddUint256Hinter(resolver, true) diff --git a/pkg/hintrunner/zero/zerohint_ec.go b/pkg/hintrunner/zero/zerohint_ec.go index 3424ccd3..02838a13 100644 --- a/pkg/hintrunner/zero/zerohint_ec.go +++ b/pkg/hintrunner/zero/zerohint_ec.go @@ -862,3 +862,86 @@ func newIsZeroDivModHint() hinter.Hinter { func createIsZeroDivModHinter() (hinter.Hinter, error) { return newIsZeroDivModHint(), nil } + +// RecoverY hint Recovers the y coordinate of a point on the elliptic curve +// y^2 = x^3 + alpha * x + beta (mod field_prime) of a given x coordinate. +// +// `newRecoverYHint` takes 2 operanders as arguments +// - `x` is the x coordinate of an elliptic curve point +// - `p` is one of the two EC points with the given x coordinate (x, y) +func newRecoverYHint(x, p hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "RecoverY", + Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error { + //> from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME + //> from starkware.python.math_utils import recover_y + //> ids.p.x = ids.x + //> # This raises an exception if `x` is not on the curve. + //> ids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME) + + pXAddr, err := p.GetAddress(vm) + if err != nil { + return err + } + + pYAddr, err := pXAddr.AddOffset(1) + if err != nil { + return err + } + + xFelt, err := hinter.ResolveAsFelt(vm, x) + if err != nil { + return err + } + + valueX := mem.MemoryValueFromFieldElement(xFelt) + + err = vm.Memory.WriteToAddress(&pXAddr, &valueX) + if err != nil { + return err + } + + const betaString = "3141592653589793238462643383279502884197169399375105820974944592307816406665" + betaBigInt, ok := new(big.Int).SetString(betaString, 10) + if !ok { + panic("failed to convert BETA string to big.Int") + } + + const fieldPrimeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481" + fieldPrimeBigInt, ok := new(big.Int).SetString(fieldPrimeString, 10) + if !ok { + panic("failed to convert FIELD_PRIME string to big.Int") + } + + xBigInt := new(big.Int) + xFelt.BigInt(xBigInt) + + // y^2 = x^3 + alpha * x + beta (mod field_prime) + ySquaredBigInt := secp_utils.YSquaredFromX(xBigInt, betaBigInt, fieldPrimeBigInt) + ySquaredFelt := new(fp.Element).SetBigInt(ySquaredBigInt) + + if secp_utils.IsQuadResidue(ySquaredFelt) { + result := new(fp.Element).SetBigInt(secp_utils.Sqrt(ySquaredBigInt, fieldPrimeBigInt)) + value := mem.MemoryValueFromFieldElement(result) + return vm.Memory.WriteToAddress(&pYAddr, &value) + } else { + ySquaredString := ySquaredBigInt.String() + return fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquaredString) + } + }, + } +} + +func createRecoverYHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + x, err := resolver.GetResOperander("x") + if err != nil { + return nil, err + } + + p, err := resolver.GetResOperander("p") + if err != nil { + return nil, err + } + + return newRecoverYHint(x, p), nil +} diff --git a/pkg/hintrunner/zero/zerohint_ec_test.go b/pkg/hintrunner/zero/zerohint_ec_test.go index 5e5abb76..07f366c8 100644 --- a/pkg/hintrunner/zero/zerohint_ec_test.go +++ b/pkg/hintrunner/zero/zerohint_ec_test.go @@ -974,6 +974,86 @@ func TestZeroHintEc(t *testing.T) { check: varValueInScopeEquals("value", bigIntString("4", 10)), }, }, + "RecoverY": { + { + operanders: []*hintOperander{ + {Name: "x", Kind: apRelative, Value: feltString("2497468900767850684421727063357792717599762502387246235265616708902555305129")}, + {Name: "p.x", Kind: uninitialized}, + {Name: "p.y", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"]) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "p.x": feltString("2497468900767850684421727063357792717599762502387246235265616708902555305129"), + "p.y": feltString("205857351767627712295703269674687767888261140702556021834663354704341414042"), + }), + }, + { + operanders: []*hintOperander{ + {Name: "x", Kind: apRelative, Value: feltString("205857351767627712295703269674687767888261140702556021834663354704341414042")}, + {Name: "p.x", Kind: uninitialized}, + {Name: "p.y", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"]) + }, + errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"), + }, + { + operanders: []*hintOperander{ + {Name: "x", Kind: apRelative, Value: feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020")}, + {Name: "p.x", Kind: uninitialized}, + {Name: "p.y", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"]) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "p.x": feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020"), + "p.y": feltString("386236054595386575795345623791920124827519018828430310912260655089307618738"), + }), + }, + { + operanders: []*hintOperander{ + {Name: "x", Kind: apRelative, Value: feltString("138597138396302485058562442936200017709939129389766076747102238692717075504")}, + {Name: "p.x", Kind: uninitialized}, + {Name: "p.y", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"]) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "p.x": feltString("138597138396302485058562442936200017709939129389766076747102238692717075504"), + "p.y": feltString("1116947097676727397390632683964789044871379304271794004325353078455954290524"), + }), + }, + { + operanders: []*hintOperander{ + {Name: "x", Kind: apRelative, Value: feltString("71635783675677659163985681365816684268526846280467284682674852685628658265882465826464572245")}, + {Name: "p.x", Kind: uninitialized}, + {Name: "p.y", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"]) + }, + check: allVarValueEquals(map[string]*fp.Element{ + "p.x": feltString("71635783675677659163985681365816684268526846280467284682674852685628658265882465826464572245"), + "p.y": feltString("903372048565605391120071143811887302063650776015287438589675702929494830362"), + }), + }, + { + operanders: []*hintOperander{ + {Name: "x", Kind: apRelative, Value: feltString("42424242424242424242")}, + {Name: "p.x", Kind: uninitialized}, + {Name: "p.y", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"]) + }, + errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"), + }, + }, }, ) } diff --git a/pkg/hintrunner/zero/zerohint_math.go b/pkg/hintrunner/zero/zerohint_math.go index 058edcff..3817411b 100644 --- a/pkg/hintrunner/zero/zerohint_math.go +++ b/pkg/hintrunner/zero/zerohint_math.go @@ -1152,41 +1152,27 @@ func newIsQuadResidueHint(x, y hinter.ResOperander) hinter.Hinter { xBigInt := math_utils.AsInt(x) var value = memory.MemoryValue{} + var result *fp.Element = new(fp.Element) - if x.IsZero() || x.IsOne() { - value = memory.MemoryValueFromFieldElement(x) + const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481" + primeBigInt, ok := new(big.Int).SetString(primeString, 10) + if !ok { + panic("failed to convert prime string to big.Int") + } + if math_utils.IsQuadResidue(x) { + result.SetBigInt(math_utils.Sqrt(&xBigInt, primeBigInt)) } else { - var result *fp.Element = new(fp.Element) - - if x.Legendre() == 1 { - // result = x.Sqrt(x) - - const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481" - primeBigInt, ok := new(big.Int).SetString(primeString, 10) - if !ok { - panic("failed to convert prime string to big.Int") - } - - // divide primeBigInt by 2 - halfPrimeBigInt := new(big.Int).Rsh(primeBigInt, 1) - - tempResult := new(big.Int).ModSqrt(&xBigInt, primeBigInt) - - // ensures that tempResult is the smaller of the two possible square roots in the prime field. - if tempResult.Cmp(halfPrimeBigInt) > 0 { - tempResult.Sub(primeBigInt, tempResult) - } - - result.SetBigInt(tempResult) - - } else { - result = x.Sqrt(new(fp.Element).Div(x, new(fp.Element).SetUint64(3))) + y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), primeBigInt) + if err != nil { + return err } - value = memory.MemoryValueFromFieldElement(result) + result.SetBigInt(math_utils.Sqrt(&y, primeBigInt)) } + value = memory.MemoryValueFromFieldElement(result) + return vm.Memory.WriteToAddress(&yAddr, &value) }, } diff --git a/pkg/vm/builtins/ecdsa.go b/pkg/vm/builtins/ecdsa.go index 8a1f6686..9b11e6f2 100644 --- a/pkg/vm/builtins/ecdsa.go +++ b/pkg/vm/builtins/ecdsa.go @@ -60,7 +60,7 @@ func (e *ECDSA) CheckWrite(segment *memory.Segment, offset uint64, value *memory pubKey := &ecdsa.PublicKey{A: key} sig, ok := e.signatures[pubOffset] if !ok { - return fmt.Errorf("signature is missing form ECDA builtin") + return fmt.Errorf("signature is missing from ECDSA builtin") } msgBytes := msgField.Bytes()