Skip to content
This repository has been archived by the owner on Apr 5, 2024. It is now read-only.

Commit

Permalink
Merge pull request #17 from OpShin/fix/guaranteed_variables
Browse files Browse the repository at this point in the history
Fix/guaranteed variables
  • Loading branch information
nielstron committed Apr 15, 2023
2 parents 6b7a327 + 8fd7b28 commit 2660ed0
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 47 deletions.
8 changes: 4 additions & 4 deletions examples/complex_datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class BatchOrder(PlutusData):
def validator(d: Union[Nothing, BatchOrder]) -> bytes:
if isinstance(d, BatchOrder):
c = d.sender.payment_credential
res = c.credential_hash
return c.credential_hash
elif isinstance(d, Nothing):
res = b""
# Throws a NameError if the instances don't match - this is fine, it means that the contract was not invoked correctly!
return res
return b""
else:
assert False, "Invalid datum"
4 changes: 2 additions & 2 deletions examples/datum_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ def validator(d: Anything, r: Anything) -> bytes:
c = e.sender.payment_credential
# this actually checks that c is of the type PubKeyCredential
if isinstance(c, PubKeyCredential):
res = c.credential_hash
return res + r2
return c.credential_hash + r2
assert False, "Invalid sender"
2 changes: 2 additions & 0 deletions examples/showcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ def validator(n: int) -> int:
if b < 5:
print("add")
c = 5
else:
c = 0
# list comprehensions for loops
d = sum([i for i in range(2)])

Expand Down
15 changes: 7 additions & 8 deletions examples/smart_contracts/marketplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,12 @@ def validator(datum: Listing, redeemer: ListingAction, context: ScriptContext) -
if isinstance(purpose, Spending):
own_utxo = resolve_spent_utxo(tx_info.inputs, purpose)
own_addr = own_utxo.address
check_single_utxo_spent(tx_info.inputs, own_addr)
if isinstance(redeemer, Buy):
check_paid(tx_info.outputs, datum.vendor, datum.price)
elif isinstance(redeemer, Unlist):
check_owner_signed(tx_info.signatories, datum.owner)
else:
assert False, "Wrong redeemer"
else:
assert False, "Wrong script purpose"

check_single_utxo_spent(tx_info.inputs, own_addr)
if isinstance(redeemer, Buy):
check_paid(tx_info.outputs, datum.vendor, datum.price)
elif isinstance(redeemer, Unlist):
check_owner_signed(tx_info.signatories, datum.owner)
else:
assert False, "Wrong redeemer"
7 changes: 4 additions & 3 deletions examples/smart_contracts/wrapped_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def own_policy_id(own_spent_utxo: TxOut) -> PolicyId:
# obtain the policy id for which this contract can validate minting/burning
cred = own_spent_utxo.address.payment_credential
if isinstance(cred, ScriptCredential):
policy_id = cred.credential_hash
# This throws a name error if the credential is not a ScriptCredential instance
return policy_id
return cred.credential_hash
assert False, "Wrong type of payment credential"


