In [None]:
# | default_exp image_readers/safetensors_reader

# Imports

In [None]:
# | export

from typing import Any

import torch
from monai.data import ImageReader, MetaTensor, is_supported_format
from safetensors import safe_open

# Main class

In [None]:
# | export


class SafetensorsReader(ImageReader):
    def __init__(
        self,
        image_key: str,
        spacing_key: str | None = None,
        other_keys: set[str] | None = None,
        add_channel_dim: bool = True,
        dtype=torch.float32,
    ):
        """Reader for Safetensors image files.

        Args:
            iamge_key: Key to access the image tensor in the safetensors file.
            spacing_key: Key to access the spacing tensor in the safetensors file. Leave blank if not applicable.
            other_keys: Set of keys to access other tensors in the safetensors file. Leave blank if not applicable.
            add_channel_dim: Whether to add a channel dimension to the image tensor.
            dtype: Desired data type for the image tensor.
        """
        self.image_key = image_key
        self.spacing_key = spacing_key
        self.other_keys = other_keys
        self.add_channel_dim = add_channel_dim
        self.dtype = dtype

    def verify_suffix(self, filename):
        """Ensure the file has a supported safetensors suffix."""
        return is_supported_format(filename, ["safetensors"])

    def read(self, filepath) -> dict[str, torch.Tensor | Any]:
        """Read image data from a safetensors file."""
        if isinstance(filepath, (list, tuple)):
            return [self.read(fp) for fp in filepath]

        with safe_open(filepath, "pt") as f:
            image = f.get_tensor(self.image_key)
            spacing = f.get_tensor(self.spacing_key) if self.spacing_key else None
            others = {key: f.get_tensor(key) for key in self.other_keys} if self.other_keys else {}

        return {"image": image, "spacing": spacing, "others": others}

    def get_data(self, datapoint):
        """Extract and process image data from the datapoint."""
        datapoint = datapoint[0]

        image = datapoint["image"].to(self.dtype)
        spacing = datapoint["spacing"]
        others = datapoint["others"]

        if self.add_channel_dim:
            image = image.unsqueeze(0)

        image = MetaTensor(image.type(torch.float32), affine=self._spacing_to_affine(spacing))

        return image, others

    @staticmethod
    def _spacing_to_affine(spacing):
        """Convert spacing tensor to affine matrix according to Metatensor notation."""
        if spacing is None:
            spacing = torch.ones(3)
        return torch.diag(torch.cat([spacing, torch.zeros(1)]))

In [None]:
spacing = torch.tensor([1, 2, 3])
SafetensorsReader._spacing_to_affine(spacing)


[1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m1[0m., [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
        [1m[[0m[1;36m0[0m., [1;36m2[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m,
        [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m3[0m., [1;36m0[0m.[1m][0m,
        [1m[[0m[1;36m0[0m., [1;36m0[0m., [1;36m0[0m., [1;36m0[0m.[1m][0m[1m][0m[1m)[0m

# nbdev

In [None]:
!nbdev_export