Skip to content

Commit

Permalink
Fix rewriting tuple indices in python3.9 and python3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
nielstron committed Mar 22, 2023
1 parent 5a7f40b commit b12f385
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 50 deletions.
26 changes: 9 additions & 17 deletions opshin/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .rewrite.rewrite_inject_builtins import RewriteInjectBuiltins
from .rewrite.rewrite_inject_builtin_constr import RewriteInjectBuiltinsConstr
from .rewrite.rewrite_remove_type_stuff import RewriteRemoveTypeStuff
from .rewrite.rewrite_subscript38 import RewriteSubscript38
from .rewrite.rewrite_tuple_assign import RewriteTupleAssign
from .rewrite.rewrite_zero_ary import RewriteZeroAry
from .optimize.optimize_remove_pass import OptimizeRemovePass
Expand Down Expand Up @@ -540,15 +541,12 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
), "Can only access elements of instances, not classes"
if isinstance(node.value.typ.typ, TupleType):
assert isinstance(
node.slice, Index
), "Only single index slices for tuples are currently supported"
assert isinstance(
node.slice.value, Constant
node.slice, Constant
), "Only constant index access for tuples is supported"
assert isinstance(
node.slice.value.value, int
node.slice.value, int
), "Only constant index integer access for tuples is supported"
index = node.slice.value.value
index = node.slice.value
if index < 0:
index += len(node.value.typ.typ.typs)
assert isinstance(node.ctx, Load), "Tuples are read-only"
Expand All @@ -561,11 +559,8 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
),
)
if isinstance(node.value.typ.typ, ListType):
assert isinstance(
node.slice, Index
), "Only single index slices for lists are currently supported"
assert (
node.slice.value.typ == IntegerInstanceType
node.slice.typ == IntegerInstanceType
), "Only single element list index access supported"
return plt.Lambda(
[STATEMONAD],
Expand All @@ -574,9 +569,7 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
("l", plt.Apply(self.visit(node.value), plt.Var(STATEMONAD))),
(
"raw_i",
plt.Apply(
self.visit(node.slice.value), plt.Var(STATEMONAD)
),
plt.Apply(self.visit(node.slice), plt.Var(STATEMONAD)),
),
(
"i",
Expand All @@ -593,7 +586,7 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
),
)
elif isinstance(node.value.typ.typ, ByteStringType):
if isinstance(node.slice, Index):
if not isinstance(node.slice, Slice):
return plt.Lambda(
[STATEMONAD],
plt.Let(
Expand All @@ -604,9 +597,7 @@ def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
),
(
"raw_ix",
plt.Apply(
self.visit(node.slice.value), plt.Var(STATEMONAD)
),
plt.Apply(self.visit(node.slice), plt.Var(STATEMONAD)),
),
(
"ix",
Expand Down Expand Up @@ -847,6 +838,7 @@ def compile(prog: AST, filename=None, force_three_params=False):
# Important to call this one first - it imports all further files
RewriteImport(filename=filename),
# Rewrites that simplify the python code
RewriteSubscript38(),
RewriteAugAssign(),
RewriteTupleAssign(),
RewriteImportPlutusData(),
Expand Down
15 changes: 15 additions & 0 deletions opshin/rewrite/rewrite_subscript38.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ast import *
import typing

from ..util import CompilingNodeTransformer

"""
Rewrites all Index/Slice occurrences such that they look like in Python 3.9 onwards (not like Python 3.8).
"""


class RewriteSubscript38(CompilingNodeTransformer):
step = "Rewriting Subscripts"

def visit_Index(self, node: Index) -> AST:
return self.visit(node.value)
2 changes: 1 addition & 1 deletion opshin/rewrite/rewrite_tuple_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def visit_Assign(self, node: Assign) -> typing.List[stmt]:
[t],
Subscript(
value=Name(f"{uid}_tup", Load()),
slice=Index(value=Constant(i)),
slice=Constant(i),
ctx=Load(),
),
)
Expand Down
48 changes: 16 additions & 32 deletions opshin/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,11 @@ def type_from_annotation(self, ann: expr):
assert isinstance(
ann.value, Name
), "Only Union, Dict and List are allowed as Generic types"
assert isinstance(ann.slice, Index), "Generic types must be parameterized"
if ann.value.id == "Union":
assert isinstance(
ann.slice.value, Tuple
ann.slice, Tuple
), "Union must combine multiple classes"
ann_types = [self.type_from_annotation(e) for e in ann.slice.value.elts]
ann_types = [self.type_from_annotation(e) for e in ann.slice.elts]
assert all(
isinstance(e, RecordType) for e in ann_types
), "Union must combine multiple PlutusData classes"
Expand All @@ -100,7 +99,7 @@ def type_from_annotation(self, ann: expr):
), "Union must combine PlutusData classes with unique constructors"
return UnionType(FrozenFrozenList(ann_types))
if ann.value.id == "List":
ann_type = self.type_from_annotation(ann.slice.value)
ann_type = self.type_from_annotation(ann.slice)
assert isinstance(
ann_type, ClassType
), "List must have a single type as parameter"
Expand All @@ -109,13 +108,11 @@ def type_from_annotation(self, ann: expr):
), "List can currently not hold tuples"
return ListType(InstanceType(ann_type))
if ann.value.id == "Dict":
assert isinstance(
ann.slice.value, Tuple
), "Dict must combine two classes"
assert len(ann.slice.value.elts) == 2, "Dict must combine two classes"
assert isinstance(ann.slice, Tuple), "Dict must combine two classes"
assert len(ann.slice.elts) == 2, "Dict must combine two classes"
ann_types = self.type_from_annotation(
ann.slice.value.elts[0]
), self.type_from_annotation(ann.slice.value.elts[1])
ann.slice.elts[0]
), self.type_from_annotation(ann.slice.elts[1])
assert all(
isinstance(e, ClassType) for e in ann_types
), "Dict must combine two classes"
Expand All @@ -125,9 +122,9 @@ def type_from_annotation(self, ann: expr):
return DictType(*(InstanceType(a) for a in ann_types))
if ann.value.id == "Tuple":
assert isinstance(
ann.slice.value, Tuple
ann.slice, Tuple
), "Tuple must combine several classes"
ann_types = [self.type_from_annotation(e) for e in ann.slice.value.elts]
ann_types = [self.type_from_annotation(e) for e in ann.slice.elts]
assert all(
isinstance(e, ClassType) for e in ann_types
), "Tuple must combine classes"
Expand Down Expand Up @@ -411,9 +408,6 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript:
"Dict",
"List",
]:
assert isinstance(
ts.slice, Index
), "Only single index slices for generic types are currently supported"
ts.value = ts.typ = self.type_from_annotation(ts)
return ts

