Skip to content

Commit

Permalink
Implement DictWrite hint (#364)
Browse files Browse the repository at this point in the history
* Add basic skeleton

* Implement DefaultDictNew

* Add tests

* Implement DefaultRead

* Implement DictWrite

* Add simple integration test

* Update dict integration test

* Update dict integration test

* Fix imports

* Fix imports

* Add comments + minor changes

* Remove unnecessary ctx init

* Add comment

* Add comment

* Clean up dict integration test

* Clean up dict integration test

* Clean up dict integration test

* Treat dicts in zero hints differently

* Remove accidental newline

* Fix ignoring err message

* Remove some unnecessary code from tests

* Fix typo

* Fix typo

* Improved comments

* Add credit comment

* Add better comments

* Fix freeOffset bug + add tests

* Fix tests

* Add and use zeroDictInScopeEquals test util

* Fix method usage

* Comment out integration test

* Enable dict integration test

---------

Co-authored-by: Tristan <122918260+TAdev0@users.noreply.github.com>
Co-authored-by: MaksymMalicki <81577596+MaksymMalicki@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 9, 2024
1 parent 057220f commit 77b0805
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 18 deletions.
40 changes: 39 additions & 1 deletion integration_tests/cairo_files/dict.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// inspired from the dict.cairo integration test in the lambdaclass cairo-vm codebase

from starkware.cairo.common.default_dict import default_dict_new
from starkware.cairo.common.dict import dict_read
from starkware.cairo.common.dict import dict_read, dict_write
from starkware.cairo.common.dict_access import DictAccess

func test_default_dict() {
Expand All @@ -24,9 +24,47 @@ func test_read() {
return ();
}

func test_write() {
alloc_locals;
let (local my_dict: DictAccess*) = default_dict_new(123);

let (local val1: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val1 = 123;

let (local val2: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val2 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=512);
let (local val3: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val3 = 512;

let (local val4: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val4 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=1024);
let (local val5: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val5 = 1024;

let (local val6: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val6 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=888);
dict_write{dict_ptr=my_dict}(key=2, new_value=999);
let (local val7: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val7 = 888;
let (local val8: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val8 = 999;
let (local val9: felt) = dict_read{dict_ptr=my_dict}(key=3);
assert val9 = 123;

return ();
}

func main() {
test_default_dict();
test_read();
test_write();

return ();
}

32 changes: 17 additions & 15 deletions pkg/hintrunner/hinter/zero_dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@ import (

// Used to keep track of all dictionaries data
type ZeroDictionary struct {
// The data contained on a dictionary
data map[f.Element]mem.MemoryValue
// The Data contained on a dictionary
Data map[f.Element]mem.MemoryValue
// Default value for key not present in the dictionary
defaultValue mem.MemoryValue
DefaultValue mem.MemoryValue
// first free offset in memory segment of dictionary
freeOffset uint64
FreeOffset *uint64
}

// Gets the memory value at certain key
func (d *ZeroDictionary) At(key f.Element) (mem.MemoryValue, error) {
if value, ok := d.data[key]; ok {
if value, ok := d.Data[key]; ok {
return value, nil
}
if d.defaultValue != mem.UnknownValue {
return d.defaultValue, nil
if d.DefaultValue != mem.UnknownValue {
return d.DefaultValue, nil
}
return mem.UnknownValue, fmt.Errorf("no value for key: %v", key)
}

// Given a key and a value, it sets the value at the given key
func (d *ZeroDictionary) Set(key f.Element, value mem.MemoryValue) {
d.data[key] = value
d.Data[key] = value
}

// Given a incrementBy value, it increments the freeOffset field of dictionary by it
func (d *ZeroDictionary) IncrementFreeOffset(freeOffset uint64) {
d.freeOffset += freeOffset
*d.FreeOffset += freeOffset
}

// Used to manage dictionaries creation
Expand All @@ -56,10 +56,11 @@ func NewZeroDictionaryManager() ZeroDictionaryManager {
// to the start of this segment
func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine) mem.MemoryAddress {
newDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.dictionaries[newDictAddr.SegmentIndex] = ZeroDictionary{
data: make(map[f.Element]mem.MemoryValue),
defaultValue: mem.UnknownValue,
freeOffset: 0,
Data: make(map[f.Element]mem.MemoryValue),
DefaultValue: mem.UnknownValue,
FreeOffset: &freeOffset,
}
return newDictAddr
}
Expand All @@ -70,10 +71,11 @@ func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine) mem.Memory
// querying the defaultValue will be returned instead.
func (dm *ZeroDictionaryManager) NewDefaultDictionary(vm *VM.VirtualMachine, defaultValue mem.MemoryValue) mem.MemoryAddress {
newDefaultDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.dictionaries[newDefaultDictAddr.SegmentIndex] = ZeroDictionary{
data: make(map[f.Element]mem.MemoryValue),
defaultValue: defaultValue,
freeOffset: 0,
Data: make(map[f.Element]mem.MemoryValue),
DefaultValue: defaultValue,
FreeOffset: &freeOffset,
}
return newDefaultDictAddr
}
Expand Down
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ const (
// ------ Dictionaries hints related code ------
defaultDictNewCode string = "if '__dict_manager' not in globals():\n from starkware.cairo.common.dict import DictManager\n __dict_manager = DictManager()\n\nmemory[ap] = __dict_manager.new_default_dict(segments, ids.default_value)"
dictReadCode string = "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ndict_tracker.current_ptr += ids.DictAccess.SIZE\nids.value = dict_tracker.data[ids.key]"
dictWriteCode string = "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ndict_tracker.current_ptr += ids.DictAccess.SIZE\nids.dict_ptr.prev_value = dict_tracker.data[ids.key]\ndict_tracker.data[ids.key] = ids.new_value"
squashDictInnerAssertLenKeys string = "assert len(keys) == 0"
squashDictInnerContinueLoop string = "ids.loop_temps.should_continue = 1 if current_access_indices else 0"
squashDictInnerSkipLoop string = "ids.should_skip_loop = 0 if current_access_indices else 1"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createDefaultDictNewHinter(resolver)
case dictReadCode:
return createDictReadHinter(resolver)
case dictWriteCode:
return createDictWriteHinter(resolver)
case squashDictInnerAssertLenKeys:
return createSquashDictInnerAssertLenKeysHinter()
case squashDictInnerContinueLoop:
Expand Down
80 changes: 80 additions & 0 deletions pkg/hintrunner/zero/zerohint_dictionaries.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,86 @@ func createDictReadHinter(resolver hintReferenceResolver) (hinter.Hinter, error)
return newDictReadHint(dictPtr, key, value), nil
}

// DictWrite hint writes a value for a given key in a dictionary
// and writes to memory the previous value for the key in the dictionary
//
// `newDictWriteHint` takes 3 operanders as argument
// - `dictPtr` variable will be pointer to the dictionary to update
// - `key` variable will be the key whose value is updated in the dictionary
// - `newValue` variable will be the new value for given key in the dictionary
func newDictWriteHint(dictPtr, key, newValue hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "DictWrite",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
//> dict_tracker.current_ptr += ids.DictAccess.SIZE
//> ids.dict_ptr.prev_value = dict_tracker.data[ids.key]
//> dict_tracker.data[ids.key] = ids.new_value

//> dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dictPtr, err := hinter.ResolveAsAddress(vm, dictPtr)
if err != nil {
return err
}
dictionaryManager, ok := ctx.ScopeManager.GetZeroDictionaryManager()
if !ok {
return fmt.Errorf("__dict_manager not in scope")
}

//> dict_tracker.current_ptr += ids.DictAccess.SIZE
err = dictionaryManager.IncrementFreeOffset(*dictPtr, 3)
if err != nil {
return err
}

key, err := hinter.ResolveAsFelt(vm, key)
if err != nil {
return err
}

//> ids.dict_ptr.prev_value = dict_tracker.data[ids.key]
//> # dict_ptr points to a DictAccess
//> struct DictAccess {
//> key: felt,
//> prev_value: felt,
//> new_value: felt,
//> }
prevKeyValue, err := dictionaryManager.At(*dictPtr, *key)
if err != nil {
return err
}
err = vm.Memory.WriteToNthStructField(*dictPtr, prevKeyValue, 1)
if err != nil {
return err
}

//> dict_tracker.data[ids.key] = ids.new_value
newValue, err := hinter.ResolveAsFelt(vm, newValue)
if err != nil {
return err
}
newValueMv := memory.MemoryValueFromFieldElement(newValue)
return dictionaryManager.Set(*dictPtr, *key, newValueMv)
},
}
}

func createDictWriteHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
dictPtr, err := resolver.GetResOperander("dict_ptr")
if err != nil {
return nil, err
}
key, err := resolver.GetResOperander("key")
if err != nil {
return nil, err
}
newValue, err := resolver.GetResOperander("new_value")
if err != nil {
return nil, err
}
return newDictWriteHint(dictPtr, key, newValue), nil
}

// SquashDictInnerAssertLenKeys hint asserts that the length
// of the `keys` descending list is zero during the squashing process
//
Expand Down
43 changes: 42 additions & 1 deletion pkg/hintrunner/zero/zerohint_dictionaries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,48 @@ func TestZeroHintDictionaries(t *testing.T) {
dictionaryManager.NewDefaultDictionary(ctx.vm, defaultValueMv)
return newDictReadHint(ctx.operanders["dict_ptr"], ctx.operanders["key"], ctx.operanders["value"])
},
check: varValueEquals("value", feltUint64(12345)),
check: func(t *testing.T, ctx *hintTestContext) {
varValueEquals("value", feltUint64(12345))(t, ctx)

dictPtr := addrWithSegment(2, 0)
expectedData := map[fp.Element]memory.MemoryValue{}
expectedDefaultValue := memory.MemoryValueFromInt(12345)
expectedFreeOffset := uint64(3)
zeroDictInScopeEquals(*dictPtr, expectedData, expectedDefaultValue, expectedFreeOffset)(t, ctx)
},
},
},
"DictWrite": {
{
operanders: []*hintOperander{
{Name: "key", Kind: apRelative, Value: feltUint64(100)},
{Name: "new_value", Kind: apRelative, Value: feltUint64(9999)},
{Name: "dict_ptr", Kind: apRelative, Value: addrWithSegment(2, 0)},
{Name: "dict_ptr.prev_value", Kind: apRelative, Value: addrWithSegment(2, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
dictionaryManager := hinter.NewZeroDictionaryManager()
err := ctx.runnerContext.ScopeManager.AssignVariable("__dict_manager", dictionaryManager)
if err != nil {
t.Fatal(err)
}
defaultValueMv := memory.MemoryValueFromInt(12345)
dictionaryManager.NewDefaultDictionary(ctx.vm, defaultValueMv)
return newDictWriteHint(ctx.operanders["dict_ptr"], ctx.operanders["key"], ctx.operanders["new_value"])
},
check: func(t *testing.T, ctx *hintTestContext) {
consecutiveVarAddrResolvedValueEquals(
"dict_ptr.prev_value",
[]*fp.Element{
feltString("12345"),
})(t, ctx)

dictPtr := addrWithSegment(2, 0)
expectedData := map[fp.Element]memory.MemoryValue{*feltUint64(100): memory.MemoryValueFromInt(9999)}
expectedDefaultValue := memory.MemoryValueFromInt(12345)
expectedFreeOffset := uint64(3)
zeroDictInScopeEquals(*dictPtr, expectedData, expectedDefaultValue, expectedFreeOffset)(t, ctx)
},
},
},
"SquashDictInnerAssertLenKeys": {
Expand Down
1 change: 0 additions & 1 deletion pkg/hintrunner/zero/zerohint_others_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
)

func TestZeroHintMemcpy(t *testing.T) {

runHinterTests(t, map[string][]hintTestCase{
"MemcpyContinueCopying": {
{
Expand Down
18 changes: 18 additions & 0 deletions pkg/hintrunner/zero/zerohint_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -225,3 +226,20 @@ func varListInScopeEquals(expectedValues map[string]any) func(t *testing.T, ctx
}
}
}

func zeroDictInScopeEquals(dictAddress memory.MemoryAddress, expectedData map[fp.Element]memory.MemoryValue, expectedDefaultValue memory.MemoryValue, expectedFreeOffset uint64) func(t *testing.T, ctx *hintTestContext) {
return func(t *testing.T, ctx *hintTestContext) {
dictionaryManager, ok := ctx.runnerContext.ScopeManager.GetZeroDictionaryManager()
if !ok {
t.Fatal("failed to fetch dictionary manager")
}
dictionary, err := dictionaryManager.GetDictionary(dictAddress)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, expectedData, dictionary.Data)
assert.Equal(t, expectedDefaultValue, dictionary.DefaultValue)
assert.Equal(t, expectedFreeOffset, *dictionary.FreeOffset)
}
}

0 comments on commit 77b0805

Please sign in to comment.