Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance nativescript.py type hint #129

Merged
merged 7 commits into from Nov 13, 2022
101 changes: 46 additions & 55 deletions pycardano/nativescript.py
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import ClassVar, List, Type, Union
from typing import ClassVar, List, Type, Union, cast

from nacl.encoding import RawEncoder
from nacl.hash import blake2b
Expand Down Expand Up @@ -35,25 +35,37 @@ def from_primitive(
) -> Union[
ScriptPubkey, ScriptAll, ScriptAny, ScriptNofK, InvalidBefore, InvalidHereAfter
]:
script_type = value[0]
for t in [
ScriptPubkey,
ScriptAll,
ScriptAny,
ScriptNofK,
InvalidBefore,
InvalidHereAfter,
]:
if t._TYPE == script_type:
return super(NativeScript, t).from_primitive(value[1:])
if not isinstance(
value,
(
list,
tuple,
),
):
raise DeserializeException(
f"A list or a tuple is required for deserialization: {str(value)}"
)

script_type: int = value[0]
if script_type == ScriptPubkey._TYPE:
return super(NativeScript, ScriptPubkey).from_primitive(value[1:])
elif script_type == ScriptAll._TYPE:
return super(NativeScript, ScriptAll).from_primitive(value[1:])
elif script_type == ScriptAny._TYPE:
return super(NativeScript, ScriptAny).from_primitive(value[1:])
elif script_type == ScriptNofK._TYPE:
return super(NativeScript, ScriptNofK).from_primitive(value[1:])
elif script_type == InvalidBefore._TYPE:
return super(NativeScript, InvalidBefore).from_primitive(value[1:])
elif script_type == InvalidHereAfter._TYPE:
return super(NativeScript, InvalidHereAfter).from_primitive(value[1:])
else:
raise DeserializeException(f"Unknown script type indicator: {script_type}")

def hash(self) -> ScriptHash:
cbor_bytes = cast(bytes, self.to_cbor("bytes"))
return ScriptHash(
blake2b(
bytes(1) + self.to_cbor("bytes"), SCRIPT_HASH_SIZE, encoder=RawEncoder
)
blake2b(bytes(1) + cbor_bytes, SCRIPT_HASH_SIZE, encoder=RawEncoder)
)

@classmethod
Expand All @@ -63,43 +75,16 @@ def from_dict(
ScriptPubkey, ScriptAll, ScriptAny, ScriptNofK, InvalidBefore, InvalidHereAfter
]:
"""Parse a standard native script dictionary (potentially parsed from a JSON file)."""

types = {
p.json_tag: p
for p in [
ScriptPubkey,
ScriptAll,
ScriptAny,
ScriptNofK,
InvalidBefore,
InvalidHereAfter,
]
}
script_type = script_json["type"]
target_class = types[script_type]
script_primitive = cls._script_json_to_primitive(script_json)
return super(NativeScript, target_class).from_primitive(script_primitive[1:])
return cls.from_primitive(script_primitive)

@classmethod
def _script_json_to_primitive(
cls: Type[NativeScript], script_json: JsonDict
) -> List[Primitive]:
"""Serialize a standard JSON native script into a primitive array"""

types = {
p.json_tag: p
for p in [
ScriptPubkey,
ScriptAll,
ScriptAny,
ScriptNofK,
InvalidBefore,
InvalidHereAfter,
]
}

script_type: str = script_json["type"]
native_script = [types[script_type]._TYPE]
native_script: List[Primitive] = [JSON_TAG_TO_INT[script_type]]

for key, value in script_json.items():
if key == "type":
Expand All @@ -118,22 +103,18 @@ def _script_jsons_to_primitive(
native_script = [cls._script_json_to_primitive(i) for i in script_jsons]
return native_script

def to_dict(self) -> dict:
def to_dict(self) -> JsonDict:
"""Export to standard native script dictionary (potentially to dump to a JSON file)."""

script = {}

script: JsonDict = {}
for value in self.__dict__.values():
script["type"] = self.json_tag

if isinstance(value, list):
script["scripts"] = [i.to_dict() for i in value]

elif isinstance(value, int):
script[self.json_field] = value
else:
if isinstance(value, int):
script[self.json_field] = value
else:
script[self.json_field] = str(value)
script[self.json_field] = str(value)

return script

Expand Down Expand Up @@ -209,7 +190,7 @@ class InvalidBefore(NativeScript):
json_field: ClassVar[str] = "slot"
_TYPE: int = field(default=4, init=False)

before: int = None
before: int


@dataclass
Expand All @@ -218,4 +199,14 @@ class InvalidHereAfter(NativeScript):
json_field: ClassVar[str] = "slot"
_TYPE: int = field(default=5, init=False)

after: int = None
after: int


JSON_TAG_TO_INT = {
ScriptPubkey.json_tag: ScriptPubkey._TYPE,
ScriptAll.json_tag: ScriptAll._TYPE,
ScriptAny.json_tag: ScriptAny._TYPE,
ScriptNofK.json_tag: ScriptNofK._TYPE,
InvalidBefore.json_tag: InvalidBefore._TYPE,
InvalidHereAfter.json_tag: InvalidHereAfter._TYPE,
}
1 change: 0 additions & 1 deletion pyproject.toml
Expand Up @@ -75,7 +75,6 @@ exclude = [
'^pycardano/key.py$',
'^pycardano/logging.py$',
'^pycardano/metadata.py$',
'^pycardano/nativescript.py$',
'^pycardano/plutus.py$',
'^pycardano/transaction.py$',
'^pycardano/txbuilder.py$',
Expand Down
7 changes: 6 additions & 1 deletion test/pycardano/test_nativescript.py
Expand Up @@ -2,7 +2,7 @@

import pytest

from pycardano.exception import InvalidArgumentException
from pycardano.exception import DeserializeException, InvalidArgumentException
from pycardano.key import VerificationKey
from pycardano.nativescript import (
InvalidBefore,
Expand Down Expand Up @@ -163,6 +163,11 @@ def test_to_dict():
assert NativeScript.from_dict(script_dict) == script_nofk


def test_from_primitive_invalid_primitive_input():
with pytest.raises(DeserializeException):
NativeScript.from_primitive(1)


def test_from_dict():

vk1 = VerificationKey.from_cbor(
Expand Down