Skip to content

Commit

Permalink
Implement RecoverY hint (#506)
Browse files Browse the repository at this point in the history
* bug fix

* Implemented the hint

* Added test for recoverY

* Added more tests

* Fix

* Update hintcode.go

* Update zerohint.go

* nit

* Refactored code

* nit

* Cleaned the code

* Modified IsQuadResidue to use helper function

* Update math_utils.go
  • Loading branch information
Sh0g0-1758 committed Jul 3, 2024
1 parent 60deaa1 commit a86cb4b
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 31 deletions.
39 changes: 37 additions & 2 deletions pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,45 @@ func sign(n *big.Int) (int, big.Int) {

func SafeDiv(x, y *big.Int) (big.Int, error) {
if y.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), fmt.Errorf("Division by zero.")
return *big.NewInt(0), fmt.Errorf("division by zero")
}
if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 {
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v.", x, y)
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y)
}
return *new(big.Int).Div(x, y), nil
}

func IsQuadResidue(x *fp.Element) bool {
// Implementation adapted from sympy implementation which can be found here :
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689
// We have omitted the prime as it will be CAIRO_PRIME

return x.IsZero() || x.IsOne() || x.Legendre() == 1
}

func YSquaredFromX(x, beta, fieldPrime *big.Int) *big.Int {
// Computes y^2 using the curve equation:
// y^2 = x^3 + alpha * x + beta (mod field_prime)
// We ignore alpha as it is a constant with a value of 1

ySquaredBigInt := new(big.Int).Set(x)
ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Add(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Add(ySquaredBigInt, beta).Mod(ySquaredBigInt, fieldPrime)

return ySquaredBigInt
}

func Sqrt(x, p *big.Int) *big.Int {
// Finds the minimum non-negative integer m such that (m*m) % p == x.

halfPrimeBigInt := new(big.Int).Rsh(p, 1)
m := new(big.Int).ModSqrt(x, p)

if m.Cmp(halfPrimeBigInt) > 0 {
m.Sub(p, m)
}

return m
}
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
isZeroNondetCode string = "memory[ap] = to_felt_or_relocatable(x == 0)"
isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P"
isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)"
recoverYCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import recover_y\nids.p.x = ids.x\n# This raises an exception if `x` is not on the curve.\nids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)"

// ------ Signature hints related code ------
verifyECDSASignatureCode string = "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createIsZeroPackHinter(resolver)
case isZeroDivModCode:
return createIsZeroDivModHinter()
case recoverYCode:
return createRecoverYHinter(resolver)
// Blake hints
case blake2sAddUint256BigendCode:
return createBlake2sAddUint256Hinter(resolver, true)
Expand Down
83 changes: 83 additions & 0 deletions pkg/hintrunner/zero/zerohint_ec.go
Original file line number Diff line number Diff line change
Expand Up @@ -862,3 +862,86 @@ func newIsZeroDivModHint() hinter.Hinter {
func createIsZeroDivModHinter() (hinter.Hinter, error) {
return newIsZeroDivModHint(), nil
}

// RecoverY hint Recovers the y coordinate of a point on the elliptic curve
// y^2 = x^3 + alpha * x + beta (mod field_prime) of a given x coordinate.
//
// `newRecoverYHint` takes 2 operanders as arguments
// - `x` is the x coordinate of an elliptic curve point
// - `p` is one of the two EC points with the given x coordinate (x, y)
func newRecoverYHint(x, p hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "RecoverY",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME
//> from starkware.python.math_utils import recover_y
//> ids.p.x = ids.x
//> # This raises an exception if `x` is not on the curve.
//> ids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)

pXAddr, err := p.GetAddress(vm)
if err != nil {
return err
}

pYAddr, err := pXAddr.AddOffset(1)
if err != nil {
return err
}

xFelt, err := hinter.ResolveAsFelt(vm, x)
if err != nil {
return err
}

valueX := mem.MemoryValueFromFieldElement(xFelt)

err = vm.Memory.WriteToAddress(&pXAddr, &valueX)
if err != nil {
return err
}

const betaString = "3141592653589793238462643383279502884197169399375105820974944592307816406665"
betaBigInt, ok := new(big.Int).SetString(betaString, 10)
if !ok {
panic("failed to convert BETA string to big.Int")
}

const fieldPrimeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
fieldPrimeBigInt, ok := new(big.Int).SetString(fieldPrimeString, 10)
if !ok {
panic("failed to convert FIELD_PRIME string to big.Int")
}

xBigInt := new(big.Int)
xFelt.BigInt(xBigInt)

// y^2 = x^3 + alpha * x + beta (mod field_prime)
ySquaredBigInt := secp_utils.YSquaredFromX(xBigInt, betaBigInt, fieldPrimeBigInt)
ySquaredFelt := new(fp.Element).SetBigInt(ySquaredBigInt)

if secp_utils.IsQuadResidue(ySquaredFelt) {
result := new(fp.Element).SetBigInt(secp_utils.Sqrt(ySquaredBigInt, fieldPrimeBigInt))
value := mem.MemoryValueFromFieldElement(result)
return vm.Memory.WriteToAddress(&pYAddr, &value)
} else {
ySquaredString := ySquaredBigInt.String()
return fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquaredString)
}
},
}
}

func createRecoverYHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
x, err := resolver.GetResOperander("x")
if err != nil {
return nil, err
}

p, err := resolver.GetResOperander("p")
if err != nil {
return nil, err
}

return newRecoverYHint(x, p), nil
}
80 changes: 80 additions & 0 deletions pkg/hintrunner/zero/zerohint_ec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,86 @@ func TestZeroHintEc(t *testing.T) {
check: varValueInScopeEquals("value", bigIntString("4", 10)),
},
},
"RecoverY": {
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("2497468900767850684421727063357792717599762502387246235265616708902555305129")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("2497468900767850684421727063357792717599762502387246235265616708902555305129"),
"p.y": feltString("205857351767627712295703269674687767888261140702556021834663354704341414042"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("205857351767627712295703269674687767888261140702556021834663354704341414042")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020"),
"p.y": feltString("386236054595386575795345623791920124827519018828430310912260655089307618738"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("138597138396302485058562442936200017709939129389766076747102238692717075504")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("138597138396302485058562442936200017709939129389766076747102238692717075504"),
"p.y": feltString("1116947097676727397390632683964789044871379304271794004325353078455954290524"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("71635783675677659163985681365816684268526846280467284682674852685628658265882465826464572245")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
check: allVarValueEquals(map[string]*fp.Element{
"p.x": feltString("71635783675677659163985681365816684268526846280467284682674852685628658265882465826464572245"),
"p.y": feltString("903372048565605391120071143811887302063650776015287438589675702929494830362"),
}),
},
{
operanders: []*hintOperander{
{Name: "x", Kind: apRelative, Value: feltString("42424242424242424242")},
{Name: "p.x", Kind: uninitialized},
{Name: "p.y", Kind: uninitialized},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newRecoverYHint(ctx.operanders["x"], ctx.operanders["p.x"])
},
errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"),
},
},
},
)
}
42 changes: 14 additions & 28 deletions pkg/hintrunner/zero/zerohint_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -1152,41 +1152,27 @@ func newIsQuadResidueHint(x, y hinter.ResOperander) hinter.Hinter {
xBigInt := math_utils.AsInt(x)

var value = memory.MemoryValue{}
var result *fp.Element = new(fp.Element)

if x.IsZero() || x.IsOne() {
value = memory.MemoryValueFromFieldElement(x)
const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
primeBigInt, ok := new(big.Int).SetString(primeString, 10)
if !ok {
panic("failed to convert prime string to big.Int")
}

if math_utils.IsQuadResidue(x) {
result.SetBigInt(math_utils.Sqrt(&xBigInt, primeBigInt))
} else {
var result *fp.Element = new(fp.Element)

if x.Legendre() == 1 {
// result = x.Sqrt(x)

const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481"
primeBigInt, ok := new(big.Int).SetString(primeString, 10)
if !ok {
panic("failed to convert prime string to big.Int")
}

// divide primeBigInt by 2
halfPrimeBigInt := new(big.Int).Rsh(primeBigInt, 1)

tempResult := new(big.Int).ModSqrt(&xBigInt, primeBigInt)

// ensures that tempResult is the smaller of the two possible square roots in the prime field.
if tempResult.Cmp(halfPrimeBigInt) > 0 {
tempResult.Sub(primeBigInt, tempResult)
}

result.SetBigInt(tempResult)

} else {
result = x.Sqrt(new(fp.Element).Div(x, new(fp.Element).SetUint64(3)))
y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), primeBigInt)
if err != nil {
return err
}

value = memory.MemoryValueFromFieldElement(result)
result.SetBigInt(math_utils.Sqrt(&y, primeBigInt))
}

value = memory.MemoryValueFromFieldElement(result)

return vm.Memory.WriteToAddress(&yAddr, &value)
},
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/vm/builtins/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (e *ECDSA) CheckWrite(segment *memory.Segment, offset uint64, value *memory
pubKey := &ecdsa.PublicKey{A: key}
sig, ok := e.signatures[pubOffset]
if !ok {
return fmt.Errorf("signature is missing form ECDA builtin")
return fmt.Errorf("signature is missing from ECDSA builtin")
}

msgBytes := msgField.Bytes()
Expand Down

0 comments on commit a86cb4b

Please sign in to comment.