Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DictWrite hint #364

Merged
merged 47 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
593ff1b
Add basic skeleton
har777 Apr 22, 2024
88028ac
Implement DefaultDictNew
har777 Apr 22, 2024
076f7b9
Add tests
har777 Apr 22, 2024
99eada8
Implement DefaultRead
har777 Apr 22, 2024
93b9c52
Implement DictWrite
har777 Apr 22, 2024
1a6212c
Add simple integration test
har777 Apr 25, 2024
66c21d2
Merge branch 'defaultdictnew_hint' into dictread_hint
har777 Apr 25, 2024
2b3aff8
Update dict integration test
har777 Apr 25, 2024
69e9838
Merge branch 'dictread_hint' into dictwrite_hint
har777 Apr 25, 2024
1a071b2
Update dict integration test
har777 Apr 25, 2024
d60bf5f
Fix imports
har777 Apr 25, 2024
b5f9d28
Merge branch 'dictread_hint' into dictwrite_hint
har777 Apr 25, 2024
414772c
Fix imports
har777 Apr 25, 2024
1db67f0
Merge main
har777 May 13, 2024
667b45a
Add comments + minor changes
har777 May 13, 2024
a261b72
Merge defaultdictnew_hint
har777 May 13, 2024
31b2bc4
Merge dictread_hint
har777 May 13, 2024
92a609d
Remove unnecessary ctx init
har777 May 13, 2024
980f4ab
Add comment
har777 May 13, 2024
4adb057
Add comment
har777 May 13, 2024
a07a75d
Clean up dict integration test
har777 May 14, 2024
957676f
Clean up dict integration test
har777 May 14, 2024
7d9815e
Clean up dict integration test
har777 May 14, 2024
03910ad
Treat dicts in zero hints differently
har777 May 15, 2024
4decc42
Remove accidental newline
har777 May 15, 2024
e04aa54
Merge defaultdictnew_hint + update implementation
har777 May 15, 2024
cefb3f5
Fix ignoring err message
har777 May 15, 2024
efc3129
Remove some unnecessary code from tests
har777 May 15, 2024
9f08d23
Merge defaultread_hint + update implementation
har777 May 15, 2024
159674a
Merge main
har777 May 16, 2024
ff554ad
Fix typo
har777 May 16, 2024
ee997e4
Fix typo
har777 May 16, 2024
00494b4
Improved comments
har777 May 16, 2024
0088b58
Merge main
har777 May 16, 2024
2bccb0e
Add credit comment
har777 May 16, 2024
ca98db3
Merge dictread_hint
har777 May 16, 2024
faeacd5
Add better comments
har777 May 16, 2024
216ed2d
Merge main
har777 May 20, 2024
6f4bc71
Fix freeOffset bug + add tests
har777 May 20, 2024
3adece9
Fix tests
har777 May 21, 2024
ed75416
Add and use zeroDictInScopeEquals test util
har777 May 21, 2024
5911929
Merge branch 'main' into dictwrite_hint
TAdev0 Jun 5, 2024
0c1b794
Merge branch 'main' into dictwrite_hint
MaksymMalicki Jun 7, 2024
951eda1
Fix method usage
har777 Jun 7, 2024
58e12a7
Comment out integration test
har777 Jun 7, 2024
ce43643
Enable dict integration test
har777 Jun 7, 2024
54ed888
Merge branch 'main' into dictwrite_hint
har777 Jun 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion integration_tests/cairo_files/dict.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// inspired from the dict.cairo integration test in the lambdaclass cairo-vm codebase

from starkware.cairo.common.default_dict import default_dict_new
from starkware.cairo.common.dict import dict_read
from starkware.cairo.common.dict import dict_read, dict_write
from starkware.cairo.common.dict_access import DictAccess

