In [None]:
# | default_exp nets/detr_3d

# Imports

In [None]:
# | export

import numpy as np
import torch
from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.layers.attention import Attention1D, Attention1DMLP

# Architecture

In [None]:
# | export


class DETR3DDecoderLayer(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio,
        layer_norm_eps,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
        mlp_drop_prob=0.0,
    ):
        super().__init__()

        self.mhsa = Attention1D(dim, num_heads, attn_drop_prob=attn_drop_prob, proj_drop_prob=proj_drop_prob)
        self.layernorm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mhca = Attention1D(dim, num_heads, attn_drop_prob=attn_drop_prob, proj_drop_prob=proj_drop_prob)
        self.layernorm2 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = Attention1DMLP(dim, mlp_ratio, mlp_drop_prob=mlp_drop_prob)
        self.layernorm3 = nn.LayerNorm(dim, eps=layer_norm_eps)

    def forward(self, object_queries: torch.Tensor, embeddings: torch.Tensor):  # This uses post-normalization
        # object_queries: (b, num_possible_objects, dim)
        # embeddings: (b, num_embed_tokens, dim)

        res_connection1 = object_queries
        # (b, num_tokens_in_q, dim)

        hidden_states_q = self.mhsa(object_queries, object_queries, object_queries)
        hidden_states_q = self.layernorm1(hidden_states_q)
        # (b, num_tokens_in_q, dim)

        res_connection2 = hidden_states_q + res_connection1
        # (b, num_tokens_in_q, dim)

        hidden_states = self.mhca(res_connection2, embeddings, embeddings)
        hidden_states = self.layernorm2(hidden_states)
        # (b, num_tokens_in_q, dim)

        res_connection3 = hidden_states + res_connection2
        # (b, num_tokens_in_q, dim)

        hidden_states = self.mlp(res_connection3)
        hidden_states = self.layernorm3(hidden_states)
        # (b, num_tokens_in_q, dim)

        hidden_states = hidden_states + res_connection3
        # (b, num_tokens_in_q, dim)

        return hidden_states

In [None]:
test = DETR3DDecoderLayer(54, 6, 2, 1e-6)
display(test)
display(test(torch.randn(2, 10, 54), torch.randn(2, 64, 54)).shape)


