In [1]:
# | default_exp layers/embeddings

# Imports

In [2]:
# | export

from typing import Union

import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn

from vision_architectures.utils.normalizations import get_norm_layer

# Architecture

### Position Embeddings

In [3]:
# | export


def get_coords_grid(grid_size):
    d, h, w = grid_size

    grid_d = torch.arange(d, dtype=torch.int32)
    grid_h = torch.arange(h, dtype=torch.int32)
    grid_w = torch.arange(w, dtype=torch.int32)

    grid = torch.meshgrid(grid_d, grid_h, grid_w, indexing="ij")
    grid = torch.stack(grid, axis=0)
    # (3, d, h, w)

    return grid

In [4]:
# | export


class RelativePositionEmbeddings3D(nn.Module):
    def __init__(
        self,
        num_heads,
        grid_size: tuple[int, int, int],
    ):
        super().__init__()

        self.num_heads = num_heads
        self.num_patches = grid_size

        # TODO: Add embed_spacing_info functionality

        relative_limits = (2 * grid_size[0] - 1, 2 * grid_size[1] - 1, 2 * grid_size[2] - 1)

        self.relative_position_bias_table = nn.Parameter(torch.randn(num_heads, np.prod(relative_limits)))
        # (num_heads, num_patches_z * num_patches_y * num_patches_x)

        # Pair-wise relative position index for each token inside the window
        coords = get_coords_grid(grid_size)
        coords_flatten = rearrange(coords, "three d h w -> three (d h w)", three=3)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += grid_size[0] - 1
        relative_coords[:, :, 1] += grid_size[1] - 1
        relative_coords[:, :, 2] += grid_size[2] - 1
        relative_position_index: torch.Tensor = (
            relative_coords[:, :, 0] * relative_limits[1] * relative_limits[2]
            + relative_coords[:, :, 1] * relative_limits[2]
            + relative_coords[:, :, 2]
        )
        self.relative_position_index = relative_position_index.flatten()
        # (num_patches, num_patches)

    def forward(self):
        relative_position_embeddings = self.relative_position_bias_table[:, self.relative_position_index]
        # (num_heads, num_patches, num_patches)
        relative_position_embeddings = relative_position_embeddings.reshape(
            1, np.prod(self.num_patches), np.prod(self.num_patches), -1
        )
        # (1, num_patches, num_patches, num_heads)
        relative_position_embeddings = relative_position_embeddings.permute(0, 3, 1, 2).contiguous()
        # (1, num_heads, num_patches, num_patches)
        return relative_position_embeddings

In [5]:
test = RelativePositionEmbeddings3D(6, (4, 4, 4))
display(test)
display(test().shape)

