From 08d12e6fd6f66e439e8d7f8147eb89c05f3d4468 Mon Sep 17 00:00:00 2001 From: Tristan <122918260+TAdev0@users.noreply.github.com> Date: Thu, 11 Jul 2024 15:46:40 +0200 Subject: [PATCH 1/2] feat: allow exporting memory file and trace file in non proof mode (#476) * allow exporting memory file and trace file in non proof mode * fix * fmt * fix2 * add comment to RelocateTrace public function * add collectMemory and collectTrace flags * rename collectMemory to BuildMemory * fmt * fmt --------- Co-authored-by: Shourya Goel --- cmd/cli/main.go | 34 ++++++++++++++++++++++++++++++++-- pkg/runners/zero/zero.go | 35 ++++++++++++++++++----------------- pkg/vm/vm.go | 25 +++++++++++-------------- 3 files changed, 61 insertions(+), 33 deletions(-) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 33ddc4fa..9b65956a 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -13,6 +13,8 @@ import ( func main() { var proofmode bool + var buildMemory bool + var collectTrace bool var maxsteps uint64 var entrypointOffset uint64 var traceLocation string @@ -49,12 +51,24 @@ func main() { Value: 0, Destination: &entrypointOffset, }, + &cli.BoolFlag{ + Name: "collect_trace", + Usage: "collects the trace and builds the relocated trace after execution", + Required: false, + Destination: &collectTrace, + }, &cli.StringFlag{ Name: "tracefile", Usage: "location to store the relocated trace", Required: false, Destination: &traceLocation, }, + &cli.BoolFlag{ + Name: "build_memory", + Usage: "builds the relocated memory after execution", + Required: false, + Destination: &buildMemory, + }, &cli.StringFlag{ Name: "memoryfile", Usage: "location to store the relocated memory", @@ -82,10 +96,12 @@ func main() { if err != nil { return fmt.Errorf("cannot load program: %w", err) } + cairoZeroJson, err := zero.ZeroProgramFromJSON(content) if err != nil { return fmt.Errorf("cannot load program: %w", err) } + program, err := runnerzero.LoadCairoZeroProgram(cairoZeroJson) if err != nil { return fmt.Errorf("cannot load program: %w", err) @@ -95,6 +111,7 @@ func main() { if err != nil { return fmt.Errorf("cannot create hints: %w", err) } + fmt.Println("Running....") runner, err := runnerzero.NewRunner(program, hints, proofmode, maxsteps, layoutName) if err != nil { @@ -117,18 +134,31 @@ func main() { if proofmode { runner.EndRun() + if err := runner.FinalizeSegments(); err != nil { return fmt.Errorf("cannot finalize segments: %w", err) } - trace, memory, err := runner.BuildProof() + } + + if proofmode || collectTrace { + trace, err := runner.BuildTrace() if err != nil { - return fmt.Errorf("cannot build proof: %w", err) + return fmt.Errorf("cannot build trace: %w", err) } + if traceLocation != "" { if err := os.WriteFile(traceLocation, trace, 0644); err != nil { return fmt.Errorf("cannot write relocated trace: %w", err) } } + } + + if proofmode || buildMemory { + memory, err := runner.BuildMemory() + if err != nil { + return fmt.Errorf("cannot build memory: %w", err) + } + if memoryLocation != "" { if err := os.WriteFile(memoryLocation, memory, 0644); err != nil { return fmt.Errorf("cannot write relocated memory: %w", err) diff --git a/pkg/runners/zero/zero.go b/pkg/runners/zero/zero.go index 909761db..0e505ba0 100644 --- a/pkg/runners/zero/zero.go +++ b/pkg/runners/zero/zero.go @@ -12,7 +12,6 @@ import ( "github.com/NethermindEth/cairo-vm-go/pkg/vm/builtins" mem "github.com/NethermindEth/cairo-vm-go/pkg/vm/memory" "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" - f "github.com/consensys/gnark-crypto/ecc/stark-curve/fp" ) type ZeroRunner struct { @@ -21,8 +20,9 @@ type ZeroRunner struct { vm *vm.VirtualMachine hintrunner hintrunner.HintRunner // config - proofmode bool - maxsteps uint64 + proofmode bool + collectTrace bool + maxsteps uint64 // auxiliar runFinished bool layout builtins.Layout @@ -162,7 +162,7 @@ func (runner *ZeroRunner) InitializeMainEntrypoint() (mem.MemoryAddress, error) } func (runner *ZeroRunner) initializeEntrypoint( - initialPCOffset uint64, arguments []*f.Element, returnFp *mem.MemoryValue, memory *mem.Memory, + initialPCOffset uint64, arguments []*fp.Element, returnFp *mem.MemoryValue, memory *mem.Memory, ) (mem.MemoryAddress, error) { stack, err := runner.initializeBuiltins(memory) if err != nil { @@ -223,7 +223,7 @@ func (runner *ZeroRunner) initializeVm( Pc: *initialPC, Ap: offset + uint64(len(stack)), Fp: offset + uint64(len(stack)), - }, memory, vm.VirtualMachineConfig{ProofMode: runner.proofmode}) + }, memory, vm.VirtualMachineConfig{ProofMode: runner.proofmode, CollectTrace: runner.collectTrace}) return err } @@ -273,12 +273,10 @@ func (runner *ZeroRunner) RunFor(steps uint64) error { // Since this vm always finishes the run of the program at the number of steps that is a power of two in the proof mode, // there is no need to run additional steps before the loop. func (runner *ZeroRunner) EndRun() { - if runner.proofmode { - for runner.checkUsedCells() != nil { - pow2Steps := utils.NextPowerOfTwo(runner.vm.Step + 1) - if err := runner.RunFor(pow2Steps); err != nil { - panic(err) - } + for runner.checkUsedCells() != nil { + pow2Steps := utils.NextPowerOfTwo(runner.vm.Step + 1) + if err := runner.RunFor(pow2Steps); err != nil { + panic(err) } } } @@ -371,13 +369,16 @@ func (runner *ZeroRunner) FinalizeSegments() error { return nil } -func (runner *ZeroRunner) BuildProof() ([]byte, []byte, error) { - relocatedTrace, err := runner.vm.ExecutionTrace() - if err != nil { - return nil, nil, err - } +// BuildMemory relocates the memory and returns it +func (runner *ZeroRunner) BuildMemory() ([]byte, error) { + relocatedMemory := runner.vm.RelocateMemory() + return vm.EncodeMemory(relocatedMemory), nil +} - return vm.EncodeTrace(relocatedTrace), vm.EncodeMemory(runner.vm.RelocateMemory()), nil +// BuildMemory relocates the trace and returns it +func (runner *ZeroRunner) BuildTrace() ([]byte, error) { + relocatedTrace := runner.vm.RelocateTrace() + return vm.EncodeTrace(relocatedTrace), nil } func (runner *ZeroRunner) pc() mem.MemoryAddress { diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 223220d2..1436f7a4 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -71,8 +71,11 @@ type Trace struct { // This type represents the current execution context of the vm type VirtualMachineConfig struct { - // If true, the vm outputs the trace and the relocated memory at the end of execution + // If true, the vm outputs the trace and the relocated memory at the end of execution and finalize segments + // in order for the prover to create a proof ProofMode bool + // If true, the vm collects the relocated trace at the end of execution, without finalizing segments + CollectTrace bool } type VirtualMachine struct { @@ -91,9 +94,10 @@ type VirtualMachine struct { func NewVirtualMachine( initialContext Context, memory *mem.Memory, config VirtualMachineConfig, ) (*VirtualMachine, error) { + // Initialize the trace if necesary var trace []Context - if config.ProofMode { + if config.ProofMode || config.CollectTrace { trace = make([]Context, 0) } @@ -136,7 +140,7 @@ func (vm *VirtualMachine) RunStep(hintRunner HintRunner) error { } // store the trace before state change - if vm.config.ProofMode { + if vm.config.ProofMode || vm.config.CollectTrace { vm.Trace = append(vm.Trace, vm.Context) } @@ -216,15 +220,6 @@ func (vm *VirtualMachine) RunInstruction(instruction *a.Instruction) error { return nil } -// It returns the current trace entry, the public memory, and the occurrence of an error -func (vm *VirtualMachine) ExecutionTrace() ([]Trace, error) { - if !vm.config.ProofMode { - return nil, fmt.Errorf("proof mode is off") - } - - return vm.relocateTrace(), nil -} - func (vm *VirtualMachine) getDstAddr(instruction *a.Instruction) (mem.MemoryAddress, error) { var dstRegister uint64 if instruction.DstRegister == a.Ap { @@ -537,8 +532,10 @@ func (vm *VirtualMachine) updateFp(instruction *a.Instruction, dstAddr *mem.Memo } } -func (vm *VirtualMachine) relocateTrace() []Trace { - // one is added, because prover expect that the first element to be on +// It returns the trace after relocation, i.e, relocates pc, ap and fp for each step +// to be their real address value +func (vm *VirtualMachine) RelocateTrace() []Trace { + // one is added, because prover expect that the first element to be // indexed on 1 instead of 0 relocatedTrace := make([]Trace, len(vm.Trace)) totalBytecode := vm.Memory.Segments[ProgramSegment].Len() + 1 From 050d0e02ee170d91b36a43b5da168133211c3ff4 Mon Sep 17 00:00:00 2001 From: Harikrishnan Shaji Date: Thu, 11 Jul 2024 22:27:26 +0530 Subject: [PATCH 2/2] Implement RandomEcPoint hint (#513) * Implement RandomEcPoint * Add unit test * Update test assert * Add comment * Some test changes * Update pkg/hintrunner/zero/zerohint_ec.go Co-authored-by: Tristan <122918260+TAdev0@users.noreply.github.com> --------- Co-authored-by: Tristan <122918260+TAdev0@users.noreply.github.com> --- integration_tests/BenchMarks.txt | 96 +++++++------- integration_tests/cairozero_test.go | 3 - pkg/hintrunner/utils/math_utils.go | 16 ++- pkg/hintrunner/zero/hintcode.go | 1 + pkg/hintrunner/zero/zerohint.go | 2 + pkg/hintrunner/zero/zerohint_ec.go | 165 +++++++++++++++++++++--- pkg/hintrunner/zero/zerohint_ec_test.go | 46 +++++++ pkg/hintrunner/zero/zerohint_math.go | 11 +- pkg/vm/memory/memory_value.go | 18 +++ 9 files changed, 283 insertions(+), 75 deletions(-) diff --git a/integration_tests/BenchMarks.txt b/integration_tests/BenchMarks.txt index ef4db204..87da4f03 100644 --- a/integration_tests/BenchMarks.txt +++ b/integration_tests/BenchMarks.txt @@ -1,97 +1,99 @@ =========================================================================================================================== | File | PythonVM (ms) | GoVM (ms) | =========================================================================================================================== -| cmp.small.cairo | 902 | 106 | +| is_quad_residue.small.cairo | 815 | 122 | --------------------------------------------------------------------------------------------------------------------------- -| dict_squash.small.cairo | 783 | 107 | +| memset.cairo | 723 | 107 | --------------------------------------------------------------------------------------------------------------------------- -| ec.small.cairo | 4479 | 158 | +| pow.small.cairo | 810 | 109 | --------------------------------------------------------------------------------------------------------------------------- -| import_secp256R1P.small.cairo | 610 | 103 | +| search_sorted_lower.small.cairo | 849 | 107 | --------------------------------------------------------------------------------------------------------------------------- -| is_positive.small.cairo | 733 | 114 | +| bitwise_builtin_test.starknet_with_keccak.cairo| 1312 | 108 | --------------------------------------------------------------------------------------------------------------------------- -| signed_div_rem.small.cairo | 701 | 105 | +| dict.cairo | 840 | 106 | --------------------------------------------------------------------------------------------------------------------------- -| uint256_add.small.cairo | 720 | 105 | +| hintrefs.cairo | 1029 | 109 | --------------------------------------------------------------------------------------------------------------------------- -| blake.starknet_with_keccak.cairo | 46788 | 482 | +| uint256_sqrt.small.cairo | 929 | 109 | --------------------------------------------------------------------------------------------------------------------------- -| pedersen_test.small.cairo | 625 | 103 | +| assert_250_bits.small.cairo | 843 | 109 | --------------------------------------------------------------------------------------------------------------------------- -| pow.small.cairo | 712 | 105 | +| assert_not_equal.cairo | 923 | 182 | --------------------------------------------------------------------------------------------------------------------------- -| verify_ecdsa_signature.small.cairo | 754 | 103 | +| div_mod_n.small.cairo | 889 | 113 | --------------------------------------------------------------------------------------------------------------------------- -| memset.cairo | 682 | 104 | +| ec.small.cairo | 4754 | 158 | --------------------------------------------------------------------------------------------------------------------------- -| split64.small.cairo | 812 | 105 | +| import_secp256R1P.small.cairo | 705 | 103 | --------------------------------------------------------------------------------------------------------------------------- -| uint256_sqrt.small.cairo | 923 | 108 | +| usort.small.cairo | 916 | 106 | --------------------------------------------------------------------------------------------------------------------------- -| assert_250_bits.small.cairo | 969 | 103 | +| verify_ecdsa_signature.small.cairo | 747 | 104 | --------------------------------------------------------------------------------------------------------------------------- -| is_quad_residue.small.cairo | 793 | 123 | +| cmp.small.cairo | 913 | 110 | --------------------------------------------------------------------------------------------------------------------------- -| reduce_v1.small.cairo | 824 | 106 | +| signed_div_rem.small.cairo | 1000 | 115 | --------------------------------------------------------------------------------------------------------------------------- -| set_add.small.cairo | 662 | 104 | +| poseidon_test.starknet_with_keccak.cairo| 1318 | 108 | --------------------------------------------------------------------------------------------------------------------------- -| split_felt.small.cairo | 812 | 106 | +| assert_not_zero.cairo | 1052 | 104 | --------------------------------------------------------------------------------------------------------------------------- -| uint256_mul_div_mod.small.cairo | 916 | 108 | +| blake.starknet_with_keccak.cairo | 47471 | 512 | --------------------------------------------------------------------------------------------------------------------------- -| poseidon_test.starknet_with_keccak.cairo| 1315 | 107 | +| set_add.small.cairo | 707 | 107 | --------------------------------------------------------------------------------------------------------------------------- -| div_mod_n.small.cairo | 942 | 111 | +| split64.small.cairo | 804 | 106 | --------------------------------------------------------------------------------------------------------------------------- -| assert_not_zero.cairo | 666 | 105 | +| sqrt.small.cairo | 853 | 106 | --------------------------------------------------------------------------------------------------------------------------- -| memcpy.cairo | 631 | 103 | +| verify_zero.small.cairo | 738 | 105 | --------------------------------------------------------------------------------------------------------------------------- -| search_sorted_lower.small.cairo | 759 | 107 | +| simple.cairo | 600 | 103 | --------------------------------------------------------------------------------------------------------------------------- -| uint256_unsigned_div_rem.small.cairo | 873 | 108 | +| dict_squash.small.cairo | 1005 | 115 | --------------------------------------------------------------------------------------------------------------------------- -| unsigned_div_rem.small.cairo | 798 | 106 | +| memcpy.cairo | 627 | 103 | --------------------------------------------------------------------------------------------------------------------------- -| verify_zero.small.cairo | 716 | 105 | +| split_felt.small.cairo | 1041 | 106 | --------------------------------------------------------------------------------------------------------------------------- -| fib.cairo | 651 | 103 | +| uint256_add.small.cairo | 722 | 105 | --------------------------------------------------------------------------------------------------------------------------- -| assert_not_equal.cairo | 638 | 102 | +| factorial.cairo | 1047 | 106 | --------------------------------------------------------------------------------------------------------------------------- -| keccak_test.starknet_with_keccak.cairo| 1293 | 107 | +| pedersen_test.small.cairo | 639 | 103 | --------------------------------------------------------------------------------------------------------------------------- -| is_zero.small.cairo | 830 | 104 | +| is_zero.small.cairo | 872 | 107 | --------------------------------------------------------------------------------------------------------------------------- -| hintrefs.cairo | 714 | 104 | +| random_ec.cairo | 763 | 108 | --------------------------------------------------------------------------------------------------------------------------- -| unsafe_keccak.small.cairo | 707 | 105 | +| split_int.small.cairo | 793 | 106 | --------------------------------------------------------------------------------------------------------------------------- -| factorial.cairo | 1013 | 109 | +| uint256_signedNN.small.cairo | 767 | 105 | --------------------------------------------------------------------------------------------------------------------------- -| simple.cairo | 637 | 103 | +| uint256_unsigned_div_rem.small.cairo | 883 | 108 | --------------------------------------------------------------------------------------------------------------------------- -| ecdsa_test.starknet_with_keccak.cairo | 1385 | 108 | +| unsafe_keccak.small.cairo | 740 | 104 | --------------------------------------------------------------------------------------------------------------------------- -| get_point_from_x.small.cairo | 870 | 106 | +| unsigned_div_rem.small.cairo | 826 | 106 | --------------------------------------------------------------------------------------------------------------------------- -| dict.cairo | 653 | 105 | +| alloc.cairo | 737 | 620 | --------------------------------------------------------------------------------------------------------------------------- -| find_element.small.cairo | 774 | 105 | +| find_element.small.cairo | 813 | 105 | --------------------------------------------------------------------------------------------------------------------------- -| split_int.small.cairo | 713 | 105 | +| get_point_from_x.small.cairo | 1039 | 135 | --------------------------------------------------------------------------------------------------------------------------- -| sqrt.small.cairo | 823 | 105 | +| is_positive.small.cairo | 784 | 105 | --------------------------------------------------------------------------------------------------------------------------- -| uint256_signedNN.small.cairo | 683 | 105 | +| reduce_v1.small.cairo | 860 | 108 | --------------------------------------------------------------------------------------------------------------------------- -| unsafe_keccak_finalize.small.cairo | 675 | 103 | +| uint256_mul_div_mod.small.cairo | 977 | 110 | --------------------------------------------------------------------------------------------------------------------------- -| usort.small.cairo | 740 | 106 | +| unsafe_keccak_finalize.small.cairo | 662 | 104 | --------------------------------------------------------------------------------------------------------------------------- -| alloc.cairo | 715 | 596 | +| fib.cairo | 641 | 103 | --------------------------------------------------------------------------------------------------------------------------- -| bitwise_builtin_test.starknet_with_keccak.cairo| 1304 | 107 | +| ecdsa_test.starknet_with_keccak.cairo | 1506 | 108 | +--------------------------------------------------------------------------------------------------------------------------- +| keccak_test.starknet_with_keccak.cairo| 1369 | 110 | =========================================================================================================================== diff --git a/integration_tests/cairozero_test.go b/integration_tests/cairozero_test.go index bcfdc468..c2cea37d 100644 --- a/integration_tests/cairozero_test.go +++ b/integration_tests/cairozero_test.go @@ -78,9 +78,6 @@ func TestCairoFiles(t *testing.T) { errorExpected := false if name == "range_check.small.cairo" { errorExpected = true - } else if name == "ecop.starknet_with_keccak.cairo" { - // temporary, being fixed in another PR soon - continue } path := filepath.Join(root, name) diff --git a/pkg/hintrunner/utils/math_utils.go b/pkg/hintrunner/utils/math_utils.go index 9b789b6f..2ff985d7 100644 --- a/pkg/hintrunner/utils/math_utils.go +++ b/pkg/hintrunner/utils/math_utils.go @@ -145,7 +145,7 @@ func IsQuadResidue(x *fp.Element) bool { return x.IsZero() || x.IsOne() || x.Legendre() == 1 } -func YSquaredFromX(x, beta, fieldPrime *big.Int) *big.Int { +func ySquaredFromX(x, beta, fieldPrime *big.Int) *big.Int { // Computes y^2 using the curve equation: // y^2 = x^3 + alpha * x + beta (mod field_prime) // We ignore alpha as it is a constant with a value of 1 @@ -171,3 +171,17 @@ func Sqrt(x, p *big.Int) *big.Int { return m } + +func RecoverY(x, beta, fieldPrime *big.Int) (*big.Int, error) { + ySquared := ySquaredFromX(x, beta, fieldPrime) + if IsQuadResidue(new(fp.Element).SetBigInt(ySquared)) { + return Sqrt(ySquared, fieldPrime), nil + } + return nil, fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquared.String()) +} + +func GetCairoPrime() (big.Int, bool) { + // 2**251 + 17 * 2**192 + 1 + cairoPrime, ok := new(big.Int).SetString("3618502788666131213697322783095070105623107215331596699973092056135872020481", 10) + return *cairoPrime, ok +} diff --git a/pkg/hintrunner/zero/hintcode.go b/pkg/hintrunner/zero/hintcode.go index 38056fb0..d0903759 100644 --- a/pkg/hintrunner/zero/hintcode.go +++ b/pkg/hintrunner/zero/hintcode.go @@ -99,6 +99,7 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])` isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P" isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)" recoverYCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import recover_y\nids.p.x = ids.x\n# This raises an exception if `x` is not on the curve.\nids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)" + randomEcPointCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import random_ec_point\nfrom starkware.python.utils import to_bytes\n\n# Define a seed for random_ec_point that's dependent on all the input, so that:\n# (1) The added point s is deterministic.\n# (2) It's hard to choose inputs for which the builtin will fail.\nseed = b\"\".join(map(to_bytes, [ids.p.x, ids.p.y, ids.m, ids.q.x, ids.q.y]))\nids.s.x, ids.s.y = random_ec_point(FIELD_PRIME, ALPHA, BETA, seed)" // ------ Signature hints related code ------ verifyECDSASignatureCode string = "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))" diff --git a/pkg/hintrunner/zero/zerohint.go b/pkg/hintrunner/zero/zerohint.go index 1e0c1a1f..d1f18d5e 100644 --- a/pkg/hintrunner/zero/zerohint.go +++ b/pkg/hintrunner/zero/zerohint.go @@ -151,6 +151,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64 return createIsZeroDivModHinter() case recoverYCode: return createRecoverYHinter(resolver) + case randomEcPointCode: + return createRandomEcPointHinter(resolver) // Blake hints case blake2sAddUint256BigendCode: return createBlake2sAddUint256Hinter(resolver, true) diff --git a/pkg/hintrunner/zero/zerohint_ec.go b/pkg/hintrunner/zero/zerohint_ec.go index 02838a13..ab1f7b92 100644 --- a/pkg/hintrunner/zero/zerohint_ec.go +++ b/pkg/hintrunner/zero/zerohint_ec.go @@ -1,6 +1,9 @@ package zero import ( + "crypto/sha256" + "encoding/binary" + "encoding/hex" "fmt" "math/big" @@ -901,33 +904,25 @@ func newRecoverYHint(x, p hinter.ResOperander) hinter.Hinter { return err } - const betaString = "3141592653589793238462643383279502884197169399375105820974944592307816406665" - betaBigInt, ok := new(big.Int).SetString(betaString, 10) - if !ok { - panic("failed to convert BETA string to big.Int") - } + betaBigInt := new(big.Int) + utils.Beta.BigInt(betaBigInt) - const fieldPrimeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481" - fieldPrimeBigInt, ok := new(big.Int).SetString(fieldPrimeString, 10) + fieldPrimeBigInt, ok := secp_utils.GetCairoPrime() if !ok { - panic("failed to convert FIELD_PRIME string to big.Int") + return fmt.Errorf("GetCairoPrime failed") } xBigInt := new(big.Int) xFelt.BigInt(xBigInt) // y^2 = x^3 + alpha * x + beta (mod field_prime) - ySquaredBigInt := secp_utils.YSquaredFromX(xBigInt, betaBigInt, fieldPrimeBigInt) - ySquaredFelt := new(fp.Element).SetBigInt(ySquaredBigInt) - - if secp_utils.IsQuadResidue(ySquaredFelt) { - result := new(fp.Element).SetBigInt(secp_utils.Sqrt(ySquaredBigInt, fieldPrimeBigInt)) - value := mem.MemoryValueFromFieldElement(result) - return vm.Memory.WriteToAddress(&pYAddr, &value) - } else { - ySquaredString := ySquaredBigInt.String() - return fmt.Errorf("%s does not represent the x coordinate of a point on the curve", ySquaredString) + resultBigInt, err := secp_utils.RecoverY(xBigInt, betaBigInt, &fieldPrimeBigInt) + if err != nil { + return err } + resultFelt := new(fp.Element).SetBigInt(resultBigInt) + resultMv := mem.MemoryValueFromFieldElement(resultFelt) + return vm.Memory.WriteToAddress(&pYAddr, &resultMv) }, } } @@ -945,3 +940,137 @@ func createRecoverYHinter(resolver hintReferenceResolver) (hinter.Hinter, error) return newRecoverYHint(x, p), nil } + +// RandomEcPoint hint returns a random non-zero point on the elliptic curve +// y^2 = x^3 + alpha * x + beta (mod field_prime). +// The point is created deterministically from the seed. +// +// `newRandomEcPointHint` takes 4 operanders as arguments +// - `p` is an EC point used for seed generation +// - `m` the multiplication coefficient of Q used for seed generation +// - `q` an EC point used for seed generation +// - `s` is where the generated random EC point is written to +func newRandomEcPointHint(p, m, q, s hinter.ResOperander) hinter.Hinter { + return &GenericZeroHinter{ + Name: "RandomEcPoint", + Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error { + //> from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME + //> from starkware.python.math_utils import random_ec_point + //> from starkware.python.utils import to_bytes + //> + //> # Define a seed for random_ec_point that's dependent on all the input, so that: + //> # (1) The added point s is deterministic. + //> # (2) It's hard to choose inputs for which the builtin will fail. + //> seed = b"".join(map(to_bytes, [ids.p.x, ids.p.y, ids.m, ids.q.x, ids.q.y])) + //> ids.s.x, ids.s.y = random_ec_point(FIELD_PRIME, ALPHA, BETA, seed) + + pAddr, err := p.GetAddress(vm) + if err != nil { + return err + } + pValues, err := vm.Memory.ResolveAsEcPoint(pAddr) + if err != nil { + return err + } + mFelt, err := hinter.ResolveAsFelt(vm, m) + if err != nil { + return err + } + qAddr, err := q.GetAddress(vm) + if err != nil { + return err + } + qValues, err := vm.Memory.ResolveAsEcPoint(qAddr) + if err != nil { + return err + } + + var bytesArray []byte + writeFeltToBytesArray := func(n *fp.Element) { + for _, byteValue := range n.Bytes() { + bytesArray = append(bytesArray, byteValue) + } + } + for _, felt := range pValues { + writeFeltToBytesArray(felt) + } + writeFeltToBytesArray(mFelt) + for _, felt := range qValues { + writeFeltToBytesArray(felt) + } + seed := sha256.Sum256(bytesArray) + + alphaBig := new(big.Int) + utils.Alpha.BigInt(alphaBig) + betaBig := new(big.Int) + utils.Beta.BigInt(betaBig) + fieldPrime, ok := secp_utils.GetCairoPrime() + if !ok { + return fmt.Errorf("GetCairoPrime failed") + } + + for i := uint64(0); i < 100; i++ { + iBytes := make([]byte, 10) + binary.LittleEndian.PutUint64(iBytes, i) + concatenated := append(seed[1:], iBytes...) + hash := sha256.Sum256(concatenated) + hashHex := hex.EncodeToString(hash[:]) + x := new(big.Int) + x.SetString(hashHex, 16) + + yCoef := big.NewInt(1) + if seed[0]&1 == 1 { + yCoef.Neg(yCoef) + } + + // Try to recover y + if !ok { + return fmt.Errorf("failed to get field prime value") + } + if y, err := secp_utils.RecoverY(x, betaBig, &fieldPrime); err == nil { + y.Mul(yCoef, y) + y.Mod(y, &fieldPrime) + + sAddr, err := s.GetAddress(vm) + if err != nil { + return err + } + + sXFelt := new(fp.Element).SetBigInt(x) + sYFelt := new(fp.Element).SetBigInt(y) + sXMv := mem.MemoryValueFromFieldElement(sXFelt) + sYMv := mem.MemoryValueFromFieldElement(sYFelt) + + err = vm.Memory.WriteToNthStructField(sAddr, sXMv, 0) + if err != nil { + return err + } + return vm.Memory.WriteToNthStructField(sAddr, sYMv, 1) + } + } + + return fmt.Errorf("could not find a point on the curve") + }, + } +} + +func createRandomEcPointHinter(resolver hintReferenceResolver) (hinter.Hinter, error) { + p, err := resolver.GetResOperander("p") + if err != nil { + return nil, err + } + m, err := resolver.GetResOperander("m") + if err != nil { + return nil, err + } + q, err := resolver.GetResOperander("q") + if err != nil { + return nil, err + } + s, err := resolver.GetResOperander("s") + if err != nil { + return nil, err + } + + return newRandomEcPointHint(p, m, q, s), nil +} diff --git a/pkg/hintrunner/zero/zerohint_ec_test.go b/pkg/hintrunner/zero/zerohint_ec_test.go index 07f366c8..3892e2c2 100644 --- a/pkg/hintrunner/zero/zerohint_ec_test.go +++ b/pkg/hintrunner/zero/zerohint_ec_test.go @@ -1054,6 +1054,52 @@ func TestZeroHintEc(t *testing.T) { errCheck: errorTextContains("does not represent the x coordinate of a point on the curve"), }, }, + "RandomEcPoint": { + { + operanders: []*hintOperander{ + {Name: "p.x", Kind: apRelative, Value: feltString("3004956058830981475544150447242655232275382685012344776588097793621230049020")}, + {Name: "p.y", Kind: apRelative, Value: feltString("3232266734070744637901977159303149980795588196503166389060831401046564401743")}, + {Name: "m", Kind: apRelative, Value: feltUint64(34)}, + {Name: "q.x", Kind: apRelative, Value: feltString("2864041794633455918387139831609347757720597354645583729611044800117714995244")}, + {Name: "q.y", Kind: apRelative, Value: feltString("2252415379535459416893084165764951913426528160630388985542241241048300343256")}, + {Name: "s.x", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newRandomEcPointHint( + ctx.operanders["p.x"], + ctx.operanders["m"], + ctx.operanders["q.x"], + ctx.operanders["s.x"], + ) + }, + check: consecutiveVarValueEquals("s.x", []*fp.Element{ + feltString("96578541406087262240552119423829615463800550101008760434566010168435227837635"), + feltString("3412645436898503501401619513420382337734846074629040678138428701431530606439"), + }), + }, + { + operanders: []*hintOperander{ + {Name: "p.x", Kind: apRelative, Value: feltUint64(12345)}, + {Name: "p.y", Kind: apRelative, Value: feltUint64(6789)}, + {Name: "m", Kind: apRelative, Value: feltUint64(101)}, + {Name: "q.x", Kind: apRelative, Value: feltUint64(98765)}, + {Name: "q.y", Kind: apRelative, Value: feltUint64(4321)}, + {Name: "s.x", Kind: uninitialized}, + }, + makeHinter: func(ctx *hintTestContext) hinter.Hinter { + return newRandomEcPointHint( + ctx.operanders["p.x"], + ctx.operanders["m"], + ctx.operanders["q.x"], + ctx.operanders["s.x"], + ) + }, + check: consecutiveVarValueEquals("s.x", []*fp.Element{ + feltString("39190969885360777615413526676655883809466222002423777590585892821354159079496"), + feltString("533983185449702770508526175744869430974740140562200547506631069957329272485"), + }), + }, + }, }, ) } diff --git a/pkg/hintrunner/zero/zerohint_math.go b/pkg/hintrunner/zero/zerohint_math.go index 3817411b..1092f4da 100644 --- a/pkg/hintrunner/zero/zerohint_math.go +++ b/pkg/hintrunner/zero/zerohint_math.go @@ -1154,21 +1154,20 @@ func newIsQuadResidueHint(x, y hinter.ResOperander) hinter.Hinter { var value = memory.MemoryValue{} var result *fp.Element = new(fp.Element) - const primeString = "3618502788666131213697322783095070105623107215331596699973092056135872020481" - primeBigInt, ok := new(big.Int).SetString(primeString, 10) + primeBigInt, ok := math_utils.GetCairoPrime() if !ok { - panic("failed to convert prime string to big.Int") + return fmt.Errorf("GetCairoPrime failed") } if math_utils.IsQuadResidue(x) { - result.SetBigInt(math_utils.Sqrt(&xBigInt, primeBigInt)) + result.SetBigInt(math_utils.Sqrt(&xBigInt, &primeBigInt)) } else { - y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), primeBigInt) + y, err := math_utils.Divmod(&xBigInt, big.NewInt(3), &primeBigInt) if err != nil { return err } - result.SetBigInt(math_utils.Sqrt(&y, primeBigInt)) + result.SetBigInt(math_utils.Sqrt(&y, &primeBigInt)) } value = memory.MemoryValueFromFieldElement(result) diff --git a/pkg/vm/memory/memory_value.go b/pkg/vm/memory/memory_value.go index 2dbd4a6c..1ab3f66a 100644 --- a/pkg/vm/memory/memory_value.go +++ b/pkg/vm/memory/memory_value.go @@ -375,3 +375,21 @@ func (memory *Memory) ResolveAsBigInt3(valAddr MemoryAddress) ([3]*f.Element, er return valValues, nil } + +func (memory *Memory) ResolveAsEcPoint(valAddr MemoryAddress) ([2]*f.Element, error) { + valMemoryValues, err := memory.GetConsecutiveMemoryValues(valAddr, int16(2)) + if err != nil { + return [2]*f.Element{}, err + } + + var valValues [2]*f.Element + for i := 0; i < 2; i++ { + valValue, err := valMemoryValues[i].FieldElement() + if err != nil { + return [2]*f.Element{}, err + } + valValues[i] = valValue + } + + return valValues, nil +}