Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: usort hints integration tests #459

Merged
merged 17 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions integration_tests/cairo_zero_hint_tests/usort.small.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
%builtins range_check
from starkware.cairo.common.usort import usort
from starkware.cairo.common.alloc import alloc

func main{range_check_ptr}() -> () {
alloc_locals;
let (input_array: felt*) = alloc();
assert input_array[0] = 8;
assert input_array[1] = 9;
assert input_array[2] = 7;

let (output_len, output, multiplicities) = usort(input_len=3, input=input_array);

assert output_len = 3;
assert output[0] = 7;
assert output[1] = 8;
assert output[2] = 9;

assert multiplicities[0] = 1;
assert multiplicities[1] = 1;
assert multiplicities[2] = 1;

let (input_array: felt*) = alloc();
assert input_array[0] = 11;
assert input_array[1] = 24;
assert input_array[2] = 99;
assert input_array[3] = 2;
assert input_array[4] = 66;
assert input_array[5] = 49;
assert input_array[6] = 11;
assert input_array[7] = 23;
assert input_array[8] = 88;
assert input_array[9] = 7;

let (output_len, output, multiplicities) = usort(input_len=10, input=input_array);

assert output_len = 9;
assert output[0] = 2;
assert output[1] = 7;
assert output[2] = 11;
assert output[3] = 23;
assert output[4] = 24;
assert output[5] = 49;
assert output[6] = 66;
assert output[7] = 88;
assert output[8] = 99;

assert multiplicities[0] = 1;
assert multiplicities[1] = 1;
assert multiplicities[2] = 2;
assert multiplicities[3] = 1;
assert multiplicities[4] = 1;
assert multiplicities[5] = 1;
assert multiplicities[6] = 1;
assert multiplicities[7] = 1;
assert multiplicities[8] = 1;

return ();
}
39 changes: 19 additions & 20 deletions pkg/hintrunner/zero/hintcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,25 @@ const (
uint256MulDivModCode string = "a = (ids.a.high << 128) + ids.a.low\nb = (ids.b.high << 128) + ids.b.low\ndiv = (ids.div.high << 128) + ids.div.low\nquotient, remainder = divmod(a * b, div)\n\nids.quotient_low.low = quotient & ((1 << 128) - 1)\nids.quotient_low.high = (quotient >> 128) & ((1 << 128) - 1)\nids.quotient_high.low = (quotient >> 256) & ((1 << 128) - 1)\nids.quotient_high.high = quotient >> 384\nids.remainder.low = remainder & ((1 << 128) - 1)\nids.remainder.high = remainder >> 128"

// ------ Usort hints related code ------
usortBodyCode string = `
from collections import defaultdict

input_ptr = ids.input
input_len = int(ids.input_len)
if __usort_max_size is not None:
assert input_len <= __usort_max_size, (
f"usort() can only be used with input_len<={__usort_max_size}. "
f"Got: input_len={input_len}."
)

positions_dict = defaultdict(list)
for i in range(input_len):
val = memory[input_ptr + i]
positions_dict[val].append(i)

output = sorted(positions_dict.keys())
ids.output_len = len(output)
ids.output = segments.gen_arg(output)
ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
usortBodyCode string = `from collections import defaultdict

input_ptr = ids.input
input_len = int(ids.input_len)
if __usort_max_size is not None:
assert input_len <= __usort_max_size, (
f"usort() can only be used with input_len<={__usort_max_size}. "
f"Got: input_len={input_len}."
)

positions_dict = defaultdict(list)
for i in range(input_len):
val = memory[input_ptr + i]
positions_dict[val].append(i)

output = sorted(positions_dict.keys())
ids.output_len = len(output)
ids.output = segments.gen_arg(output)
ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])`
usortEnterScopeCode string = "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))"
usortVerifyMultiplicityAssertCode string = "assert len(positions) == 0"
usortVerifyCode string = "last_pos = 0\npositions = positions_dict[ids.value][::-1]"
Expand Down
191 changes: 96 additions & 95 deletions pkg/hintrunner/zero/zerohint_usort.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,30 @@ import (
"github.com/consensys/gnark-crypto/ecc/stark-curve/fp"
)

func createUsortBodyHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
input, err := resolver.GetResOperander("input")
if err != nil {
return nil, err
}

input_len, err := resolver.GetResOperander("input_len")
if err != nil {
return nil, err
}

output, err := resolver.GetResOperander("output")
if err != nil {
return nil, err
}
// UsortEnterScope hint enters a new scope with `__usort_max_size` value
//
// `newUsortEnterScopeHint` doesn't take any operander as argument
//
// `newUsortEnterScopeHint` gets `__usort_max_size` value from the current
// scope and enters a new scope with this same value
func newUsortEnterScopeHint() hinter.Hinter {
Sh0g0-1758 marked this conversation as resolved.
Show resolved Hide resolved
return &GenericZeroHinter{
Name: "UsortEnterScope",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))
usortMaxSize := uint64(1 << 20)

output_len, err := resolver.GetResOperander("output_len")
if err != nil {
return nil, err
}
ctx.ScopeManager.EnterScope(map[string]any{
"__usort_max_size": usortMaxSize,
})

multiplicities, err := resolver.GetResOperander("multiplicities")
if err != nil {
return nil, err
return nil
},
}
}

return newUsortBodyHint(input, input_len, output, output_len, multiplicities), nil
func createUsortEnterScopeHinter() (hinter.Hinter, error) {
return newUsortEnterScopeHint(), nil
}

// UsortBody hint sorts the input array of field elements. The sorting results in generation of output array without duplicates and multiplicites array, where each element represents the number of times the corresponding element in the output array appears in the input array. The output and multiplicities arrays are written to the new, separate segments in memory.
Expand Down Expand Up @@ -109,6 +106,11 @@ func newUsortBodyHint(input, inputLen, output, outputLen, multiplicities hinter.
}
}

err = ctx.ScopeManager.AssignVariable("positions_dict", positionsDict)
if err != nil {
return err
}

outputArray := make([]fp.Element, len(positionsDict))
iterator := 0
for key := range positionsDict {
Expand Down Expand Up @@ -189,70 +191,33 @@ func newUsortBodyHint(input, inputLen, output, outputLen, multiplicities hinter.
}
}

// UsortEnterScope hint enters a new scope with `__usort_max_size` value
//
// `newUsortEnterScopeHint` doesn't take any operander as argument
//
// `newUsortEnterScopeHint` gets `__usort_max_size` value from the current
// scope and enters a new scope with this same value
func newUsortEnterScopeHint() hinter.Hinter {
return &GenericZeroHinter{
Name: "UsortEnterScope",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))
usortMaxSize, err := ctx.ScopeManager.GetVariableValue("__usort_max_size")
if err != nil {
return err
}

ctx.ScopeManager.EnterScope(map[string]any{
"__usort_max_size": usortMaxSize,
})

return nil
},
func createUsortBodyHinter(resolver hintReferenceResolver) (hinter.Hinter, error) {
input, err := resolver.GetResOperander("input")
if err != nil {
return nil, err
}
}

func createUsortEnterScopeHinter() (hinter.Hinter, error) {
return newUsortEnterScopeHint(), nil
}

// UsortVerifyMultiplicityAssert hint checks that the `positions` variable in scope
// doesn't contain any value
//
// `newUsortVerifyMultiplicityAssertHint` doesn't take any operander as argument
//
// This hint is used when sorting an array of field elements while removing duplicates
// in `usort` Cairo function
func newUsortVerifyMultiplicityAssertHint() hinter.Hinter {
return &GenericZeroHinter{
Name: "UsortVerifyMultiplicityAssert",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> assert len(positions) == 0

positionsInterface, err := ctx.ScopeManager.GetVariableValue("positions")

if err != nil {
return err
}
input_len, err := resolver.GetResOperander("input_len")
if err != nil {
return nil, err
}

positions, ok := positionsInterface.([]uint64)
if !ok {
return fmt.Errorf("casting positions into an array failed")
}
output, err := resolver.GetResOperander("output")
if err != nil {
return nil, err
}

if len(positions) != 0 {
return fmt.Errorf("assertion `len(positions) == 0` failed")
}
output_len, err := resolver.GetResOperander("output_len")
if err != nil {
return nil, err
}

return nil
},
multiplicities, err := resolver.GetResOperander("multiplicities")
if err != nil {
return nil, err
}
}

func createUsortVerifyMultiplicityAssertHinter() (hinter.Hinter, error) {
return newUsortEnterScopeHint(), nil
return newUsortBodyHint(input, input_len, output, output_len, multiplicities), nil
}

// UsortVerify hint prepares for verifying the presence of duplicates of
Expand All @@ -273,28 +238,29 @@ func newUsortVerifyHint(value hinter.ResOperander) hinter.Hinter {
//> positions = positions_dict[ids.value][::-1]

positionsDictInterface, err := ctx.ScopeManager.GetVariableValue("positions_dict")

if err != nil {
return err
}

positionsDict, ok := positionsDictInterface.(map[fp.Element][]uint64)

if !ok {
return fmt.Errorf("casting positions_dict into an dictionary failed")
}

value, err := hinter.ResolveAsFelt(vm, value)

if err != nil {
return err
}

positions := positionsDict[*value]
positionsToCopy := positionsDict[*value]

positions := make([]uint64, len(positionsToCopy))
copy(positions, positionsToCopy)

utils.Reverse(positions)

return ctx.ScopeManager.AssignVariables(map[string]any{
"last_pos": 0,
"last_pos": uint64(0),
"positions": positions,
})
},
Expand Down Expand Up @@ -333,9 +299,9 @@ func newUsortVerifyMultiplicityBodyHint(nextItemIndex hinter.ResOperander) hinte
return err
}

positions, ok := positionsInterface.([]fp.Element)
positions, ok := positionsInterface.([]uint64)
if !ok {
return fmt.Errorf("cannot cast positionsInterface to []fp.Element")
return fmt.Errorf("cannot cast positionsInterface to []uint64")
}

currentPos, err := utils.Pop(&positions)
Expand All @@ -348,33 +314,32 @@ func newUsortVerifyMultiplicityBodyHint(nextItemIndex hinter.ResOperander) hinte
return err
}

lastPos, err := ctx.ScopeManager.GetVariableValue("last_pos")
lastPosition, err := ctx.ScopeManager.GetVariableValue("last_pos")
if err != nil {
return err
}

lastPosFelt, ok := lastPos.(fp.Element)
lastPos, ok := lastPosition.(uint64)
if !ok {
return fmt.Errorf("cannot cast last_pos to felt")
return fmt.Errorf("cannot cast last_pos to uint64")
}

// Calculate `next_item_index` memory value
var newNextItemIndexValue fp.Element
newNextItemIndexValue.Sub(&currentPos, &lastPosFelt)
newNextItemIndexMemoryValue := memory.MemoryValueFromFieldElement(&newNextItemIndexValue)
newNextItemIndexValue := currentPos - lastPos
newNextItemIndexMemoryValue := memory.MemoryValueFromUint(newNextItemIndexValue)

// Save `next_item_index` value in address
addrNextItemIndex, err := nextItemIndex.GetAddress(vm)
if err != nil {
return err
}

err = vm.Memory.WriteToAddress(&addrNextItemIndex, &newNextItemIndexMemoryValue)
err = ctx.ScopeManager.AssignVariable("last_pos", currentPos+1)
if err != nil {
return err
}

return ctx.ScopeManager.AssignVariable("last_pos", *currentPos.Add(&currentPos, &utils.FeltOne))
return vm.Memory.WriteToAddress(&addrNextItemIndex, &newNextItemIndexMemoryValue)
},
}
}
Expand All @@ -387,3 +352,39 @@ func createUsortVerifyMultiplicityBodyHinter(resolver hintReferenceResolver) (hi

return newUsortVerifyMultiplicityBodyHint(nextItemIndex), nil
}

// UsortVerifyMultiplicityAssert hint checks that the `positions` variable in scope
// doesn't contain any value
//
// `newUsortVerifyMultiplicityAssertHint` doesn't take any operander as argument
//
// This hint is used when sorting an array of field elements while removing duplicates
// in `usort` Cairo function
func newUsortVerifyMultiplicityAssertHint() hinter.Hinter {
return &GenericZeroHinter{
Name: "UsortVerifyMultiplicityAssert",
Op: func(vm *VM.VirtualMachine, ctx *hinter.HintRunnerContext) error {
//> assert len(positions) == 0

positionsInterface, err := ctx.ScopeManager.GetVariableValue("positions")
if err != nil {
return err
}

positions, ok := positionsInterface.([]uint64)
if !ok {
return fmt.Errorf("casting positions into a []uint64 failed")
}

if len(positions) != 0 {
return fmt.Errorf("assertion `len(positions) == 0` failed")
}

return nil
},
}
}

func createUsortVerifyMultiplicityAssertHinter() (hinter.Hinter, error) {
return newUsortVerifyMultiplicityAssertHint(), nil
}
Loading
Loading