Skip to content

Commit

Permalink
Merge branch 'main' into _copy_inputs_hint
Browse files Browse the repository at this point in the history
  • Loading branch information
TAdev0 committed Jun 28, 2024
2 parents b021861 + d56008d commit 536dc50
Show file tree
Hide file tree
Showing 18 changed files with 559 additions and 290 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
python-version: '3.9'

- name: Install cairo-lang
run: pip install cairo-lang==0.11
run: pip install cairo-lang==0.13.1

- name: Build
run: make build
Expand Down
31 changes: 31 additions & 0 deletions integration_tests/cairo_zero_hint_tests/is_zero.small.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Returns 1 if x == 0 (mod secp256k1_prime), and 0 otherwise.
// Serves as integration test for the following hints :
// isZeroNondetCode
// isZeroPackCode
// isZeroDivModCode

%builtins range_check

from starkware.cairo.common.cairo_secp.field import is_zero, SumBigInt3

func main{range_check_ptr}() -> () {

// Test One
let a = SumBigInt3(0, 0, 0);
let (res: felt) = is_zero(a);
assert res = 1;

// Test Two
let b = SumBigInt3(42, 0, 0);
let (res: felt) = is_zero(b);
assert res = 0;

// Test Three
let c = SumBigInt3(
77371252455336262886226991, 77371252455336267181195263, 19342813113834066795298815
);
let (res: felt) = is_zero(c);
assert res = 1;

return ();
}
59 changes: 59 additions & 0 deletions integration_tests/cairo_zero_hint_tests/usort.small.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
%builtins range_check
from starkware.cairo.common.usort import usort
from starkware.cairo.common.alloc import alloc

