In [None]:
from typing import Any, Callable, Self

from atria_types._utilities._repr import RepresentationMixin
import torch
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator


class TensorOperations:
    def __init__(self, model: "TensorDataModel"):
        self.model = model

    @property
    def model_fields(self):
        return self.model.__class__.model_fields

    def _map_tensors(self, fn: Callable) -> "TensorDataModel":
        """Apply function to all tensor fields."""
        updates = {}
        for field_name in self.model_fields.keys():
            if field_name == "metadata":
                continue
            val = getattr(self.model, field_name)
            updates[field_name] = fn(val) if isinstance(val, torch.Tensor) else val

        return self.model.model_copy(update=updates)

    def to(self, device: torch.device) -> "TensorDataModel":
        return self._map_tensors(lambda t: t.to(device))

    def cpu(self) -> "TensorDataModel":
        return self._map_tensors(lambda t: t.cpu())

    def cuda(self) -> "TensorDataModel":
        return self._map_tensors(lambda t: t.cuda())

    def numpy(self) -> "TensorDataModel":
        return self._map_tensors(lambda t: t.detach().cpu().numpy())



class TensorDataModel(RepresentationMixin, BaseModel):
    """Base model where all declared fields must be tensors."""

    model_config = ConfigDict(
        arbitrary_types_allowed=True,
        extra="allow",
        validate_assignment=True,
    )

    metadata: dict[str, Any] = Field(default_factory=dict, repr=False)
    _is_batched: bool = PrivateAttr(default=False)

    @property
    def ops(self) -> TensorOperations:
        return TensorOperations(self)

    @model_validator(mode="before")
    @classmethod
    def move_extras_to_metadata(cls, data: Any) -> dict[str, Any]:
        """Move non-declared fields to metadata."""
        if not isinstance(data, dict):
            return data

        declared_fields = set(cls.model_fields.keys())

        extras = {}
        cleaned_data = {}
        metadata = data.get("metadata", {})

        for key, value in data.items():
            if key in declared_fields:
                cleaned_data[key] = value
            elif key != "metadata":
                extras[key] = value

        if extras:
            metadata.update(extras)
            cleaned_data["metadata"] = metadata
        elif metadata:
            cleaned_data["metadata"] = metadata

        for key in declared_fields:
            if key not in cleaned_data and key in data:
                cleaned_data[key] = data[key]

        return cleaned_data

    @model_validator(mode="after")
    def validate_tensor_fields(self) -> Self:
        """Validate that all non-metadata fields are tensors."""
        for field_name in self.__class__.model_fields.keys():
            if field_name == "metadata":
                continue

            value = getattr(self, field_name)
            if value is not None and not isinstance(value, torch.Tensor):
                raise TypeError(
                    f"Field '{field_name}' must be a torch.Tensor, got {type(value).__name__}"
                )
        return self

    @classmethod
    def batch(cls, items: list[Self]) -> Self:
        """Create a batched instance from a list of instances."""
        if not items:
            raise ValueError("Cannot batch empty list")

        if not all(type(item) is type(items[0]) for item in items):
            raise TypeError("All items must be of the same type")

        field_values = {}

        for field_name in cls.model_fields.keys():
            if field_name == "metadata":
                # Batch metadata as lists
                batched_meta = {}
                for item in items:
                    for k, v in item.metadata.items():
                        batched_meta.setdefault(k, []).append(v)
                field_values[field_name] = batched_meta
                continue

            vals = [getattr(item, field_name) for item in items]

            if vals[0] is None:
                field_values[field_name] = None
            else:
                field_values[field_name] = torch.stack(vals, dim=0)

        batched_instance = cls(**field_values)
        batched_instance._is_batched = True
        return batched_instance

    def __len__(self):
        """Return batch size if batched, else 1."""
        if not self._is_batched:
            return 1

        for field_name in self.__class__.model_fields.keys():
            if field_name == "metadata":
                continue
            val = getattr(self, field_name)
            if isinstance(val, torch.Tensor):
                return val.shape[0]
        return 1
    
    def __rich_repr__(self):        
        yield from super().__rich_repr__()
        yield "is_batched", self._is_batched
        yield "batch_size", len(self) if self._is_batched else 1

