## Test PyTorch and dataclasses integration

Sources:
* [How to use dataclass with PyTorch](https://discuss.pytorch.org/t/how-to-use-dataclass-with-pytorch/53444)
* [Subclassing torch.Tensor](https://discuss.pytorch.org/t/subclassing-torch-tensor/23754)

In [115]:
import torch
from dataclasses import dataclass, field

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')

DEVICE

device(type='mps')

### Scenario #1 – no subclassing, class stores tensors
* Possible, no errors thrown when annotating using `torch.Tensor`
* Any extra parameters (e.g., `dtype`, `device`) for `tensor` object have to be set either when initialising class instance or in `__post_init__` method

In [116]:
@dataclass
class A:
    tensor: torch.Tensor
    dtype: torch.dtype
    param: int

    def __post_init__(self):
        self.tensor = self.tensor.to(device=DEVICE, dtype=self.dtype)

a = A(torch.ones((3, 3)), dtype=torch.float16, param=1)
a

A(tensor=tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='mps:0', dtype=torch.float16), dtype=torch.float16, param=1)

### Scenario #2 – subclassing `torch.Tensor`
* Doesn't easily work, even setting default values of parameters and using `__new__` to initialise `torch.Tensor` parent class, `dataclass` doesn't automatically allow to instantiate a tensor and throws `RecursionError` for some reason
* Could be dangerous, `dataclass` automatically generates some methods (e.g., `__init__`, `__eq__`), which might break functionality of a subclass. This could be turned off using `@dataclass(eg=False)`

In [117]:
@dataclass(eq=False)
class B(torch.Tensor):
    data: torch.Tensor = field(default_factory=torch.tensor)
    param: int = 1

    @staticmethod
    def __new__(cls, data, param, *args, **kwargs):
        return super().__new__(cls, data, *args, **kwargs)

b = B(torch.ones((3, 3)), 1)
b

RecursionError: maximum recursion depth exceeded

### Scenario #2 – subclassing `torch.nn.module`
* Code example from [How to use dataclass with PyTorch](https://discuss.pytorch.org/t/how-to-use-dataclass-with-pytorch/53444)
* Seems to work, requires usage of `__new__` to initialise parent class and `__post_init__` to initalise layers
* Additionally `unsafe_hash=True` is required (`torch.nn.Module` needs a `__hash__` function)
* Could lead to further problems, one of the comments in the source mentions problems with transfering weights to the gpu, which might be resolved with `@dataclass(eq=False)` (not generating `__eq__` method)

In [None]:
@dataclass
class DataclassModule(torch.nn.Module):
    def __new__(cls, *args, **k):
        inst = super().__new__(cls)
        torch.nn.Module.__init__(inst)
        return inst

@dataclass(unsafe_hash=True)
class Net(DataclassModule):
    other_layer: torch.nn.Module
    input_feats: int = 10
    output_feats: int = 20

    def __post_init__(self):
        self.layer = torch.nn.Linear(self.input_feats, self.output_feats)

    def forward(self, x):
        return self.layer(self.other_layer(x))

net = Net(other_layer=torch.nn.Linear(10, 10))
assert net(torch.tensor([1.]*10)).shape == (20,)
assert len(list(net.parameters())) == 4

@dataclass(unsafe_hash=True)
class A(DataclassModule):
    x: int
    def __post_init__(self):
        self.layer1 = torch.nn.Linear(self.x, self.x)

@dataclass(unsafe_hash=True)
class B(A):
    y: int
    def __post_init__(self):
        super().__post_init__()
        self.layer2 = torch.nn.Linear(self.y, self.y)

assert len(list(A(1).parameters())) == 2
assert len(list(B(1, 2).parameters())) == 4