Skip to content

Commit

Permalink
Merge pull request #136 from OpShin/feat/to_cbor
Browse files Browse the repository at this point in the history
Implement to_cbor, mapping to serialiseData
  • Loading branch information
nielstron committed May 8, 2023
2 parents 4b4f116 + 0a29367 commit 778e2d3
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 449 deletions.
4 changes: 1 addition & 3 deletions opshin/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,7 @@ def validator(_: None) -> SomeOutputDatum:
ret = uplc_eval(f)
self.assertEqual(
ret,
uplc.data_from_cbor(
prelude.SomeOutputDatum(b"a").to_cbor(encoding="bytes")
),
uplc.data_from_cbor(prelude.SomeOutputDatum(b"a").to_cbor()),
"Wrapping to generic data failed",
)

Expand Down
69 changes: 68 additions & 1 deletion opshin/tests/test_stdlib.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import dataclass
import unittest

from hypothesis import example, given
from hypothesis import example, given, settings
from hypothesis import strategies as st
from uplc import ast as uplc, eval as uplc_eval
from pycardano import PlutusData

from .. import compiler

Expand Down Expand Up @@ -279,3 +281,68 @@ def validator(x: None) -> bool:
f = uplc.Apply(f, d)
ret = uplc_eval(f).value == 1
self.assertEqual(ret, x, "literal bool returned wrong value")

@given(st.integers(), st.binary())
@settings(deadline=None)
def test_plutusdata_to_cbor(self, x: int, y: bytes):
source_code = f"""
from opshin.prelude import *
@dataclass
class Test(PlutusData):
x: int
y: bytes
def validator(x: int, y: bytes) -> bytes:
return Test(x, y).to_cbor()
"""

@dataclass
class Test(PlutusData):
x: int
y: bytes

ast = compiler.parse(source_code)
code = compiler.compile(ast)
code = code.compile()
f = code.term
# UPLC lambdas may only take one argument at a time, so we evaluate by repeatedly applying
for d in [uplc.PlutusInteger(x), uplc.PlutusByteString(y)]:
f = uplc.Apply(f, d)
ret = uplc_eval(f).value
self.assertEqual(ret, Test(x, y).to_cbor(), "to_cbor returned wrong value")

@given(st.integers())
@settings(deadline=None)
def test_union_to_cbor(self, x: int):
source_code = f"""
from opshin.prelude import *
@dataclass
class Test(PlutusData):
CONSTR_ID = 1
x: int
y: bytes
@dataclass
class Test2(PlutusData):
x: int
def validator(x: int) -> bytes:
y: Union[Test, Test2] = Test2(x)
return y.to_cbor()
"""

@dataclass
class Test2(PlutusData):
x: int

ast = compiler.parse(source_code)
code = compiler.compile(ast)
code = code.compile()
f = code.term
# UPLC lambdas may only take one argument at a time, so we evaluate by repeatedly applying
for d in [uplc.PlutusInteger(x)]:
f = uplc.Apply(f, d)
ret = uplc_eval(f).value
self.assertEqual(ret, Test2(x).to_cbor(), "to_cbor returned wrong value")
76 changes: 51 additions & 25 deletions opshin/typed_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def attribute_type(self, attr: str) -> Type:
for n, t in self.record.fields:
if n == attr:
return t
if attr == "to_cbor":
return InstanceType(
FunctionType(FrozenFrozenList([]), ByteStringInstanceType)
)
raise TypeInferenceError(
f"Type {self.record.name} does not have attribute {attr}"
)
Expand All @@ -120,18 +124,27 @@ def attribute(self, attr: str) -> plt.AST:
["self"],
plt.Constructor(plt.Var("self")),
)
attr_typ = self.attribute_type(attr)
pos = next(i for i, (n, _) in enumerate(self.record.fields) if n == attr)
# access to normal fields
return plt.Lambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.NthField(
if attr in (n for n, t in self.record.fields):
attr_typ = self.attribute_type(attr)
pos = next(i for i, (n, _) in enumerate(self.record.fields) if n == attr)
# access to normal fields
return plt.Lambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.NthField(
plt.Var("self"),
plt.Integer(pos),
),
),
)
if attr == "to_cbor":
return plt.Lambda(
["self", "_"],
plt.SerialiseData(
plt.Var("self"),
plt.Integer(pos),
),
),
)
)
raise NotImplementedError(f"Attribute {attr} not implemented for type {self}")

def cmp(self, op: cmpop, o: "Type") -> plt.AST:
"""The implementation of comparing this type to type o via operator op. Returns a lambda that expects as first argument the object itself and as second the comparison."""
Expand Down Expand Up @@ -213,6 +226,10 @@ def attribute_type(self, attr) -> "Type":
)
# return Anytype
return InstanceType(AnyType())
if attr == "to_cbor":
return InstanceType(
FunctionType(FrozenFrozenList([]), ByteStringInstanceType)
)
raise TypeInferenceError(
f"Can not access attribute {attr} of Union type. Cast to desired type with an 'if isinstance(_, _):' branch."
)
Expand All @@ -225,24 +242,33 @@ def attribute(self, attr: str) -> plt.AST:
plt.Constructor(plt.Var("self")),
)
# iterate through all names/types of the unioned records by position
attr_typ = self.attribute_type(attr)
pos = next(
i
for i, (ns, _) in enumerate(
map(lambda x: zip(*x), zip(*(t.record.fields for t in self.typs)))
if any(attr in (n for n, t in r.record.fields) for r in self.typs):
attr_typ = self.attribute_type(attr)
pos = next(
i
for i, (ns, _) in enumerate(
map(lambda x: zip(*x), zip(*(t.record.fields for t in self.typs)))
)
if all(n == attr for n in ns)
)
if all(n == attr for n in ns)
)
# access to normal fields
return plt.Lambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.NthField(
# access to normal fields
return plt.Lambda(
["self"],
transform_ext_params_map(attr_typ)(
plt.NthField(
plt.Var("self"),
plt.Integer(pos),
),
),
)
if attr == "to_cbor":
return plt.Lambda(
["self", "_"],
plt.SerialiseData(
plt.Var("self"),
plt.Integer(pos),
),
),
)
)
raise NotImplementedError(f"Attribute {attr} not implemented for type {self}")

def __ge__(self, other):
if isinstance(other, UnionType):
Expand Down
2 changes: 1 addition & 1 deletion opshin/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def data_from_json(j: typing.Dict[str, typing.Any]) -> uplc.PlutusData:


def datum_to_cbor(d: pycardano.Datum) -> bytes:
return pycardano.PlutusData.to_cbor(d, encoding="bytes")
return pycardano.PlutusData.to_cbor(d)


def datum_to_json(d: pycardano.Datum) -> str:
Expand Down
Loading

0 comments on commit 778e2d3

Please sign in to comment.