Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix scoping for shadowed variables #142

Merged
merged 2 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from .rewrite.rewrite_import_typing import RewriteImportTyping
from .rewrite.rewrite_inject_builtins import RewriteInjectBuiltins
from .rewrite.rewrite_inject_builtin_constr import RewriteInjectBuiltinsConstr
from .rewrite.rewrite_orig_name import RewriteOrigName
from .rewrite.rewrite_remove_type_stuff import RewriteRemoveTypeStuff
from .rewrite.rewrite_scoping import RewriteScoping
from .rewrite.rewrite_subscript38 import RewriteSubscript38
from .rewrite.rewrite_tuple_assign import RewriteTupleAssign
from .rewrite.rewrite_zero_ary import RewriteZeroAry
Expand Down Expand Up @@ -940,6 +942,9 @@ def compile(

# from here on raw uplc may occur, so we dont attempt to fix locations
compile_pipeline = [
# Save the original names of variables
RewriteOrigName(),
RewriteScoping(),
# Apply optimizations
OptimizeRemoveDeadvars(),
OptimizeVarlen(),
Expand Down
2 changes: 1 addition & 1 deletion opshin/optimize/optimize_remove_deadvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def visit_Module(self, node: Module) -> Module:
# collect all variable names
collector = NameLoadCollector()
collector.visit(node_cp)
loaded_vars = set(collector.loaded.keys()) | {"validator"}
loaded_vars = set(collector.loaded.keys()) | {"validator_0"}
# break if the set of loaded vars did not change -> set of vars to remove does also not change
if loaded_vars == self.loaded_vars:
break
Expand Down
4 changes: 0 additions & 4 deletions opshin/optimize/optimize_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,22 @@ def visit_Module(self, node: Module) -> Module:

def visit_Name(self, node: Name) -> Name:
nc = copy(node)
nc.orig_id = node.id
nc.id = self.varmap[node.id]
return nc

def visit_ClassDef(self, node: ClassDef) -> ClassDef:
node_cp = copy(node)
node_cp.orig_name = node.name
node_cp.name = self.varmap[node.name]
# ignore the content of class definitions
return node_cp

def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
node_cp = copy(node)
node_cp.orig_name = node.name
node_cp.name = self.varmap[node.name]
node_cp.args = copy(node.args)
node_cp.args.args = []
for a in node.args.args:
a_cp = copy(a)
a_cp.orig_arg = a.arg
a_cp.arg = self.varmap[a.arg]
node_cp.args.args.append(a_cp)
node_cp.body = [self.visit(s) for s in node.body]
Expand Down
34 changes: 34 additions & 0 deletions opshin/rewrite/rewrite_orig_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from ast import *
from copy import copy

from ..util import CompilingNodeTransformer

"""
Rewrites all occurrences of names to contain a pointer to the original name for good
"""


class RewriteOrigName(CompilingNodeTransformer):
step = "Assigning the orig_id/orig_name field with the variable name"

def visit_Name(self, node: Name) -> Name:
nc = copy(node)
nc.orig_id = node.id
return nc

def visit_ClassDef(self, node: ClassDef) -> ClassDef:
node_cp = copy(node)
node_cp.orig_name = node.name
return node_cp

def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
node_cp = copy(node)
node_cp.orig_name = node.name
node_cp.args = copy(node.args)
node_cp.args.args = []
for a in node.args.args:
a_cp = copy(a)
a_cp.orig_arg = a.arg
node_cp.args.args.append(a_cp)
node_cp.body = [self.visit(s) for s in node.body]
return node_cp
115 changes: 115 additions & 0 deletions opshin/rewrite/rewrite_scoping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from ast import *
from copy import copy
from collections import defaultdict

from ..type_inference import INITIAL_SCOPE
from ..util import CompilingNodeTransformer, CompilingNodeVisitor

"""
Rewrites all variable names to point to the definition in the nearest enclosing scope
"""


class ShallowNameDefCollector(CompilingNodeVisitor):
step = "Collecting occuring variable names"

def __init__(self):
self.vars = set()

def visit_Name(self, node: Name) -> None:
if isinstance(node.ctx, Store):
self.vars.add(node.id)

def visit_ClassDef(self, node: ClassDef):
self.vars.add(node.name)
# ignore the content (i.e. attribute names) of class definitions

def visit_FunctionDef(self, node: FunctionDef):
self.vars.add(node.name)
# ignore the recursive stuff


class RewriteScoping(CompilingNodeTransformer):
step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope"

def __init__(self):
self.latest_scope_id = 0
self.scopes = [(set(INITIAL_SCOPE.keys()), -1)]

def variable_scope_id(self, name: str) -> int:
"""find the id of the scope in which this variable is defined (closest to its usage)"""
name = name
for scope, scope_id in reversed(self.scopes):
if name in scope:
return scope_id
raise NameError(
f"free variable '{name}' referenced before assignment in enclosing scope"
)

def enter_scope(self):
self.scopes.append((set(), self.latest_scope_id))
self.latest_scope_id += 1

def exit_scope(self):
self.scopes.pop()

def set_variable_scope(self, name: str):
self.scopes[-1][0].add(name)

def map_name(self, name: str):
scope_id = self.variable_scope_id(name)
if scope_id == -1:
# do not rewrite Dict, Union, etc
return name
return f"{name}_{scope_id}"

def visit_Module(self, node: Module) -> Module:
node_cp = copy(node)
self.enter_scope()
# vars defined in this scope
shallow_node_def_collector = ShallowNameDefCollector()
for s in node.body:
shallow_node_def_collector.visit(s)
vars_def = shallow_node_def_collector.vars
for var_name in vars_def:
self.set_variable_scope(var_name)
node_cp.body = [self.visit(s) for s in node.body]
return node_cp

def visit_Name(self, node: Name) -> Name:
nc = copy(node)
# setting is handled in either enclosing module or function
nc.id = self.map_name(node.id)
return nc

def visit_ClassDef(self, node: ClassDef) -> ClassDef:
node_cp = copy(node)
# setting is handled in either enclosing module or function
node_cp.name = self.map_name(node.name)
# ignore the content of class definitions
return node_cp

def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
node_cp = copy(node)
# setting is handled in either enclosing module or function
node_cp.name = self.map_name(node.name)
self.enter_scope()
node_cp.args = copy(node.args)
node_cp.args.args = []
# args are defined in this scope
for a in node.args.args:
a_cp = copy(a)
self.set_variable_scope(a.arg)
a_cp.arg = self.map_name(a.arg)
node_cp.args.args.append(a_cp)
# vars defined in this scope
shallow_node_def_collector = ShallowNameDefCollector()
for s in node.body:
shallow_node_def_collector.visit(s)
vars_def = shallow_node_def_collector.vars
for var_name in vars_def:
self.set_variable_scope(var_name)
# map all vars and recurse
node_cp.body = [self.visit(s) for s in node.body]
self.exit_scope()
return node_cp
2 changes: 0 additions & 2 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,6 @@ def validator(x: Token) -> bool:
failed = True
self.assertTrue(failed, "Machine did validate the content")

@unittest.skip
def test_inner_outer_state_functions(self):
source_code = """
a = 2
Expand All @@ -969,7 +968,6 @@ def validator(_: None) -> int:
res = uplc_eval(uplc.Apply(code, uplc.PlutusConstr(0, [])))
self.assertEqual(res, uplc.PlutusInteger(2))

@unittest.skip
def test_inner_outer_state_functions_nonglobal(self):
source_code = """

Expand Down