In [1]:
# | default_exp nets/cait_3d

# Imports

In [2]:
# | export

import torch
from einops import repeat
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.vit_3d import ViT3DLayerMLP, ViT3DMHCA, ViT3DMHSA

# Architecture

### Attention Layers

In [3]:
# | export


class CaiT3DMHSA(ViT3DMHSA):  # Multi-head self attention
    pass

In [4]:
# | export


class CaiT3DMHCA(ViT3DMHCA):  # Multi-head class attention
    def forward(self, class_tokens: torch.Tensor, embeddings: torch.Tensor):
        return super().forward(class_tokens, torch.cat([class_tokens, embeddings], dim=1))

In [5]:
test = CaiT3DMHCA(54, 6)
display(test)
display(test(torch.randn(2, 1, 54), torch.randn(2, 64, 54)).shape)


[1;35mCaiT3DMHCA[0m[1m([0m
  [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mattn_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m54[0m[1m][0m[1m)[0m

In [6]:
# | export

class CaiT3DLayerMLP(ViT3DLayerMLP):
    pass

### Basic Layers

In [7]:
# | export


class CaiT3DStage1Layer(nn.Module):  # Self attention without class tokens
    def __init__(
        self,
        dim,
        num_heads,
        intermediate_ratio,
        layer_norm_eps,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
        mlp_drop_prob=0.0,
    ):
        super().__init__()

        self.mhsa = CaiT3DMHSA(dim, num_heads, attn_drop_prob, proj_drop_prob)
        self.gamma1 = nn.Parameter(torch.empty(1, 1, dim))
        self.layernorm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = CaiT3DLayerMLP(dim, intermediate_ratio, mlp_drop_prob)
        self.gamma2 = nn.Parameter(torch.empty(1, 1, dim))
        self.layernorm2 = nn.LayerNorm(dim, eps=layer_norm_eps)

        nn.init.uniform_(self.gamma1, a=-1e-4, b=1e-4)
        nn.init.uniform_(self.gamma2, a=-1e-4, b=1e-4)

    def forward(self, embeddings: torch.Tensor):  # This uses post-normalization
        # embeddings: (b, num_tokens, dim)

        res_connection1 = embeddings
        # (b, num_tokens, dim)

        hidden_states = self.mhsa(embeddings)
        hidden_states = self.gamma1 * hidden_states
        hidden_states = self.layernorm1(hidden_states)
        # (b, num_tokens, dim)

        res_connection2 = hidden_states + res_connection1
        # (b, num_tokens, dim)

        hidden_states = self.mlp(res_connection2)
        hidden_states = self.gamma2 * hidden_states
        hidden_states = self.layernorm2(hidden_states)
        # (b, num_tokens, dim)

        hidden_states = hidden_states + res_connection2
        # (b, num_tokens, dim)

        return hidden_states

In [8]:
test = CaiT3DStage1Layer(52, 4, 2, 1e-6)
display(test)
display(test(torch.randn(2, 64, 52)).shape)


[1;35mCaiT3DStage1Layer[0m[1m([0m
  [1m([0mmhsa[1m)[0m: [1;35mCaiT3DMHSA[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mattn_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m52[0m[1m][0m[1m)[0m

In [9]:
# | export


class CaiT3DStage2Layer(nn.Module):  # Attention with class tokens
    def __init__(
        self,
        dim,
        num_heads,
        intermediate_ratio,
        layer_norm_eps,
        attn_drop_prob=0.0,
        proj_drop_prob=0.0,
        mlp_drop_prob=0.0,
    ):
        super().__init__()

        self.mhca = CaiT3DMHCA(dim, num_heads, attn_drop_prob, proj_drop_prob)
        self.gamma1 = nn.Parameter(torch.empty(1, 1, dim))
        self.layernorm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
        self.mlp = CaiT3DLayerMLP(dim, intermediate_ratio, mlp_drop_prob)
        self.gamma2 = nn.Parameter(torch.empty(1, 1, dim))
        self.layernorm2 = nn.LayerNorm(dim, eps=layer_norm_eps)

        nn.init.uniform_(self.gamma1, a=-1e-4, b=1e-4)
        nn.init.uniform_(self.gamma2, a=-1e-4, b=1e-4)

    def forward(self, class_tokens: torch.Tensor, embeddings: torch.Tensor):  # This uses post-normalization
        # class_tokens: (b, num_class_tokens, dim)
        # embeddings: (b, num_embedding_tokens, dim)

        res_connection1 = class_tokens
        # (b, num_class_tokens, dim)

        hidden_states = self.mhca(class_tokens, embeddings)
        hidden_states = self.gamma1 * hidden_states
        hidden_states = self.layernorm1(hidden_states)
        # (b, num_class_tokens, dim)

        res_connection2 = hidden_states + res_connection1
        # (b, num_class_tokens, dim)

        hidden_states = self.mlp(res_connection2)
        hidden_states = self.gamma2 * hidden_states
        hidden_states = self.layernorm2(hidden_states)
        # (b, num_class_tokens, dim)

        hidden_states = hidden_states + res_connection2
        # (b, num_class_tokens, dim)

        return hidden_states

In [10]:
test = CaiT3DStage2Layer(52, 4, 2, 1e-6)
display(test)
display(test(torch.randn(2, 1, 52), torch.randn(2, 64, 52)).shape)


[1;35mCaiT3DStage2Layer[0m[1m([0m
  [1m([0mmhca[1m)[0m: [1;35mCaiT3DMHCA[0m[1m([0m
    [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mattn_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m52[0m, [33mout_features[0m=[1;36m52[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m52[0m[1m][0m[1m)[0m

### Stages

In [11]:
# | export


class CaiT3DStage1(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        self.layers = nn.ModuleList(
            [
                CaiT3DStage1Layer(
                    config["dim"],
                    config["num_heads"],
                    config["intermediate_ratio"],
                    config["layer_norm_eps"],
                    config["attn_drop_prob"],
                    config["proj_drop_prob"],
                    config["mlp_drop_prob"],
                )
                for _ in range(config["encoder_depth"])
            ]
        )

    def forward(self, embeddings: torch.Tensor):
        # embeddings: (b, num_tokens, dim)

        layer_outputs = []
        for layer in self.layers:
            embeddings = layer(embeddings)
            # (b, num_tokens, dim)

            layer_outputs.append(embeddings)

        return embeddings, layer_outputs

In [12]:
test_config = {
    "dim": 54,
    "num_heads": 6,
    "intermediate_ratio": 2,
    "layer_norm_eps": 1e-6,
    "attn_drop_prob": 0.0,
    "proj_drop_prob": 0.0,
    "mlp_drop_prob": 0.0,
    "encoder_depth": 5,
}

test = CaiT3DStage1(test_config)
display(test)
o = test(torch.randn(2, 64, 54))
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mCaiT3DStage1[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m4[0m[1m)[0m: [1;36m5[0m x [1;35mCaiT3DStage1Layer[0m[1m([0m
      [1m([0mmhsa[1m)[0m: [1;35mCaiT3DMHSA[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mattn_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

In [13]:
# | export


class CaiT3DStage2(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        self.layers = nn.ModuleList(
            [
                CaiT3DStage2Layer(
                    config["dim"],
                    config["num_heads"],
                    config["intermediate_ratio"],
                    config["layer_norm_eps"],
                    config["attn_drop_prob"],
                    config["proj_drop_prob"],
                    config["mlp_drop_prob"],
                )
                for _ in range(config["encoder_depth"])
            ]
        )

    def forward(self, class_tokens: torch.Tensor, embeddings: torch.Tensor):
        # class_tokens: (b, num_class_tokens, dim)
        # embeddings: (b, num_embed_tokens, dim)

        class_embeddings = class_tokens

        layer_outputs = []
        for layer in self.layers:
            class_embeddings = layer(class_embeddings, embeddings)
            # (b, num_class_tokens, dim)

            layer_outputs.append(class_embeddings)

        return class_embeddings, layer_outputs

In [14]:
test_config = {
    "dim": 54,
    "num_heads": 6,
    "intermediate_ratio": 2,
    "layer_norm_eps": 1e-6,
    "attn_drop_prob": 0.0,
    "proj_drop_prob": 0.0,
    "mlp_drop_prob": 0.0,
    "encoder_depth": 5,
}

test = CaiT3DStage2(test_config)
display(test)
o = test(torch.randn(2, 1, 54), torch.randn(2, 64, 54))
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mCaiT3DStage2[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m4[0m[1m)[0m: [1;36m5[0m x [1;35mCaiT3DStage2Layer[0m[1m([0m
      [1m([0mmhca[1m)[0m: [1;35mCaiT3DMHCA[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mattn_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

# Models

In [15]:
# | export


class CaiT3DStage2OnlyModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        self.num_class_tokens = config["num_class_tokens"]
        self.class_tokens = nn.Parameter(torch.randn(1, config["num_class_tokens"], config["dim"]))

        self.class_attention = CaiT3DStage2(config)
        self.classifiers = nn.ModuleList([nn.Linear(config["dim"], 1) for i in range(self.num_class_tokens)])

    def forward(self, embeddings: torch.Tensor):
        # embeddings: (b, num_embedding_tokens, dim)

        class_tokens = repeat(self.class_tokens, "1 n d -> b n d", b=embeddings.shape[0])
        # (b, num_class_tokens, dim)

        class_embeddings, layer_outputs = self.class_attention(class_tokens, embeddings)
        # class_embeddings: (b, num_class_tokens, dim)
        # layer_outputs: list of (b, num_embedding_tokens, dim)

        class_logits = torch.cat(
            [self.classifiers[i](class_embeddings[:, i]) for i in range(len(self.classifiers))], dim=1
        )
        # list of (b, num_classes) for each class token

        return class_logits, class_embeddings, layer_outputs

In [16]:
test_config = {
    "num_class_tokens": 3,
    "attn_drop_prob": 0.2,
    "dim": 768,
    "drop_prob": 0.2,
    "encoder_depth": 4,
    "intermediate_ratio": 2,
    "layer_norm_eps": 1e-6,
    "mlp_drop_prob": 0.2,
    "num_heads": 4,
    "proj_drop_prob": 0.2,
}

test = CaiT3DStage2OnlyModel(test_config)
display(test)
o = test(
    torch.randn(2, 4096, 768),
)
display((o[0].shape, o[1].shape, [x.shape for x in o[2]]))


[1;35mCaiT3DStage2OnlyModel[0m[1m([0m
  [1m([0mclass_attention[1m)[0m: [1;35mCaiT3DStage2[0m[1m([0m
    [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m3[0m[1m)[0m: [1;36m4[0m x [1;35mCaiT3DStage2Layer[0m[1m([0m
        [1m([0mmhca[1m)[0m: [1;35mCaiT3DMHCA[0m[1m([0m
          [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mattn_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

In [17]:
# | export


class CaiT3DModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        self.num_class_tokens = config["num_class_tokens"]
        self.class_tokens = nn.Parameter(torch.randn(1, config["num_class_tokens"], config["dim"]))

        self.self_attention = CaiT3DStage1(config)
        self.class_attention = CaiT3DStage2(config)
        self.classifiers = nn.ModuleList([nn.Linear(config["dim"], 1) for i in range(self.num_class_tokens)])

    def forward(self, tokens: torch.Tensor):
        # tokens: (b, num_embedding_tokens, dim)

        embeddings, layer_outputs1 = self.self_attention(tokens)

        class_tokens = repeat(self.class_tokens, "1 n d -> b n d", b=embeddings.shape[0])
        # (b, num_class_tokens, dim)

        class_embeddings, layer_outputs2 = self.class_attention(class_tokens, embeddings)
        # class_embeddings: (b, num_class_tokens, dim)
        # layer_outputs: list of (b, num_embedding_tokens, dim)

        class_logits = torch.cat(
            [self.classifiers[i](class_embeddings[:, i]) for i in range(len(self.classifiers))], dim=1
        )
        # list of (b, num_classes) for each class token

        return class_logits, class_embeddings, [layer_outputs1, layer_outputs2]

In [18]:
test_config = {
    "num_class_tokens": 3,
    "attn_drop_prob": 0.2,
    "dim": 768,
    "drop_prob": 0.2,
    "encoder_depth": 4,
    "intermediate_ratio": 2,
    "layer_norm_eps": 1e-6,
    "mlp_drop_prob": 0.2,
    "num_heads": 4,
    "proj_drop_prob": 0.2,
}

test = CaiT3DModel(test_config)
display(test)
o = test(
    torch.randn(2, 4096, 768),
)
display((o[0].shape, o[1].shape, [[x.shape for x in o[2][0]], [x.shape for x in o[2][1]]]))


[1;35mCaiT3DModel[0m[1m([0m
  [1m([0mself_attention[1m)[0m: [1;35mCaiT3DStage1[0m[1m([0m
    [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m3[0m[1m)[0m: [1;36m4[0m x [1;35mCaiT3DStage1Layer[0m[1m([0m
        [1m([0mmhsa[1m)[0m: [1;35mCaiT3DMHSA[0m[1m([0m
          [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m768[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mattn_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
    


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1m[[0m
            [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m,
            [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m,
            [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m,
            [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m
        [1m][0m,
        [1m[[0m
            [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
            [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m3[0m, [1;36m768[0m[1m][0m[1m)[0m,
            [1;35mtorch.Size[0m[1m

# nbdev

In [19]:
!nbdev_export