In [1]:
! pip install vit-pytorch

Collecting vit-pytorch
  Downloading vit_pytorch-1.11.7-py3-none-any.whl.metadata (69 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/69.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Downloading vit_pytorch-1.11.7-py3-none-any.whl (142 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.7/142.7 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: vit-pytorch
Successfully installed vit-pytorch-1.11.7


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from vit_pytorch.simple_vit_3d import SimpleViT
from vit_pytorch.simple_vit_3d import posemb_sincos_3d
from einops import rearrange


In [9]:
vit_3d = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

#example input:
#video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

#embedding size will be (4, 1000)
#preds = vit_3d(video) # (4, 1000)

In [34]:
class SimpleViT3dSeg(SimpleViT):
  #override the forward function so it doesn't apply the linear classification head
  def forward(self, video):
        *_, h, w, dtype = *video.shape, video.dtype

        x = self.to_patch_embedding(video)
        print(f"patch shape: {x.shape}")
        _, depth_patch_size, height_patch_size, width_patch_size, _ = x.shape

        pe = posemb_sincos_3d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.transformer(x)
        print(f"raw embeddings shape: {x.shape}")
        batch_size, patch_volume, embd_size = x.shape

        #convert to per-patch embedding format for segmentation
        #credit: code for re-arranging to per-patch format generated from chatgpt
        feat_grid = x.transpose(1, 2).contiguous().view(batch_size, embd_size, depth_patch_size, height_patch_size, width_patch_size)
        print(f"per-patch embedding dim: {feat_grid.shape}")

        #x = x.mean(dim = 1) #don't apply pooling since we want the per-patch embedding
        x = self.to_latent(x) #this is a palce holder, does nothing so we can keep it


        return x

In [35]:
vit_3d_seg = SimpleViT3dSeg(
    image_size = 128,          # image size
    frames = 128,               # for volumetric data: this is slice number/depth
    image_patch_size = 16,     # image patch size
    frame_patch_size = 16,      # for volumetric data: this should be same as image patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

video = torch.randn(4, 3, 128, 128, 128)

preds = vit_3d_seg(video)

print(preds.shape)

patch shape: torch.Size([4, 8, 8, 8, 1024])
raw embeddings shape: torch.Size([4, 512, 1024])
per-patch embedding dim: torch.Size([4, 1024, 8, 8, 8])
torch.Size([4, 512, 1024])
