In [1]:
! pip install vit-pytorch

Collecting vit-pytorch
  Downloading vit_pytorch-1.11.7-py3-none-any.whl.metadata (69 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/69.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Downloading vit_pytorch-1.11.7-py3-none-any.whl (142 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.7/142.7 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: vit-pytorch
Successfully installed vit-pytorch-1.11.7


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from vit_pytorch.simple_vit_3d import SimpleViT
from vit_pytorch.simple_vit_3d import posemb_sincos_3d
from einops import rearrange

import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


In [3]:
vit_3d = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

#example input:
#video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

#embedding size will be (4, 1000)
#preds = vit_3d(video) # (4, 1000)

In [4]:
class SingleDeconv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes):
        super().__init__()
        self.block = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0)

    def forward(self, x):
        return self.block(x)


class SingleConv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size):
        super().__init__()
        self.block = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1,
                               padding=((kernel_size - 1) // 2))

    def forward(self, x):
        return self.block(x)


class Conv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3):
        super().__init__()
        self.block = nn.Sequential(
            SingleConv3DBlock(in_planes, out_planes, kernel_size),
            nn.BatchNorm3d(out_planes),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.block(x)


class Deconv3DBlock(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3):
        super().__init__()
        self.block = nn.Sequential(
            SingleDeconv3DBlock(in_planes, out_planes),
            SingleConv3DBlock(out_planes, out_planes, kernel_size),
            nn.BatchNorm3d(out_planes),
            nn.ReLU(True)
        )

    def forward(self, x):
        return self.block(x)

In [5]:
class SimpleViT3dSeg(SimpleViT):
  #add decoder attributes for segmentation
  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

    print(kwargs['dim'], kwargs['num_classes'])
    #upsampling decoder, from patch size 8x8x8 to 128x128x128
    self.decoder = nn.Sequential(
      Deconv3DBlock(kwargs['dim'], 256),
      Deconv3DBlock(256, 128),
      Deconv3DBlock(128, 64),
      nn.Conv3d(in_channels=64, out_channels=kwargs['num_classes'], kernel_size=1)
    )

  #override the forward function so embeddings get fed into decoder instead of linear classification head
  def forward(self, video):
        *_, h, w, dtype = *video.shape, video.dtype

        x = self.to_patch_embedding(video)
        print(f"patch shape: {x.shape}")
        _, depth_patch_size, height_patch_size, width_patch_size, _ = x.shape

        pe = posemb_sincos_3d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.transformer(x)
        print(f"raw embeddings shape: {x.shape}")
        batch_size, patch_volume, embd_size = x.shape

        #convert to per-patch embedding format for segmentation
        #credit: code for re-arranging to per-patch format generated from chatgpt
        feat_grid_embeddings = x.transpose(1, 2).contiguous().view(batch_size, embd_size, depth_patch_size, height_patch_size, width_patch_size)
        print(f"per-patch embedding dim: {feat_grid_embeddings.shape}")

        #x = x.mean(dim = 1) #don't apply pooling since we want the per-patch embedding
        feat_grid_embeddings = self.to_latent(feat_grid_embeddings) #this is a palce holder, does nothing so we can keep it

        logits = self.decoder(feat_grid_embeddings)
        return logits

In [6]:
test_vit_seg_model = SimpleViT3dSeg(
    image_size = 64,          # image size
    frames = 64,               # for volumetric data: this is slice number/depth
    image_patch_size = 8,     # image patch size
    frame_patch_size = 8,      # for volumetric data: this should be same as image patch size
    num_classes = 3,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
)

test_input = torch.randn(1, 3, 64, 64, 64)

preds = test_vit_seg_model(test_input)

print(preds.shape, type(preds))

1024 3
patch shape: torch.Size([1, 8, 8, 8, 1024])
raw embeddings shape: torch.Size([1, 512, 1024])
per-patch embedding dim: torch.Size([1, 1024, 8, 8, 8])
torch.Size([1, 3, 64, 64, 64]) <class 'torch.Tensor'>


In [7]:
#training loop
vit_3d_seg = SimpleViT3dSeg(
    image_size = 64,
    frames = 64,
    image_patch_size = 8,
    frame_patch_size = 8,
    num_classes = 3,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 5
ce_loss = nn.CrossEntropyLoss()
vit_3d_seg.to(device)
optimizer = optim.Adam(vit_3d_seg.parameters(), lr=1e-3, weight_decay=1e-4)


#dummy data
X = torch.randn(8, 3, 64, 64, 64)            # (B, C, T, H, W)
Y = torch.randint(0, 3, (8, 64, 64, 64))     # (B, T, H, W) with 3 classes
train_loader = DataLoader(TensorDataset(X, Y), batch_size=2, shuffle=True)

for ep in range(epochs):
  vit_3d_seg.train()

  for x, y in train_loader:
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    logits = vit_3d_seg(x)
    loss = ce_loss(logits, y)
    loss.backward()
    optimizer.step()

  print(f"epoch: {ep}, loss: {loss.item():.4f}")


1024 3
patch shape: torch.Size([2, 8, 8, 8, 1024])
raw embeddings shape: torch.Size([2, 512, 1024])
per-patch embedding dim: torch.Size([2, 1024, 8, 8, 8])
patch shape: torch.Size([2, 8, 8, 8, 1024])
raw embeddings shape: torch.Size([2, 512, 1024])
per-patch embedding dim: torch.Size([2, 1024, 8, 8, 8])
patch shape: torch.Size([2, 8, 8, 8, 1024])
raw embeddings shape: torch.Size([2, 512, 1024])
per-patch embedding dim: torch.Size([2, 1024, 8, 8, 8])
patch shape: torch.Size([2, 8, 8, 8, 1024])
raw embeddings shape: torch.Size([2, 512, 1024])
per-patch embedding dim: torch.Size([2, 1024, 8, 8, 8])
epoch: 0, loss: 1.1150
patch shape: torch.Size([2, 8, 8, 8, 1024])
raw embeddings shape: torch.Size([2, 512, 1024])
per-patch embedding dim: torch.Size([2, 1024, 8, 8, 8])
patch shape: torch.Size([2, 8, 8, 8, 1024])
raw embeddings shape: torch.Size([2, 512, 1024])
per-patch embedding dim: torch.Size([2, 1024, 8, 8, 8])
patch shape: torch.Size([2, 8, 8, 8, 1024])
raw embeddings shape: torch.Size