In [None]:
# | default_exp nets/vit_3d

# Imports

In [None]:
# | export

from typing import Literal

import numpy as np
import torch
from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.blocks.transformer import TransformerDecoderBlock1D, TransformerEncoderBlock1D
from vision_architectures.layers.embeddings import AbsolutePositionEmbeddings3D, PatchEmbeddings3D
from vision_architectures.utils.custom_base_model import CustomBaseModel

# Config

In [None]:
# | export


class ViT3DEncoderConfig(CustomBaseModel):
    dim: int
    num_heads: int
    mlp_ratio: int
    layer_norm_eps: float
    attn_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    mlp_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    norm_location: Literal["pre", "post"] = "pre"

    encoder_depth: int


class ViT3DConfig(ViT3DEncoderConfig):
    patch_size: tuple[int, int, int]
    in_channels: int
    num_class_tokens: int

    drop_prob: float = 0.0

    # For MIM
    image_size: tuple[int, int, int] | None = None
    mask_ratio: float | None = None


class ViT3DDecoderConfig(CustomBaseModel):
    dim: int
    num_heads: int
    mlp_ratio: int
    layer_norm_eps: float
    attn_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    mlp_drop_prob: float = 0.0
    proj_drop_prob: float = 0.0
    norm_location: Literal["pre", "post"] = "pre"

    decoder_depth: int

In [None]:
test_config = ViT3DConfig.model_validate(
    {
        "patch_size": (2, 2, 2),
        "in_channels": 3,
        "dim": 64,
        "num_heads": 8,
        "mlp_ratio": 4,
        "layer_norm_eps": 1e-6,
        "encoder_depth": 3,
        "decoder_depth": 3,
        "num_class_tokens": 0,
    }
)
test_config