[1;35mRelativePositionEmbeddings3D[0m[1m([0m[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m6[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

In [6]:
# | export


class RelativePositionEmbeddings3DMetaNetwork(nn.Module):
    def __init__(
        self,
        num_heads,
        grid_size: tuple[int, int, int],
    ):
        super().__init__()

        self.num_heads = num_heads
        self.num_patches = grid_size

        # TODO: Add embed_spacing_info functionality
        self.cpb_mlp = nn.Sequential(
            nn.Linear(3, 512, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_heads, bias=False),
        )

        relative_limits = (2 * grid_size[0] - 1, 2 * grid_size[1] - 1, 2 * grid_size[2] - 1)

        # Relative coordinates table
        relative_coords_table = get_coords_grid(relative_limits).float()
        for i in range(3):
            relative_coords_table[i] = (relative_coords_table[i] - (grid_size[0] - 1)) / (
                grid_size[0] - 1 + 1e-8  # small value added to ensure there is no NaN when window size is 1
            )
        relative_coords_table = rearrange(
            relative_coords_table,
            "three num_patches_z num_patches_y num_patches_x -> num_patches_z num_patches_y num_patches_x three",
        ).contiguous()
        relative_coords_table *= 8  # Normalize to -8, 8
        relative_coords_table = (
            torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / np.log2(8)
        )
        # (num_patches_z, num_patches_y, num_patches_x, 3)
        # Allow moving this to and from cuda whenever required but don't save to state_dict
        self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)

        # Pair-wise relative position index for each token inside the window
        coords = get_coords_grid(grid_size)
        coords_flatten = rearrange(coords, "three d h w -> three (d h w)", three=3)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += grid_size[0] - 1
        relative_coords[:, :, 1] += grid_size[1] - 1
        relative_coords[:, :, 2] += grid_size[2] - 1
        relative_position_index: torch.Tensor = (
            relative_coords[:, :, 0] * relative_limits[1] * relative_limits[2]
            + relative_coords[:, :, 1] * relative_limits[2]
            + relative_coords[:, :, 2]
        )
        self.relative_position_index = relative_position_index.flatten()
        # (num_patches, num_patches)

    def forward(self):
        # (num_patches_z, num_patches_y, num_patches_x, 3)
        relative_position_embeddings_table = self.cpb_mlp(self.relative_coords_table)
        # (num_patches_z, num_patches_y, num_patches_x, num_heads)
        relative_position_embeddings_table = relative_position_embeddings_table.reshape(-1, self.num_heads)
        # (num_patches, num_heads)
        relative_position_embeddings = relative_position_embeddings_table[self.relative_position_index]
        # (num_patches * num_patches, num_heads)
        relative_position_embeddings = rearrange(
            relative_position_embeddings,
            "(num_patches1 num_patches2) num_heads -> num_heads num_patches1 num_patches2",
            num_patches1=np.prod(self.num_patches),
            num_patches2=np.prod(self.num_patches),
            num_heads=self.num_heads,
        ).contiguous()
        # (num_heads, num_patches, num_patches)
        relative_position_embeddings = 16 * torch.sigmoid(relative_position_embeddings)
        # (num_heads, num_patches, num_patches)
        return relative_position_embeddings

In [7]:
test = RelativePositionEmbeddings3DMetaNetwork(6, (4, 4, 4))
display(test)
display(test().shape)


[1;35mRelativePositionEmbeddings3DMetaNetwork[0m[1m([0m
  [1m([0mcpb_mlp[1m)[0m: [1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m3[0m, [33mout_features[0m=[1;36m512[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m512[0m, [33mout_features[0m=[1;36m6[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m6[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

In [8]:
# | export

RelativePositionEmbeddings = Union[RelativePositionEmbeddings3D, RelativePositionEmbeddings3DMetaNetwork]

In [9]:
# | export


def get_absolute_position_embeddings_3d(dim, grid_size, spacing=(1, 1, 1)):
    if dim % 6 != 0:
        raise ValueError("embed_dim must be divisible by 6")

    grid = get_coords_grid(grid_size)
    # (3, d, h, w)

    grid = rearrange(grid, "x d h w -> x 1 d h w")
    # (3, 1, d, h, w)

    omega = torch.arange(dim // 6, dtype=torch.float32)
    omega /= dim / 6.0
    omega = 1.0 / 10000**omega
    # (dim // 6)

    patch_multiplier = torch.Tensor(spacing) / min(spacing)

    position_embeddings = []
    for i, grid_subset in enumerate(grid):
        grid_subset = grid_subset.reshape(-1)
        out = torch.einsum("m,d->md", grid_subset, omega)

        emb_sin = torch.sin(out)
        emb_cos = torch.cos(out)

        emb = torch.cat([emb_sin, emb_cos], axis=1) * patch_multiplier[i]
        position_embeddings.append(emb)

    position_embeddings = torch.cat(position_embeddings, axis=1)
    # (dim, d * h * w)
    position_embeddings = rearrange(
        position_embeddings, "(d h w) e -> 1 e d h w", d=grid_size[0], h=grid_size[1], w=grid_size[2]
    )
    # (1, dim, d, h, w)

    return position_embeddings

In [10]:
# | export


class AbsolutePositionEmbeddings3D(nn.Module):
    def __init__(self, dim, grid_size: tuple[int, int, int] | None = None, learnable=False, spacing=(1, 1, 1)):
        super().__init__()

        self.dim = dim

        if learnable and grid_size is None:
            raise ValueError("grid_size must be provided when learnable=True")

        self.position_embeddings_cache = {}
        self.position_embeddings = None
        if grid_size is not None:
            self.position_embeddings_cache[grid_size] = get_absolute_position_embeddings_3d(
                dim, grid_size, spacing=spacing
            )
            self.position_embeddings = nn.Parameter(self.position_embeddings_cache[grid_size], requires_grad=learnable)

    def forward(self, batch_size=None, grid_size=None, spacings=None, device=torch.device('cpu')):
        assert (
            self.position_embeddings is not None or grid_size is not None
        ), "grid_size must be provided"
        assert batch_size is not None or spacings is not None, "Either batch_size or spacings must be provided"

        if self.position_embeddings is not None:
            position_embeddings = self.position_embeddings
        else:
            if grid_size not in self.position_embeddings_cache:
                self.position_embeddings_cache[grid_size] = get_absolute_position_embeddings_3d(self.dim, grid_size)
            position_embeddings = self.position_embeddings_cache[grid_size].to(device)
        # (1, dim, d, h, w)

        if batch_size is not None:
            b = batch_size
        else:
            assert spacings.ndim == 2 and spacings.shape[1] == 3, "spacings must be of shape (batch_size, 3)"
            assert self.dim % 3 == 0, "embed_dim must be divisible by 3"

            b = spacings.shape[0]

        position_embeddings = repeat(position_embeddings, "1 e d h w -> b e d h w", b=b)

        if spacings is not None:
            # (b, 3)
            spacings = repeat(spacings, "b three -> b (three dim_by_three) 1 1 1", three=3, dim_by_three=self.dim // 3)
            # (b, dim, 1, 1, 1)

            position_embeddings = position_embeddings * spacings
            # (b, dim, d, h, w)

        return position_embeddings

In [11]:
test = AbsolutePositionEmbeddings3D(6, (4, 4, 4), learnable=True)
display(test)
display(test(batch_size=2).shape)
display(test(spacings=torch.randn(4, 3)).shape)

test = AbsolutePositionEmbeddings3D(6)
display(test(grid_size=(3, 3, 3), batch_size=2).shape)

[1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m6[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m4[0m, [1;36m6[0m, [1;36m4[0m, [1;36m4[0m, [1;36m4[0m[1m][0m[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m6[0m, [1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m][0m[1m)[0m

### Patch embeddings

In [12]:
# | export


class PatchEmbeddings3D(nn.Module):
    def __init__(self, patch_size: tuple[int, int, int], in_channels: int, dim: int, norm_layer="layernorm"):
        super().__init__()

        self.patch_embeddings = nn.Conv3d(
            in_channels=in_channels,
            out_channels=dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

        if norm_layer is None:
            self.normalization = nn.Identity()
        elif isinstance(norm_layer, nn.Module):
            self.normalization = norm_layer(dim)
        else:
            self.normalization = get_norm_layer(norm_layer, dim)

    def forward(self, pixel_values: torch.Tensor):
        # pixel_values: (b, c, z, y, x)

        embeddings = self.patch_embeddings(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)
        embeddings = rearrange(embeddings, "b d z y x -> b z y x d")
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)
        embeddings = self.normalization(embeddings)
        # (b, num_patches_z, num_patches_y, num_patches_x, dim)
        embeddings = rearrange(embeddings, "b z y x d -> b d z y x")
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        return embeddings

In [13]:
test = PatchEmbeddings3D((1, 8, 8), 1, 12)
display(test)
o = test(torch.randn(2, 1, 32, 512, 512))
display(o.shape)


[1;35mPatchEmbeddings3D[0m[1m([0m
  [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m12[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m
  [1m([0mnormalization[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m12[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m12[0m, [1;36m32[0m, [1;36m64[0m, [1;36m64[0m[1m][0m[1m)[0m

# nbdev

In [14]:
!nbdev_export

# Rough work