Skip to content

Commit

Permalink
Merge pull request #1278 from bacox/feature/serial-ext
Browse files Browse the repository at this point in the history
Add serialization of lists of ints and booleans
  • Loading branch information
qstokkink committed Mar 4, 2024
2 parents 3ba7ccd + 6f58aa1 commit e768329
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 6 deletions.
5 changes: 4 additions & 1 deletion doc/reference/serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ A ``Serializer`` can be extended with additional data types by calling ``seriali
"varlenI", "4 + ?", "str (length < 4294967295)"
"doublevarlenH", "2 + ?", "str (length ? < 65356)"
"payload", "2 + ?", "Serializable"

"payload-list", "?", "[Serializable]"
"arrayH-?", "2 + ? * 1", "[bool]"
"arrayH-q", "2 + ? * 8", "[int]"
"arrayH-d", "2 + ? * 8", "[float]"

Some of these data types represent common usage of serializable classes:

Expand Down
10 changes: 8 additions & 2 deletions ipv8/messaging/payload_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,29 @@ def type_from_format(fmt: str) -> TypeVar:
return out


def type_map(t: Type) -> FormatListType:
def type_map(t: Type) -> FormatListType: # noqa: PLR0911
if t is bool:
return "?"
if t is int:
return "q"
if t is float:
return "d"
if t is bytes:
return "varlenH"
if t is str:
return "varlenHutf8"
if isinstance(t, TypeVar):
return t.__name__
if getattr(t, '__origin__', None) in (tuple, list, set):
return [t.__args__[0]]
fmt = t.__args__[0]
if issubclass(fmt, Serializable):
return [fmt]
return f"arrayH-{type_map(t.__args__[0])}"
if isinstance(t, (tuple, list, set)) or Serializable in getattr(t, "mro", list)():
return cast(Type[Serializable], t)
raise NotImplementedError(t, " unknown")


def dataclass(cls: type | None = None, *, # noqa: PLR0913
init: bool = True,
repr: bool = True, # noqa: A002
Expand Down
42 changes: 40 additions & 2 deletions ipv8/messaging/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
import socket
import typing
from array import array
from binascii import hexlify
from contextlib import suppress
from struct import Struct, pack, unpack_from
Expand Down Expand Up @@ -331,6 +332,40 @@ def unpack(self, data: bytes, offset: int, unpack_list: list, *args: object) ->
return offset


class DefaultArray(Packer):
"""
A format known to the ``array`` module (like 'I', 'B', etc.).
Also adds support for '?'.
"""

def __init__(self, format_str: str, length_format: str) -> None:
"""
Create a new packer for the given ``array`` format string.
"""
self.format_str = format_str
self.real_format_str = "B" if format_str == "?" else format_str
self.length_format = length_format
self.length_size = Struct(length_format).size
self.base = array(self.real_format_str).itemsize

def pack(self, data: list) -> bytes:
"""
Pack a list of items by forwarding them to ``array``.
"""
return pack(self.length_format, len(data)) + array(self.real_format_str, data).tobytes()

def unpack(self, data: bytes, offset: int, unpack_list: list, *args: object) -> int:
"""
Unpack a list of items from the known ``array`` format.
"""
str_length = unpack_from(self.length_format, data, offset)[0] * self.base
a = array(self.real_format_str)
a.frombytes(data[offset + self.length_size: offset + self.length_size + str_length])
unpack_list.append([bool(b) for b in a] if self.format_str == "?" else list(a))
return offset + self.length_size + str_length


class DefaultStruct(Packer):
"""
A format known to the ``struct`` module (like 'I', '20s', etc.).
Expand All @@ -351,7 +386,7 @@ def pack(self, *data: list) -> bytes:

def unpack(self, data: bytes, offset: int, unpack_list: list, *args: object) -> int:
"""
Unpack a list of items from a the known ``struct`` format.
Unpack a list of items from the known ``struct`` format.
"""
result = unpack_from(self.format_str, data, offset)
unpack_list.append(result if len(result) > 1 else result[0])
Expand Down Expand Up @@ -407,7 +442,10 @@ def __init__(self) -> None:
'varlenI': VarLen('>I'),
'doublevarlenH': VarLen('>H'),
'payload': NestedPayload(self),
'payload-list': ListOf(NestedPayload(self))
'payload-list': ListOf(NestedPayload(self)),
'arrayH-?': DefaultArray("?", "H"),
'arrayH-q': DefaultArray("q", "H"),
'arrayH-d': DefaultArray("d", "H"),
}

def get_available_formats(self) -> list[str]:
Expand Down
47 changes: 46 additions & 1 deletion ipv8/test/messaging/test_payload_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ class NestedListType:

a: List[NativeInt] # Backward compatibility: Python >= 3.9 can use ``list[NativeInt]``

@dataclass
class ListIntType:
"""
A single list of integers.
"""

a: List[int]

@dataclass
class ListBoolType:
"""
A single list of booleans.
"""

a: List[bool]

@ogdataclass
class Unknown:
Expand Down Expand Up @@ -161,6 +176,8 @@ class Everything:
d: EverythingItem
e: List[EverythingItem] # Backward compatibility: Python >= 3.9 can use ``list[EverythingItem]``
f: str
g: List[int]
h: List[bool]


class TestDataclassPayload(TestBase):
Expand Down Expand Up @@ -348,6 +365,26 @@ def test_nested_payload(self) -> None:
self.assertEqual(payload.a, NativeInt(42))
self.assertEqual(deserialized.a, NativeInt(42))

def test_native_intlist_payload(self) -> None:
"""
Check if a list of native types works correctly.
"""
payload = ListIntType([1, 2])
deserialized = self._pack_and_unpack(ListIntType, payload)

self.assertListEqual(payload.a, [1, 2])
self.assertListEqual(deserialized.a, [1, 2])

def test_native_boollist_payload(self) -> None:
"""
Check if a list of native types works correctly.
"""
payload = ListBoolType([True, False])
deserialized = self._pack_and_unpack(ListBoolType, payload)

self.assertListEqual(payload.a, [True, False])
self.assertListEqual(deserialized.a, [True, False])

def test_nestedlist_empty_payload(self) -> None:
"""
Check if an empty list of nested payloads works correctly.
Expand Down Expand Up @@ -416,7 +453,9 @@ def test_everything(self) -> None:
b'1337',
EverythingItem(True),
[EverythingItem(False), EverythingItem(True)],
"hi")
"hi",
[3, 4],
[False, True])

self.assertTrue(is_dataclass(a))

Expand All @@ -439,3 +478,9 @@ def test_everything(self) -> None:

self.assertEqual(a.f, "hi")
self.assertEqual(r.f, "hi")

self.assertEqual(a.g, [3, 4])
self.assertEqual(r.g, [3, 4])

self.assertEqual(a.h, [False, True])
self.assertEqual(r.h, [False, True])

0 comments on commit e768329

Please sign in to comment.