class MyInput(TensorDataModel):
    x: torch.Tensor
    y: torch.Tensor


class MyOtherInput(TensorDataModel):
    features: torch.Tensor
    labels: torch.Tensor
    masks: torch.Tensor

In [19]:
# Test basic creation
inp = MyInput(
    x=torch.randn(3),
    y=torch.randn(5),
    extra_field="goes to metadata"
)

print("Input:", inp)
print("\nMetadata:", inp.metadata)
print("\nX shape:", inp.x.shape)
print("\nY shape:", inp.y.shape)

Input: MyInput(
    __class__='MyInput',
    metadata={'extra_field': 'goes to metadata'},
    x=tensor([-0.4323,  0.2839, -1.4108]),
    y=tensor([ 1.0940,  1.3503, -0.7195,  0.9806,  1.1266]),
    is_batched=False
)

Metadata: {'extra_field': 'goes to metadata'}

X shape: torch.Size([3])

Y shape: torch.Size([5])


In [28]:
# Test batching - now using the same class!
inp1 = MyInput(x=torch.randn(3), y=torch.randn(5))
inp2 = MyInput(x=torch.randn(3), y=torch.randn(5))
inp3 = MyInput(x=torch.randn(3), y=torch.randn(5))

batched = MyInput.batch([inp1, inp2, inp3])
print(batched)

# Test with the other model too
other1 = MyOtherInput(
    features=torch.randn(10), labels=torch.randn(2), masks=torch.randn(10)
)
other2 = MyOtherInput(
    features=torch.randn(10), labels=torch.randn(2), masks=torch.randn(10)
)

batched_other = MyOtherInput.batch([other1, other2])
print("\n\nBatched other:", batched_other)

MyInput(
    __class__='MyInput',
    metadata={},
    x=tensor([[-0.6975, -1.8428,  2.2085],
        [ 0.2949,  0.4840,  0.8585],
        [ 0.2605, -0.2446,  1.0338]]),
    y=tensor([[ 0.8363, -1.0062,  0.0951,  0.4698,  0.2463],
        [-0.2219,  1.3413,  0.8341,  0.0491,  1.1501],
        [-1.3777,  0.9197, -0.8478, -1.0484,  0.8841]]),
    is_batched=True
)


Batched other: MyOtherInput(
    __class__='MyOtherInput',
    metadata={},
    features=tensor([[-0.5417, -0.1178,  0.0296, -0.9935, -0.9857, -0.4276, -0.7270,  1.1856,
         -0.4716,  0.4992],
        [-0.7749, -0.4819,  1.4746, -0.5938, -0.4834, -0.2157, -0.7273,  0.4117,
         -0.0335,  1.6378]]),
    labels=tensor([[ 0.3722, -1.1239],
        [-0.2652,  0.3614]]),
    masks=tensor([[ 2.1070,  1.2946, -1.0728, -1.0005, -0.6491,  0.3386,  0.2027, -0.4565,
         -1.4883, -0.9343],
        [-0.7686, -0.0308,  1.3096, -1.3216,  1.3275,  1.4102,  0.4972, -1.1578,
         -0.0819,  1.6519]]),
    is_batched=True
)


In [4]:
# Test device operations on batched data
batched_cpu = batched.ops.cpu()
print("CPU device:", batched_cpu.x.device)

# Test numpy conversion on batched data
arr = batched.ops.numpy()
print("Numpy batched x shape:", arr.x.shape)

# Test on single instance too
inp_cpu = inp.ops.cpu()
print("\nSingle instance CPU device:", inp_cpu.x.device)

CPU device: cpu
Numpy batched x shape: (3, 3)

Single instance CPU device: cpu


In [5]:
# Test validation - should raise TypeError
try:
    bad_input = MyInput(x="not a tensor", y=torch.randn(5))
except TypeError as e:
    print(f"✓ Validation works: {e}")

ValidationError: 1 validation error for MyInput
x
  Input should be an instance of Tensor [type=is_instance_of, input_value='not a tensor', input_type=str]
    For further information visit https://errors.pydantic.dev/2.12/v/is_instance_of