we will be implementing `convNextv2` here. 

This is already implemented in [medct]. we will use this as backbone here.

In [1]:
#| default_exp networks/convnextv2

In [2]:
#| export 
import torch
import torch.nn as nn
import fastcore.all as fc

from collections import OrderedDict
from medct.convnextv2 import ConvNextV2Model3d, ConvNextV2Config3d
from monai.networks.blocks.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

In [3]:
cfg = ConvNextV2Config3d(num_channels=1, image_size=(96, 192, 192), num_stages=3, hidden_sizes=[80, 160, 320], depths=[3, 6, 3])
backbone = ConvNextV2Model3d(cfg)

In [4]:
count = 0
for name, params in backbone.state_dict().items():
    count+=params.numel()
count 

5134400

Atto: depths=[2, 2, 6, 2], dims=[40, 80, 160, 320]  
fempto: depths=[2, 2, 6, 2], dims=[48, 96, 192, 384]  
pico: depths=[2, 2, 6, 2], dims=[64, 128, 256, 512]  
nano: depths=[2, 2, 8, 2], dims=[80, 160, 320, 640]  
tiny: depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]  
base: depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]  
large: depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536]  
huge: depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816]  

In [5]:
model_sizes = dict(
atto = dict(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320]), 
fempto = dict(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384]),  
pico= dict(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512]), 
nano= dict(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640]), 
tiny= dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]), 
base= dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024]), 
large= dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536]), 
huge= dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816]))

In [6]:
for name, params in model_sizes.items():
    cfg = ConvNextV2Config3d(num_channels=1, image_size=(96, 192, 192), num_stages=4, hidden_sizes=params["dims"], depths=params["depths"])
    backbone = ConvNextV2Model3d(cfg)
    count = 0
    for _, paramst in backbone.state_dict().items():
        count+=paramst.numel()
    print(name, params, count)

atto {'depths': [2, 2, 6, 2], 'dims': [40, 80, 160, 320]} 4197800
fempto {'depths': [2, 2, 6, 2], 'dims': [48, 96, 192, 384]} 5885232
pico {'depths': [2, 2, 6, 2], 'dims': [64, 128, 256, 512]} 10107968
nano {'depths': [2, 2, 8, 2], 'dims': [80, 160, 320, 640]} 17329360
tiny {'depths': [3, 3, 9, 3], 'dims': [96, 192, 384, 768]} 31363776
base {'depths': [3, 3, 27, 3], 'dims': [128, 256, 512, 1024]} 95753472
large {'depths': [3, 3, 27, 3], 'dims': [192, 384, 768, 1536]} 210575232
huge {'depths': [3, 3, 27, 3], 'dims': [352, 704, 1408, 2816]} 692885952


> what if we want to use only depth of 2 ?

In [7]:
model_sizes = dict(
atto = dict(depths=[3, 3], dims=[40, 80]), 
fempto = dict(depths=[3, 3], dims=[48, 96]),  
pico= dict(depths=[3, 3], dims=[64, 128]), 
nano= dict(depths=[3, 6], dims=[80, 160]), 
tiny= dict(depths=[3, 6], dims=[96, 192]), 
base= dict(depths=[3, 6], dims=[128, 256]), 
large= dict(depths=[3, 6], dims=[192, 384]), 
huge= dict(depths=[3, 6], dims=[352, 704]))

In [8]:
for name, params in model_sizes.items():
    cfg = ConvNextV2Config3d(num_channels=1, image_size=(96, 192, 192), num_stages=2, hidden_sizes=params["dims"], depths=params["depths"])
    backbone = ConvNextV2Model3d(cfg)
    count = 0
    for _, paramst in backbone.state_dict().items():
        count+=paramst.numel()
    print(name, params, count)

atto {'depths': [3, 3], 'dims': [40, 80]} 349840
fempto {'depths': [3, 3], 'dims': [48, 96]} 472032
pico {'depths': [3, 3], 'dims': [64, 128]} 768640
nano {'depths': [3, 6], 'dims': [80, 160]} 1921600
tiny {'depths': [3, 6], 'dims': [96, 192]} 2662272
base {'depths': [3, 6], 'dims': [128, 256]} 4499968
large {'depths': [3, 6], 'dims': [192, 384]} 9600768
huge {'depths': [3, 6], 'dims': [352, 704]} 30667648


In [9]:
out = backbone(torch.randn((1, 1, 96, 192, 192)), output_hidden_states=True)

In [10]:
out.last_hidden_state.shape

torch.Size([1, 704, 12, 24, 24])

In [11]:
[i.shape for i in out.hidden_states]

[torch.Size([1, 352, 24, 48, 48]),
 torch.Size([1, 352, 24, 48, 48]),
 torch.Size([1, 704, 12, 24, 24])]

In [12]:
#| export 
class ConvNextV2BackbonewithFPN3D(nn.Module):
    def __init__(self, backbone_cfg, returned_layers=[1, 2], out_channels=256, extra_blocks=False):
        super().__init__()
        fc.store_attr(names=["backbone_cfg", "returned_layers", "out_channels", "extra_blocks"])
        from omegaconf import DictConfig, OmegaConf #during inference self.backbone_cfg is DictConfig which is not supported by transformers.
        if isinstance(self.backbone_cfg, DictConfig):
            self.backbone_cfg = OmegaConf.to_object(self.backbone_cfg)
        self.cfg = ConvNextV2Config3d(**self.backbone_cfg)
        self.body = ConvNextV2Model3d(self.cfg)
        self.fpn = FeaturePyramidNetwork( 
            spatial_dims=3, 
            in_channels_list=self.cfg.hidden_sizes,
            out_channels=out_channels,
            extra_blocks=LastLevelMaxPool(3) if extra_blocks else None,
        )
        
    def forward(self, x):
        out = self.body(x, output_hidden_states=True).hidden_states
        out = OrderedDict({f"layer{k}": v for k, v in enumerate(out) if k in self.returned_layers})
        y = self.fpn(out)
        return y

