Skip to content

Commit

Permalink
GetVariableAs generic method
Browse files Browse the repository at this point in the history
  • Loading branch information
TAdev0 committed Jul 11, 2024
1 parent 050d0e0 commit 25b18bb
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 193 deletions.
96 changes: 48 additions & 48 deletions integration_tests/BenchMarks.txt

Large diffs are not rendered by default.

29 changes: 7 additions & 22 deletions pkg/hintrunner/hinter/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package hinter

import (
"fmt"
"math/big"
)

// ScopeManager handles all operations regarding scopes:
Expand Down Expand Up @@ -88,32 +87,18 @@ func (sm *ScopeManager) GetVariableValue(name string) (any, error) {
return nil, fmt.Errorf("variable %s not found in current scope", name)
}

func (sm *ScopeManager) GetVariableValueAsBigInt(name string) (*big.Int, error) {
// GetVariableAs retrieves a variable from the current scope and asserts its type
func GetVariableAs[T any](sm *ScopeManager, name string) (T, error) {
var zero T // Zero value of the generic type
value, err := sm.GetVariableValue(name)
if err != nil {
return nil, err
}

valueBig, ok := value.(*big.Int)
if !ok {
return nil, fmt.Errorf("value: %s is not a *big.Int", value)
}

return valueBig, nil
}

func (sm *ScopeManager) GetVariableValueAsUint64(name string) (uint64, error) {
value, err := sm.GetVariableValue(name)
if err != nil {
return 0, err
return zero, err
}

valueUint, ok := value.(uint64)
typedValue, ok := value.(T)
if !ok {
return 0, fmt.Errorf("value: %s is not a uint64", value)
return zero, fmt.Errorf("value has a different type")
}

return valueUint, nil
return typedValue, nil
}

func (sm *ScopeManager) getCurrentScope() (*map[string]any, error) {
Expand Down
71 changes: 12 additions & 59 deletions pkg/hintrunner/zero/zerohint_dictionaries.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,10 @@ func newDictNewHint() hinter.Hinter {
}
}

initialDictValue, err := ctx.ScopeManager.GetVariableValue("initial_dict")
initialDict, err := hinter.GetVariableAs[map[fp.Element]memory.MemoryValue](&ctx.ScopeManager, "initial_dict")
if err != nil {
return err
}
initialDict, ok := initialDictValue.(map[fp.Element]memory.MemoryValue)
if !ok {
return fmt.Errorf("value: %s is not a map[f.Element]mem.MemoryValue", initialDictValue)
}

//> memory[ap] = __dict_manager.new_dict(segments, initial_dict)
newDictAddr := dictionaryManager.NewDictionary(vm, initialDict)
Expand Down Expand Up @@ -616,12 +612,11 @@ func newSquashDictInnerAssertLenKeysHint() hinter.Hinter {
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> assert len(keys) == 0

keys_, err := ctx.ScopeManager.GetVariableValue("keys")
keys, err := hinter.GetVariableAs[[]fp.Element](&ctx.ScopeManager, "keys")
if err != nil {
return err
}

keys := keys_.([]fp.Element)
if len(keys) != 0 {
return fmt.Errorf("assertion `len(keys) == 0` failed")
}
Expand Down Expand Up @@ -661,16 +656,11 @@ func newSquashDictInnerCheckAccessIndexHint(loopTemps hinter.ResOperander) hinte
//> ids.loop_temps.index_delta_minus1 = new_access_index - current_access_index - 1
//> current_access_index = new_access_index

currentAccessIndices_, err := ctx.ScopeManager.GetVariableValue("current_access_indices")
currentAccessIndices, err := hinter.GetVariableAs[[]fp.Element](&ctx.ScopeManager, "current_access_indices")
if err != nil {
return err
}

currentAccessIndices, ok := currentAccessIndices_.([]fp.Element)
if !ok {
return fmt.Errorf("casting currentAccessIndices_ into an array of felts failed")
}

newAccessIndex, err := utils.Pop(&currentAccessIndices)
if err != nil {
return err
Expand All @@ -681,16 +671,11 @@ func newSquashDictInnerCheckAccessIndexHint(loopTemps hinter.ResOperander) hinte
return err
}

currentAccessIndex_, err := ctx.ScopeManager.GetVariableValue("current_access_index")
currentAccessIndex, err := hinter.GetVariableAs[fp.Element](&ctx.ScopeManager, "current_access_index")
if err != nil {
return err
}

currentAccessIndex, ok := currentAccessIndex_.(fp.Element)
if !ok {
return fmt.Errorf("casting currentAccessIndex_ into a felt failed")
}

err = ctx.ScopeManager.AssignVariable("current_access_index", newAccessIndex)
if err != nil {
return err
Expand Down Expand Up @@ -746,16 +731,11 @@ func newSquashDictInnerContinueLoopHint(loopTemps hinter.ResOperander) hinter.Hi
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> ids.loop_temps.should_continue = 1 if current_access_indices else 0

currentAccessIndices_, err := ctx.ScopeManager.GetVariableValue("current_access_indices")
currentAccessIndices, err := hinter.GetVariableAs[[]fp.Element](&ctx.ScopeManager, "current_access_indices")
if err != nil {
return err
}

currentAccessIndices, ok := currentAccessIndices_.([]fp.Element)
if !ok {
return fmt.Errorf("casting currentAccessIndices_ into an array of felts failed")
}

loopTempsAddr, err := loopTemps.GetAddress(vm)
if err != nil {
return err
Expand Down Expand Up @@ -800,26 +780,16 @@ func newSquashDictInnerFirstIterationHint(rangeCheckPtr hinter.ResOperander) hin
//> current_access_index = current_access_indices.pop()
//> memory[ids.range_check_ptr] = current_access_index

key_, err := ctx.ScopeManager.GetVariableValue("key")
key, err := hinter.GetVariableAs[fp.Element](&ctx.ScopeManager, "key")
if err != nil {
return err
}

accessIndices_, err := ctx.ScopeManager.GetVariableValue("access_indices")
accessIndices, err := hinter.GetVariableAs[map[fp.Element][]fp.Element](&ctx.ScopeManager, "access_indices")
if err != nil {
return err
}

accessIndices, ok := accessIndices_.(map[fp.Element][]fp.Element)
if !ok {
return fmt.Errorf("cannot cast access_indices_ to a mapping of felts")
}

key, ok := key_.(fp.Element)
if !ok {
return fmt.Errorf("cannot cast key_ to felt")
}

accessIndicesAtKey := accessIndices[key]

accessIndicesAtKeyCopy := make([]fp.Element, len(accessIndicesAtKey))
Expand Down Expand Up @@ -879,16 +849,11 @@ func newSquashDictInnerSkipLoopHint(shouldSkipLoop hinter.ResOperander) hinter.H
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> ids.should_skip_loop = 0 if current_access_indices else 1

currentAccessIndices_, err := ctx.ScopeManager.GetVariableValue("current_access_indices")
currentAccessIndices, err := hinter.GetVariableAs[[]fp.Element](&ctx.ScopeManager, "current_access_indices")
if err != nil {
return err
}

currentAccessIndices, ok := currentAccessIndices_.([]fp.Element)
if !ok {
return fmt.Errorf("casting currentAccessIndices_ into an array of felts failed")
}

shouldSkipLoopAddr, err := shouldSkipLoop.GetAddress(vm)
if err != nil {
return err
Expand Down Expand Up @@ -927,12 +892,11 @@ func newSquashDictInnerLenAssertHint() hinter.Hinter {
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> assert len(current_access_indices) == 0

currentAccessIndices_, err := ctx.ScopeManager.GetVariableValue("current_access_indices")
currentAccessIndices, err := hinter.GetVariableAs[[]fp.Element](&ctx.ScopeManager, "current_access_indices")
if err != nil {
return err
}

currentAccessIndices := currentAccessIndices_.([]fp.Element)
if len(currentAccessIndices) != 0 {
return fmt.Errorf("assertion `len(current_access_indices) == 0` failed")
}
Expand All @@ -958,12 +922,11 @@ func newSquashDictInnerNextKeyHint(nextKey hinter.ResOperander) hinter.Hinter {
//> assert len(keys) > 0, 'No keys left but remaining_accesses > 0.'
//> ids.next_key = key = keys.pop()

keys_, err := ctx.ScopeManager.GetVariableValue("keys")
keys, err := hinter.GetVariableAs[[]fp.Element](&ctx.ScopeManager, "keys")
if err != nil {
return err
}

keys := keys_.([]fp.Element)
if len(keys) == 0 {
return fmt.Errorf("no keys left but remaining_accesses > 0")
}
Expand Down Expand Up @@ -1015,26 +978,16 @@ func newSquashDictInnerUsedAccessesAssertHint(nUsedAccesses hinter.ResOperander)
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> assert ids.n_used_accesses == len(access_indices[key])

accessIndices_, err := ctx.ScopeManager.GetVariableValue("access_indices")
accessIndices, err := hinter.GetVariableAs[map[fp.Element][]fp.Element](&ctx.ScopeManager, "access_indices")
if err != nil {
return err
}

accessIndices, ok := accessIndices_.(map[fp.Element][]fp.Element)
if !ok {
return fmt.Errorf("cannot cast access_indices_ to a mapping of felts")
}

key_, err := ctx.ScopeManager.GetVariableValue("key")
key, err := hinter.GetVariableAs[fp.Element](&ctx.ScopeManager, "key")
if err != nil {
return err
}

key, ok := key_.(fp.Element)
if !ok {
return fmt.Errorf("cannot cast key_ to felt")
}

accessIndicesAtKeyLen := uint64(len(accessIndices[key]))

nUsedAccesses, err := hinter.ResolveAsUint64(vm, nUsedAccesses)
Expand Down
6 changes: 2 additions & 4 deletions pkg/hintrunner/zero/zerohint_dictionaries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ func TestZeroHintDictionaries(t *testing.T) {
t.Fatalf("incorrect apValue: %s expected %s", dictAddr.String(), "2:0")
}

dictionaryManagerValue, err := ctx.runnerContext.ScopeManager.GetVariableValue("__dict_manager")
dictionaryManager, err := hinter.GetVariableAs[hinter.ZeroDictionaryManager](&ctx.runnerContext.ScopeManager, "__dict_manager")
if err != nil {
t.Fatalf("__dict_manager missing")
}
dictionaryManager := dictionaryManagerValue.(hinter.ZeroDictionaryManager)

for _, key := range []fp.Element{*feltUint64(10), *feltUint64(20), *feltUint64(30)} {
value, err := dictionaryManager.At(dictAddr, key)
Expand Down Expand Up @@ -78,12 +77,11 @@ func TestZeroHintDictionaries(t *testing.T) {
return newDefaultDictNewHint(ctx.operanders["default_value"])
},
check: func(t *testing.T, ctx *hintTestContext) {
dictionaryManagerValue, err := ctx.runnerContext.ScopeManager.GetVariableValue("__dict_manager")
dictionaryManager, err := hinter.GetVariableAs[hinter.ZeroDictionaryManager](&ctx.runnerContext.ScopeManager, "__dict_manager")
if err != nil {
t.Fatalf("__dict_manager missing")
}

dictionaryManager := dictionaryManagerValue.(hinter.ZeroDictionaryManager)
apAddr := ctx.vm.Context.AddressAp()
dictAddr, err := ctx.vm.Memory.ReadFromAddressAsAddress(&apAddr)
if err != nil {
Expand Down
Loading

0 comments on commit 25b18bb

Please sign in to comment.