In [None]:
# | default_exp layers/embeddings

# Imports

In [None]:
# | export

from functools import lru_cache
from typing import Literal, Union

import numpy as np
import torch
from einops import rearrange, repeat
from torch import nn

from vision_architectures.blocks.cnn import CNNBlock3D, CNNBlockConfig
from vision_architectures.docstrings import populate_docstring
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field, model_validator
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class RelativePositionEmbeddings3DConfig(CustomBaseModel):
    num_heads: int = Field(..., description="Number of query attention heads")
    grid_size: tuple[int, int, int] = Field(..., description="Size of entire patch matrix.")

    @property
    def num_patches(self) -> int:
        """Number of patches."""
        return np.prod(self.grid_size)

    @model_validator(mode="before")
    @classmethod
    def validate_before(cls, data):
        grid_size = data.get("grid_size")
        if isinstance(data["grid_size"], int):
            data["grid_size"] = (grid_size, grid_size, grid_size)
        return data

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        if isinstance(self.grid_size, int):
            self.grid_size = (self.grid_size, self.grid_size, self.grid_size)
        return self


class AbsolutePositionEmbeddings3DConfig(CustomBaseModel):
    dim: int | None = Field(None, description="Dimension of the position embeddings")
    grid_size: tuple[int, int, int] | None = Field(None, description="Size of entire patch matrix.")
    learnable: bool = Field(False, description="Whether the position embeddings are learnable.")

    @property
    def num_patches(self) -> int:
        """Number of patches."""
        if self.grid_size is None:
            return None
        return np.prod(self.grid_size)

    @model_validator(mode="before")
    @classmethod
    def validate_before(cls, data):
        if isinstance(data.get("grid_size"), int):
            data["grid_size"] = (
                data["grid_size"],
                data["grid_size"],
                data["grid_size"],
            )
        return data

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        if self.learnable and (self.dim is None or self.grid_size is None):
            raise ValueError("dim and grid_size must be provided if learnable is True")
        return self


class AbsolutePositionEmbeddings1DConfig(CustomBaseModel):
    dim: int | None = Field(None, description="Dimension of the position embeddings")
    length: int | None = Field(None, description="Length of the sequence.")
    learnable: bool = Field(False, description="Whether the position embeddings are learnable.")

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        if self.learnable and (self.dim is None or self.length is None):
            raise ValueError("dim and length must be provided if learnable is True")
        return self


class RotaryPositionEmbeddings1DConfig(CustomBaseModel):
    dim: int | None = Field(None, description="Dimension of the position embeddings")

    base: float = Field(10000.0, description="Base value for the exponent.")


class RotaryPositionEmbeddings3DConfig(RotaryPositionEmbeddings1DConfig):
    split: tuple[float, float, float] | tuple[int, int, int] = Field(
        (1 / 3, 1 / 3, 1 / 3),
        description="Split of the position embeddings. If float, converted to int based on self.dim",
    )

    def get_split_as_ints(self, dim: int | None):
        if isinstance(self.split[0], int):
            return self.split

        if dim is None:
            dim = self.dim
        assert dim is not None, "dim must be provided if not specified in config"

        return tuple(int(s * dim) for s in self.split)

    @model_validator(mode="after")
    def validate(self):
        super().validate()
        if self.dim is not None and isinstance(self.split[0], int):
            assert sum(self.split) <= self.dim, "Sum of split must be less than or equal to dim"
        return self


class PatchEmbeddings3DConfig(CNNBlockConfig):
    patch_size: tuple[int, int, int] = Field(..., description="Size of the patches to extract from the input.")
    in_channels: int = Field(..., description="Number of input channels.")
    dim: int = Field(..., description="Dimension of the embeddings.")
    norm_layer: str = Field("layernorm", description="Normalization layer to use.")

    out_channels: None = None
    kernel_size: None = None

    @model_validator(mode="before")
    @classmethod
    def validate_before(cls, data: dict):
        data.setdefault("patch_size", data.pop("kernel_size", None))
        data.setdefault("dim", data.pop("out_channels", None))
        return data

In [None]:
AbsolutePositionEmbeddingsConfig = Union[AbsolutePositionEmbeddings1DConfig, AbsolutePositionEmbeddings3DConfig]
RotaryPositionEmbeddingsConfig = Union[RotaryPositionEmbeddings1DConfig, RotaryPositionEmbeddings3DConfig]

# Architecture

### Position Embeddings

In [None]:
# | export


