Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataclass wrapper for ipv8 Payload #1011

Closed
grimadas opened this issue Jul 1, 2021 · 17 comments · Fixed by #1045
Closed

Dataclass wrapper for ipv8 Payload #1011

grimadas opened this issue Jul 1, 2021 · 17 comments · Fixed by #1045
Assignees
Labels
priority: medium Enhancements, features or exotic bugs

Comments

@grimadas
Copy link

grimadas commented Jul 1, 2021

Request for dataclass-based message payloads.

Since message payloads are used very frequently I suggest we simplify the creation and usage of Payloads.
Python data classes might be a promising direction.

Here is an example:
image

The payload is a wrapper:


class LazierPayload:
    serializer = default_serializer

    @staticmethod
    def _type_map(t: Type) -> str:
        if t == int:
            return "Q"
        elif t == bytes:
            return "varlenH"
        elif "Tuple" in str(t) or "List" in str(t) or "Set" in str(t):
            return (
                "varlenH-list"
                if "int" in str(t) or "bytes" in str(t)
                else [typing.get_args(t)[0]]
            )
        elif hasattr(t, "format_list"):
            return t
        else:
            raise NotImplementedError(t, " unknown")

    @classmethod
    def init_class(cls):
        # Copy all methods of VariablePayload except init
        d = {
            k: v
            for k, v in VariablePayload.__dict__.items()
            if not str(k).startswith("__")
            and str(k) != "names"
            and str(k) != "format_list"
        }

        for (name, method) in d.items():
            setattr(cls, name, method)
        # Populate names and format list
        fields = get_cls_fields(cls)

        for f, t in fields.items():
            cls.names.append(f)
            cls.format_list.append(cls._type_map(t))
        return cls

    def is_optimal_size(self):
        return check_size_limit(self.to_bytes())

    def to_bytes(self) -> bytes:
        return self.serializer.pack_serializable(self)

    @classmethod
    def from_bytes(cls, pack: bytes) -> "BamiPayload":
        return cls.serializer.unpack_serializable(cls, pack)[0]

def payload(cls):
    d = {k: v for k, v in LazierPayload.__dict__.items() if not str(k).startswith("__")}
    for k, v in d.items():
        setattr(cls, k, v)
    cls.names = list()
    cls.format_list = list()

    # Populate all by mro
    added_classes = set()
    new_mro = []
    has_imp_ser = False

    for superclass in cls.__mro__:
        if superclass == ImpSer:
            has_imp_ser = True
        if hasattr(superclass, "names"):
            cls.names.extend(superclass.names)
            cls.format_list.extend(superclass.format_list)

    new_mro.append(ImpSer)
    if ImpSer not in added_classes:
        added_classes.add(ImpSer)

    if not has_imp_ser:
        new_vals = tuple([ImpSer] + list(cls.__bases__))
        new_cls = type(cls.__name__, new_vals, dict(cls.__dict__))
    else:
        new_cls = cls

    return new_cls.init_class()

From the developer point of view the dataclass would act as a VariablePayload.

Downside: we have to write map from Python types to ipv8 struct types

@grimadas grimadas added the priority: medium Enhancements, features or exotic bugs label Jul 1, 2021
@qstokkink
Copy link
Collaborator

Thanks for the suggestion. This is definitely something that fits the IPv8 core library.

I spent some time minimizing your code to make it smaller (combining your @dataclass and @payload) and slightly faster:

Wrapper
from dataclasses import dataclass
from functools import partial
from typing import Type, get_args, get_type_hints

from ipv8.messaging.serialization import Serializable


def type_map(t: Type) -> str:
    if t == int:
        return "Q"
    elif t == bytes:
        return "varlenH"
    elif "Tuple" in str(t) or "List" in str(t) or "Set" in str(t):
        return (
            "varlenH-list"
            if "int" in str(t) or "bytes" in str(t)
            else [get_args(t)[0]]
        )
    elif hasattr(t, "format_list"):
        return t
    else:
        raise NotImplementedError(t, " unknown")


def dataclass_payload(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False):
    """
    Equivalent to ``@dataclass``, but also makes the wrapped class a ``Serializable``.

    See ``dataclasses.dataclass`` for argument descriptions.
    """
    if cls is None:
        return partial(dataclass_payload, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash,
                       frozen=frozen)
    origin = dataclass(cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen)

    class DataClassPayload(origin, Serializable):
        names = list(get_type_hints(origin).keys())
        format_list = list(map(type_map, get_type_hints(origin).values()))

        @classmethod
        def from_unpack_list(cls, *args):
            return DataClassPayload(*args)

        def to_pack_list(self):
            return [(self.format_list[i], getattr(self, self.names[i])) for i in range(len(self.names))]

    DataClassPayload.__name__ = origin.__name__
    DataClassPayload.__qualname__ = origin.__qualname__
    return DataClassPayload