func test_default_dict() {
Expand All @@ -24,9 +24,46 @@ func test_read() {
return ();
}

func test_write() {
alloc_locals;
let (local my_dict: DictAccess*) = default_dict_new(123);

let (local val1: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val1 = 123;

let (local val2: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val2 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=512);
let (local val3: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val3 = 512;

let (local val4: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val4 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=1024);
let (local val5: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val5 = 1024;

let (local val6: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val6 = 123;

dict_write{dict_ptr=my_dict}(key=1, new_value=888);
dict_write{dict_ptr=my_dict}(key=2, new_value=999);
let (local val7: felt) = dict_read{dict_ptr=my_dict}(key=1);
assert val7 = 888;
let (local val8: felt) = dict_read{dict_ptr=my_dict}(key=2);
assert val8 = 999;
let (local val9: felt) = dict_read{dict_ptr=my_dict}(key=3);
assert val9 = 123;

return ();
}

func main() {
test_default_dict();
test_read();
test_write();

return ();
}
32 changes: 17 additions & 15 deletions pkg/hintrunner/hinter/zero_dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,33 @@ import (

// Used to keep track of all dictionaries data
type ZeroDictionary struct {
// The data contained on a dictionary
data map[f.Element]mem.MemoryValue
// The Data contained on a dictionary
Data map[f.Element]mem.MemoryValue
// Default value for key not present in the dictionary
defaultValue mem.MemoryValue
DefaultValue mem.MemoryValue
// first free offset in memory segment of dictionary
freeOffset uint64
FreeOffset *uint64
}

// Gets the memory value at certain key
func (d *ZeroDictionary) At(key f.Element) (mem.MemoryValue, error) {
if value, ok := d.data[key]; ok {
if value, ok := d.Data[key]; ok {
return value, nil
}
if d.defaultValue != mem.UnknownValue {
return d.defaultValue, nil
if d.DefaultValue != mem.UnknownValue {
return d.DefaultValue, nil
}
return mem.UnknownValue, fmt.Errorf("no value for key: %v", key)
}

// Given a key and a value, it sets the value at the given key
func (d *ZeroDictionary) Set(key f.Element, value mem.MemoryValue) {
d.data[key] = value
d.Data[key] = value
}

// Given a incrementBy value, it increments the freeOffset field of dictionary by it
func (d *ZeroDictionary) IncrementFreeOffset(freeOffset uint64) {
d.freeOffset += freeOffset
*d.FreeOffset += freeOffset
}

// Used to manage dictionaries creation
Expand All @@ -56,10 +56,11 @@ func NewZeroDictionaryManager() ZeroDictionaryManager {
// to the start of this segment
func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine) mem.MemoryAddress {
newDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.dictionaries[newDictAddr.SegmentIndex] = ZeroDictionary{
data: make(map[f.Element]mem.MemoryValue),
defaultValue: mem.UnknownValue,
freeOffset: 0,
Data: make(map[f.Element]mem.MemoryValue),
DefaultValue: mem.UnknownValue,
FreeOffset: &freeOffset,
}
return newDictAddr
}
Expand All @@ -70,10 +71,11 @@ func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine) mem.Memory
// querying the defaultValue will be returned instead.
func (dm *ZeroDictionaryManager) NewDefaultDictionary(vm *VM.VirtualMachine, defaultValue mem.MemoryValue) mem.MemoryAddress {
newDefaultDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.dictionaries[newDefaultDictAddr.SegmentIndex] = ZeroDictionary{
data: make(map[f.Element]mem.MemoryValue),
defaultValue: defaultValue,
freeOffset: 0,
Data: make(map[f.Element]mem.MemoryValue),
DefaultValue: defaultValue,
FreeOffset: &freeOffset,
}
return newDefaultDictAddr
}
Expand Down
1 change: 1 addition & 0 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const (
// ------ Dictionaries hints related code ------
defaultDictNewCode string = "if '__dict_manager' not in globals():\n from starkware.cairo.common.dict import DictManager\n __dict_manager = DictManager()\n\nmemory[ap] = __dict_manager.new_default_dict(segments, ids.default_value)"
dictReadCode string = "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ndict_tracker.current_ptr += ids.DictAccess.SIZE\nids.value = dict_tracker.data[ids.key]"
dictWriteCode string = "dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)\ndict_tracker.current_ptr += ids.DictAccess.SIZE\nids.dict_ptr.prev_value = dict_tracker.data[ids.key]\ndict_tracker.data[ids.key] = ids.new_value"
squashDictInnerAssertLenKeys string = "assert len(keys) == 0"
squashDictInnerContinueLoop string = "ids.loop_temps.should_continue = 1 if current_access_indices else 0"
squashDictInnerSkipLoop string = "ids.should_skip_loop = 0 if current_access_indices else 1"
Expand Down
2 changes: 2 additions & 0 deletions pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ func GetHintFromCode(program *zero.ZeroProgram, rawHint zero.Hint, hintPC uint64
return createDefaultDictNewHinter(resolver)
case dictReadCode:
return createDictReadHinter(resolver)
case dictWriteCode:
return createDictWriteHinter(resolver)
case squashDictInnerAssertLenKeys:
return createSquashDictInnerAssertLenKeysHinter()
case squashDictInnerContinueLoop:
Expand Down
80 changes: 80 additions & 0 deletions pkg/hintrunner/zero/zerohint_dictionaries.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,86 @@
return newDictReadHint(dictPtr, key, value), nil
}

// DictWrite hint writes a value for a given key in a dictionary
// and writes to memory the previous value for the key in the dictionary
//
// `newDictWriteHint` takes 3 operanders as argument
// - `dictPtr` variable will be pointer to the dictionary to update
// - `key` variable will be the key whose value is updated in the dictionary
// - `newValue` variable will be the new value for given key in the dictionary
func newDictWriteHint(dictPtr, key, newValue hinter.ResOperander) hinter.Hinter {
return &GenericZeroHinter{
Name: "DictWrite",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
//> dict_tracker.current_ptr += ids.DictAccess.SIZE
//> ids.dict_ptr.prev_value = dict_tracker.data[ids.key]
//> dict_tracker.data[ids.key] = ids.new_value

//> dict_tracker = __dict_manager.get_tracker(ids.dict_ptr)
dictPtr, err := hinter.ResolveAsAddress(vm, dictPtr)
if err != nil {
return err
}
dictionaryManager, ok := ctx.ScopeManager.GetZeroDictionaryManager()
if !ok {
return fmt.Errorf("__dict_manager not in scope")
}

//> dict_tracker.current_ptr += ids.DictAccess.SIZE
err = dictionaryManager.IncrementFreeOffset(*dictPtr, 3)
har777 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}

key, err := hinter.ResolveAsFelt(vm, key)
if err != nil {
return err
}

//> ids.dict_ptr.prev_value = dict_tracker.data[ids.key]
//> # dict_ptr points to a DictAccess
//> struct DictAccess {
//> key: felt,
//> prev_value: felt,
//> new_value: felt,
//> }
prevKeyValue, err := dictionaryManager.At(*dictPtr, *key)
if err != nil {
return err
}
err = hinter.WriteToNthStructField(vm, *dictPtr, prevKeyValue, 1)

Check failure on line 176 in pkg/hintrunner/zero/zerohint_dictionaries.go

View workflow job for this annotation

GitHub Actions / lint

undefined: hinter.WriteToNthStructField

Check failure on line 176 in pkg/hintrunner/zero/zerohint_dictionaries.go

View workflow job for this annotation

GitHub Actions / lint

undefined: hinter.WriteToNthStructField

Check failure on line 176 in pkg/hintrunner/zero/zerohint_dictionaries.go

View workflow job for this annotation

GitHub Actions / lint

undefined: hinter.WriteToNthStructField

Check failure on line 176 in pkg/hintrunner/zero/zerohint_dictionaries.go

View workflow job for this annotation

GitHub Actions / build

undefined: hinter.WriteToNthStructField
if err != nil {
return err
}

//> dict_tracker.data[ids.key] = ids.new_value
newValue, err := hinter.ResolveAsFelt(vm, newValue)
if err != nil {
return err
}
newValueMv := memory.MemoryValueFromFieldElement(newValue)
return dictionaryManager.Set(*dictPtr, *key, newValueMv)
},
}
}

func createDictWriteHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
dictPtr, err := resolver.GetResOperander("dict_ptr")
if err != nil {
return nil, err
}
key, err := resolver.GetResOperander("key")
if err != nil {
return nil, err
}
newValue, err := resolver.GetResOperander("new_value")
if err != nil {
return nil, err
}
return newDictWriteHint(dictPtr, key, newValue), nil
}

// SquashDictInnerAssertLenKeys hint asserts that the length
// of the `keys` descending list is zero during the squashing process
//
Expand Down
43 changes: 42 additions & 1 deletion pkg/hintrunner/zero/zerohint_dictionaries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,48 @@ func TestZeroHintDictionaries(t *testing.T) {
dictionaryManager.NewDefaultDictionary(ctx.vm, defaultValueMv)
return newDictReadHint(ctx.operanders["dict_ptr"], ctx.operanders["key"], ctx.operanders["value"])
},
check: varValueEquals("value", feltUint64(12345)),
check: func(t *testing.T, ctx *hintTestContext) {
varValueEquals("value", feltUint64(12345))(t, ctx)

dictPtr := addrWithSegment(2, 0)
expectedData := map[fp.Element]memory.MemoryValue{}
expectedDefaultValue := memory.MemoryValueFromInt(12345)
expectedFreeOffset := uint64(3)
zeroDictInScopeEquals(*dictPtr, expectedData, expectedDefaultValue, expectedFreeOffset)(t, ctx)
},
},
},
"DictWrite": {
{
operanders: []*hintOperander{
{Name: "key", Kind: apRelative, Value: feltUint64(100)},
{Name: "new_value", Kind: apRelative, Value: feltUint64(9999)},
{Name: "dict_ptr", Kind: apRelative, Value: addrWithSegment(2, 0)},
{Name: "dict_ptr.prev_value", Kind: apRelative, Value: addrWithSegment(2, 1)},
},
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
dictionaryManager := hinter.NewZeroDictionaryManager()
err := ctx.runnerContext.ScopeManager.AssignVariable("__dict_manager", dictionaryManager)
if err != nil {
t.Fatal(err)
}
defaultValueMv := memory.MemoryValueFromInt(12345)
dictionaryManager.NewDefaultDictionary(ctx.vm, defaultValueMv)
return newDictWriteHint(ctx.operanders["dict_ptr"], ctx.operanders["key"], ctx.operanders["new_value"])
},
check: func(t *testing.T, ctx *hintTestContext) {
consecutiveVarAddrResolvedValueEquals(
"dict_ptr.prev_value",
[]*fp.Element{
feltString("12345"),
})(t, ctx)

dictPtr := addrWithSegment(2, 0)
expectedData := map[fp.Element]memory.MemoryValue{*feltUint64(100): memory.MemoryValueFromInt(9999)}
expectedDefaultValue := memory.MemoryValueFromInt(12345)
expectedFreeOffset := uint64(3)
zeroDictInScopeEquals(*dictPtr, expectedData, expectedDefaultValue, expectedFreeOffset)(t, ctx)
},
},
},
"SquashDictInnerAssertLenKeys": {
Expand Down
1 change: 0 additions & 1 deletion pkg/hintrunner/zero/zerohint_others_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
)

func TestZeroHintMemcpy(t *testing.T) {

runHinterTests(t, map[string][]hintTestCase{
"MemcpyContinueCopying": {
{
Expand Down
18 changes: 18 additions & 0 deletions pkg/hintrunner/zero/zerohint_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/NethermindEth/cairo-vm-go/pkg/vm"
"github.com/NethermindEth/cairo-vm-go/pkg/vm/memory"
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -225,3 +226,20 @@ func varListInScopeEquals(expectedValues map[string]any) func(t *testing.T, ctx
}
}
}

func zeroDictInScopeEquals(dictAddress memory.MemoryAddress, expectedData map[fp.Element]memory.MemoryValue, expectedDefaultValue memory.MemoryValue, expectedFreeOffset uint64) func(t *testing.T, ctx *hintTestContext) {
return func(t *testing.T, ctx *hintTestContext) {
dictionaryManager, ok := ctx.runnerContext.ScopeManager.GetZeroDictionaryManager()
if !ok {
t.Fatal("failed to fetch dictionary manager")
}
dictionary, err := dictionaryManager.GetDictionary(dictAddress)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, expectedData, dictionary.Data)
assert.Equal(t, expectedDefaultValue, dictionary.DefaultValue)
assert.Equal(t, expectedFreeOffset, *dictionary.FreeOffset)
}
}
Loading