func main{range_check_ptr}() -> () {
alloc_locals;
let (input_array: felt*) = alloc();
assert input_array[0] = 8;
assert input_array[1] = 9;
assert input_array[2] = 7;

let (output_len, output, multiplicities) = usort(input_len=3, input=input_array);

assert output_len = 3;
assert output[0] = 7;
assert output[1] = 8;
assert output[2] = 9;

assert multiplicities[0] = 1;
assert multiplicities[1] = 1;
assert multiplicities[2] = 1;

let (input_array: felt*) = alloc();
assert input_array[0] = 11;
assert input_array[1] = 24;
assert input_array[2] = 99;
assert input_array[3] = 2;
assert input_array[4] = 66;
assert input_array[5] = 49;
assert input_array[6] = 11;
assert input_array[7] = 23;
assert input_array[8] = 88;
assert input_array[9] = 7;

let (output_len, output, multiplicities) = usort(input_len=10, input=input_array);

assert output_len = 9;
assert output[0] = 2;
assert output[1] = 7;
assert output[2] = 11;
assert output[3] = 23;
assert output[4] = 24;
assert output[5] = 49;
assert output[6] = 66;
assert output[7] = 88;
assert output[8] = 99;

assert multiplicities[0] = 1;
assert multiplicities[1] = 1;
assert multiplicities[2] = 2;
assert multiplicities[3] = 1;
assert multiplicities[4] = 1;
assert multiplicities[5] = 1;
assert multiplicities[6] = 1;
assert multiplicities[7] = 1;
assert multiplicities[8] = 1;

return ();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
%builtins poseidon
from starkware.cairo.common.cairo_builtins import PoseidonBuiltin
from starkware.cairo.common.poseidon_state import PoseidonBuiltinState
from starkware.cairo.common.builtin_poseidon.poseidon import (
poseidon_hash,
poseidon_hash_single,
poseidon_hash_many,
)
from starkware.cairo.common.alloc import alloc

func main{poseidon_ptr: PoseidonBuiltin*}() {
// Hash one
let (x) = poseidon_hash_single(
218676008889449692916464780911713710628115973574242889792891157041292792362
);
assert x = 2835120893146788752888137145656423078969524407843035783270702964188823073934;
// Hash two
let (y) = poseidon_hash(1253795, 18540013156130945068);
assert y = 37282360750367388068593128053386029947772104009544220786084510532118246655;
// Hash five
let felts: felt* = alloc();
assert felts[0] = 84175983715088675913672849362079546;
assert felts[1] = 9384720329467203286234076408512594689579283578028960384690;
assert felts[2] = 291883989128409324823849293040390493094093;
assert felts[3] = 5849589438543859348593485948598349584395839402940940290490324;
assert felts[4] = 1836254780028456372728992049476335424263474849;
let (z) = poseidon_hash_many(5, felts);
assert z = 47102513329160951064697157194713013753695317629154835326726810042406974264;
return ();
}
21 changes: 11 additions & 10 deletions pkg/hintrunner/hinter/zero_dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,27 @@ import (
// Used to keep track of all Dictionaries data
type ZeroDictionary struct {
// The Data contained in a dictionary
Data map[fp.Element]mem.MemoryValue
Data *map[fp.Element]mem.MemoryValue
// Default value for key not present in the dictionary
DefaultValue mem.MemoryValue
DefaultValue *mem.MemoryValue
// first free offset in memory segment of dictionary
FreeOffset *uint64
}

// Gets the memory value at certain key
func (d *ZeroDictionary) at(key fp.Element) (mem.MemoryValue, error) {
if value, ok := d.Data[key]; ok {
if value, ok := (*d.Data)[key]; ok {
return value, nil
}
if d.DefaultValue != mem.UnknownValue {
return d.DefaultValue, nil
if *d.DefaultValue != mem.UnknownValue {
return *d.DefaultValue, nil
}
return mem.UnknownValue, fmt.Errorf("no value for key: %v", key)
}

// Given a key and a value, it sets the value at the given key
func (d *ZeroDictionary) set(key fp.Element, value mem.MemoryValue) {
d.Data[key] = value
(*d.Data)[key] = value
}

// Given a incrementBy value, it increments the freeOffset field of dictionary by it
Expand Down Expand Up @@ -63,8 +63,8 @@ func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine, data map[f
newDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.Dictionaries[newDictAddr.SegmentIndex] = ZeroDictionary{
Data: data,
DefaultValue: mem.UnknownValue,
Data: &data,
DefaultValue: &mem.UnknownValue,
FreeOffset: &freeOffset,
}
return newDictAddr
Expand All @@ -76,10 +76,11 @@ func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine, data map[f
// querying the defaultValue will be returned instead.
func (dm *ZeroDictionaryManager) NewDefaultDictionary(vm *VM.VirtualMachine, defaultValue mem.MemoryValue) mem.MemoryAddress {
newDefaultDictAddr := vm.Memory.AllocateEmptySegment()
newData := make(map[fp.Element]mem.MemoryValue)
freeOffset := uint64(0)
dm.Dictionaries[newDefaultDictAddr.SegmentIndex] = ZeroDictionary{
Data: make(map[fp.Element]mem.MemoryValue),
DefaultValue: defaultValue,
Data: &newData,
DefaultValue: &defaultValue,
FreeOffset: &freeOffset,
}
return newDefaultDictAddr
Expand Down
43 changes: 22 additions & 21 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,25 @@ const (
uint256MulDivModCode string = "a = (ids.a.high << 128) + ids.a.low\nb = (ids.b.high << 128) + ids.b.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a * b, div)\n\nids.quotient_low.low = quotient & ((1 << 128) - 1)\nids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)\nids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)\nids.quotient_high.high = quotient >> 384\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128"

// ------ Usort hints related code ------
usortBodyCode string = `
from collections import defaultdict
input_ptr = ids.input
input_len = int(ids.input_len)
if __usort_max_size is not None:
assert input_len <= __usort_max_size, (
f"usort() can only be used with input_len<={__usort_max_size}. "
f"Got: input_len={input_len}."
)
positions_dict = defaultdict(list)
for i in range(input_len):
val = memory[input_ptr + i]
positions_dict[val].append(i)
output = sorted(positions_dict.keys())
ids.output_len = len(output)
ids.output = segments.gen_arg(output)
ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
usortBodyCode string = `from collections import defaultdict
input_ptr = ids.input
input_len = int(ids.input_len)
if __usort_max_size is not None:
assert input_len <= __usort_max_size, (
f"usort() can only be used with input_len<={__usort_max_size}. "
f"Got: input_len={input_len}."
)
positions_dict = defaultdict(list)
for i in range(input_len):
val = memory[input_ptr + i]
positions_dict[val].append(i)
output = sorted(positions_dict.keys())
ids.output_len = len(output)
ids.output = segments.gen_arg(output)
ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
usortEnterScopeCode string = "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))"
usortVerifyMultiplicityAssertCode string = "assert len(positions) == 0"
usortVerifyCode string = "last_pos = 0\npositions = positions_dict[ids.value][::-1]"
Expand All @@ -95,7 +94,7 @@ const (
ecDoubleAssignNewXV1Code string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nslope = pack(ids.slope, PRIME)\nx = pack(ids.point.x, PRIME)\ny = pack(ids.point.y, PRIME)\n\nvalue = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P"
ecDoubleAssignNewYV1Code string = "value = new_y = (slope * (x - new_x) - y) % SECP_P"
ecMulInnerCode string = "memory[ap] = (ids.scalar % PRIME) % 2"
isZeroNondetCode string = "x == 0"
isZeroNondetCode string = "memory[ap] = to_felt_or_relocatable(x == 0)"
isZeroPackCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\n\nx = pack(ids.x, PRIME) % SECP_P"
isZeroDivModCode string = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P\nfrom starkware.python.math_utils import div_mod\n\nvalue = x_inv = div_mod(1, x, SECP_P)"

Expand Down Expand Up @@ -155,6 +154,8 @@ const (
vmEnterScopeCode string = "vm_enter_scope()"
vmExitScopeCode string = "vm_exit_scope()"
findElementCode string = "array_ptr = ids.array_ptr\nelm_size = ids.elm_size\nassert isinstance(elm_size, int) and elm_size > 0, \\\n f'Invalid value for elm_size. Got: {elm_size}.'\nkey = ids.key\n\nif '__find_element_index' in globals():\n ids.index = __find_element_index\n found_key = memory[array_ptr + elm_size * __find_element_index]\n assert found_key == key, \\\n f'Invalid index found in __find_element_index. index: {__find_element_index}, ' \\\n f'expected key {key}, found key: {found_key}.'\n # Delete __find_element_index to make sure it's not used for the next calls.\n del __find_element_index\nelse:\n n_elms = ids.n_elms\n assert isinstance(n_elms, int) and n_elms >= 0, \\\n f'Invalid value for n_elms. Got: {n_elms}.'\n if '__find_element_max_size' in globals():\n assert n_elms <= __find_element_max_size, \\\n f'find_element() can only be used with n_elms<={__find_element_max_size}. ' \\\n f'Got: n_elms={n_elms}.'\n\n for i in range(n_elms):\n if memory[array_ptr + elm_size * i] == key:\n ids.index = i\n break\n else:\n raise ValueError(f'Key {key} was not found.')"
nondetElementsOverTWoCode string = "memory[ap] = to_felt_or_relocatable(ids.n >= 2)"
nondetElementsOverTenCode string = "memory[ap] = to_felt_or_relocatable(ids.n >= 10)"
setAddCode string = "assert ids.elm_size > 0\nassert ids.set_ptr <= ids.set_end_ptr\nelm_list = memory.get_range(ids.elm_ptr, ids.elm_size)\nfor i in range(0, ids.set_end_ptr - ids.set_ptr, ids.elm_size):\n if memory.get_range(ids.set_ptr + i, ids.elm_size) == elm_list:\n ids.index = i // ids.elm_size\n ids.is_elm_in_set = 1\n break\nelse:\n ids.is_elm_in_set = 0"
searchSortedLowerCode string = `array_ptr = ids.array_ptr
elm_size = ids.elm_size
Expand Down
8 changes: 4 additions & 4 deletions pkg/hintrunner/zero/hintparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (expression OffsetExp) Evaluate() (*int, error) {
negNumber := -*expression.NegNumber
return &negNumber, nil
default:
return nil, fmt.Errorf("Expected a number")
return nil, fmt.Errorf("expected a number")
}
}

Expand All @@ -226,7 +226,7 @@ func (expression DerefExp) Evaluate() (any, error) {
}
cellRef, ok := cellRefExp.(hinter.CellRefer)
if !ok {
return nil, fmt.Errorf("Expected a CellRefer expression but got %s", cellRefExp)
return nil, fmt.Errorf("expected a CellRefer expression but got %s", cellRefExp)
}
return hinter.Deref{Deref: cellRef}, nil
}
Expand Down Expand Up @@ -301,7 +301,7 @@ func (expression LeftExp) Evaluate() (any, error) {
case expression.DerefExp != nil:
return expression.DerefExp.Evaluate()
}
return nil, fmt.Errorf("Unexpected left expression in binary operation")
return nil, fmt.Errorf("unexpected left expression in binary operation")
}

func (expression RightExp) Evaluate() (any, error) {
Expand All @@ -311,7 +311,7 @@ func (expression RightExp) Evaluate() (any, error) {
case expression.Offset != nil:
return expression.Offset.Evaluate()
}
return nil, fmt.Errorf("Unexpected right expression in binary operation")
return nil, fmt.Errorf("unexpected right expression in binary operation")
}

func ParseIdentifier(value string) (hinter.Reference, error) {
Expand Down
4 changes: 4 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createTestAssignHinter(resolver)
case findElementCode:
return createFindElementHinter(resolver)
case nondetElementsOverTWoCode:
return createNondetElementsOverTWoHinter(resolver)
case nondetElementsOverTenCode:
return createNondetElementsOverTenHinter(resolver)
default:
return nil, fmt.Errorf("not identified hint")
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/hintrunner/zero/zerohint_dictionaries.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func newDictSquashCopyDictHint(dictAccessesEnd hinter.ResOperander) hinter.Hinte
}

dictionaryDataCopy := make(map[fp.Element]memory.MemoryValue)
for k, v := range dictionary.Data {
for k, v := range *dictionary.Data {
// Copy the key
keyCopy := fp.Element{}
keyCopy.Set(&k)
Expand Down Expand Up @@ -490,7 +490,7 @@ func newSquashDictHint(dictAccesses, ptrDiff, nAccesses, bigKeys, firstKey hinte
return err
}
if ptrDiffValue%DictAccessSize != 0 {
return fmt.Errorf("Accesses array size must be divisible by DictAccess.SIZE")
return fmt.Errorf("accesses array size must be divisible by DictAccess.SIZE")
}

//> n_accesses = ids.n_accesses
Expand All @@ -506,7 +506,7 @@ func newSquashDictHint(dictAccesses, ptrDiff, nAccesses, bigKeys, firstKey hinte
// __squash_dict_max_size is always in scope and has a value of 2**20,
squashDictMaxSize := uint64(1048576)
if nAccessesValue > squashDictMaxSize {
return fmt.Errorf("squash_dict() can only be used with n_accesses<={%d}. Got: n_accesses={%d}.", squashDictMaxSize, nAccessesValue)
return fmt.Errorf("squash_dict() can only be used with n_accesses<={%d}. Got: n_accesses={%d}", squashDictMaxSize, nAccessesValue)
}

//> # A map from key to the list of indices accessing it.
Expand Down
4 changes: 2 additions & 2 deletions pkg/hintrunner/zero/zerohint_dictionaries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ func TestZeroHintDictionaries(t *testing.T) {
ctx.operanders["first_key"],
)
},
errCheck: errorTextContains("Accesses array size must be divisible by DictAccess.SIZE"),
errCheck: errorTextContains("accesses array size must be divisible by DictAccess.SIZE"),
},
{
operanders: []*hintOperander{
Expand All @@ -625,7 +625,7 @@ func TestZeroHintDictionaries(t *testing.T) {
ctx.operanders["first_key"],
)
},
errCheck: errorTextContains("squash_dict() can only be used with n_accesses<={1048576}. Got: n_accesses={1048577}."),
errCheck: errorTextContains("squash_dict() can only be used with n_accesses<={1048576}. Got: n_accesses={1048577}"),
},
{
operanders: []*hintOperander{
Expand Down
Loading

0 comments on commit 536dc50

Please sign in to comment.