Benchmark code
from dataclasses import asdict, dataclass, is_dataclass
from time import time

from testdc_bulat import payload
from testdc_quinten import dataclass_payload
from ipv8.messaging.serialization import Serializable, Payload, default_serializer


@payload
@dataclass
class BaseDataclass:
    other: bytes


@dataclass
class BaseDataclass2:
    other: bytes


t_load_1_start = time()


@payload
@dataclass(unsafe_hash=True)
class CellState(BaseDataclass):
    cell_additive: int
    cells: bytes


t_load_1_end = time()

t_init_1_start = time()
c1 = CellState(b"def", 1, b"abc")
t_init_1_end = time()

t_pack_1_start = time()
raw = default_serializer.pack_serializable(c1)
t_pack_1_end = time()

t_unpack_1_start = time()
default_serializer.unpack_serializable(CellState, raw)
t_unpack_1_end = time()

print(raw)
print(default_serializer.unpack_serializable(CellState, raw)[0])
print(f"{is_dataclass(c1)=}, {isinstance(c1, Serializable)=}, {isinstance(c1, Payload)=}")


t_load_2_start = time()


@dataclass_payload(unsafe_hash=True)
class CellState2(BaseDataclass2):
    cell_additive: int
    cells: bytes


t_load_2_end = time()

t_init_2_start = time()
c2 = CellState2(b"def", 1, b"abc")
t_init_2_end = time()

t_pack_2_start = time()
raw = default_serializer.pack_serializable(c2)
t_pack_2_end = time()

t_unpack_2_start = time()
default_serializer.unpack_serializable(CellState2, raw)
t_unpack_2_end = time()

print(raw)
print(default_serializer.unpack_serializable(CellState2, raw)[0])
print(f"{is_dataclass(c2)=}, {isinstance(c2, Serializable)=}, {isinstance(c2, Payload)=}")

assert asdict(c1) == asdict(c2)

print("=== RESULTS 1 ===")
print(f"Time load: {t_load_1_end-t_load_1_start} seconds")
print(f"Time init: {t_init_1_end-t_init_1_start} seconds")
print(f"Time pack: {t_pack_1_end-t_pack_1_start} seconds")
print(f"Time total: {(t_load_1_end-t_load_1_start+t_init_1_end-t_init_1_start+t_pack_1_end-t_pack_1_start+t_unpack_1_end-t_unpack_1_start)} seconds")
print(" - Important -")
print(f"Time unpack: {t_unpack_1_end-t_unpack_1_start} seconds")
print(f"Time init + pack: {(t_init_1_end-t_init_1_start+t_pack_1_end-t_pack_1_start)} seconds")

print("=== RESULTS 2 ===")
print(f"Time load: {t_load_2_end-t_load_2_start} seconds")
print(f"Time init: {t_init_2_end-t_init_2_start} seconds")
print(f"Time pack: {t_pack_2_end-t_pack_2_start} seconds")
print(f"Time total: {(t_load_2_end-t_load_2_start+t_init_2_end-t_init_2_start+t_pack_2_end-t_pack_2_start+t_unpack_2_end-t_unpack_2_start)} seconds")
print(" - Important -")
print(f"Time unpack: {t_unpack_2_end-t_unpack_2_start} seconds")
print(f"Time init + pack: {(t_init_2_end-t_init_2_start+t_pack_2_end-t_pack_2_start)} seconds")

If anyone thinks they can make this even smaller and faster, be my guest. This is probably where I'll leave it.

@qstokkink
Copy link
Collaborator

By the way, I think this is the most controversial choice in this suggestion:

if t == int:
    return "Q"

All ints will now be unsigned long long. Perhaps we should use q for generality instead (signed long long).

I'd like to hear feedback on this.

@qstokkink
Copy link
Collaborator

Summary of offline discussion:

  1. int should be q.
  2. The Type-to-Serializer-format mapping should be in a dict that can be overwritten per dataclass.

@drew2a
Copy link
Contributor

drew2a commented Oct 6, 2021

I'm not an expert in python black magic, but I can help with any other work on this feature.

