Skip to content

Commit

Permalink
Improve ZeroDictionary struct (#475)
Browse files Browse the repository at this point in the history
* improve zerodict pointers

* fmt

* fmt

---------

Co-authored-by: Shourya Goel <shouryagoel10000@gmail.com>
  • Loading branch information
TAdev0 and Sh0g0-1758 committed Jun 28, 2024
1 parent dd54197 commit d56008d
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 23 deletions.
21 changes: 11 additions & 10 deletions pkg/hintrunner/hinter/zero_dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,27 @@ import (
// Used to keep track of all Dictionaries data
type ZeroDictionary struct {
// The Data contained in a dictionary
Data map[fp.Element]mem.MemoryValue
Data *map[fp.Element]mem.MemoryValue
// Default value for key not present in the dictionary
DefaultValue mem.MemoryValue
DefaultValue *mem.MemoryValue
// first free offset in memory segment of dictionary
FreeOffset *uint64
}

// Gets the memory value at certain key
func (d *ZeroDictionary) at(key fp.Element) (mem.MemoryValue, error) {
if value, ok := d.Data[key]; ok {
if value, ok := (*d.Data)[key]; ok {
return value, nil
}
if d.DefaultValue != mem.UnknownValue {
return d.DefaultValue, nil
if *d.DefaultValue != mem.UnknownValue {
return *d.DefaultValue, nil
}
return mem.UnknownValue, fmt.Errorf("no value for key: %v", key)
}

// Given a key and a value, it sets the value at the given key
func (d *ZeroDictionary) set(key fp.Element, value mem.MemoryValue) {
d.Data[key] = value
(*d.Data)[key] = value
}

// Given a incrementBy value, it increments the freeOffset field of dictionary by it
Expand Down Expand Up @@ -63,8 +63,8 @@ func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine, data map[f
newDictAddr := vm.Memory.AllocateEmptySegment()
freeOffset := uint64(0)
dm.Dictionaries[newDictAddr.SegmentIndex] = ZeroDictionary{
Data: data,
DefaultValue: mem.UnknownValue,
Data: &data,
DefaultValue: &mem.UnknownValue,
FreeOffset: &freeOffset,
}
return newDictAddr
Expand All @@ -76,10 +76,11 @@ func (dm *ZeroDictionaryManager) NewDictionary(vm *VM.VirtualMachine, data map[f
// querying the defaultValue will be returned instead.
func (dm *ZeroDictionaryManager) NewDefaultDictionary(vm *VM.VirtualMachine, defaultValue mem.MemoryValue) mem.MemoryAddress {
newDefaultDictAddr := vm.Memory.AllocateEmptySegment()
newData := make(map[fp.Element]mem.MemoryValue)
freeOffset := uint64(0)
dm.Dictionaries[newDefaultDictAddr.SegmentIndex] = ZeroDictionary{
Data: make(map[fp.Element]mem.MemoryValue),
DefaultValue: defaultValue,
Data: &newData,
DefaultValue: &defaultValue,
FreeOffset: &freeOffset,
}
return newDefaultDictAddr
Expand Down
8 changes: 4 additions & 4 deletions pkg/hintrunner/zero/hintparser.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (expression OffsetExp) Evaluate() (*int, error) {
negNumber := -*expression.NegNumber
return &negNumber, nil
default:
return nil, fmt.Errorf("Expected a number")
return nil, fmt.Errorf("expected a number")
}
}

Expand All @@ -226,7 +226,7 @@ func (expression DerefExp) Evaluate() (any, error) {
}
cellRef, ok := cellRefExp.(hinter.CellRefer)
if !ok {
return nil, fmt.Errorf("Expected a CellRefer expression but got %s", cellRefExp)
return nil, fmt.Errorf("expected a CellRefer expression but got %s", cellRefExp)
}
return hinter.Deref{Deref: cellRef}, nil
}
Expand Down Expand Up @@ -301,7 +301,7 @@ func (expression LeftExp) Evaluate() (any, error) {
case expression.DerefExp != nil:
return expression.DerefExp.Evaluate()
}
return nil, fmt.Errorf("Unexpected left expression in binary operation")
return nil, fmt.Errorf("unexpected left expression in binary operation")
}

func (expression RightExp) Evaluate() (any, error) {
Expand All @@ -311,7 +311,7 @@ func (expression RightExp) Evaluate() (any, error) {
case expression.Offset != nil:
return expression.Offset.Evaluate()
}
return nil, fmt.Errorf("Unexpected right expression in binary operation")
return nil, fmt.Errorf("unexpected right expression in binary operation")
}

func ParseIdentifier(value string) (hinter.Reference, error) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/hintrunner/zero/zerohint_dictionaries.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ func newDictSquashCopyDictHint(dictAccessesEnd hinter.ResOperander) hinter.Hinte
}

dictionaryDataCopy := make(map[fp.Element]memory.MemoryValue)
for k, v := range dictionary.Data {
for k, v := range *dictionary.Data {
// Copy the key
keyCopy := fp.Element{}
keyCopy.Set(&k)
Expand Down Expand Up @@ -490,7 +490,7 @@ func newSquashDictHint(dictAccesses, ptrDiff, nAccesses, bigKeys, firstKey hinte
return err
}
if ptrDiffValue%DictAccessSize != 0 {
return fmt.Errorf("Accesses array size must be divisible by DictAccess.SIZE")
return fmt.Errorf("accesses array size must be divisible by DictAccess.SIZE")
}

//> n_accesses = ids.n_accesses
Expand All @@ -506,7 +506,7 @@ func newSquashDictHint(dictAccesses, ptrDiff, nAccesses, bigKeys, firstKey hinte
// __squash_dict_max_size is always in scope and has a value of 2**20,
squashDictMaxSize := uint64(1048576)
if nAccessesValue > squashDictMaxSize {
return fmt.Errorf("squash_dict() can only be used with n_accesses<={%d}. Got: n_accesses={%d}.", squashDictMaxSize, nAccessesValue)
return fmt.Errorf("squash_dict() can only be used with n_accesses<={%d}. Got: n_accesses={%d}", squashDictMaxSize, nAccessesValue)
}

//> # A map from key to the list of indices accessing it.
Expand Down
4 changes: 2 additions & 2 deletions pkg/hintrunner/zero/zerohint_dictionaries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ func TestZeroHintDictionaries(t *testing.T) {
ctx.operanders["first_key"],
)
},
errCheck: errorTextContains("Accesses array size must be divisible by DictAccess.SIZE"),
errCheck: errorTextContains("accesses array size must be divisible by DictAccess.SIZE"),
},
{
operanders: []*hintOperander{
Expand All @@ -625,7 +625,7 @@ func TestZeroHintDictionaries(t *testing.T) {
ctx.operanders["first_key"],
)
},
errCheck: errorTextContains("squash_dict() can only be used with n_accesses<={1048576}. Got: n_accesses={1048577}."),
errCheck: errorTextContains("squash_dict() can only be used with n_accesses<={1048576}. Got: n_accesses={1048577}"),
},
{
operanders: []*hintOperander{
Expand Down
2 changes: 1 addition & 1 deletion pkg/hintrunner/zero/zerohint_keccak.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func newUnsafeKeccakHint(data, length, high, low hinter.ResOperander) hinter.Hin
//> f'Got: length={length}.'
keccakMaxSize := uint64(1 << 20)
if lengthVal > keccakMaxSize {
return fmt.Errorf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d.", keccakMaxSize, lengthVal)
return fmt.Errorf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d", keccakMaxSize, lengthVal)
}

dataPtr, err := hinter.ResolveAsAddress(vm, data)
Expand Down
2 changes: 1 addition & 1 deletion pkg/hintrunner/zero/zerohint_keccak_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestZeroHintKeccak(t *testing.T) {
makeHinter: func(ctx *hintTestContext) hinter.Hinter {
return newUnsafeKeccakHint(ctx.operanders["data"], ctx.operanders["length"], ctx.operanders["high"], ctx.operanders["low"])
},
errCheck: errorTextContains(fmt.Sprintf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d.", 1<<20, (1<<20)+1)),
errCheck: errorTextContains(fmt.Sprintf("unsafe_keccak() can only be used with length<=%d.\n Got: length=%d", 1<<20, (1<<20)+1)),
},
{
operanders: []*hintOperander{
Expand Down
4 changes: 2 additions & 2 deletions pkg/hintrunner/zero/zerohint_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ func zeroDictInScopeEquals(dictAddress memory.MemoryAddress, expectedData map[fp
t.Fatal(fmt.Errorf("no dictionary at address: %s", dictAddress))
}

assert.Equal(t, expectedData, dictionary.Data)
assert.Equal(t, expectedDefaultValue, dictionary.DefaultValue)
assert.Equal(t, expectedData, *dictionary.Data)
assert.Equal(t, expectedDefaultValue, *dictionary.DefaultValue)
assert.Equal(t, expectedFreeOffset, *dictionary.FreeOffset)
}
}

0 comments on commit d56008d

Please sign in to comment.