def own_address(own_policy_id: PolicyId) -> Address:
Expand Down Expand Up @@ -78,6 +77,8 @@ def validator(
own_pid = own_policy_id(own_utxo)
own_addr = own_utxo.address
else:
own_addr = Address(PubKeyCredential(b""), NoStakingCredential())
own_pid = b""
assert False, "Incorrect purpose given"
token = Token(token_policy_id, token_name)
all_locked = all_tokens_locked_at_contract_address(
Expand Down
2 changes: 2 additions & 0 deletions hebi/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .optimize.optimize_remove_comments import OptimizeRemoveDeadconstants
from .rewrite.rewrite_forbidden_overwrites import RewriteForbiddenOverwrites
from .rewrite.rewrite_guaranteed_variables import RewriteGuaranteedVariables
from .rewrite.rewrite_import import RewriteImport
from .rewrite.rewrite_import_dataclasses import RewriteImportDataclasses
from .rewrite.rewrite_import_hashlib import RewriteImportHashlib
Expand Down Expand Up @@ -704,6 +705,7 @@ def compile(
RewriteImportDataclasses(),
RewriteInjectBuiltins(),
RewriteDuplicateAssignment(),
RewriteGuaranteedVariables(),
# The type inference needs to be run after complex python operations were rewritten
AggressiveTypeInferencer(),
# Rewrites that circumvent the type inference or use its results
Expand Down
5 changes: 2 additions & 3 deletions hebi/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,12 @@ def resolve_datum_unsafe(txout: TxOut, tx_info: TxInfo) -> BuiltinData:
"""
attached_datum = txout.datum
if isinstance(attached_datum, SomeOutputDatumHash):
res = tx_info.data[attached_datum.datum_hash]
return tx_info.data[attached_datum.datum_hash]
elif isinstance(attached_datum, SomeOutputDatum):
res = attached_datum.datum
return attached_datum.datum
else:
# no datum attached
assert False, "No datum was attached to the given transaction output"
return res


def resolve_datum(
Expand Down
132 changes: 132 additions & 0 deletions hebi/rewrite/rewrite_guaranteed_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from ast import *
from copy import copy
from collections import defaultdict

from ..util import CompilingNodeVisitor, CompilingNodeTransformer
from ..type_inference import INITIAL_SCOPE

"""
Checks that used variables are bound in every branch of preceding control flow
"""


class RewriteGuaranteedVariables(CompilingNodeTransformer):
step = "Ensure variables are bound"

loaded_vars = None
# names that are guaranteed to be available to the current node
# this acts differently to the type inferencer! in particular, ite/while/for all produce their own scope
guaranteed_avail_names = [
list(INITIAL_SCOPE.keys()) + ["List", "Dict", "Union", "isinstance"]
]

def guaranteed(self, name: str) -> bool:
name = name
for scope in reversed(self.guaranteed_avail_names):
if name in scope:
return True
return False

def enter_scope(self):
self.guaranteed_avail_names.append([])

def exit_scope(self):
self.guaranteed_avail_names.pop()

def set_guaranteed(self, name: str):
self.guaranteed_avail_names[-1].append(name)

def visit_Module(self, node: Module) -> Module:
# repeat until no more change due to removal
# i.e. b = a; c = b needs 2 passes to remove c and b
node_cp = copy(node)
self.enter_scope()
node_cp.body = [self.visit(s) for s in node_cp.body]
self.exit_scope()
return node_cp

def visit_If(self, node: If):
node_cp = copy(node)
node_cp.test = self.visit(node.test)
self.enter_scope()
node_cp.body = [self.visit(s) for s in node.body]
scope_body_cp = self.guaranteed_avail_names[-1].copy()
self.exit_scope()
self.enter_scope()
node_cp.orelse = [self.visit(s) for s in node.orelse]
scope_orelse_cp = self.guaranteed_avail_names[-1].copy()
self.exit_scope()
# what remains after this in the scope is the intersection of both
for var in set(scope_body_cp).intersection(scope_orelse_cp):
self.set_guaranteed(var)
return node_cp

def visit_While(self, node: While):
node_cp = copy(node)
node_cp.test = self.visit(node.test)
self.enter_scope()
node_cp.body = [self.visit(s) for s in node.body]
node_cp.orelse = [self.visit(s) for s in node.orelse]
self.exit_scope()
return node_cp

def visit_For(self, node: For):
node_cp = copy(node)
assert isinstance(node.target, Name), "Can only assign to singleton name"
self.enter_scope()
self.guaranteed(node.target.id)
node_cp.body = [self.visit(s) for s in node.body]
node_cp.orelse = [self.visit(s) for s in node.orelse]
self.exit_scope()
return node_cp

def visit_ListComp(self, node: ListComp):
assert len(node.generators) == 1, "Currently only one generator supported"
gen = node.generators[0]
assert isinstance(
gen.target, Name
), "Can only assign value to singleton element"
assert isinstance(gen.target, Name), "Can only assign to singleton name"
node_cp = copy(node)
node_cp.generators = [copy(gen)]
self.enter_scope()
self.set_guaranteed(gen.target.id)
node_cp.generators[0].ifs = [self.visit(if_expr) for if_expr in gen.ifs]
node_cp.elt = self.visit(node.elt)
self.exit_scope()
return node_cp

def visit_Assign(self, node: Assign):
for t in node.targets:
assert isinstance(t, Name), f"Need to have name, not {t.__class__}"
self.set_guaranteed(t.id)
return self.generic_visit(node)

def visit_AnnAssign(self, node: AnnAssign):
assert isinstance(
node.target, Name
), f"Need to have name, not {node.target.__class__}"
self.set_guaranteed(node.target.id)
return self.generic_visit(node)

def visit_ClassDef(self, node: ClassDef):
self.set_guaranteed(node.name)
return node

def visit_FunctionDef(self, node: FunctionDef):
node_cp = copy(node)
self.set_guaranteed(node.name)
self.enter_scope()
# variable names are available here
for a in node.args.args:
self.set_guaranteed(a.arg)
node_cp.body = [self.visit(s) for s in node.body]
self.exit_scope()
return node_cp

def visit_Name(self, node: Name):
if isinstance(node.ctx, Load):
assert self.guaranteed(
node.id
), f"Variable {node.id} is not initialized in (every branch of) preceding control flow"
return self.generic_visit(node)
27 changes: 0 additions & 27 deletions hebi/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,33 +398,6 @@ def validator(x: Token) -> bool:
failed = True
self.assertTrue(failed, "Machine did validate the content")

def test_opt_shared_var(self):
# this tests that errors that are caused by assignments are actually triggered at the time of assigning
source_code = """
from hebi.prelude import *
def validator(x: Token) -> bool:
if False:
y = x
else:
a = y
return True
"""
ast = compiler.parse(source_code)
code = compiler.compile(ast)
code = code.compile()
f = code.term
# UPLC lambdas may only take one argument at a time, so we evaluate by repeatedly applying
try:
for d in [
uplc.PlutusConstr(0, []),
]:
f = uplc.Apply(f, d)
ret = uplc_eval(f)
failed = False
except Exception as e:
failed = True
self.assertTrue(failed, "Machine did validate the content")

def test_list_expr(self):
# this tests that the list expression is evaluated correctly
source_code = """
Expand Down

0 comments on commit 2660ed0

Please sign in to comment.