Skip to content

Commit

Permalink
Simplify multi-expression clauses (#442)
Browse files Browse the repository at this point in the history
* Add `use_seq_if_multiple` helper

* Add multi-expression to Then/Else

* Add multi-expression to While/Do

* Add multi-expression to For/Do

* Add multi-expression to Cond

* Add multi-expression to Assert

* Tests for invalid multi-expression

* Use `list` instead of `List`

* Fix argument list type annotation

* Add incorrectly formatted `Cond` test case

* Fix type errors

* Improve type annotations

* Handle empty tuple in `use_seq_if_multiple`

* Raise exception instead of returning `None`

* Rename `Assert`’s additional parameter

* Include new behavior in docstrings

* Mark internal methods

* Use snake case for method parameters
  • Loading branch information
jdtzmn committed Jul 22, 2022
1 parent c936cbc commit d51857b
Show file tree
Hide file tree
Showing 12 changed files with 362 additions and 16 deletions.
16 changes: 12 additions & 4 deletions pyteal/ast/assert_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pyteal.types import TealType, require_type
from pyteal.ir import TealOp, Op, TealBlock, TealSimpleBlock, TealConditionalBlock
from pyteal.ast.expr import Expr
from pyteal.ast.seq import Seq

if TYPE_CHECKING:
from pyteal.compiler import CompileOptions
Expand All @@ -11,23 +12,30 @@
class Assert(Expr):
"""A control flow expression to verify that a condition is true."""

def __init__(self, cond: Expr) -> None:
def __init__(self, cond: Expr, *additional_conds: Expr) -> None:
"""Create an assert statement that raises an error if the condition is false.
Args:
cond: The condition to check. Must evaluate to a uint64.
*additional_conds: Additional conditions to check. Must evaluate to uint64.
"""
super().__init__()
require_type(cond, TealType.uint64)
self.cond = cond
for cond_single in additional_conds:
require_type(cond_single, TealType.uint64)
self.cond = [cond] + list(additional_conds)

def __teal__(self, options: "CompileOptions"):
if len(self.cond) > 1:
asserts = [Assert(cond) for cond in self.cond]
return Seq(*asserts).__teal__(options)

if options.version >= Op.assert_.min_version:
# use assert op if available
return TealBlock.FromOp(options, TealOp(self, Op.assert_), self.cond)
return TealBlock.FromOp(options, TealOp(self, Op.assert_), self.cond[0])

# if assert op is not available, use branches and err
condStart, condEnd = self.cond.__teal__(options)
condStart, condEnd = self.cond[0].__teal__(options)

end = TealSimpleBlock([])
errBlock = TealSimpleBlock([TealOp(self, Op.err)])
Expand Down
37 changes: 37 additions & 0 deletions pyteal/ast/assert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@ def test_teal_2_assert():
assert actual == expected


def test_teal_2_assert_multi():
args = [pt.Int(1), pt.Int(2)]
expr = pt.Assert(*args)
assert expr.type_of() == pt.TealType.none

firstAssert = pt.Assert(args[0])
secondAssert = pt.Assert(args[1])

expected, _ = pt.Seq(firstAssert, secondAssert).__teal__(teal2Options)

actual, _ = expr.__teal__(teal2Options)

with pt.TealComponent.Context.ignoreExprEquality():
assert actual == expected


def test_teal_3_assert():
arg = pt.Int(1)
expr = pt.Assert(arg)
Expand All @@ -39,6 +55,27 @@ def test_teal_3_assert():
assert actual == expected


def test_teal_3_assert_multi():
args = [pt.Int(1), pt.Int(2)]
expr = pt.Assert(*args)
assert expr.type_of() == pt.TealType.none

expected = pt.TealSimpleBlock(
[pt.TealOp(args[0], pt.Op.int, 1), pt.TealOp(expr, pt.Op.assert_)]
+ [pt.TealOp(args[1], pt.Op.int, 2), pt.TealOp(expr, pt.Op.assert_)]
)

actual, _ = expr.__teal__(teal3Options)
actual.addIncoming()
actual = pt.TealBlock.NormalizeBlocks(actual)

with pt.TealComponent.Context.ignoreExprEquality():
assert actual == expected


def test_assert_invalid():
with pytest.raises(pt.TealTypeError):
pt.Assert(pt.Txn.receiver())

with pytest.raises(pt.TealTypeError):
pt.Assert(pt.Int(1), pt.Txn.receiver())
34 changes: 27 additions & 7 deletions pyteal/ast/cond.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, cast, TYPE_CHECKING
from pyteal.ast.seq import _use_seq_if_multiple

from pyteal.types import TealType, require_type
from pyteal.ir import TealOp, Op, TealSimpleBlock, TealConditionalBlock
Expand All @@ -9,16 +10,34 @@
from pyteal.compiler import CompileOptions


def _reformat_multi_argv(argv: tuple[list[Expr], ...]) -> list[list[Expr]]:
"""Reformat a list of lists of expressions with potentially multiple value expressions into a list of lists of expressions with only one value expression
by using Seq blocks where appropriate.
Example:
[ [a, b, c], [d, e, f] ] -> [ [a, Seq(b, c)], [d, Seq(e, f)] ]
"""
reformatted = []
for arg in argv:
# Note: this is not a valid Cond arg, but will be caught later in Cond.__init__()
if len(arg) <= 1:
reformatted.append(arg)
else:
reformatted.append([arg[0], _use_seq_if_multiple(arg[1:])])

return reformatted


class Cond(Expr):
"""A chainable branching expression that supports an arbitrary number of conditions."""

def __init__(self, *argv: List[Expr]):
"""Create a new Cond expression.
At least one argument must be provided, and each argument must be a list with two elements.
The first element is a condition which evalutes to uint64, and the second is the body of the
condition, which will execute if that condition is true. All condition bodies must have the
same return type. During execution, each condition is tested in order, and the first
At least one argument must be provided, and each argument must be a list with two or more elements.
The first element is a condition which evalutes to uint64, and the remaining elements are the body
of the condition, which will execute if that condition is true. The last elements of the condition bodies
must have the same return type. During execution, each condition is tested in order, and the first
condition to evaluate to a true value will cause its associated body to execute and become
the value for this Cond expression. If no condition evalutes to a true value, the Cond
expression produces an error and the TEAL program terminates.
Expand All @@ -27,7 +46,7 @@ def __init__(self, *argv: List[Expr]):
.. code-block:: python
Cond([Global.group_size() == Int(5), bid],
[Global.group_size() == Int(4), redeem],
[Global.group_size() == Int(4), redeem, log],
[Global.group_size() == Int(1), wrapup])
"""
super().__init__()
Expand All @@ -36,8 +55,9 @@ def __init__(self, *argv: List[Expr]):
raise TealInputError("Cond requires at least one [condition, value]")

value_type = None
sequenced_argv = _reformat_multi_argv(argv)

for arg in argv:
for arg in sequenced_argv:
msg = "Cond should be in the form of Cond([cond1, value1], [cond2, value2], ...), error in {}"
if not isinstance(arg, list):
raise TealInputError(msg.format(arg))
Expand All @@ -52,7 +72,7 @@ def __init__(self, *argv: List[Expr]):
require_type(arg[1], value_type)

self.value_type = value_type
self.args = argv
self.args = sequenced_argv

def __teal__(self, options: "CompileOptions"):
start = None
Expand Down
50 changes: 50 additions & 0 deletions pyteal/ast/cond_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,58 @@ def test_cond_invalid():
with pytest.raises(pt.TealInputError):
pt.Cond([])

with pytest.raises(pt.TealInputError):
pt.Cond([pt.Int(1)], [pt.Int(2), pt.Pop(pt.Txn.receiver())])

with pytest.raises(pt.TealTypeError):
pt.Cond([pt.Int(1), pt.Int(2)], [pt.Int(2), pt.Txn.receiver()])

with pytest.raises(pt.TealTypeError):
pt.Cond([pt.Arg(0), pt.Int(2)])

with pytest.raises(pt.TealTypeError):
pt.Cond([pt.Int(1), pt.Int(2)], [pt.Int(2), pt.Pop(pt.Int(2))])

with pytest.raises(pt.TealTypeError):
pt.Cond([pt.Int(1), pt.Pop(pt.Int(1))], [pt.Int(2), pt.Int(2)])


def test_cond_two_pred_multi():
args = [
pt.Int(1),
[pt.Pop(pt.Int(1)), pt.Bytes("one")],
pt.Int(0),
[pt.Pop(pt.Int(2)), pt.Bytes("zero")],
]
expr = pt.Cond(
[args[0]] + args[1],
[args[2]] + args[3],
)
assert expr.type_of() == pt.TealType.bytes

cond1, _ = args[0].__teal__(options)
pred1, pred1End = pt.Seq(args[1]).__teal__(options)
cond1Branch = pt.TealConditionalBlock([])
cond2, _ = args[2].__teal__(options)
pred2, pred2End = pt.Seq(args[3]).__teal__(options)
cond2Branch = pt.TealConditionalBlock([])
end = pt.TealSimpleBlock([])

cond1.setNextBlock(cond1Branch)
cond1Branch.setTrueBlock(pred1)
cond1Branch.setFalseBlock(cond2)
pred1End.setNextBlock(end)

cond2.setNextBlock(cond2Branch)
cond2Branch.setTrueBlock(pred2)
cond2Branch.setFalseBlock(pt.Err().__teal__(options)[0])
pred2End.setNextBlock(end)

expected = cond1

actual, _ = expr.__teal__(options)
print(actual)
print(expected)

with pt.TealComponent.Context.ignoreExprEquality():
assert actual == expected
13 changes: 12 additions & 1 deletion pyteal/ast/for_.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Optional
from pyteal.ast.seq import _use_seq_if_multiple

from pyteal.types import TealType, require_type
from pyteal.ir import TealSimpleBlock, TealConditionalBlock
Expand All @@ -23,6 +24,13 @@ def __init__(self, start: Expr, cond: Expr, step: Expr) -> None:
start: Expression setting the variable's initial value
cond: The condition to check. Must evaluate to uint64.
step: Expression to update the variable's value.
Example:
.. code-block:: python
i = ScratchVar()
For(i.store(Int(0)), i.load() < Int(10), i.store(i.load() + Int(1))
.Do(expr1, expr2, ...)
"""
super().__init__()
require_type(cond, TealType.uint64)
Expand Down Expand Up @@ -89,9 +97,12 @@ def type_of(self):
def has_return(self):
return False

def Do(self, doBlock: Expr):
def Do(self, doBlock: Expr, *do_block_multi: Expr):
if self.doBlock is not None:
raise TealCompileError("For expression already has a doBlock", self)

doBlock = _use_seq_if_multiple(doBlock, *do_block_multi)

require_type(doBlock, TealType.none)
self.doBlock = doBlock
return self
Expand Down
39 changes: 39 additions & 0 deletions pyteal/ast/for_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,49 @@ def test_invalid_for():
pt.Int(0)
)

with pytest.raises(pt.TealTypeError):
i = pt.ScratchVar()
expr = pt.For(i.store(pt.Int(0)), pt.Int(1), i.store(i.load() + pt.Int(1))).Do(
pt.Pop(pt.Int(1)), pt.Int(2)
)

with pytest.raises(pt.TealCompileError):
expr = (
pt.For(i.store(pt.Int(0)), pt.Int(1), i.store(i.load() + pt.Int(1)))
.Do(pt.Continue())
.Do(pt.Continue())
)
expr.__str__()


def test_for_multi():
i = pt.ScratchVar()
items = [
(i.store(pt.Int(0))),
i.load() < pt.Int(10),
i.store(i.load() + pt.Int(1)),
[pt.Pop(pt.Int(1)), pt.App.globalPut(pt.Itob(i.load()), i.load() * pt.Int(2))],
]
expr = pt.For(items[0], items[1], items[2]).Do(*items[3])

assert expr.type_of() == pt.TealType.none
assert not expr.has_return()

expected, varEnd = items[0].__teal__(options)
condStart, condEnd = items[1].__teal__(options)
stepStart, stepEnd = items[2].__teal__(options)
do, doEnd = pt.Seq(items[3]).__teal__(options)
expectedBranch = pt.TealConditionalBlock([])
end = pt.TealSimpleBlock([])

varEnd.setNextBlock(condStart)
doEnd.setNextBlock(stepStart)

expectedBranch.setTrueBlock(do)
expectedBranch.setFalseBlock(end)
condEnd.setNextBlock(expectedBranch)
stepEnd.setNextBlock(condStart)

actual, _ = expr.__teal__(options)

assert actual == expected
9 changes: 7 additions & 2 deletions pyteal/ast/if_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pyteal.types import TealType, require_type
from pyteal.ir import TealSimpleBlock, TealConditionalBlock
from pyteal.ast.expr import Expr
from pyteal.ast.seq import _use_seq_if_multiple

if TYPE_CHECKING:
from pyteal.compiler import CompileOptions
Expand Down Expand Up @@ -99,10 +100,12 @@ def has_return(self):
# otherwise, this expression has a return op only if both branches result in a return op
return self.thenBranch.has_return() and self.elseBranch.has_return()

def Then(self, thenBranch: Expr):
def Then(self, thenBranch: Expr, *then_branch_multi: Expr):
if not self.alternateSyntaxFlag:
raise TealInputError("Cannot mix two different If syntax styles")

thenBranch = _use_seq_if_multiple(thenBranch, *then_branch_multi)

if not self.elseBranch:
self.thenBranch = thenBranch
else:
Expand All @@ -123,13 +126,15 @@ def ElseIf(self, cond):
self.elseBranch.ElseIf(cond)
return self

def Else(self, elseBranch: Expr):
def Else(self, elseBranch: Expr, *else_branch_multi: Expr):
if not self.alternateSyntaxFlag:
raise TealInputError("Cannot mix two different If syntax styles")

if not self.thenBranch:
raise TealInputError("Must set Then branch before Else branch")

elseBranch = _use_seq_if_multiple(elseBranch, *else_branch_multi)

if not self.elseBranch:
require_type(elseBranch, self.thenBranch.type_of())
self.elseBranch = elseBranch
Expand Down

0 comments on commit d51857b

Please sign in to comment.