## we will implement vitdet3d with FPN here 

In [1]:
#| default_exp networks/vitdet3d

In [1]:
#| export 
import torch
import fastcore.all as fc
import torch.nn as nn

from medct.vitdet3d import VitDet3dBackbone, VitDetConfig, VitDet3dLayerNorm
from collections import OrderedDict

In [2]:
config = VitDetConfig(image_size=(96, 192, 192), 
                      patch_size=(4, 8, 8), 
                      hidden_size=96,
                      num_channels=1,
                      use_relative_position_embeddings=True, 
                      window_block_indices=list(range(4)),
                      window_size =(4, 4, 4), 
                      out_indices = [2, 4], 
                      num_hidden_layers= 4,
                      out_features = ["stage2", "stage4"], 
                      stage_names = ["stem"]+[f"stage{i}" for i in range(1, 5)])
model = VitDet3dBackbone(config)


In [3]:
model

VitDet3dBackbone(
  (embeddings): ViTDet3dEmbeddings(
    (projection): Conv3d(1, 96, kernel_size=(4, 8, 8), stride=(4, 8, 8))
  )
  (encoder): VitDet3dEncoder(
    (layer): ModuleList(
      (0-3): 4 x VitDet3dLayer(
        (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
        (attention): VitDet3dAttention(
          (qkv): Linear(in_features=96, out_features=288, bias=True)
          (proj): Linear(in_features=96, out_features=96, bias=True)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
        (mlp): VitDet3dMlp(
          (fc1): Linear(in_features=96, out_features=384, bias=True)
          (act): GELUActivation()
          (fc2): Linear(in_features=384, out_features=96, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
  )
)

In [4]:
x = torch.randn((1, 1, 96, 192, 192))
out = model(x, output_hidden_states=True)

In [5]:
[i.shape for i in out.feature_maps]

[torch.Size([1, 96, 24, 24, 24]), torch.Size([1, 96, 24, 24, 24])]

In [6]:
[i.shape for i in out.hidden_states]

[torch.Size([1, 96, 24, 24, 24]),
 torch.Size([1, 96, 24, 24, 24]),
 torch.Size([1, 96, 24, 24, 24]),
 torch.Size([1, 96, 24, 24, 24]),
 torch.Size([1, 96, 24, 24, 24])]

## Simple FPN3D 

> we got a stride of (4, 8, 8) from the backbone. Now we need to 
- downsample this to 12 and 6
- keep it the same as 24
- upsample it to 48

so that we have 4 layers to process through

In [7]:
scale_factors = [0.25, 0.5, 1, 2]
dim = 96
out_channels = 256

## Scale it to 2 

In [8]:
out_dim = dim // 2 
layers = [
    nn.ConvTranspose3d(dim, dim // 2, kernel_size=2, stride=2), 
    nn.Conv3d(out_dim, out_channels, kernel_size=1, bias=None), 
    nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=None)
]
layers = nn.Sequential(*layers)
layers

Sequential(
  (0): ConvTranspose3d(96, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (1): Conv3d(48, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
)

In [9]:
layers(out.feature_maps[0]).shape

torch.Size([1, 256, 48, 48, 48])

## Scale to 1

In [10]:
out_dim = dim
layers = [
    nn.Conv3d(out_dim, out_channels, kernel_size=1, bias=None), 
    VitDet3dLayerNorm(out_channels), 
    nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=None),
    VitDet3dLayerNorm(out_channels), 
]
layers = nn.Sequential(*layers)
layers(out.feature_maps[0]).shape

torch.Size([1, 256, 24, 24, 24])

## Scale to 0.5

In [11]:
out_dim = dim
layers = [
    nn.MaxPool3d(kernel_size=2, stride=2),
    nn.Conv3d(out_dim, out_channels, kernel_size=1, bias=None), 
    VitDet3dLayerNorm(out_channels), 
    nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=None),
    VitDet3dLayerNorm(out_channels), 
]
layers = nn.Sequential(*layers)
layers(out.feature_maps[0]).shape

torch.Size([1, 256, 12, 12, 12])

In [12]:
out_dim = dim
layers = [
    nn.Conv3d(out_dim, out_dim, kernel_size=2, stride=2),
    VitDet3dLayerNorm(out_dim), 
    nn.GELU(),
    nn.Conv3d(out_dim, out_channels, kernel_size=1, bias=None), 
    VitDet3dLayerNorm(out_channels), 
    nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=None),
    VitDet3dLayerNorm(out_channels), 
]
layers = nn.Sequential(*layers)
layers(out.feature_maps[0]).shape

torch.Size([1, 256, 12, 12, 12])

## Scale to 0.25 

In [13]:
out_dim = dim
layers = [
    nn.Conv3d(out_dim, out_dim, kernel_size=2, stride=2),
    VitDet3dLayerNorm(out_dim), 
    nn.GELU(),
    nn.Conv3d(out_dim, out_dim, kernel_size=2, stride=2),
    VitDet3dLayerNorm(out_dim), 
    nn.GELU(),
    nn.Conv3d(out_dim, out_channels, kernel_size=1, bias=None), 
    VitDet3dLayerNorm(out_channels), 
    nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=None),
    VitDet3dLayerNorm(out_channels), 
]
layers = nn.Sequential(*layers)
layers(out.feature_maps[0]).shape

torch.Size([1, 256, 6, 6, 6])

## Combine everything 

In [14]:
#| export 
def conv3d_reduce(in_dim, out_dim):
    layers = [
                        nn.Conv3d(in_dim, out_dim, kernel_size=2, stride=2),
                        VitDet3dLayerNorm(out_dim), 
                        nn.GELU()
                ]
    return nn.Sequential(*layers)

In [15]:
#| export 
# copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
class SimpleFeaturePyramidNetwork(torch.nn.Module):
    def __init__(self, dim, out_channels, scales):
        super().__init__()
        self.scales = sorted(scales)[::-1]
        ##Scales should be always from high to low 
        for n, scale in enumerate(self.scales):
            if scale not in [2, 1, 0.5, 0.25]: raise NotImplementedError("These modules are not implemented.")
            if scale == 2:
                out_dim = dim // 2 
                layers = [nn.ConvTranspose3d(dim, dim // 2, kernel_size=2, stride=2)]
            
            if scale == 1:
                out_dim = dim
                layers = []
            
            if scale == 0.5:
                out_dim = dim 
                layers = [conv3d_reduce(out_dim, out_dim)] 
                
            if scale == 0.25:
                out_dim = dim
                layers = [conv3d_reduce(out_dim, out_dim), 
                          conv3d_reduce(out_dim, out_dim)]      
                
                
            
            layers.extend([
                nn.Conv3d(out_dim, out_channels, kernel_size=1, bias=None), 
                VitDet3dLayerNorm(out_channels), 
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=None),
                VitDet3dLayerNorm(out_channels), 
            ])
            
            layers = nn.Sequential(*layers)
            self.add_module(f"layer{n+1}", layers)
    
    def forward(self, x):
        out = OrderedDict()
        for n, _ in enumerate(self.scales):
            out[f"layer{n+1}"] = getattr(self, f"layer{n+1}")(x)
        return out    

In [16]:
sfpn = SimpleFeaturePyramidNetwork(dim=96, out_channels=256, scales=[0.25, 0.5, 1, 2])
sfpn

SimpleFeaturePyramidNetwork(
  (layer1): Sequential(
    (0): ConvTranspose3d(96, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    (1): Conv3d(48, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (2): VitDet3dLayerNorm()
    (3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (4): VitDet3dLayerNorm()
  )
  (layer2): Sequential(
    (0): Conv3d(96, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (1): VitDet3dLayerNorm()
    (2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (3): VitDet3dLayerNorm()
  )
  (layer3): Sequential(
    (0): Sequential(
      (0): Conv3d(96, 96, kernel_size=(2, 2, 2), stride=(2, 2, 2))
      (1): VitDet3dLayerNorm()
      (2): GELU(approximate='none')
    )
    (1): Conv3d(96, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (2): VitDet3dLayerNorm()
    (3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1,

In [17]:
fpn_out = sfpn(out.feature_maps[0])

In [18]:
[(k, v.shape) for k, v in fpn_out.items()]

[('layer1', torch.Size([1, 256, 48, 48, 48])),
 ('layer2', torch.Size([1, 256, 24, 24, 24])),
 ('layer3', torch.Size([1, 256, 12, 12, 12])),
 ('layer4', torch.Size([1, 256, 6, 6, 6]))]

In [19]:
count = 0
for name, params in sfpn.named_parameters():
    count+=params.numel()
count #Around 7 million and this is fine

7426960

In [20]:
#| export 
class VitDet3dBackbonewithFPN3D(nn.Module):
    def __init__(self, backbone_cfg, scales=[2, 1, 0.5, 0.25], out_channels=256):
        super().__init__()
        fc.store_attr(names=["backbone_cfg", "scales", "out_channels"])
        from omegaconf import DictConfig, OmegaConf #during inference self.backbone_cfg is DictConfig which is not supported by transformers.
        if isinstance(self.backbone_cfg, DictConfig):
            self.backbone_cfg = OmegaConf.to_object(self.backbone_cfg)
        self.cfg = VitDetConfig(**self.backbone_cfg)
        self.body = VitDet3dBackbone(self.cfg)
        self.fpn = SimpleFeaturePyramidNetwork(self.cfg.hidden_size, out_channels, scales)
        
    def forward(self, x):
        out = self.body(x).feature_maps[-1]
        y = self.fpn(out)
        return y

In [21]:
config = dict(image_size=(96, 192, 192), 
          patch_size=(4, 8, 8), 
          hidden_size=96,
          num_channels=1,
          use_relative_position_embeddings=True, 
          window_block_indices=list(range(4)),
          window_size =(4, 4, 4), 
          out_indices = [2, 4], 
          num_hidden_layers= 4,
          out_features = ["stage2", "stage4"], 
          stage_names = ["stem"]+[f"stage{i}" for i in range(1, 5)])
model = VitDet3dBackbonewithFPN3D(config, scales=[2, 1, 0.5])
model

VitDet3dBackbonewithFPN3D(
  (body): VitDet3dBackbone(
    (embeddings): ViTDet3dEmbeddings(
      (projection): Conv3d(1, 96, kernel_size=(4, 8, 8), stride=(4, 8, 8))
    )
    (encoder): VitDet3dEncoder(
      (layer): ModuleList(
        (0-3): 4 x VitDet3dLayer(
          (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (attention): VitDet3dAttention(
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (proj): Linear(in_features=96, out_features=96, bias=True)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (mlp): VitDet3dMlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (act): GELUActivation()
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (drop): Dropout(p=0.0, inplace=False)
          )
        )
      )
    )
  )
  (fpn): SimpleFeaturePyramidNetwork(
    (layer1): Sequentia

In [22]:
x = torch.randn((1, 1, 96, 192, 192))
fpn_out = model(x)

In [23]:
[(k, v.shape) for k, v in fpn_out.items()]

[('layer1', torch.Size([1, 256, 48, 48, 48])),
 ('layer2', torch.Size([1, 256, 24, 24, 24])),
 ('layer3', torch.Size([1, 256, 12, 12, 12]))]

In [146]:
#| hide
import nbdev; nbdev.nbdev_export()