[1;35mDETR3DDecoderLayer[0m[1m([0m
  [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mlayernorm1[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m54[0m,[1m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m

### Decoder

In [None]:
# | export


class DETR3DDecoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        self.layers = nn.ModuleList(
            [
                DETR3DDecoderLayer(
                    config["dim"],
                    config["num_heads"],
                    config["mlp_ratio"],
                    config["layer_norm_eps"],
                    config["attn_drop_prob"],
                    config["proj_drop_prob"],
                    config["mlp_drop_prob"],
                )
                for _ in range(config["decoder_depth"])
            ]
        )

    def forward(self, object_queries: torch.Tensor, embeddings: torch.Tensor):
        # object_queries: (b, num_possible_objects, dim)
        # embeddings: (b, num_embed_tokens, dim)

        object_embeddings = object_queries

        layer_outputs = []
        for layer in self.layers:
            object_embeddings = layer(object_embeddings, embeddings)
            layer_outputs.append(object_embeddings)

        return object_embeddings, layer_outputs

In [None]:
test_config = {
    "attn_drop_prob": 0.2,
    "dim": 54,
    "drop_prob": 0.2,
    "embed_spacing_info": False,
    "decoder_depth": 4,
    "in_channels": 1,
    "mlp_ratio": 2,
    "layer_norm_eps": 1e-6,
    "learnable_absolute_position_embeddings": False,
    "mlp_drop_prob": 0.2,
    "num_heads": 6,
    "patch_size": (8, 16, 16),
    "proj_drop_prob": 0.2,
}

test = DETR3DDecoder(test_config)
display(test)
o = test(
    torch.randn(2, 10, 54),
    torch.randn(2, 64, 54),
)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mDETR3DDecoder[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m3[0m[1m)[0m: [1;36m4[0m x [1;35mDETR3DDecoderLayer[0m[1m([0m
      [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

In [None]:
# | export


class DETR3DBBoxMLP(nn.Module):
    def __init__(self, config):
        super().__init__()

        dim = config["dim"]
        num_classes = config["num_classes"]

        self.linear = nn.Linear(dim, 1 + 4 + num_classes)

    def forward(self, object_embeddings: torch.Tensor):
        # object_embeddings: (b, num_possible_objects, dim)

        bboxes = self.linear(object_embeddings)
        # (b, num_possible_objects, 1 + 4 + num_classes)

        bboxes[:, :, :5] = bboxes[:, :, :5].sigmoid()
        bboxes[:, :, 5:] = bboxes[:, :, 5:].softmax(-1)

        return bboxes

In [None]:
test_config = {
    "dim": 54,
    "num_classes": 10,
}

test = DETR3DBBoxMLP(test_config)
display(test)
o = test(
    torch.randn(2, 10, 54),
)
display((o[0].shape), (o[:, :, :5].min(), o[:, :, :5].max()), o[:, :, 5:].sum(dim=-1))


[1;35mDETR3DBBoxMLP[0m[1m([0m
  [1m([0mlinear[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m15[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m10[0m, [1;36m15[0m[1m][0m[1m)[0m

[1m([0m[1;35mtensor[0m[1m([0m[1;36m0.2030[0m, [33mgrad_fn[0m=[1m<[0m[1;95mMinBackward1[0m[39m>[0m[1;39m)[0m[39m, [0m[1;35mtensor[0m[1;39m([0m[1;36m0.7401[0m[39m, [0m[33mgrad_fn[0m[39m=<MaxBackward1[0m[1m>[0m[1m)[0m[1m)[0m


[1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m,
         [1;36m1.0000[0m[1m][0m,
        [1m[[0m[1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m, [1;36m1.0000[0m,
         [1;36m1.0000[0m[1m][0m[1m][0m, [33mgrad_fn[0m=[1m<[0m[1;95mSumBackward1[0m[1m>[0m[1m)[0m

# Embeddings

### Position embeddings

In [None]:
# | export


def get_coords_grid(grid_size):
    d, h, w = grid_size

    grid_d = torch.arange(d, dtype=torch.int32)
    grid_h = torch.arange(h, dtype=torch.int32)
    grid_w = torch.arange(w, dtype=torch.int32)

    grid = torch.meshgrid(grid_w, grid_h, grid_d, indexing="ij")
    grid = torch.stack(grid, axis=0)
    # (3, d, h, w)

    return grid

In [None]:
# | export


def get_3d_position_embeddings(embedding_size, grid_size, patch_size=(1, 1, 1)):
    if embedding_size % 6 != 0:
        raise ValueError("embed_dim must be divisible by 6")

    grid = get_coords_grid(grid_size)
    # (3, d, h, w)

    grid = rearrange(grid, "x d h w -> x 1 d h w")
    # (3, 1, d, h, w)

    omega = torch.arange(embedding_size // 6, dtype=torch.float32)
    omega /= embedding_size / 6.0
    omega = 1.0 / 10000**omega
    # (d // 6)

    patch_multiplier = torch.Tensor(patch_size) / min(patch_size)

    position_embeddings = []
    for i, grid_subset in enumerate(grid):
        grid_subset = grid_subset.reshape(-1)
        out = torch.einsum("m,d->md", grid_subset, omega)

        emb_sin = torch.sin(out)
        emb_cos = torch.cos(out)

        emb = torch.cat([emb_sin, emb_cos], axis=1) * patch_multiplier[i]
        position_embeddings.append(emb)

    position_embeddings = torch.cat(position_embeddings, axis=1)
    # (embedding_size, d * h * w)
    d, h, w = grid_size
    position_embeddings = rearrange(position_embeddings, "(d h w) e -> 1 e d h w", d=d, h=h, w=w)
    # (1, embedding_size, d, h, w)

    return position_embeddings

In [None]:
# | export


def embed_spacings_in_position_embeddings(embeddings: torch.Tensor, spacings: torch.Tensor):
    assert spacings.ndim == 2, "Please provide spacing information for each batch element"
    _, embedding_size, _, _, _ = embeddings.shape
    assert embedding_size % 3 == 0, "To embed spacing info, the embedding size must be divisible by 3"
    embeddings = embeddings * repeat(spacings, f"B S -> B (S {int(embedding_size / 3)}) 1 1 1", S=3)

    return embeddings

In [None]:
# | export


class DETR3DPositionEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config

        dim = config["dim"]
        grid_size = (
            config["image_size"][0] // config["patch_size"][0],
            config["image_size"][1] // config["patch_size"][1],
            config["image_size"][2] // config["patch_size"][2],
        )
        if config["learnable_absolute_position_embeddings"]:
            absolute_position_embeddings = nn.Parameter(
                torch.randn(1, dim, grid_size[0], grid_size[1], grid_size[2]),
                requires_grad=True,
            )
        else:
            absolute_position_embeddings = get_3d_position_embeddings(dim, grid_size, config["patch_size"])
        self.register_buffer("absolute_position_embeddings", absolute_position_embeddings)

    def forward(
        self,
        embeddings: torch.Tensor,
        spacings: torch.Tensor,
    ):
        # embeddings: (b, dim, num_tokens_z, num_tokens_y, num_tokens_x)

        absolute_position_embeddings = self.absolute_position_embeddings
        # (1, dim, num_tokens_z, num_tokens_y, num_tokens_x)
        if self.config["embed_spacing_info"]:
            absolute_position_embeddings = embed_spacings_in_position_embeddings(absolute_position_embeddings, spacings)
            # (b, dim, num_tokens_z, num_tokens_y, num_tokens_x)

        embeddings = embeddings + absolute_position_embeddings
        # (b, dim, num_tokens_z, num_tokens_y, num_tokens_x)

        embeddings = rearrange(embeddings, "b d nz ny nx -> b (nz ny nx) d")

        return embeddings

In [None]:
test_config = {
    "patch_size": (8, 16, 16),
    "in_channels": 1,
    "dim": 768,
    "learnable_absolute_position_embeddings": True,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
}

test = DETR3DPositionEmbeddings(test_config)
display(test)
o = test(
    torch.randn(2, 1, 4, 32, 32),
    torch.randn(2, 3),
)
display(o.shape)

[1;35mDETR3DPositionEmbeddings[0m[1m([0m[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m

# Models

In [None]:
# | export


class DETR3DModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        self.embeddings = DETR3DPositionEmbeddings(config)
        self.pos_drop = nn.Dropout(config.get("drop_prob", 0.0))
        self.num_possible_objects = config["num_possible_objects"]
        object_queries = nn.Parameter(torch.randn(1, self.num_possible_objects, config["dim"]))
        self.register_buffer("object_queries", object_queries)
        self.decoder = DETR3DDecoder(config)
        self.bbox_mlp = DETR3DBBoxMLP(config)

    def forward(
        self,
        embeddings: torch.Tensor,
        spacings: torch.Tensor,
    ):
        # embeddings: (b, dim, num_tokens_z, num_tokens_y, num_tokens_x)
        # spacings: (b, 3)

        embeddings = self.embeddings(embeddings, spacings)
        embeddings = self.pos_drop(embeddings)
        # (b, num_embed_tokens, dim)

        object_queries = repeat(self.object_queries, "1 n d -> b n d", b=embeddings.shape[0])
        # (b, num_possible_objects, dim)

        object_embeddings, layer_outputs = self.decoder(object_queries, embeddings)
        # object_embeddings: (b, num_possible_objects, dim)
        # layer_outputs: list of (b, num_possible_objects, dim)

        bboxes = self.bbox_mlp(object_embeddings)
        # (b, num_possible_objects, 1 + 4 + num_classes)

        return bboxes, object_embeddings, layer_outputs

In [None]:
test_config = {
    "patch_size": (8, 16, 16),
    "in_channels": 1,
    "dim": 54,
    "num_heads": 6,
    "mlp_ratio": 2,
    "layer_norm_eps": 1e-6,
    "attn_drop_prob": 0.2,
    "proj_drop_prob": 0.2,
    "mlp_drop_prob": 0.2,
    "learnable_absolute_position_embeddings": True,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
    "num_possible_objects": 10,
    "num_classes": 3,
    "decoder_depth": 4,
}

test = DETR3DModel(test_config)
display(test)
o = test(
    torch.randn(2, 1, 4, 32, 32),
    torch.randn(2, 3),
)
display((o[0].shape, o[1].shape, [x.shape for x in o[2]]))


[1;35mDETR3DModel[0m[1m([0m
  [1m([0membeddings[1m)[0m: [1;35mDETR3DPositionEmbeddings[0m[1m([0m[1m)[0m
  [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mdecoder[1m)[0m: [1;35mDETR3DDecoder[0m[1m([0m
    [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m3[0m[1m)[0m: [1;36m4[0m x [1;35mDETR3DDecoderLayer[0m[1m([0m
        [1m([0mmhsa[1m)[0m: [1;35mAttention1D[0m[1m([0m
          [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_feat


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m8[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

# nbdev

In [None]:
!nbdev_export