In [13]:
cfg = dict(num_channels=1, image_size=(96, 192, 192), num_stages=3, hidden_sizes=[80, 160, 320], depths=[3, 6, 3])
#cfg = dict(num_channels=1, image_size=(96, 192, 192), num_stages=2, hidden_sizes=[40, 80])
model = ConvNextV2BackbonewithFPN3D(cfg, returned_layers=[1, 2, 3])

In [19]:
model.body.embeddings.patch_embeddings

Conv3d(1, 80, kernel_size=(4, 4, 4), stride=(4, 4, 4))

In [18]:
size_divisible = tuple(2 * s * 2 ** max([1, 2]) for s in model.body.embeddings.patch_embeddings.stride)
size_divisible

(32, 32, 32)

In [18]:
model

ConvNextV2BackbonewithFPN3D(
  (body): ConvNextV2Model3d(
    (embeddings): ConvNextV2Embeddings3d(
      (patch_embeddings): Conv3d(1, 80, kernel_size=(4, 4, 4), stride=(4, 4, 4))
      (layernorm): ConvNextV2LayerNorm3d()
    )
    (encoder): ConvNextV2Encoder3d(
      (stages): ModuleList(
        (0): ConvNextV2Stage3d(
          (downsampling_layer): Identity()
          (layers): Sequential(
            (0): ConvNextV2Layer3d(
              (dwconv): Conv3d(80, 80, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=80)
              (layernorm): ConvNextV2LayerNorm3d()
              (pwconv1): Linear(in_features=80, out_features=320, bias=True)
              (act): GELUActivation()
              (grn): ConvNextV2GRN3d()
              (pwconv2): Linear(in_features=320, out_features=80, bias=True)
              (drop_path): Identity()
            )
            (1): ConvNextV2Layer3d(
              (dwconv): Conv3d(80, 80, kernel_size=(7, 7, 7), stride=(1, 1, 1), pad

In [16]:
out = model(torch.randn((1, 1, 96, 192, 192)))

In [17]:
[(k, v.shape) for k, v in out.items()]

[('layer1', torch.Size([1, 256, 24, 48, 48])),
 ('layer2', torch.Size([1, 256, 12, 24, 24])),
 ('layer3', torch.Size([1, 256, 6, 12, 12]))]

## Example-2 with extra_blocks=True

In [13]:
cfg = dict(num_channels=1, image_size=(96, 192, 192), num_stages=2, hidden_sizes=[40, 80])
model = ConvNextV2BackbonewithFPN3D(cfg, extra_blocks=True)
out = model(torch.randn((1, 1, 96, 192, 192)))
[(k, v.shape) for k, v in out.items()]

[('layer1', torch.Size([1, 256, 24, 48, 48])),
 ('layer2', torch.Size([1, 256, 12, 24, 24])),
 ('pool', torch.Size([1, 256, 6, 12, 12]))]

## Patch size is different in different dimensions 

In [19]:
cfg = dict(num_channels=1, image_size=(96, 192, 192), patch_size=(2, 4, 4), num_stages=2, hidden_sizes=[40, 80])
model = ConvNextV2BackbonewithFPN3D(cfg, extra_blocks=False)
out = model(torch.randn((1, 1, 96, 192, 192)))
[(k, v.shape) for k, v in out.items()]

[('layer1', torch.Size([1, 256, 48, 48, 48])),
 ('layer2', torch.Size([1, 256, 24, 24, 24]))]

In [20]:
tuple(2 * s * 2 ** max([1, 2]) for s in model.body.embeddings.patch_embeddings.stride)

(16, 32, 32)

In [29]:
fpn_cfg = dict(
    __class_fullname__ ="voxdet.networks.convnextv2.ConvNextV2BackbonewithFPN3D",
    out_channels = 256,
    returned_layers = [1, 2], #from backbone 
    extra_blocks = True)
backbone_cfg = dict(num_channels=1, 
                    image_size=(96, 192, 192), 
                    patch_size=(2, 4, 4), 
                    num_stages=2, 
                    hidden_sizes=[40, 80])


#TODO: size_divisible is pending 

In [30]:
from voxdet.utils import locate_cls
fe = locate_cls(fpn_cfg, return_partial=True)(backbone_cfg=backbone_cfg)

In [31]:
fe

ConvNextV2BackbonewithFPN3D(
  (body): ConvNextV2Model3d(
    (embeddings): ConvNextV2Embeddings3d(
      (patch_embeddings): Conv3d(1, 40, kernel_size=(2, 4, 4), stride=(2, 4, 4))
      (layernorm): ConvNextV2LayerNorm3d()
    )
    (encoder): ConvNextV2Encoder3d(
      (stages): ModuleList(
        (0): ConvNextV2Stage3d(
          (downsampling_layer): Identity()
          (layers): Sequential(
            (0): ConvNextV2Layer3d(
              (dwconv): Conv3d(40, 40, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=40)
              (layernorm): ConvNextV2LayerNorm3d()
              (pwconv1): Linear(in_features=40, out_features=160, bias=True)
              (act): GELUActivation()
              (grn): ConvNextV2GRN3d()
              (pwconv2): Linear(in_features=160, out_features=40, bias=True)
              (drop_path): Identity()
            )
            (1): ConvNextV2Layer3d(
              (dwconv): Conv3d(40, 40, kernel_size=(7, 7, 7), stride=(1, 1, 1), pad

In [1]:
#| hide
import nbdev; nbdev.nbdev_export()