In [189]:
import dataclasses
from dataclasses import dataclass
from typing import IO, Any
from pathlib import Path
import functools

import numpy as np
from numpy.typing import NDArray

import h5py

In [243]:
FileType = str | Path | IO[bytes]


def _is_primitive(obj_or_class: Any):
    primitives = (int, float, str)
    if type(obj_or_class) == type:
        return  obj_or_class in primitives
    else:
        return isinstance(obj_or_class, primitives)


def _fields(T: Any) -> dict[str, type]:
    # TODO: two separate implementations - one for dataclass, one for pydantic model
    ret: dict[str, type] = {}
    for field in dataclasses.fields(T):
        ret[field.name] = field.type
    return ret


def serialisable(cls):
    serialisable_attrs = _fields(cls)

    def serialise(self, output: FileType | h5py.File | h5py.Group, mode: str = "w"):
        # TODO: replace mode with ADD/REPLACE/NEW?
        h5 = (
            output
            if isinstance(output, (h5py.File, h5py.Group))
            else h5py.File(output, mode)
        )

        for attr, _ in serialisable_attrs.items():
            val = getattr(self, attr)
            if val is None:
                continue

            if _is_primitive(val):
                h5.attrs[attr] = val
            # TODO: elif dict/list -- json? check size?
            elif isinstance(val, np.ndarray):
                h5.create_dataset(attr, data=val)
            elif getattr(val, "__is_serialisable", False):
                grp = h5.create_group(attr)
                val.serialise(output=grp)
            else:
                raise Exception(f"Unsupported type of attribute {attr}")

    @staticmethod
    def deserialise(input: FileType | h5py.File | h5py.Group) -> cls:
        h5 = (
            input
            if isinstance(input, (h5py.File, h5py.Group))
            else h5py.File(input, "r")
        )

        attrs = {}
        for attr, T in serialisable_attrs.items():
            val = None
            if _is_primitive(T):
                val = h5.attrs.get(attr)
            else:
                serialised = h5[attr]
                if isinstance(serialised, h5py.Dataset):
                    val = np.array(serialised)
                elif isinstance(serialised, h5py.Group):
                    val = T.deserialise(serialised)
                else:
                    raise Exception("Unknown type of data in hdf5")
            attrs[attr] = val
        return cls(**attrs)

    cls.serialise = serialise
    cls.deserialise = deserialise
    cls.__is_serialisable = True
    return cls

In [246]:
@dataclass
class Base:
    name: str

@serialisable
@dataclass
class DataClass(Base):
    data: NDArray[np.float_]


@serialisable
@dataclass
class Small:
    x: int
    y: str


@serialisable
@dataclass
class Big:
    a: int
    s: Small
    d: DataClass

In [245]:

s = Small(None, "three")
d = DataClass(data=np.random.rand(2,3), name="D")
b = Big(a=10, s=s, d=d)
b.serialise("/tmp/output.hdf5")

Exception: Unsupported type of attribute s

In [242]:
x = Big.deserialise("/tmp/output.hdf5")
x.s.x is None

True

In [238]:
type(int)

type