@qstokkink
Copy link
Collaborator

@drew2a thanks for offering to help, I'll fire up the latest prototype to see what still needs to be done in order to finish this. I'll follow up with the findings.

[Admin notice: unassigning @grimadas due to scheduling constraints.]

@qstokkink qstokkink assigned qstokkink and unassigned grimadas Oct 6, 2021
@qstokkink
Copy link
Collaborator

qstokkink commented Oct 6, 2021

I double-checked the black magic and corrected it a bit to allow for all of the possible strange stuff you can do with inheritance and nesting. The WIP is available here https://github.com/qstokkink/py-ipv8/tree/add_dcpayload

Example usage from the test:

varlenH = TypeVar('varlenH')  # Can be any format string that the Serializer can handle


@dataclass_payload
class A:
    @dataclass_payload
    class Item:
        a: bool

    a: int
    b: bytes
    c: varlenH
    d: Item
    e: [Item]
    f: str
    g: List[Item]

In order to get this to production, the following still needs to be done:

  • Write(/refactor into) unit tests for payload operations.
  • Write unit tests for dataclass operations.
  • Write documentation.
  • Sanity check that everything still works when hooked into a Community.

@drew2a If you have time, would you like to perform a sanity check using the prototype branch?

@drew2a
Copy link
Contributor

drew2a commented Oct 6, 2021

Sure!

@qstokkink
Copy link
Collaborator

qstokkink commented Oct 6, 2021

Thanks, I'll leave that to you and focus on the documentation myself.
Edit: done.

@drew2a
Copy link
Contributor

drew2a commented Oct 6, 2021

First of all: thank you for your work!
It is amazing feeling, when you declare dataclass-like object in ipv8 👍

Community checks

I've tried to test this branch on PopularityCommunity.

Here the code: https://github.com/Tribler/tribler/pull/6427/files

Conclusion: it works on local tests, but doesn't work on a wild network.
(maybe I just implemented it incorrectly)

Stacktrace:

[PID:40930] 2021-10-06 21:30:56,289 - ERROR <community:435> PopularityCommunity.on_packet(): Exception occurred while handling packet!
Traceback (most recent call last):
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/messaging/serialization.py", line 392, in unpack_serializable
    offset = self._packers[fmt].unpack(data, offset, unpack_list)
TypeError: unhashable type: 'list'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/messaging/serialization.py", line 392, in unpack_serializable
    offset = self._packers[fmt].unpack(data, offset, unpack_list)
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/messaging/serialization.py", line 255, in unpack
    result = unpack_from(self.format_str, data, offset)
struct.error: unpack_from requires a buffer of at least 35108 bytes for unpacking 8 bytes at offset 35100 (actual buffer size is 395)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/community.py", line 431, in on_packet
    result = handler(source_address, data)
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/lazy_community.py", line 80, in wrapper
    unpacked = self.serializer.unpack_serializable_list(payloads, remainder, offset=23)
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/messaging/serialization.py", line 420, in unpack_serializable_list
    payload, offset = self.unpack_serializable(serializable, data, offset)
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/messaging/serialization.py", line 400, in unpack_serializable
    offset = self._packers['payload-list'].unpack(data, offset, unpack_list, fmt[0])
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/messaging/serialization.py", line 241, in unpack
    offset = self.packer.unpack(data, offset, result, *args)
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/messaging/serialization.py", line 86, in unpack
    unpacked, offset = self.serializer.unpack_serializable(serializable_class, data, offset=offset + 2)
  File "/Users/<user>/Projects/github.com/Tribler/tribler/src/pyipv8/ipv8/messaging/serialization.py", line 402, in unpack_serializable
    raise PackError("Could not unpack item: %s\n%s: %s" % (fmt, type(e).__name__, str(e))) from e
ipv8.messaging.serialization.PackError: Could not unpack item: q