[1;35mViT3DConfig[0m[1m([0m
    [33mdim[0m=[1;36m64[0m,
    [33mnum_heads[0m=[1;36m8[0m,
    [33mmlp_ratio[0m=[1;36m4[0m,
    [33mlayer_norm_eps[0m=[1;36m1e[0m[1;36m-06[0m,
    [33mattn_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mproj_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mmlp_drop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mnorm_location[0m=[32m'pre'[0m,
    [33mencoder_depth[0m=[1;36m3[0m,
    [33mpatch_size[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m,
    [33min_channels[0m=[1;36m3[0m,
    [33mnum_class_tokens[0m=[1;36m0[0m,
    [33mdrop_prob[0m=[1;36m0[0m[1;36m.0[0m,
    [33mimage_size[0m=[3;35mNone[0m,
    [33mmask_ratio[0m=[3;35mNone[0m
[1m)[0m

# Architecture

### Encoder

In [None]:
# | export


class ViT3DEncoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: ViT3DEncoderConfig):
        super().__init__()

        self.layers = nn.ModuleList(
            [TransformerEncoderBlock1D(config.model_dump()) for _ in range(config.encoder_depth)]
        )

    def forward(self, embeddings: torch.Tensor, return_intermediates: bool = False):
        # hidden_states: (b, num_tokens, dim)

        layer_outputs = []
        for encoder_layer in self.layers:
            embeddings = encoder_layer(embeddings)
            # (b, num_tokens, dim)

            layer_outputs.append(embeddings)

        if return_intermediates:
            return embeddings, layer_outputs
        return embeddings

In [None]:
test_config = ViT3DEncoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "encoder_depth": 3,
    }
)

test = ViT3DEncoder(test_config)
display(test)
o = test(torch.randn(2, 64, 54), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mViT3DEncoder[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m2[0m[1m)[0m: [1;36m3[0m x [1;35mTransformerEncoderBlock1D[0m[1m([0m
      [1m([0mattn[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1

[1m([0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1m[[0m[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m[1m][0m[1m)[0m

### Decoder

In [None]:
# | export


class ViT3DDecoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: ViT3DDecoderConfig):
        super().__init__()

        self.layers = nn.ModuleList(
            [TransformerDecoderBlock1D(config.model_dump()) for _ in range(config.decoder_depth)]
        )

    def forward(self, q: torch.Tensor, kv: torch.Tensor, return_intermediates: bool = False):
        # q: (b, num_q_tokens, dim)
        # kv: (b, num_kv_tokens, dim)

        embeddings = q

        layer_outputs = []
        for decoder_layer in self.layers:
            embeddings = decoder_layer(embeddings, kv)
            # (b, num_q_tokens, dim)

            layer_outputs.append(embeddings)

        if return_intermediates:
            return embeddings, layer_outputs
        return embeddings

In [None]:
test_config = ViT3DDecoderConfig.model_validate(
    {
        "dim": 54,
        "num_heads": 6,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "attn_drop_prob": 0.0,
        "proj_drop_prob": 0.0,
        "mlp_drop_prob": 0.0,
        "decoder_depth": 5,
    }
)

test = ViT3DDecoder(test_config)
display(test)
o = test(torch.randn(2, 64, 54), torch.randn(2, 128, 54), return_intermediates=True)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mViT3DDecoder[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m4[0m[1m)[0m: [1;36m5[0m x [1;35mTransformerDecoderBlock1D[0m[1m([0m
      [1m([0mattn1[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m64[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

# Models

In [None]:
# | export


class ViT3D(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: ViT3DConfig):
        super().__init__()

        self.patchify = PatchEmbeddings3D(patch_size=config.patch_size, in_channels=config.in_channels, dim=config.dim)
        self.absolute_position_embeddings = AbsolutePositionEmbeddings3D(dim=config.dim, learnable=False)
        self.pos_drop = nn.Dropout(config.drop_prob)
        self.num_class_tokens = config.num_class_tokens
        if self.num_class_tokens > 0:
            self.class_tokens = nn.Parameter(torch.randn(1, config.num_class_tokens, config.dim))
        self.encoder = ViT3DEncoder(config)

    def forward(
        self,
        pixel_values: torch.Tensor,
        spacings: torch.Tensor,
        mask_patches: torch.Tensor = None,
        mask_token: torch.Tensor = None,
        return_intermediates: bool = False,
    ):
        # pixel_values: (b, c, z, y, x)
        # spacings: (b, 3)
        # mask_patches: (b, num_patches_z, num_patches_y, num_patches_x)
        # mask_token: (1, dim, 1, 1, 1)

        embeddings = self.patchify(pixel_values)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        if mask_patches is not None:
            # mask_patches (binary mask): (b, num_patches_z, num_patches_y, num_patches_x)
            # mask_token: (1, dim, 1, 1, 1)
            mask_patches = repeat(mask_patches, "b z y x -> b d z y x", d=embeddings.shape[1])
            embeddings = (embeddings * (1 - mask_patches)) + (mask_patches * mask_token)

        # (b, dim, num_patches_z, num_patches_y, num_patches_x)
        embeddings = self.absolute_position_embeddings(embeddings, spacings=spacings, device=pixel_values.device)
        # (b, dim, num_patches_z, num_patches_y, num_patches_x)

        embeddings = rearrange(embeddings, "b e nz ny nx -> b (nz ny nx) e").contiguous()
        # (b, num_tokens, dim)

        embeddings = self.pos_drop(embeddings)
        # (b, num_tokens, dim)

        class_tokens = None
        if self.num_class_tokens > 0:
            class_tokens = repeat(self.class_tokens, "1 n d -> b n d", b=embeddings.shape[0])
            embeddings = torch.cat([class_tokens, embeddings], dim=1)
            # (b, num_tokens + num_class_tokens, dim)

        encoded, layer_outputs = self.encoder(embeddings, return_intermediates=True)
        # encoded: (b, num_tokens (+ num_class_tokens), dim)
        # layer_outputs: list of (b, num_tokens (+ 1), dim)

        if self.num_class_tokens > 0:
            class_tokens = encoded[:, : self.num_class_tokens]
            encoded = encoded[:, self.num_class_tokens :]

        if return_intermediates:
            return encoded, class_tokens, layer_outputs
        return encoded, class_tokens

In [None]:
test_config = ViT3DConfig.model_validate(
    {
        "num_class_tokens": 2,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "drop_prob": 0.2,
        "embed_spacing_info": False,
        "encoder_depth": 4,
        "image_size": (32, 512, 512),
        "in_channels": 1,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
    }
)

test = ViT3D(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 512, 512),
    torch.randn(2, 3),
    return_intermediates=True,
)
display((o[0].shape, o[1].shape, [x.shape for x in o[2]]))


[1;35mViT3DModel[0m[1m([0m
  [1m([0mpatchify[1m)[0m: [1;35mPatchEmbeddings3D[0m[1m([0m
    [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m[1m)[0m
    [1m([0mnormalization[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m768[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m, [33mcheckpointing_level[0m=[1;36m1[0m[1m)[0m
  [1m)[0m
  [1m([0mabsolute_position_embeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
  [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.2[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m(


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4096[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m2[0m, [1;36m768[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m4098[0m, [1;36m768[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

# Masked Image Modeling

In [None]:
# | export


class ViT3DMIMDecoder(nn.Module):
    def __init__(self, dim, image_size, in_channels, patch_size):
        super().__init__()

        self.image_size = image_size
        self.in_channels = in_channels
        self.patch_size = patch_size

        out_dim = np.prod(self.patch_size) * self.in_channels

        self.decoder = nn.Linear(dim, out_dim)

    def forward(self, encodings: torch.Tensor):
        # encodings: (b, num_tokens, dim)

        decoded = self.decoder(encodings)
        # (b, num_tokens, new_dim)

        decoded = rearrange(
            decoded,
            "b (nz ny nx) (c pz py px) -> b c (nz pz) (ny py) (nx px)",
            c=self.in_channels,
            pz=self.patch_size[0],
            py=self.patch_size[1],
            px=self.patch_size[2],
            nz=self.image_size[0] // self.patch_size[0],
            ny=self.image_size[1] // self.patch_size[1],
            nx=self.image_size[2] // self.patch_size[2],
        ).contiguous()
        # (b, c, z, y, x)

        return decoded

In [None]:
test = ViT3DMIMDecoder(768, (32, 512, 512), 1, (8, 16, 16))
display(test)
display(test(torch.randn(2, 4096, 768)).shape)


[1;35mViT3DMIMDecoder[0m[1m([0m
  [1m([0mdecoder[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m768[0m, [33mout_features[0m=[1;36m2048[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m512[0m, [1;36m512[0m[1m][0m[1m)[0m

In [None]:
# | export


class ViT3DMIM(nn.Module):
    def __init__(self, config: ViT3DConfig):
        super().__init__()

        assert config.num_class_tokens == 0, "MIM does not support class tokens"

        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.in_channels = config.in_channels
        self.mask_ratio = config.mask_ratio

        self.vit = ViT3D(config)
        self.decoder = ViT3DMIMDecoder(config.dim, config.image_size, config.in_channels, config.patch_size)

        self.mask_token = nn.Parameter(torch.randn(1, config.dim, 1, 1, 1))

    def mask_image(self, pixel_values: torch.Tensor):
        b = pixel_values.shape[0]

        mask_ratio = self.mask_ratio
        grid_size = tuple([size // patch for size, patch in zip(self.image_size, self.patch_size)])
        num_tokens = np.prod(grid_size)
        mask_patches = []
        for _ in range(b):
            _mask_patches = torch.zeros(num_tokens, dtype=torch.int8, device=pixel_values.device)
            _mask_patches[: int(mask_ratio * num_tokens)] = 1
            _mask_patches = _mask_patches[torch.randperm(num_tokens)]
            _mask_patches = rearrange(
                _mask_patches,
                "(z y x) -> z y x",
                z=grid_size[0],
                y=grid_size[1],
                x=grid_size[2],
            ).contiguous()
            mask_patches.append(_mask_patches)
        mask_patches: torch.Tensor = torch.stack(mask_patches, dim=0)

        return mask_patches

In [None]:
# | export


class ViT3DSimMIM(ViT3DMIM, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__(config)

    @staticmethod
    def loss_fn(pred: torch.Tensor, target: torch.Tensor, reduction="mean"):
        return nn.functional.l1_loss(pred, target, reduction=reduction)

    def forward(self, pixel_values: torch.Tensor, spacings: torch.Tensor):
        mask_patches = self.mask_image(pixel_values)

        encodings, _ = self.vit(pixel_values, spacings, mask_patches, self.mask_token)
        decoded = self.decoder(encodings)

        loss = self.loss_fn(decoded, pixel_values, reduction="none")
        mask = repeat(
            mask_patches,
            "b z y x -> b (z pz) (y py) (x px)",
            pz=self.patch_size[0],
            py=self.patch_size[1],
            px=self.patch_size[2],
        )
        loss = (loss * mask).sum() / ((mask.sum() + 1e-5) * self.in_channels)

        return decoded, loss, mask

In [None]:
test_config = ViT3DConfig.model_validate(
    {
        "num_class_tokens": 0,
        "attn_drop_prob": 0.2,
        "dim": 768,
        "drop_prob": 0.2,
        "encoder_depth": 4,
        "image_size": (32, 128, 128),
        "in_channels": 1,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
        "mask_ratio": 0.8,
    }
)

test = ViT3DSimMIM(test_config)
display(test)
o = test(
    torch.randn(2, 1, 32, 128, 128),
    torch.randn(2, 3),
)
display((o[0].shape, o[1], o[2].shape))


[1;35mViT3DSimMIM[0m[1m([0m
  [1m([0mvit[1m)[0m: [1;35mViT3DModel[0m[1m([0m
    [1m([0mpatchify[1m)[0m: [1;35mPatchEmbeddings3D[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m)[0m[1m)[0m
      [1m([0mnormalization[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m768[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
      [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m, [33mcheckpointing_level[0m=[1;36m1[0m[1m)[0m
    [1m)[0m
    [1m([0mabsolute_position_embeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
    [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m1[0m, [1;36m32[0m, [1;36m128[0m, [1;36m128[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1;36m2.1494[0m, [33mgrad_fn[0m=[1m<[0m[1;95mDivBackward0[0m[1m>[0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m32[0m, [1;36m128[0m, [1;36m128[0m[1m][0m[1m)[0m
[1m)[0m

# Some more tests

### Overfitting tests

In [None]:
from tqdm.auto import tqdm

sample_spacings = torch.tensor([[1, 0.1, 0.1], [2, 0.2, 0.2], [3, 0.3, 0.3]])
sample_batch = torch.rand(3, 1, 16, 128, 128)
sample_config = ViT3DConfig.model_validate(
    {
        "num_class_tokens": 0,
        "attn_drop_prob": 0.2,
        "dim": 384,
        "drop_prob": 0.2,
        "embed_spacing_info": False,
        "encoder_depth": 4,
        "image_size": (16, 128, 128),
        "in_channels": 1,
        "mlp_ratio": 2,
        "layer_norm_eps": 1e-6,
        "mlp_drop_prob": 0.2,
        "num_heads": 4,
        "patch_size": (8, 16, 16),
        "proj_drop_prob": 0.2,
        "mask_ratio": 0.8,
    }
)

model = ViT3DSimMIM(sample_config)

sum(x.numel() for x in model.vit.parameters()), sum(x.numel() for x in model.decoder.parameters())

[1m([0m[1;36m5523076[0m, [1;36m788480[0m[1m)[0m

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [None]:
sample_batch = sample_batch.cuda()
sample_spacings = sample_spacings.cuda()
model = model.cuda()

In [None]:
for i in tqdm(range(200)):
    optimizer.zero_grad()
    output = model(sample_batch, sample_spacings)
    print(f"Loss: {output[1]:f}\tLR: {scheduler.get_last_lr()[0]:f}")
    output[1].backward()
    optimizer.step()
    scheduler.step()

  0%|          | 0/200 [00:00<?, ?it/s]

Loss: 2.595103	LR: 0.500000
Loss: 2.010396	LR: 0.500000
Loss: 1.866450	LR: 0.500000
Loss: 1.758274	LR: 0.500000
Loss: 1.708617	LR: 0.500000
Loss: 1.616564	LR: 0.450000
Loss: 1.497164	LR: 0.450000
Loss: 1.458940	LR: 0.450000
Loss: 1.453419	LR: 0.450000
Loss: 1.416709	LR: 0.450000
Loss: 1.373173	LR: 0.405000
Loss: 1.327370	LR: 0.405000
Loss: 1.344934	LR: 0.405000
Loss: 1.332798	LR: 0.405000
Loss: 1.291747	LR: 0.405000
Loss: 1.290107	LR: 0.364500
Loss: 1.272306	LR: 0.364500
Loss: 1.237117	LR: 0.364500
Loss: 1.230260	LR: 0.364500
Loss: 1.220406	LR: 0.364500
Loss: 1.229550	LR: 0.328050
Loss: 1.217688	LR: 0.328050
Loss: 1.195621	LR: 0.328050
Loss: 1.182000	LR: 0.328050
Loss: 1.179293	LR: 0.328050
Loss: 1.177949	LR: 0.295245
Loss: 1.169717	LR: 0.295245
Loss: 1.165436	LR: 0.295245
Loss: 1.167471	LR: 0.295245
Loss: 1.160034	LR: 0.295245
Loss: 1.156658	LR: 0.265721
Loss: 1.145080	LR: 0.265721
Loss: 1.137138	LR: 0.265721
Loss: 1.134528	LR: 0.265721
Loss: 1.136742	LR: 0.265721
Loss: 1.133928	LR: 0

In [None]:
for name, param in model.named_parameters():
    if param.grad is None:
        print(name)

vit.encoder.layers.0.attn.logit_scale
vit.encoder.layers.1.attn.logit_scale
vit.encoder.layers.2.attn.logit_scale
vit.encoder.layers.3.attn.logit_scale


# nbdev

In [None]:
!nbdev_export