Skip to content

Commit

Permalink
Merge branch 'main' into export_trace_memory_execution_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
TAdev0 committed Jul 4, 2024
2 parents de7ecac + 420549c commit 7afac88
Show file tree
Hide file tree
Showing 16 changed files with 1,746 additions and 131 deletions.
21 changes: 21 additions & 0 deletions integration_tests/cairo_zero_hint_tests/unsafe_keccak.small.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
%builtins output

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.keccak import unsafe_keccak

func main{output_ptr: felt*}() {
alloc_locals;

let (data: felt*) = alloc();

assert data[0] = 500;
assert data[1] = 2;
assert data[2] = 3;
assert data[3] = 6;
assert data[4] = 1;
assert data[5] = 4444;

let (low: felt, high: felt) = unsafe_keccak(data, 6);

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
%builtins output

from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.keccak import unsafe_keccak_finalize, KeccakState
from starkware.cairo.common.uint256 import Uint256

func main{output_ptr: felt*}() {
alloc_locals;

let (data: felt*) = alloc();

assert data[0] = 0;
assert data[1] = 1;
assert data[2] = 2;

let keccak_state = KeccakState(start_ptr=data, end_ptr=data + 2);

let res: Uint256 = unsafe_keccak_finalize(keccak_state);

return ();
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
%builtins range_check bitwise

from starkware.cairo.common.cairo_keccak.keccak import cairo_keccak, finalize_keccak
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (keccak_ptr: felt*) = alloc();
let keccak_ptr_start = keccak_ptr;

let (inputs: felt*) = alloc();

assert inputs[0] = 8031924123371070792;
assert inputs[1] = 560229490;

let n_bytes = 16;

let (res: Uint256) = cairo_keccak{keccak_ptr=keccak_ptr}(inputs=inputs, n_bytes=n_bytes);

assert res.low = 293431514620200399776069983710520819074;
assert res.high = 317109767021952548743448767588473366791;

finalize_keccak(keccak_ptr_start=keccak_ptr_start, keccak_ptr_end=keccak_ptr);

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
%builtins output range_check bitwise

from starkware.cairo.common.cairo_keccak.keccak import _keccak
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.serialize import serialize_word

func fill_array(array: felt*, base: felt, array_length: felt, iterator: felt) {
if (iterator == array_length) {
return ();
}

assert array[iterator] = base;

return fill_array(array, base, array_length, iterator + 1);
}

func main{output_ptr: felt*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (output: felt*) = alloc();
let keccak_output = output;

let (inputs: felt*) = alloc();
let inputs_start = inputs;
fill_array(inputs, 9, 3, 0);

let (state: felt*) = alloc();
let state_start = state;
fill_array(state, 5, 25, 0);

let n_bytes = 24;

let (res: felt*) = _keccak{keccak_ptr=keccak_output}(
inputs=inputs_start, n_bytes=n_bytes, state=state_start
);

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
%builtins output range_check bitwise

from starkware.cairo.common.keccak_utils.keccak_utils import keccak_add_uint256
from starkware.cairo.common.uint256 import Uint256
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.serialize import serialize_word

func main{output_ptr: felt*, range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;

let (inputs) = alloc();
let inputs_start = inputs;

let num = Uint256(34623634663146736, 598249824422424658356);

keccak_add_uint256{inputs=inputs_start}(num=num, bigend=0);

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
%builtins range_check bitwise keccak
from starkware.cairo.common.cairo_builtins import KeccakBuiltin, BitwiseBuiltin
from starkware.cairo.common.builtin_keccak.keccak import keccak_uint256s
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.uint256 import Uint256

func main{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*}() {
let elements: Uint256* = alloc();
assert elements[0] = Uint256(713458135386519, 18359173571);
assert elements[1] = Uint256(1536741637546373185, 84357893467438914);
assert elements[2] = Uint256(2842949328439284983294, 39248298942938492384);
assert elements[3] = Uint256(27518568234293478923754395731931, 981587843715983274);
assert elements[4] = Uint256(326848123647324823482, 93453458349589345);
let (res) = keccak_uint256s(5, elements);

return ();
}
39 changes: 37 additions & 2 deletions pkg/hintrunner/utils/math_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,45 @@ func sign(n *big.Int) (int, big.Int) {

func SafeDiv(x, y *big.Int) (big.Int, error) {
if y.Cmp(big.NewInt(0)) == 0 {
return *big.NewInt(0), fmt.Errorf("Division by zero.")
return *big.NewInt(0), fmt.Errorf("division by zero")
}
if new(big.Int).Mod(x, y).Cmp(big.NewInt(0)) != 0 {
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v.", x, y)
return *big.NewInt(0), fmt.Errorf("%v is not divisible by %v", x, y)
}
return *new(big.Int).Div(x, y), nil
}

func IsQuadResidue(x *fp.Element) bool {
// Implementation adapted from sympy implementation which can be found here :
// https://github.com/sympy/sympy/blob/d91b8ad6d36a59a879cc70e5f4b379da5fdd46ce/sympy/ntheory/residue_ntheory.py#L689
// We have omitted the prime as it will be CAIRO_PRIME

return x.IsZero() || x.IsOne() || x.Legendre() == 1
}

func YSquaredFromX(x, beta, fieldPrime *big.Int) *big.Int {
// Computes y^2 using the curve equation:
// y^2 = x^3 + alpha * x + beta (mod field_prime)
// We ignore alpha as it is a constant with a value of 1

ySquaredBigInt := new(big.Int).Set(x)
ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Mul(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Add(ySquaredBigInt, x).Mod(ySquaredBigInt, fieldPrime)
ySquaredBigInt.Add(ySquaredBigInt, beta).Mod(ySquaredBigInt, fieldPrime)

return ySquaredBigInt
}

func Sqrt(x, p *big.Int) *big.Int {
// Finds the minimum non-negative integer m such that (m*m) % p == x.

halfPrimeBigInt := new(big.Int).Rsh(p, 1)
m := new(big.Int).ModSqrt(x, p)

if m.Cmp(halfPrimeBigInt) > 0 {
m.Sub(p, m)
}

return m
}
25 changes: 14 additions & 11 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const (

// is_quad_residue() hint
isQuadResidueCode string = "from starkware.crypto.signature.signature import FIELD_PRIME\nfrom starkware.python.math_utils import div_mod, is_quad_residue, sqrt\n\nx = ids.x\nif is_quad_residue(x, FIELD_PRIME):\n ids.y = sqrt(x, FIELD_PRIME)\nelse:\n ids.y = sqrt(div_mod(x, 3, FIELD_PRIME), FIELD_PRIME)"

// ------ Uint256 hints related code ------
uint256AddCode string = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0"
split64Code string = "ids.low = ids.a & ((1<<64) - 1)\nids.high = ids.a >> 64"
Expand Down Expand Up @@ -97,6 +98,7 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
isZeroNondetCode string = "memory[ap] = to_felt_or_relocatable(x == 0)"
isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P"
isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)"
recoverYCode string = "from starkware.crypto.signature.signature import ALPHA, BETA, FIELD_PRIME\nfrom starkware.python.math_utils import recover_y\nids.p.x = ids.x\n# This raises an exception if `x` is not on the curve.\nids.p.y = recover_y(ids.x, ALPHA, BETA, FIELD_PRIME)"

// ------ Signature hints related code ------
verifyECDSASignatureCode string = "ecdsa_builtin.add_signature(ids.ecdsa_ptr.address_, (ids.signature_r, ids.signature_s))"
Expand All @@ -113,20 +115,21 @@ ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
blake2sComputeCode string = "from starkware.cairo.common.cairo_blake2s.blake2s_utils import compute_blake2s_func\ncompute_blake2s_func(segments=segments, output_ptr=ids.output)"

// ------ Keccak hints related code ------
unsafeKeccakFinalizeCode string = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
cairoKeccakFinalizeCode string = `# Add dummy pairs of input and output.
_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)
_block_size = int(ids.BLOCK_SIZE)
assert 0 <= _keccak_state_size_felts < 100
assert 0 <= _block_size < 10
inp = [0] * _keccak_state_size_felts
padding = (inp + keccak_func(inp)) * _block_size
segments.write_arg(ids.keccak_ptr_end, padding)`
unsafeKeccakFinalizeCode string = "from eth_hash.auto import keccak\nkeccak_input = bytearray()\nn_elms = ids.keccak_state.end_ptr - ids.keccak_state.start_ptr\nfor word in memory.get_range(ids.keccak_state.start_ptr, n_elms):\n keccak_input += word.to_bytes(16, 'big')\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
unsafeKeccakCode string = "from eth_hash.auto import keccak\n\ndata, length = ids.data, ids.length\n\nif '__keccak_max_size' in globals():\n assert length <= __keccak_max_size, \\\n f'unsafe_keccak() can only be used with length<={__keccak_max_size}. ' \\\n f'Got: length={length}.'\n\nkeccak_input = bytearray()\nfor word_i, byte_i in enumerate(range(0, length, 16)):\n word = memory[data + word_i]\n n_bytes = min(16, length - byte_i)\n assert 0 <= word < 2 ** (8 * n_bytes)\n keccak_input += word.to_bytes(n_bytes, 'big')\n\nhashed = keccak(keccak_input)\nids.high = int.from_bytes(hashed[:16], 'big')\nids.low = int.from_bytes(hashed[16:32], 'big')"
cairoKeccakFinalizeCode string = "# Add dummy pairs of input and output.\n_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)\n_block_size = int(ids.BLOCK_SIZE)\nassert 0 <= _keccak_state_size_felts < 100\nassert 0 <= _block_size < 10\ninp = [0] * _keccak_state_size_felts\npadding = (inp + keccak_func(inp)) * _block_size\nsegments.write_arg(ids.keccak_ptr_end, padding)"
keccakWriteArgsCode string = "segments.write_arg(ids.inputs, [ids.low % 2 ** 64, ids.low // 2 ** 64])\nsegments.write_arg(ids.inputs + 2, [ids.high % 2 ** 64, ids.high // 2 ** 64])"
blockPermutationCode string = "from starkware.cairo.common.keccak_utils.keccak_utils import keccak_func\n_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)\nassert 0 <= _keccak_state_size_felts < 100\noutput_values = keccak_func(memory.get_range(\nids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts))\nsegments.write_arg(ids.keccak_ptr, output_values)"
blockPermutationCode string = "from starkware.cairo.common.keccak_utils.keccak_utils import keccak_func\n_keccak_state_size_felts = int(ids.KECCAK_STATE_SIZE_FELTS)\nassert 0 <= _keccak_state_size_felts < 100\n\noutput_values = keccak_func(memory.get_range(\n ids.keccak_ptr - _keccak_state_size_felts, _keccak_state_size_felts))\nsegments.write_arg(ids.keccak_ptr, output_values)"
compareBytesInWordCode string = "memory[ap] = to_felt_or_relocatable(ids.n_bytes < ids.BYTES_IN_WORD)"
compareKeccakFullRateInBytesCode string = "memory[ap] = to_felt_or_relocatable(ids.n_bytes >= ids.KECCAK_FULL_RATE_IN_BYTES)"
splitInput3Code string = "ids.high3, ids.low3 = divmod(memory[ids.inputs + 3], 256)"
splitInput6Code string = "ids.high6, ids.low6 = divmod(memory[ids.inputs + 6], 256 ** 2)"
splitInput9Code string = "ids.high9, ids.low9 = divmod(memory[ids.inputs + 9], 256 ** 3)"
splitInput12Code string = "ids.high12, ids.low12 = divmod(memory[ids.inputs + 12], 256 ** 4)"
splitInput15Code string = "ids.high15, ids.low15 = divmod(memory[ids.inputs + 15], 256 ** 5)"
splitOutputMidLowHighCode string = "tmp, ids.output1_low = divmod(ids.output1, 256 ** 7)\nids.output1_high, ids.output1_mid = divmod(tmp, 2 ** 128)"
splitOutput0Code string = "ids.output0_low = ids.output0 & ((1 << 128) - 1)\nids.output0_high = ids.output0 >> 128"
SplitNBytesCode string = "ids.n_words_to_copy, ids.n_bytes_left = divmod(ids.n_bytes, ids.BYTES_IN_WORD)"

// ------ Dictionaries hints related code ------
dictNewCode string = "if '__dict_manager' not in globals():\n from starkware.cairo.common.dict import DictManager\n __dict_manager = DictManager()\n\nmemory[ap] = __dict_manager.new_dict(segments, initial_dict)\ndel initial_dict"
Expand Down
18 changes: 18 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createIsZeroPackHinter(resolver)
case isZeroDivModCode:
return createIsZeroDivModHinter()
case recoverYCode:
return createRecoverYHinter(resolver)
// Blake hints
case blake2sAddUint256BigendCode:
return createBlake2sAddUint256Hinter(resolver, true)
Expand All @@ -173,6 +175,22 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createBlockPermutationHinter(resolver)
case compareBytesInWordCode:
return createCompareBytesInWordNondetHinter(resolver)
case splitInput3Code:
return createSplitInput3Hinter(resolver)
case splitInput6Code:
return createSplitInput6Hinter(resolver)
case splitInput9Code:
return createSplitInput9Hinter(resolver)
case splitInput12Code:
return createSplitInput12Hinter(resolver)
case splitInput15Code:
return createSplitInput15Hinter(resolver)
case splitOutputMidLowHighCode:
return createSplitOutputMidLowHighHinter(resolver)
case splitOutput0Code:
return createSplitOutput0Hinter(resolver)
case SplitNBytesCode:
return createSplitNBytesHinter(resolver)
// Usort hints
case usortEnterScopeCode:
return createUsortEnterScopeHinter()
Expand Down
Loading

0 comments on commit 7afac88

Please sign in to comment.