Notes (doesn't related to the wild network):

msg_id

If I specify msg_id with type, eg:

timestamp_type = TypeVar('Q')


@dataclass_payload
class TorrentsHealthPayload:


    @dataclass_payload
    class Torrent:
        infohash: bytes
        seeders: int
        leechers: int
        timestamp: timestamp_type

    random_torrents_length: int
    torrents_checked_length: int
    random_torrents: List[Torrent]
    torrents_checked: List[Torrent]

    msg_id:int = 1

Error occurred:

>       return TorrentsHealthPayload(
            random_torrents_length=len(random_torrents_checked),
            torrents_checked_length=len(popular_torrents_checked),
            random_torrents=to_list(random_torrents_checked),
            torrents_checked=to_list(popular_torrents_checked)
        )
E       TypeError: __init__() missing 1 required positional argument: 'msg_id'

Hints doesn't work

Here the dataclass_payload example:
image

Here just an ordinary dataclass:
image

@dataclass
class DataClass:
    any: int
    data: str

@qstokkink
Copy link
Collaborator

@drew2a that was fast, thanks for the quick feedback 👍

I see the following going "wrong" (i.e., a mismatch between your expectations and the implementation - so this should change):

  • Right now the msg_id has to be the last thing defined (otherwise @dataclass will interpret it as the first argument of an instance instead of a class field).
  • Because the msg_id is typed, it is assumed to be part of the dataclass.
  • It seems my IDE uses a dirty workaround to determine what is and isn't a dataclass (the wrapper name and not whether the class is actually a dataclass class). You can get hinting by either importing from ipv8.messaging.payload_dataclass import dataclass_payload as dataclass or using both @dataclass_payload and @dataclass on your class. Both are.. suboptimal.
  • If you're working with the wild network the error you're getting is probably because int translates to q instead of I, making the wire formats incompatible.

Please let me know what you think of these possible changes to address these issues:

  1. Instead of adding a msg_id to the dataclass field, add a msg_id argument to the dataclass_payload. This would take care of both of the msg_id related issues. The downside of this would be that this is different from how all other Payloads define the message identifier.
  2. Rename dataclass_payload to dataclass to trick the IDE into type hinting. We're intentionally causing a namespace conflict though: this is a bit dirty.

I don't have a solution for the q vs I format. Without the need for backward compatibility this problem wouldn't exist. Perhaps this is not something to be "solved" and simply something to accept if we want backward compatibility.

@qstokkink
Copy link
Collaborator

qstokkink commented Oct 7, 2021

@drew2a I pushed two new commits to the repo that implement the possible changes from my previous post.

[5b88423]
Allows you to pass a msg_id to the wrapper. This lets you define message identifiers like this:

@dataclass_payload(msg_id=53)
class MyMessage:
    a: int
    b: bytes

[b92a8e8]
Renames dataclass_payload to dataclass. This lets you enable type hinting like this:

from dataclasses import dataclass
from ipv8.messaging.payload_dataclass import dataclass

Please let me know if you think these are improvements, or if we should throw these changes out.

@drew2a
Copy link
Contributor

drew2a commented Oct 7, 2021

I'll check it soon :)

@drew2a
Copy link
Contributor

drew2a commented Oct 7, 2021

Rename dataclass_payload to dataclass to trick the IDE into type hinting. We're intentionally causing a namespace conflict though: this is a bit dirty.

@qstokkink pydantic do something very similar to your approach: https://github.com/samuelcolvin/pydantic/blob/master/pydantic/dataclasses.py#L218-L256

@drew2a
Copy link
Contributor

drew2a commented Oct 7, 2021

msg_id

This improvement works and looks perfectly to me:

@dataclass(msg_id=1)
class TorrentsHealthPayload:
    @dataclass
    class Torrent:
        infohash: bytes
        seeders: int
        leechers: int
        timestamp: timestamp_type

    random_torrents_length: int
    torrents_checked_length: int
    random_torrents: List[Torrent]
    torrents_checked: List[Torrent]

dataclass

This improvement also looks cool, but it doesn't solve the hint problem (at least for me):
image

Probably it is not related to improvements, but I see the following warning:
image

All my changes are located here: https://github.com/Tribler/tribler/pull/6427/files#diff-ada92791a877316681cd06afcbf338be4bee5b61357522ea75c700206f3e098f

@qstokkink
Copy link
Collaborator

Thanks for checking, I'll keep the msg_id change and I'll try to fix the dataclass hinting.

After restarting my IDE my type hinting also no longer worked: PyCharm seems to somehow cache the type hints if you've first run it with a normal dataclass and then with the replacement dataclass 😕

@qstokkink
Copy link
Collaborator

qstokkink commented Oct 7, 2021

🤔 I have it to the point where it (PyCharm) checks the types:

type checking

But it doesn't suggest the names:

hints

@qstokkink
Copy link
Collaborator

🤦‍♂️ nevermind, Ctrl + Q is not a thing normal dataclasses support. My bad. Works as expected:

derp

Well, I guess everything works then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
priority: medium Enhancements, features or exotic bugs
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants