Skip to content

Commit

Permalink
Unit test speed up (#536)
Browse files Browse the repository at this point in the history
* first attempt

* seems a rusty way

* more minor constraints

* config sample num

* comments, del old code

* Refactor to parameterize inputs for test_method_config (#569)

Co-authored-by: Michael Diamant <michaeldiamant@users.noreply.github.com>
  • Loading branch information
ahangsu and michaeldiamant committed Oct 17, 2022
1 parent 1a1c93f commit 5d5f4d4
Showing 1 changed file with 120 additions and 81 deletions.
201 changes: 120 additions & 81 deletions pyteal/ast/router_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pyteal as pt
import secrets
from pyteal.ast.router import ASTBuilder
import pytest
import typing
Expand Down Expand Up @@ -222,38 +223,72 @@ def power_set(no_dup_list: list, length_override: int = None):
yield [elem for mask, elem in zip(masks, no_dup_list) if i & mask]


def full_ordered_combination_gen(non_dup_list: list, perm_length: int):
class FullOrderCombinationGen:
"""
This function serves as a generator for all possible vectors of length `perm_length`,
each of whose entries are one of the elements in `non_dup_list`,
which is a list of non-duplicated elements.
Args:
non_dup_list: must be a list of elements with no duplication
perm_length: must be a non-negative number indicating resulting length of the vector
This class serves as a generator for all possible vectors of maximal length `largest_perm_length` (non-negative),
each of whose entries are one of the elements in `non_dup_list`, namely, a list of non-duplicated elements.
"""
if perm_length < 0:
raise pt.TealInputError("input permutation length must be non-negative")
elif len(set(non_dup_list)) != len(non_dup_list):
raise pt.TealInputError(f"input non_dup_list {non_dup_list} has duplications")
elif perm_length == 0:
yield []
return
# we can index all possible cases of vectors with an index in range
# [0, |non_dup_list| ^ perm_length - 1]
# by converting an index into |non_dup_list|-based number,
# we can get the vector mapped by the index.
for index in range(len(non_dup_list) ** perm_length):
index_list_basis = []
temp = index
for _ in range(perm_length):
index_list_basis.append(non_dup_list[temp % len(non_dup_list)])
temp //= len(non_dup_list)
yield index_list_basis

def __init__(self, non_dup_list: list, largest_perm_length: int) -> None:
if largest_perm_length < 0:
raise pt.TealInputError(
"largest input permutation length must be non-negative"
)
elif len(set(non_dup_list)) != len(non_dup_list):
raise pt.TealInputError(
f"input non_dup_list {non_dup_list} has duplications"
)
elif not len(non_dup_list):
raise pt.TealInputError("input non_dup_list must be non empty")

self.__basis_symbol = non_dup_list
self.__basis_size = len(self.__basis_symbol)
self.__pre_gen_table: list[list[int]] = [
[] for _ in range(self.__basis_size**largest_perm_length)
]

def oncomplete_is_in_oc_list(sth: pt.EnumInt, oc_list: list[pt.EnumInt]):
return any(map(lambda x: str(x) == str(sth), oc_list))
# we can index all possible cases of vectors with an index in range
# [0, |non_dup_list| ^ perm_length - 1]
# by converting an index into |non_dup_list|-based number,
# we can get the vector mapped by the index.

# we iterate through [0, |non_dup_list|^largest_perm_length - 1] to precompute permutation table.
lhs_scope = 0
discrete_log = 0
for expn in range(largest_perm_length + 1):
for index in range(lhs_scope, self.__basis_size**expn):
basis_repr = [0 for _ in range(discrete_log)]
if discrete_log:
temp = index
for i in range(discrete_log):
basis_repr[i] = temp % self.__basis_size
temp //= self.__basis_size
self.__pre_gen_table[index] = basis_repr

lhs_scope = self.__basis_size**expn
discrete_log = expn + 1

def sample_gen(self, perm_length: int, sample_num: int = 10):
if perm_length < 0:
raise pt.TealInputError("input permutation length must be non-negative")
elif perm_length == 0:
yield []
return

# since we are sampling for a permutation with length `perm_length`,
# this corresponds to sampling a value from [|non_dup_list|^(perm_length - 1), |non_dup_list|^perm_length - 1].
# if sample number is greater than interval size, by pigeonhole principle there is re-testing
# reduce back down to interval size
sample_num = min(sample_num, self.__basis_size ** (perm_length - 1))

for _ in range(sample_num):
take = secrets.choice(
range(
self.__basis_size ** (perm_length - 1),
self.__basis_size**perm_length,
)
)
yield [self.__basis_symbol[j] for j in self.__pre_gen_table[take]]


def assemble_helper(what: pt.Expr) -> pt.TealBlock:
Expand Down Expand Up @@ -308,72 +343,76 @@ def test_call_config():
)


def test_method_config():
def test_method_config_call_config_never():
never_mc = pt.MethodConfig(no_op=pt.CallConfig.NEVER)
assert never_mc.is_never()
assert never_mc.approval_cond() == 0
assert never_mc.clear_state_cond() == 0


def _gen_method_configs(sample_count: int = 10):
on_complete_pow_set = power_set(ON_COMPLETE_CASES)
focg = FullOrderCombinationGen(list(pt.CallConfig), len(ON_COMPLETE_CASES))

for on_complete_set in on_complete_pow_set:
oc_names = [camel_to_snake(oc.name) for oc in on_complete_set]
for call_configs in focg.sample_gen(len(on_complete_set), sample_count):
yield pt.MethodConfig(**dict(zip(oc_names, call_configs)))


@pytest.mark.parametrize("mc", _gen_method_configs())
def test_method_config(mc: pt.MethodConfig):
approval_check_names_n_ocs = [
(camel_to_snake(oc.name), oc)
for oc in ON_COMPLETE_CASES
if str(oc) != str(pt.OnComplete.ClearState)
]
for on_complete_set in on_complete_pow_set:
oc_names = [camel_to_snake(oc.name) for oc in on_complete_set]
ordered_call_configs = full_ordered_combination_gen(
list(pt.CallConfig), len(on_complete_set)
)
for call_configs in ordered_call_configs:
mc = pt.MethodConfig(**dict(zip(oc_names, call_configs)))
match mc.clear_state:
case pt.CallConfig.NEVER:
assert mc.clear_state_cond() == 0
case pt.CallConfig.CALL:
assert mc.clear_state_cond() == 1
case pt.CallConfig.CREATE | pt.CallConfig.ALL:
with pytest.raises(
pt.TealInputError,
match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$",
):
mc.clear_state_cond()
if mc.is_never() or all(
getattr(mc, i) == pt.CallConfig.NEVER
for i, _ in approval_check_names_n_ocs
):
assert mc.approval_cond() == 0
continue
elif all(
getattr(mc, i) == pt.CallConfig.ALL
for i, _ in approval_check_names_n_ocs

match mc.clear_state:
case pt.CallConfig.NEVER:
assert mc.clear_state_cond() == 0
case pt.CallConfig.CALL:
assert mc.clear_state_cond() == 1
case pt.CallConfig.CREATE | pt.CallConfig.ALL:
with pytest.raises(
pt.TealInputError,
match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$",
):
assert mc.approval_cond() == 1
continue
list_of_cc = [
(
typing.cast(
pt.CallConfig, getattr(mc, i)
).approval_condition_under_config(),
oc,
)
for i, oc in approval_check_names_n_ocs
]
list_of_expressions = []
for expr_or_int, oc in list_of_cc:
match expr_or_int:
case pt.Expr():
list_of_expressions.append(
pt.And(pt.Txn.on_completion() == oc, expr_or_int)
)
case 0:
continue
case 1:
list_of_expressions.append(pt.Txn.on_completion() == oc)
with pt.TealComponent.Context.ignoreExprEquality():
assert assemble_helper(mc.approval_cond()) == assemble_helper(
pt.Or(*list_of_expressions)
mc.clear_state_cond()
if mc.is_never() or all(
getattr(mc, i) == pt.CallConfig.NEVER for i, _ in approval_check_names_n_ocs
):
assert mc.approval_cond() == 0
return
elif all(
getattr(mc, i) == pt.CallConfig.ALL for i, _ in approval_check_names_n_ocs
):
assert mc.approval_cond() == 1
return
list_of_cc = [
(
typing.cast(
pt.CallConfig, getattr(mc, i)
).approval_condition_under_config(),
oc,
)
for i, oc in approval_check_names_n_ocs
]
list_of_expressions: list[pt.Expr] = []
for expr_or_int, oc in list_of_cc:
match expr_or_int:
case pt.Expr():
list_of_expressions.append(
pt.And(pt.Txn.on_completion() == oc, expr_or_int)
)
case 0:
continue
case 1:
list_of_expressions.append(pt.Txn.on_completion() == oc)
with pt.TealComponent.Context.ignoreExprEquality():
ac = mc.approval_cond()
assert isinstance(ac, pt.Expr)
assert assemble_helper(ac) == assemble_helper(pt.Or(*list_of_expressions))


def test_on_complete_action():
Expand Down

0 comments on commit 5d5f4d4

Please sign in to comment.