def get_coords_grid(grid_size: tuple[int, int, int]) -> torch.Tensor:
    """Get a coordinate grid of shape (3, d, h, w) for a given grid size.

    Args:
        grid_size: Size of the grid (d, h, w).

    Returns:
        A tensor of shape (3, d, h, w) containing the coordinates.
    """
    d, h, w = grid_size

    grid_d = torch.arange(d, dtype=torch.int32)
    grid_h = torch.arange(h, dtype=torch.int32)
    grid_w = torch.arange(w, dtype=torch.int32)

    grid = torch.meshgrid(grid_d, grid_h, grid_w, indexing="ij")
    grid = torch.stack(grid, axis=0)
    # (3, d, h, w)

    return grid

In [None]:
# | export


@populate_docstring
class RelativePositionEmbeddings3D(nn.Module):
    """Learnable 3D Relative Position Embeddings. This can be passed directly to the attention layers.
    {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: RelativePositionEmbeddings3DConfig = {}, **kwargs):
        """Initialize RelativePositionEmbeddings3D.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = RelativePositionEmbeddings3DConfig.model_validate(config | kwargs)

        num_heads = self.config.num_heads
        grid_size = self.config.grid_size

        # TODO: Add embed_spacing_info functionality

        relative_limits = (
            2 * grid_size[0] - 1,
            2 * grid_size[1] - 1,
            2 * grid_size[2] - 1,
        )

        self.relative_position_bias_table = nn.Parameter(torch.randn(num_heads, np.prod(relative_limits)))
        # (num_heads, num_patches_z * num_patches_y * num_patches_x)

        # Pair-wise relative position index for each token inside the window
        coords = get_coords_grid(grid_size)
        coords_flatten = rearrange(coords, "three d h w -> three (d h w)", three=3).contiguous()
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += grid_size[0] - 1
        relative_coords[:, :, 1] += grid_size[1] - 1
        relative_coords[:, :, 2] += grid_size[2] - 1
        relative_position_index: torch.Tensor = (
            relative_coords[:, :, 0] * relative_limits[1] * relative_limits[2]
            + relative_coords[:, :, 1] * relative_limits[2]
            + relative_coords[:, :, 2]
        )
        self.relative_position_index = relative_position_index.flatten()
        # (num_patches, num_patches)

    def forward(self) -> torch.Tensor:
        """Get relative position embeddings as specified by the config.

        Returns:
            A tensor of shape (1, num_heads, num_patches, num_patches) containing the relative position embeddings.
        """
        relative_position_embeddings = self.relative_position_bias_table[:, self.relative_position_index].contiguous()
        # (num_heads, num_patches, num_patches)
        relative_position_embeddings = relative_position_embeddings.reshape(
            1, self.config.num_patches, self.config.num_patches, -1
        )
        # (1, num_patches, num_patches, num_heads)
        relative_position_embeddings = rearrange(
            relative_position_embeddings,
            "1 num_patches1 num_patches2 num_heads -> 1 num_heads num_patches1 num_patches2",
        ).contiguous()
        # (1, num_heads, num_patches, num_patches)
        return relative_position_embeddings

In [None]:
test = RelativePositionEmbeddings3D(num_heads=6, grid_size=4)
display(test)
display(test().shape)

In [None]:
# | export


