Skip to content

Commit

Permalink
[Python] struct serialization support (#482)
Browse files Browse the repository at this point in the history
struct serializer for python
  • Loading branch information
chaokunyang committed Jun 18, 2023
1 parent ca84f50 commit 1e366ed
Showing 1 changed file with 202 additions and 0 deletions.
202 changes: 202 additions & 0 deletions python/pyfury/_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import datetime
import logging
import typing

from pyfury._serializer import NOT_SUPPORT_CROSS_LANGUAGE
from pyfury.buffer import Buffer
from pyfury.error import ClassNotCompatibleError
from pyfury.serializer import (
ListSerializer,
MapSerializer,
PickleSerializer,
Serializer,
)
from pyfury.type import (
TypeVisitor,
infer_field,
FuryType,
Int8Type,
Int16Type,
Int32Type,
Int64Type,
Float32Type,
Float64Type,
is_py_array_type,
compute_string_hash,
qualified_class_name,
)

logger = logging.getLogger(__name__)


basic_types = {
bool,
Int8Type,
Int16Type,
Int32Type,
Int64Type,
Float32Type,
Float64Type,
int,
float,
str,
bytes,
datetime.datetime,
datetime.date,
datetime.time,
}


class ComplexTypeVisitor(TypeVisitor):
def __init__(
self,
fury,
):
self.fury = fury

def visit_list(self, field_name, elem_type, types_path=None):
# Infer type recursively for type such as List[Dict[str, str]]
elem_serializer = infer_field("item", elem_type, self, types_path=types_path)
return ListSerializer(self.fury, list, elem_serializer)

def visit_dict(self, field_name, key_type, value_type, types_path=None):
# Infer type recursively for type such as Dict[str, Dict[str, str]]
key_serializer = infer_field("key", key_type, self, types_path=types_path)
value_serializer = infer_field("value", value_type, self, types_path=types_path)
return MapSerializer(self.fury, dict, key_serializer, value_serializer)

def visit_customized(self, field_name, type_, types_path=None):
return None

def visit_other(self, field_name, type_, types_path=None):
if type_ not in basic_types and not is_py_array_type(type_):
return None
serializer = self.fury.class_resolver.get_serializer(type_)
assert not isinstance(serializer, (PickleSerializer,))
return serializer


def _get_hash(fury_, field_names: list, type_hints: dict):
visitor = StructHashVisitor(fury_)
for index, key in enumerate(field_names):
infer_field(key, type_hints[key], visitor, types_path=[])
hash_ = visitor.get_hash()
assert hash_ != 0
return hash_


class ComplexObjectSerializer(Serializer):
def __init__(self, fury_, clz: type, type_tag: str):
super().__init__(fury_, clz)
self._type_tag = type_tag
self._type_hints = typing.get_type_hints(clz)
self._field_names = sorted(self._type_hints.keys())
self._serializers = [None] * len(self._field_names)
visitor = ComplexTypeVisitor(fury_)
for index, key in enumerate(self._field_names):
serializer = infer_field(key, self._type_hints[key], visitor, types_path=[])
self._serializers[index] = serializer
from pyfury._fury import Language

if self.fury_.language == Language.PYTHON:
logger.warning(
"Type of class %s shouldn't be serialized using cross-language "
"serializer",
clz,
)
self._hash = 0

def get_cross_language_type_id(self):
return FuryType.FURY_TYPE_TAG.value

def get_cross_language_type_tag(self):
return self._type_tag

def write(self, buffer, value):
return self.cross_language_write(buffer, value)

def read(self, buffer):
return self.cross_language_read(buffer)

def cross_language_write(self, buffer: Buffer, value):
if self._hash == 0:
self._hash = _get_hash(self.fury_, self._field_names, self._type_hints)
buffer.write_int32(self._hash)
for index, field_name in enumerate(self._field_names):
field_value = getattr(value, field_name)
serializer = self._serializers[index]
self.fury_.cross_language_serialize_referencable(
buffer, field_value, serializer=serializer
)

def cross_language_read(self, buffer):
if self._hash == 0:
self._hash = _get_hash(self.fury_, self._field_names, self._type_hints)
hash_ = buffer.read_int32()
if hash_ != self._hash:
raise ClassNotCompatibleError(
f"Hash {hash_} is not consistent with {self._hash} "
f"for class {self.type_}",
)
obj = self.type_.__new__(self.type_)
self.fury_.reference_resolver.reference(obj)
for index, field_name in enumerate(self._field_names):
serializer = self._serializers[index]
field_value = self.fury_.cross_language_deserialize_referencable(
buffer, serializer=serializer
)
setattr(
obj,
field_name,
field_value,
)
return obj


class StructHashVisitor(TypeVisitor):
def __init__(
self,
fury,
):
self.fury = fury
self._hash = 17

def visit_list(self, field_name, elem_type, types_path=None):
# TODO add list element type to hash.
id_ = abs(ListSerializer(self.fury, list).get_cross_language_type_id())
self._hash = self._compute_field_hash(self._hash, id_)

def visit_dict(self, field_name, key_type, value_type, types_path=None):
# TODO add map key/value type to hash.
id_ = abs(MapSerializer(self.fury, dict).get_cross_language_type_id())
self._hash = self._compute_field_hash(self._hash, id_)

def visit_customized(self, field_name, type_, types_path=None):
serializer = self.fury.class_resolver.get_serializer(type_)
if serializer.get_cross_language_type_id() != NOT_SUPPORT_CROSS_LANGUAGE:
tag = serializer.get_cross_language_type_tag()
else:
tag = qualified_class_name(type_)
tag_hash = compute_string_hash(tag)
self._hash = self._compute_field_hash(self._hash, tag_hash)

def visit_other(self, field_name, type_, types_path=None):
if type_ not in basic_types and not is_py_array_type(type_):
# FIXME ignore unknown types for hash calculation
return None
serializer = self.fury.class_resolver.get_serializer(type_)
assert not isinstance(serializer, (PickleSerializer,))
id_ = serializer.get_cross_language_type_id()
assert id_ is not None, serializer
id_ = abs(id_)
self._hash = self._compute_field_hash(self._hash, id_)

@staticmethod
def _compute_field_hash(hash_, id_):
new_hash = hash_ * 31 + id_
while new_hash >= 2**31 - 1:
new_hash = new_hash // 7
return new_hash

def get_hash(self):
return self._hash

0 comments on commit 1e366ed

Please sign in to comment.