Expand All @@ -423,34 +417,24 @@ def visit_Subscript(self, node: Subscript) -> TypedSubscript:
assert (
ts.value.typ.typ.typs
), "Accessing elements from the empty tuple is not allowed"
assert isinstance(
ts.slice, Index
), "Only single index slices for tuples are currently supported"
if all(ts.value.typ.typ.typs[0] == t for t in ts.value.typ.typ.typs):
ts.typ = ts.value.typ.typ.typs[0]
elif isinstance(ts.slice.value, Constant) and isinstance(
ts.slice.value.value, int
):
ts.typ = ts.value.typ.typ.typs[ts.slice.value.value]
elif isinstance(ts.slice, Constant) and isinstance(ts.slice.value, int):
ts.typ = ts.value.typ.typ.typs[ts.slice.value]
else:
raise TypeInferenceError(
f"Could not infer type of subscript of typ {ts.value.typ.__class__}"
)
elif isinstance(ts.value.typ.typ, ListType):
assert isinstance(
ts.slice, Index
), "Only single index slices for lists are currently supported"
ts.typ = ts.value.typ.typ.typ
ts.slice.value = self.visit(node.slice.value)
assert (
ts.slice.value.typ == IntegerInstanceType
), "List indices must be integers"
ts.slice = self.visit(node.slice)
assert ts.slice.typ == IntegerInstanceType, "List indices must be integers"
elif isinstance(ts.value.typ.typ, ByteStringType):
if isinstance(ts.slice, Index):
if not isinstance(ts.slice, Slice):
ts.typ = IntegerInstanceType
ts.slice.value = self.visit(node.slice.value)
ts.slice = self.visit(node.slice)
assert (
ts.slice.value.typ == IntegerInstanceType
ts.slice.typ == IntegerInstanceType
), "bytes indices must be integers"
elif isinstance(ts.slice, Slice):
ts.typ = ByteStringInstanceType
Expand Down

0 comments on commit b12f385

Please sign in to comment.