Skip to content

Commit

Permalink
Implement SquashDict hint (#412)
Browse files Browse the repository at this point in the history
* Implement basic SquashDict

* Add more functionality

* Add error tests

* Add a successful test

* Test more successful scenarios

* Test with more diverse values

* Add integration test from lambda

* Use Pop util

* Add function docs

* Comment out integration test

---------

Co-authored-by: Tristan <122918260+TAdev0@users.noreply.github.com>
  • Loading branch information
har777 and TAdev0 committed Jun 17, 2024
1 parent 8cf54a1 commit fc520ad
Show file tree
Hide file tree
Showing 6 changed files with 422 additions and 0 deletions.
38 changes: 38 additions & 0 deletions integration_tests/cairo_files/squash_dict.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// %builtins range_check

// from starkware.cairo.common.squash_dict import squash_dict
// from starkware.cairo.common.alloc import alloc
// from starkware.cairo.common.dict_access import DictAccess

// func main{range_check_ptr}() -> () {
// alloc_locals;
// let (dict_start: DictAccess*) = alloc();
// assert dict_start[0] = DictAccess(key=0, prev_value=100, new_value=100);
// assert dict_start[1] = DictAccess(key=1, prev_value=50, new_value=50);
// assert dict_start[2] = DictAccess(key=0, prev_value=100, new_value=200);
// assert dict_start[3] = DictAccess(key=1, prev_value=50, new_value=100);
// assert dict_start[4] = DictAccess(key=0, prev_value=200, new_value=300);
// assert dict_start[5] = DictAccess(key=1, prev_value=100, new_value=150);

// let dict_end = dict_start + 6 * DictAccess.SIZE;
// // (dict_start, dict_end) now represents the dictionary
// // {0: 100, 1: 50, 0: 200, 1: 100, 0: 300, 1: 150}.

// // Squash the dictionary from an array of 6 DictAccess structs
// // to an array of 2, with a single DictAccess entry per key.
// let (local squashed_dict_start: DictAccess*) = alloc();
// let (squashed_dict_end) = squash_dict{range_check_ptr=range_check_ptr}(
// dict_start, dict_end, squashed_dict_start
// );

// // Check the values of the squashed_dict
// // should be: {0: (100, 300), 1: (50, 150)}
// assert squashed_dict_start[0] = DictAccess(key=0, prev_value=100, new_value=300);
// assert squashed_dict_start[1] = DictAccess(key=1, prev_value=50, new_value=150);
// return ();
// }

func main() {
return ();
}

1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ const (
dictReadCode string = "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ndict_tracker.current_ptr += ids.DictAccess.SIZE\nids.value = dict_tracker.data[ids.key]"
dictWriteCode string = "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ndict_tracker.current_ptr += ids.DictAccess.SIZE\nids.dict_ptr.prev_value = dict_tracker.data[ids.key]\ndict_tracker.data[ids.key] = ids.new_value"
dictUpdateCode string = "# Verify dict pointer and prev value.\ndict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ncurrent_value = dict_tracker.data[ids.key]\nassert current_value == ids.prev_value, \\\n f'Wrong previous value in dict. Got {ids.prev_value}, expected {current_value}.'\n\n# Update value.\ndict_tracker.data[ids.key] = ids.new_value\ndict_tracker.current_ptr += ids.DictAccess.SIZE"
squashDictCode string = "dict_access_size = ids.DictAccess.SIZE\naddress = ids.dict_accesses.address_\nassert ids.ptr_diff % dict_access_size == 0, \\\n 'Accesses array size must be divisible by DictAccess.SIZE'\nn_accesses = ids.n_accesses\nif '__squash_dict_max_size' in globals():\n assert n_accesses <= __squash_dict_max_size, \\\n f'squash_dict() can only be used with n_accesses<={__squash_dict_max_size}. ' \\\n f'Got: n_accesses={n_accesses}.'\n# A map from key to the list of indices accessing it.\naccess_indices = {}\nfor i in range(n_accesses):\n key = memory[address + dict_access_size * i]\n access_indices.setdefault(key, []).append(i)\n# Descending list of keys.\nkeys = sorted(access_indices.keys(), reverse=True)\n# Are the keys used bigger than range_check bound.\nids.big_keys = 1 if keys[0] >= range_check_builtin.bound else 0\nids.first_key = key = keys.pop()"
squashDictInnerAssertLenKeys string = "assert len(keys) == 0"
squashDictInnerCheckAccessIndex string = "new_access_index = current_access_indices.pop()\nids.loop_temps.index_delta_minus1 = new_access_index - current_access_index - 1\ncurrent_access_index = new_access_index"
squashDictInnerContinueLoop string = "ids.loop_temps.should_continue = 1 if current_access_indices else 0"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createDictWriteHinter(resolver)
case dictUpdateCode:
return createDictUpdateHinter(resolver)
case squashDictCode:
return createSquashDictHinter(resolver)
case squashDictInnerAssertLenKeys:
return createSquashDictInnerAssertLenKeysHinter()
case squashDictInnerCheckAccessIndex:
Expand Down
168 changes: 168 additions & 0 deletions pkg/hintrunner/zero/zerohint_dictionaries.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
VM "github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
"golang.org/x/exp/maps"
)

// struct DictAccess {
Expand Down Expand Up @@ -366,6 +367,173 @@ func createDictUpdateHinter(resolver hintReferenceResolver) (hinter.Hinter, erro
return newDictUpdateHint(dictPtr, key, newValue, prevValue), nil
}

// SquashDict hint as part of the larger dict_squash cairo function does data validation
// and writes to scope a set of variables which indicate the largest used key in the dict,
// a map from key to the list of indices accessing it
// and a descending list of used keys except the largest key.
// It also writes to a cairo variable the largest used key
// and a boolean indicating if any of the keys used were bigger than the range_check
//
// `newSquashDictHint` takes 5 operanders as arguments
// - `dictAccesses` variable will be a pointer to the beginning of an array of DictAccess instances. The format of
// each entry is a triplet (key, prev_value, new_value)
// - `ptrDiff` variable will be the size of the above dictAccesses array
// - `nAccesses` variable will have a value indicating the number of times the dict was accessed
// - `bigKeys` variable will be written a value of 1 if the keys used are bigger than the range_check and 0 otherwise
// - `firstKey` variable will be written the value of the largest used key after the hint is run
func newSquashDictHint(dictAccesses, ptrDiff, nAccesses, bigKeys, firstKey hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "SquashDict",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> dict_access_size = ids.DictAccess.SIZE
//> address = ids.dict_accesses.address_
//> assert ids.ptr_diff % dict_access_size == 0, \
//> 'Accesses array size must be divisible by DictAccess.SIZE'
//> n_accesses = ids.n_accesses
//> if '__squash_dict_max_size' in globals():
//> assert n_accesses <= __squash_dict_max_size, \
//> f'squash_dict() can only be used with n_accesses<={__squash_dict_max_size}. ' \
//> f'Got: n_accesses={n_accesses}.'
//> # A map from key to the list of indices accessing it.
//> access_indices = {}
//> for i in range(n_accesses):
//> key = memory[address + dict_access_size * i]
//> access_indices.setdefault(key, []).append(i)
//> # Descending list of keys.
//> keys = sorted(access_indices.keys(), reverse=True)
//> # Are the keys used bigger than range_check bound.
//> ids.big_keys = 1 if keys[0] >= range_check_builtin.bound else 0
//> ids.first_key = key = keys.pop()

//> address = ids.dict_accesses.address_
address, err := dictAccesses.GetAddress(vm)
if err != nil {
return err
}

//> assert ids.ptr_diff % dict_access_size == 0, \
//> 'Accesses array size must be divisible by DictAccess.SIZE'
ptrDiffValue, err := hinter.ResolveAsUint64(vm, ptrDiff)
if err != nil {
return err
}
if ptrDiffValue%DictAccessSize != 0 {
return fmt.Errorf("Accesses array size must be divisible by DictAccess.SIZE")
}

//> n_accesses = ids.n_accesses
nAccessesValue, err := hinter.ResolveAsUint64(vm, nAccesses)
if err != nil {
return err
}

//> if '__squash_dict_max_size' in globals():
//> assert n_accesses <= __squash_dict_max_size, \
//> f'squash_dict() can only be used with n_accesses<={__squash_dict_max_size}. ' \
//> f'Got: n_accesses={n_accesses}.'
// __squash_dict_max_size is always in scope and has a value of 2**20,
squashDictMaxSize := uint64(1048576)
if nAccessesValue > squashDictMaxSize {
return fmt.Errorf("squash_dict() can only be used with n_accesses<={%d}. Got: n_accesses={%d}.", squashDictMaxSize, nAccessesValue)
}

//> # A map from key to the list of indices accessing it.
//> access_indices = {}
//> for i in range(n_accesses):
//> key = memory[address + dict_access_size * i]
//> access_indices.setdefault(key, []).append(i)
accessIndices := make(map[fp.Element][]uint64)
for i := uint64(0); i < nAccessesValue; i++ {
memoryAddress, err := address.AddOffset(int16(DictAccessSize * i))
if err != nil {
return err
}
key, err := vm.Memory.ReadFromAddressAsElement(&memoryAddress)
if err != nil {
return err
}
accessIndices[key] = append(accessIndices[key], i)
}

//> # Descending list of keys.
//> keys = sorted(access_indices.keys(), reverse=True)
keys := maps.Keys(accessIndices)
if len(keys) == 0 {
return fmt.Errorf("empty keys array")
}
sort.Slice(keys, func(i, j int) bool {
return keys[i].Cmp(&keys[j]) > 0
})

//> ids.big_keys = 1 if keys[0] >= range_check_builtin.bound else 0
bigKeysAddr, err := bigKeys.GetAddress(vm)
if err != nil {
return err
}
var bigKeysMv memory.MemoryValue
if utils.FeltIsPositive(&keys[0]) {
bigKeysMv = memory.MemoryValueFromFieldElement(&utils.FeltZero)
} else {
bigKeysMv = memory.MemoryValueFromFieldElement(&utils.FeltOne)
}
err = vm.Memory.WriteToAddress(&bigKeysAddr, &bigKeysMv)
if err != nil {
return err
}

//> ids.first_key = key = keys.pop()
firstKeyAddr, err := firstKey.GetAddress(vm)
if err != nil {
return err
}
firstKeyValue, err := utils.Pop(&keys)
if err != nil {
return err
}
firstKeyMv := memory.MemoryValueFromFieldElement(&firstKeyValue)
err = vm.Memory.WriteToAddress(&firstKeyAddr, &firstKeyMv)
if err != nil {
return err
}

err = ctx.ScopeManager.AssignVariable("access_indices", accessIndices)
if err != nil {
return err
}
err = ctx.ScopeManager.AssignVariable("keys", keys)
if err != nil {
return err
}
return ctx.ScopeManager.AssignVariable("key", firstKeyValue)
},
}
}

func createSquashDictHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
dictAccesses, err := resolver.GetResOperander("dict_accesses")
if err != nil {
return nil, err
}
ptrDiff, err := resolver.GetResOperander("ptr_diff")
if err != nil {
return nil, err
}
nAccesses, err := resolver.GetResOperander("n_accesses")
if err != nil {
return nil, err
}
bigKeys, err := resolver.GetResOperander("big_keys")
if err != nil {
return nil, err
}
firstKey, err := resolver.GetResOperander("first_key")
if err != nil {
return nil, err
}

return newSquashDictHint(dictAccesses, ptrDiff, nAccesses, bigKeys, firstKey), nil
}

// SquashDictInnerAssertLenKeys hint asserts that the length
// of the `keys` descending list is zero during the squashing process
//
Expand Down
Loading

0 comments on commit fc520ad

Please sign in to comment.