@populate_docstring
class RelativePositionEmbeddings3DMetaNetwork(nn.Module):
    """3D Relative Position Embeddings obtained from a meta network (inspired by SwinV2). This can be passed directly
    to the attention layers. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: RelativePositionEmbeddings3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize RelativePositionEmbeddings3DMetaNetwork.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = RelativePositionEmbeddings3DConfig.model_validate(config | kwargs)

        num_heads = self.config.num_heads
        grid_size = self.config.grid_size

        # TODO: Add embed_spacing_info functionality
        self.cpb_mlp = nn.Sequential(
            nn.Linear(3, 512, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_heads, bias=False),
        )

        relative_limits = (
            2 * grid_size[0] - 1,
            2 * grid_size[1] - 1,
            2 * grid_size[2] - 1,
        )

        # Relative coordinates table
        relative_coords_table = get_coords_grid(relative_limits).float()
        for i in range(3):
            relative_coords_table[i] = (relative_coords_table[i] - (grid_size[0] - 1)) / (
                grid_size[0] - 1 + 1e-5  # small value added to ensure there is no NaN when window size is 1
            )
        relative_coords_table = rearrange(
            relative_coords_table,
            "three num_patches_z num_patches_y num_patches_x -> num_patches_z num_patches_y num_patches_x three",
        ).contiguous()
        relative_coords_table *= 8  # Normalize to -8, 8
        relative_coords_table = (
            torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / np.log2(8)
        )
        # (num_patches_z, num_patches_y, num_patches_x, 3)
        # Allow moving this to and from cuda whenever required but don't save to state_dict
        self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)

        # Pair-wise relative position index for each token inside the window
        coords = get_coords_grid(grid_size)
        coords_flatten = rearrange(coords, "three d h w -> three (d h w)", three=3).contiguous()
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = rearrange(
            relative_coords, "three num_patches1 num_patches2 -> num_patches1 num_patches2 three"
        ).contiguous()
        relative_coords[:, :, 0] += grid_size[0] - 1
        relative_coords[:, :, 1] += grid_size[1] - 1
        relative_coords[:, :, 2] += grid_size[2] - 1
        relative_position_index: torch.Tensor = (
            relative_coords[:, :, 0] * relative_limits[1] * relative_limits[2]
            + relative_coords[:, :, 1] * relative_limits[2]
            + relative_coords[:, :, 2]
        )
        self.relative_position_index = relative_position_index.flatten()
        # (num_patches, num_patches)

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)

    def get_relative_position_embeddings_table(self) -> torch.Tensor:
        """Get the relative position embeddings table from the meta network.

        Returns:
            A tensor of shape (num_patches, num_heads) containing the relative position embeddings table.
        """
        # (num_patches_z, num_patches_y, num_patches_x, 3)
        relative_position_embeddings_table: torch.Tensor = self.cpb_mlp(self.relative_coords_table)
        # (num_patches_z, num_patches_y, num_patches_x, num_heads)
        relative_position_embeddings_table = relative_position_embeddings_table.reshape(-1, self.config.num_heads)
        # (num_patches, num_heads)
        return relative_position_embeddings_table

    def forward(self) -> torch.Tensor:
        """Get relative position embeddings as specified by the config.

        Returns:
            A tensor of shape (num_heads, num_patches, num_patches) containing the relative position embeddings.
        """
        relative_position_embeddings_table = self.checkpointing_level1(self.get_relative_position_embeddings_table)
        # (num_patches, num_heads)
        relative_position_embeddings = relative_position_embeddings_table[self.relative_position_index]
        # (num_patches * num_patches, num_heads)
        relative_position_embeddings = rearrange(
            relative_position_embeddings,
            "(num_patches1 num_patches2) num_heads -> num_heads num_patches1 num_patches2",
            num_patches1=self.config.num_patches,
            num_patches2=self.config.num_patches,
            num_heads=self.config.num_heads,
        ).contiguous()
        # (num_heads, num_patches, num_patches)
        relative_position_embeddings = 16 * torch.sigmoid(relative_position_embeddings)
        # (num_heads, num_patches, num_patches)
        return relative_position_embeddings

In [None]:
test = RelativePositionEmbeddings3DMetaNetwork(num_heads=6, grid_size=(4, 4, 4))
display(test)
display(test().shape)

In [None]:
# | export

RelativePositionEmbeddings = Union[RelativePositionEmbeddings3D, RelativePositionEmbeddings3DMetaNetwork]

In [None]:
# | export


@populate_docstring
def get_sinusoidal_embeddings_3d(
    dim: int,
    grid_size: tuple[int, int, int],
    spacing: tuple[float, float, float] = (1.0, 1.0, 1.0),
    crop_offset: tuple[int, int, int] = None,
    channels_first: bool = True,
) -> torch.Tensor:
    """Get 3D sinusoidal position embeddings.

    Args:
        dim: Embedding dimension. Must be divisible by 6.
        grid_size: Size of the patch grid (d, h, w).
        spacing: Spacing between patches in each dimension. Useful for medical images.
        crop_offset: Used if the embeddings required are of a crop of a larger image. If provided, the grid coordinates
            will be offset accordingly.
        channels_first: {CHANNELS_FIRST_DOC}

    Returns:
        {OUTPUT_3D_DOC}
    """
    if dim % 6 != 0:
        raise ValueError("dim must be divisible by 6")

    grid = get_coords_grid(grid_size)
    # (3, d, h, w)

    # Apply offset if crop parameters are provided
    if crop_offset is not None:
        # Offset the grid coordinates to represent their position in the full volume
        for i in range(3):
            grid[i] = grid[i] + crop_offset[i]

    grid = rearrange(grid, "x d h w -> x 1 d h w").contiguous()
    # (3, 1, d, h, w)

    omega = torch.arange(dim // 6, dtype=torch.float32)
    omega /= dim / 6.0
    omega = 1.0 / (10000**omega)
    # (dim // 6)

    patch_multiplier = torch.Tensor(spacing) / min(spacing)

    embeddings = []
    for i, grid_subset in enumerate(grid):
        grid_subset = grid_subset.reshape(-1)

        out = torch.einsum("m,d->md", grid_subset, omega)

        emb_sin = torch.sin(out)
        emb_cos = torch.cos(out)

        emb = torch.cat([emb_sin, emb_cos], axis=1) * patch_multiplier[i]
        embeddings.append(emb)

    embeddings = torch.cat(embeddings, axis=1)
    # (dim, d * h * w)
    embeddings = rearrange(
        embeddings,
        "(d h w) e -> 1 e d h w",
        d=grid_size[0],
        h=grid_size[1],
        w=grid_size[2],
    ).contiguous()
    # (1, dim, d, h, w)

    embeddings = rearrange_channels(embeddings, True, channels_first)
    # (1, [dim], d, h, w, [dim])

    return embeddings


get_absolute_position_embeddings_3d = get_sinusoidal_embeddings_3d

In [None]:
# | export


@populate_docstring
class AbsolutePositionEmbeddings3D(nn.Module):
    """3D Absolute Position Embeddings. May or may not learnable. These have to be applied on the input manually and
    cannot be passed to attention layers directly. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: AbsolutePositionEmbeddings3DConfig = {}, **kwargs):
        """Initialize AbsolutePositionEmbeddings3D.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = AbsolutePositionEmbeddings3DConfig.model_validate(config | kwargs)

        dim = self.config.dim
        grid_size = self.config.grid_size
        learnable = self.config.learnable

        self.position_embeddings_cache = {}
        self.position_embeddings = None
        if dim is not None and grid_size is not None:
            self.position_embeddings = nn.Parameter(
                get_absolute_position_embeddings_3d(dim, grid_size), requires_grad=learnable
            )

    @populate_docstring
    def forward(
        self,
        x: torch.Tensor,
        embedding_type: Literal["add", "concat"] = "add",
        spacings: torch.Tensor = None,
        channels_first: bool = True,
        crop_offsets: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Apply absolute position embeddings to the input tensor.

        Args:
            x: {INPUT_3D_DOC}
            embedding_type: Type of embedding to apply. 'add' to add the position embeddings to the input,
                'concat' to concatenate them along the channel dimension.
            spacings: {SPACINGS_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            crop_offsets: Used if the embeddings required are of a crop of a larger image. If provided, the grid
                coordinates will be offset accordingly.

        Returns:
            {OUTPUT_3D_DOC}
        """
        assert x.ndim == 5, "Input tensor must be of shape (b, [d], z, y, x, [d])"
        # Check if sufficient information has been provided
        if self.position_embeddings is None:
            dim = self.config.dim
            if dim is None:
                dim = x.shape[1] if channels_first else x.shape[-1]
            grid_size = self.config.grid_size
            if grid_size is None:
                grid_size = tuple(x.shape[2:5] if channels_first else x.shape[1:4])
        else:
            dim = self.config.dim
            grid_size = self.config.grid_size

        # Estimate batch size
        b = x.shape[0]

        # Get position embeddings, adjust based on crop offsets if applicable
        if self.position_embeddings is not None:
            position_embeddings = rearrange_channels(self.position_embeddings, True, channels_first)
            position_embeddings = repeat(position_embeddings, "1 ... -> b ...", b=b)
        else:
            if isinstance(grid_size, int):
                grid_size = (grid_size, grid_size, grid_size)

            if crop_offsets is None:
                cache_key = (dim, grid_size, None)
                if cache_key not in self.position_embeddings_cache:
                    self.position_embeddings_cache[cache_key] = get_absolute_position_embeddings_3d(
                        dim, grid_size, channels_first=channels_first
                    )
                position_embeddings = self.position_embeddings_cache[cache_key]
                position_embeddings = repeat(position_embeddings, "1 ... -> b ...", b=b)
            else:
                if crop_offsets.ndim == 1:
                    crop_offsets = crop_offsets.unsqueeze(0)

                position_embeddings = []
                for crop_offset in crop_offsets:
                    position_embeddings.append(
                        get_absolute_position_embeddings_3d(
                            dim, grid_size, crop_offset=crop_offset.tolist(), channels_first=channels_first
                        )
                    )
                position_embeddings = torch.cat(position_embeddings, dim=0)
            position_embeddings = position_embeddings.to(x.device)
        # (b, [dim], d, h, w, [dim])

        # Incorporate spacing information
        if spacings is not None:
            assert spacings.shape == (b, 3), "spacings must be of shape (batch_size, 3)"
            assert dim % 3 == 0, "dim must be divisible by 3"
            # (b, 3)
            spacings = repeat(spacings, "b three -> b (three dim_by_three) 1 1 1", three=3, dim_by_three=dim // 3)
            # (b, dim, 1, 1, 1)
            spacings = rearrange_channels(spacings, True, channels_first)
            # (b, [dim], 1, 1, 1, [dim])

            position_embeddings = position_embeddings * spacings.to(position_embeddings.device)
            # (b, [dim], d, h, w, [dim])

        if embedding_type == "add":
            x = x + position_embeddings
        elif embedding_type == "concat":
            x = torch.cat([x, position_embeddings], dim=1 if channels_first else -1)
        else:
            raise NotImplementedError("Only 'add' and 'concat' are supported for embedding_type")

        return x

In [None]:
sample_input1 = torch.randn(2, 6, 4, 4, 4)

test = AbsolutePositionEmbeddings3D(dim=6, grid_size=4, learnable=True)
display(test)
display(test(sample_input1).shape)
display(test(sample_input1, spacings=torch.randn(2, 3)).shape)

test = AbsolutePositionEmbeddings3D(dim=6)
display(test(sample_input1).shape)

test = AbsolutePositionEmbeddings3D()
display(test(sample_input1, crop_offsets=torch.Tensor([(0, 0, 0), (10, 10, 10)])).shape)

In [None]:
sample_input1 = torch.randn(2, 4, 4, 4, 6)

test = AbsolutePositionEmbeddings3D(dim=6, grid_size=4, learnable=True)
display(test)
display(test(sample_input1, channels_first=False).shape)

test = AbsolutePositionEmbeddings3D(dim=6)
display(test(sample_input1, channels_first=False).shape)

test = AbsolutePositionEmbeddings3D()
display(test(sample_input1, channels_first=False).shape)

In [None]:
# | export


def get_specific_sinusoidal_embeddings_1d(dim: int, indices: torch.Tensor) -> torch.Tensor:
    """Get 1D sinusoidal position embeddings for specific indices.

    Args:
        dim: Embedding dimension. Must be divisible by 2.
        indices: Indices for which to get the embeddings. Shape: (length,).

    Returns:
        A tensor of shape (1, length, dim) containing the position embeddings.
    """
    if dim % 2 != 0:
        raise ValueError("dim must be divisible by 2")

    # Create frequency bands
    omega = torch.arange(dim // 2, dtype=torch.float32, device=indices.device)
    omega /= dim / 2.0
    omega = 1.0 / (10000**omega)
    # (dim // 2)

    # Outer product of positions / timesteps and frequencies
    out = torch.einsum("n,d->nd", indices, omega)
    # (length, dim//2)

    # Apply sin and cos functions
    emb_sin = torch.sin(out)
    emb_cos = torch.cos(out)

    # Interleave sin and cos embeddings
    embeddings = torch.stack([emb_sin, emb_cos], dim=2)
    embeddings = embeddings.flatten(1)
    # (length, dim)

    # Reshape to expected output format
    embeddings = rearrange(embeddings, "length dim -> 1 length dim").contiguous()
    # (1, length, dim)

    return embeddings


def get_sinusoidal_embeddings_1d(dim: int, length: int, device=torch.device("cpu")) -> torch.Tensor:
    """Get 1D sinusoidal position embeddings.

    Args:
        dim: Embedding dimension. Must be divisible by 2.
        length: Length of the sequence.
        device: Device to create the embeddings on.

    Returns:
        A tensor of shape (1, length, dim) containing the position embeddings.
    """
    # Create position / timestep indices
    indices = torch.arange(length, dtype=torch.int32, device=device)
    # (length,)

    return get_specific_sinusoidal_embeddings_1d(dim, indices)


get_timestep_embeddings_1d = get_specific_sinusoidal_embeddings_1d
get_all_timestep_embeddings_1d = get_sinusoidal_embeddings_1d
get_absolute_position_embeddings_1d = get_sinusoidal_embeddings_1d

In [None]:
get_timestep_embeddings_1d(2, torch.tensor([1, 5, 11])).shape, get_all_timestep_embeddings_1d(2, 10).shape

In [None]:
# | export


@populate_docstring
class AbsolutePositionEmbeddings1D(nn.Module):
    """1D Absolute Position Embeddings. May or may not learnable. These have to be applied on the input manually and
    cannot be passed to attention layers directly. {CLASS_DESCRIPTION_1D_DOC}"""

    @populate_docstring
    def __init__(self, config: AbsolutePositionEmbeddings1DConfig = {}, **kwargs):
        """Initialize AbsolutePositionEmbeddings1D.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = AbsolutePositionEmbeddings1DConfig.model_validate(config | kwargs)

        dim = self.config.dim
        length = self.config.length
        learnable = self.config.learnable

        self.position_embeddings_cache = {}
        self.position_embeddings = None
        if dim is not None and length is not None:
            self.position_embeddings = nn.Parameter(
                get_absolute_position_embeddings_1d(dim, length), requires_grad=learnable
            )

    @populate_docstring
    def forward(
        self,
        x: torch.Tensor,
        embedding_type: Literal["add", "concat"] = "add",
    ) -> torch.Tensor:
        """Apply absolute position embeddings to the input tensor.

        Args:
            x: {INPUT_1D_DOC}
            embedding_type: Type of embedding to apply. 'add' to add the position embeddings to the input,
                'concat' to concatenate them along the last dimension.

        Returns:
            {OUTPUT_1D_DOC}
        """
        assert x.ndim == 3, "Input tensor must be of shape (b, length, dim)"
        # Check if sufficient information has been provided
        if self.position_embeddings is None:
            dim = self.config.dim
            if dim is None:
                dim = x.shape[2]
            length = self.config.length
            if length is None:
                length = x.shape[1]
        else:
            dim = self.config.dim
            length = self.config.length

        # Estimate batch size
        b = x.shape[0]

        # Get position embeddings, adjust based on crop offsets if applicable
        if self.position_embeddings is not None:
            position_embeddings = self.position_embeddings
            position_embeddings = repeat(position_embeddings, "1 l d-> b l d", b=b)
        else:
            cache_key = (dim, length)
            if cache_key not in self.position_embeddings_cache:
                self.position_embeddings_cache[cache_key] = get_absolute_position_embeddings_1d(dim, length)
            position_embeddings = self.position_embeddings_cache[cache_key]
            position_embeddings = repeat(position_embeddings, "1 l d -> b l d", b=b).to(x.device)
        # (b, length, dim)

        if embedding_type == "add":
            x = x + position_embeddings
        elif embedding_type == "concat":
            x = torch.cat([x, position_embeddings], dim=1)
        else:
            raise NotImplementedError("Only 'add' and 'concat' are supported for embedding_type")

        return x

In [None]:
sample_input1 = torch.randn(3, 6, 2)

test = AbsolutePositionEmbeddings1D(dim=2, length=6, learnable=True)
display(test)
display(test(sample_input1).shape)

test = AbsolutePositionEmbeddings1D()
display(test(sample_input1).shape)

### RoPE

In [None]:
# | export


def get_rope_rotation_coefficients_1d(
    dim: int, length: int, base: float = 10000.0
) -> tuple[torch.Tensor, torch.Tensor]:
    """Get 1D RoPE cos and sin rotation coefficients.

    Args:
        dim: Embedding dimension. Must be divisible by 2.
        length: Length of the sequence.
        base: Base value to use for the rotation coefficients.

    Returns:
        A tuple of tensors containing the cos and sin rotation coefficients.
    """
    if dim % 2 != 0:
        raise ValueError("Dimension must be even to apply RoPE.")

    half_dim = dim // 2
    pair_idx = torch.arange(half_dim)
    # (half_dim,)
    inverse_frequency = base ** (-pair_idx / half_dim)
    # (half_dim,)

    positions = torch.arange(length).unsqueeze(-1)
    # (length, 1)
    angles = positions * inverse_frequency.unsqueeze(0)
    # (length, half_dim)

    cos = angles.cos()
    # (length, half_dim)
    sin = angles.sin()
    # (length, half_dim)

    # Repeat each angle twice to match (even, odd) channels
    cos = torch.repeat_interleave(cos, repeats=2, dim=-1)
    # (length, dim)
    sin = torch.repeat_interleave(sin, repeats=2, dim=-1)
    # (length, dim)

    return cos, sin

In [None]:
get_rope_rotation_coefficients_1d(dim=4, length=5)

In [None]:
# | export


@populate_docstring
class RotaryPositionEmbeddings1D(nn.Module):
    """1D Rotary Position Embeddings. {CLASS_DESCRIPTION_1D_DOC}"""

    @populate_docstring
    def __init__(self, config: RotaryPositionEmbeddings1DConfig = {}, **kwargs):
        """Initialize RotaryPositionEmbeddings1D.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()
        self.config = RotaryPositionEmbeddings1DConfig.model_validate(config | kwargs)

    @staticmethod
    @lru_cache(maxsize=64)
    def get_rotation_coefficients(
        dim: int, length: int, device: torch.device, dtype=torch.dtype
    ) -> tuple[torch.Tensor, torch.Tensor]:
        cos, sin = get_rope_rotation_coefficients_1d(dim=dim, length=length)
        cos, sin = cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype)
        return cos, sin

    @staticmethod
    def rearrange_for_sin_coefficients(x: torch.Tensor) -> torch.Tensor:
        """Split the tensor into pairs along the last axis, flip each pair's order, and then negate the first
        element. That is, for an input tensor [a, b, c, d], the output will be [-b, a, -d, c].

        Args:
            x: Input tensor with last dimension dim

        Returns:
            Rearranged tensor
        """
        x = rearrange(x, "... (half_d two) -> ... half_d two", two=2).contiguous()
        # (..., half_d, 2)
        x1, x2 = x.unbind(-1)
        # (..., half_d), (..., half_d)
        x = torch.stack([-x2, x1], dim=-1)
        # (..., half_d, 2)
        x = rearrange(x, "... half_d two -> ... (half_d two)").contiguous()
        # (..., dim)
        return x

    def apply_rope(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
        """Apply 1D Rotary Position Embeddings to the given tensor.

        Args:
            x: Input tensor with last dimension dim
            cos: Cosine rotation coefficients
            sin: Sine rotation coefficients

        Returns:
            Tensor after applying 1D Rotary Position Embeddings
        """
        return x * cos + self.rearrange_for_sin_coefficients(x) * sin

    @populate_docstring
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply 1D Rotary Position Embeddings.

        Args:
            x: {INPUT_1D_DOC}

        Returns:
            {OUTPUT_1D_DOC}
        """
        # Identify dim
        if self.config.dim is None:
            dim = x.shape[-1]
        else:
            dim = self.config.dim

        # Get rotation coefficients
        cos, sin = self.get_rotation_coefficients(dim, x.shape[1], x.device, x.dtype)
        # (length, dim)

        # Apply rotation
        x = self.apply_rope(x, cos, sin)
        # (1, length, dim)

        return x

    def extra_repr(self):
        return f"dim={self.config.dim}, base={self.config.base}"

In [None]:
T = 5
sample_input1 = torch.randn((2, T, 4))
test = RotaryPositionEmbeddings1D(dim=4)
display(test)
roped_output1 = test(sample_input1)

print(sample_input1.shape)
print(roped_output1.shape)

i = 2
j = 4
print(f"Dot product between {i} and {j} input: {torch.dot(sample_input1[0, i], sample_input1[0, j])}")
print(f"Dot product between {i} and {j} output: {torch.dot(roped_output1[0, i], roped_output1[0, j])}")

sample_input2 = torch.cat([torch.randn(2, 1, 4), sample_input1], dim=1)
roped_output2 = test(sample_input2)

print(sample_input2.shape)
print(roped_output2.shape)

i = 3
j = 5
print(f"Dot product between {i} and {j} input: {torch.dot(sample_input2[0, i], sample_input2[0, j])}")
print(f"Dot product between {i} and {j} output: {torch.dot(roped_output2[0, i], roped_output2[0, j])}")

torch.Size([2, 5, 4])
torch.Size([2, 5, 4])
Dot product between 2 and 4 input: 1.6273961067199707
Dot product between 2 and 4 output: 1.0360474586486816
torch.Size([2, 6, 4])
torch.Size([2, 6, 4])
Dot product between 3 and 5 input: 1.6273961067199707
Dot product between 3 and 5 output: 1.0360478162765503


In [None]:
# | export


@populate_docstring
class RotaryPositionEmbeddings3D(RotaryPositionEmbeddings1D):
    """3D Rotary Position Embeddings. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: RotaryPositionEmbeddings3DConfig = {}, **kwargs):
        """Initialize RotaryPositionEmbeddings1D.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__(config, **kwargs)

        self.config = RotaryPositionEmbeddings3DConfig.model_validate(config | kwargs)

    def apply_rope(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, axis: int) -> torch.Tensor:
        """Apply 1D Rotary Position Embeddings to the given tensor specific to partixcular axis.

        Args:
            x: Input tensor with last dimension dim
            cos: Cosine rotation coefficients
            sin: Sine rotation coefficients
            axis: Axis which corresponds to the current dimension

        Returns:
            Tensor after applying 1D Rotary Position Embeddings
        """
        num_unsqueezes = 3 - axis
        for _ in range(num_unsqueezes):
            cos = cos.unsqueeze(1)
            sin = sin.unsqueeze(1)
        return super().apply_rope(x, cos, sin)

    @populate_docstring
    def forward(self, x: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        """Apply 3D Rotary Position Embeddings.

        Args:
            x: {INPUT_3D_DOC}
            channels_first: {CHANNELS_FIRST_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # Rearrange to channels last format
        x = rearrange_channels(x, channels_first, False)
        # (B, Z, Y, X, D)

        # Decide on dim
        if self.config.dim is None:
            dim = x.shape[-1]
        else:
            dim = self.config.dim

        # Split tensor into three parts for each axis
        z_dim, y_dim, x_dim = list(self.config.get_split_as_ints(dim))
        rest_dim = x.shape[-1] - z_dim - y_dim - x_dim
        z_part, y_part, x_part, rest = x.split([z_dim, y_dim, x_dim, rest_dim], dim=-1)
        # (B, Z, Y, X, D_z), (B, Z, Y, X, D_y), (B, Z, Y, X, D_x), (B, Z, Y, X, D_rest)

        # Get rotation coefficients
        cos_z, sin_z = self.get_rotation_coefficients(z_dim, x.shape[1], x.device, x.dtype)
        cos_y, sin_y = self.get_rotation_coefficients(y_dim, x.shape[2], x.device, x.dtype)
        cos_x, sin_x = self.get_rotation_coefficients(x_dim, x.shape[3], x.device, x.dtype)
        # (length, dim)

        # Apply rotation
        z_part = self.apply_rope(z_part, cos_z, sin_z, axis=1)
        y_part = self.apply_rope(y_part, cos_y, sin_y, axis=2)
        x_part = self.apply_rope(x_part, cos_x, sin_x, axis=3)

        # Concatenate the three parts along the channel dimension
        x = torch.cat([z_part, y_part, x_part, rest], dim=-1)
        # (B, Z, Y, X, D_z + D_y + D_x + D_rest)

        # Revert channels rearrangement
        x = rearrange_channels(x, False, channels_first)
        # (B, [D], Z, Y, X, [D])

        return x

    def extra_repr(self):
        return super().extra_repr() + f", split={self.config.split}"

In [None]:
sample_input1 = torch.randn((2, 6, 10, 11, 12))
test = RotaryPositionEmbeddings3D()
display(test)
roped_output1 = test(sample_input1)

print(sample_input1.shape)
print(roped_output1.shape)

i = (1, 2, 3)
j = (2, 3, 4)
print(
    f"Dot product between {i} and {j} input: {torch.dot(sample_input1[0, :, i[0], i[1], i[2]], sample_input1[0, :, j[0], j[1], j[2]])}"
)
print(
    f"Dot product between {i} and {j} output: {torch.dot(roped_output1[0, :, i[0], i[1], i[2]], roped_output1[0, :, j[0], j[1], j[2]])}"
)

sample_input2 = torch.randn(2, 6, 11, 12, 13)
sample_input2[:, :, 1:, 1:, 1:] = sample_input1
roped_output2 = test(sample_input2)

print(sample_input2.shape)
print(roped_output2.shape)

i = (2, 3, 4)
j = (3, 4, 5)
print(
    f"Dot product between {i} and {j} input: {torch.dot(sample_input2[0, :, i[0], i[1], i[2]], sample_input2[0, :, j[0], j[1], j[2]])}"
)
print(
    f"Dot product between {i} and {j} output: {torch.dot(roped_output2[0, :, i[0], i[1], i[2]], roped_output2[0, :, j[0], j[1], j[2]])}"
)

torch.Size([2, 6, 10, 11, 12])
torch.Size([2, 6, 10, 11, 12])
Dot product between (1, 2, 3) and (2, 3, 4) input: 0.8099715113639832
Dot product between (1, 2, 3) and (2, 3, 4) output: -2.4859845638275146
torch.Size([2, 6, 11, 12, 13])
torch.Size([2, 6, 11, 12, 13])
Dot product between (2, 3, 4) and (3, 4, 5) input: 0.8099715113639832
Dot product between (2, 3, 4) and (3, 4, 5) output: -2.4859840869903564


### Patch embeddings

In [None]:
# | export


@populate_docstring
class PatchEmbeddings3D(CNNBlock3D):
    """3D Patch Embeddings using a convolutional layer. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(self, config: PatchEmbeddings3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize PatchEmbeddings3D.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        self.config = PatchEmbeddings3DConfig.model_validate(config | kwargs)
        config = self.config.model_dump() | {
            "kernel_size": self.config.get("patch_size"),
            "stride": self.config.get("patch_size"),
            "padding": 0,
            "out_channels": self.config.get("dim"),
        }
        super().__init__(config, checkpointing_level, **kwargs)

In [None]:
test = PatchEmbeddings3D(patch_size=(1, 8, 8), in_channels=1, dim=12)
display(test)
o = test(torch.randn(2, 1, 32, 512, 512))
display(o.shape)

In [None]:
test = PatchEmbeddings3D(patch_size=(1, 8, 8), in_channels=1, dim=12)
display(test)
o = test(torch.randn(2, 32, 512, 512, 1), channels_first=False)
display(o.shape)

# nbdev

In [None]:
!nbdev_export

# Rough work