In [1]:
#| default_exp networks/convnext

In [2]:
#| export 
import torch 
import torch.nn as nn
import torch.nn.functional as F
import fastcore.all as fc

from voxdet.networks.fpn import BackbonewithFPN3D
from voxdet.networks.res_se_net import conv3d

In [3]:
%load_ext autoreload
%autoreload 2

## ConvNext 
They have made several changes to ResNet to make it as roboust as possible to `Swin-Transformers`. We will explore some of the changes here. 
- Change Stage Compute ratio: For ResNet50, we have (3, 4, 6, 3). They changed it to (3, 3, 9, 3). For ResNet10, lets change it to `(1, 3, 3, 1)`
- stem cell 4x4 stride 4 conv layer, 96 channels - 0.1% improvement
- They have added inverted blocks with depthwise conv with large kernel size and 1x1 conv: 

```
d3x3, 96x96
1x1, 96x384
1x1, 384x96
```
- use GELU instead of RELU, also GELU will be only between two 1x1 blocks 
- we will use LN instead of BN, Also we will use only few batch-norm layers: before 1x1 layer. 
- instead of downsampling included in the stages, we will include this as a separate step after each stage. Use 2x2 conv with stride 2 


In [4]:
img = torch.zeros((1, 1, 96, 192, 192))
img.shape

torch.Size([1, 1, 96, 192, 192])

In [5]:
#| export
class LayerNorm3d(nn.LayerNorm):
    """ LayerNorm for channels of '3D' spatial NCDHW tensors """
    def __init__(self, num_channels, eps=1e-6, affine=True):
        super().__init__(num_channels, eps=eps, elementwise_affine=affine)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.permute(0, 2, 3, 4, 1) #NCDHW -> NDHWC
        #(0, 2, 3, 1) -> NCHW -> NHWC 
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        x = x.permute(0, 4, 1, 2, 3) # NDHWC -> NCDHW
        return x

In [6]:
#| export 
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    copied from https://github.com/rwightman/pytorch-image-models/blob/7d9e321b761a673000af312ad21ef1dec491b1e9/timm/layers/drop.py#L137
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

In [7]:
#| export
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super().__init__()
        fc.store_attr()
    __repr__ = fc.basic_repr("drop_prob, scale_by_keep")
    def forward(self, x):return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

In [8]:
x = DropPath(0.1, True)

In [9]:
%time x(img).shape

CPU times: user 95.2 ms, sys: 0 ns, total: 95.2 ms
Wall time: 32.7 ms


torch.Size([1, 1, 96, 192, 192])

### Stem 

In [10]:
stem = conv3d(1, 96, ks=4, stride=(2, 4, 4), norm=LayerNorm3d, padding=(1, 0, 0))
stem

Sequential(
  (0): Conv3d(1, 96, kernel_size=(4, 4, 4), stride=(2, 4, 4), padding=(1, 0, 0), bias=False)
  (1): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
)

In [11]:
outs = stem(img)
outs.shape

torch.Size([1, 96, 48, 48, 48])

### Depthwise conv 
`!pip install fvcore`

In [12]:
dconv = nn.Conv3d(96, 96, stride=2, kernel_size=4, groups=96)
nconv = nn.Conv3d(96, 96, stride=2, kernel_size=4)
tconv = nn.Conv3d(64, 64, stride=2, kernel_size=4)

In [13]:
from fvcore.nn import FlopCountAnalysis
flops = FlopCountAnalysis(dconv, outs)
flops2 = FlopCountAnalysis(nconv, outs)
flops3 = FlopCountAnalysis(tconv, torch.zeros((1, 64, 48, 48, 48)))
flops.total(), flops2.total(), flops3.total(), flops2.total()/flops.total(), flops3.total()/flops.total()

(74754048, 7176388608, 3189506048, 96.0, 42.666666666666664)

### ConvNextBlock 

In [14]:
#| export
class ConvNextBlock(nn.Module):
    def __init__(self, dim, drop_path=0.):
        super().__init__()
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm3d(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.pwconv1 = nn.Conv3d(dim, 4 * dim, kernel_size=1)
        self.act = nn.GELU()
        self.pwconv2 = nn.Conv3d(4 * dim, dim, kernel_size=1)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()


    def forward(self, x):
        inputs = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        x = inputs + self.drop_path(x)
        return x

In [15]:
block = ConvNextBlock(96)
block

ConvNextBlock(
  (dwconv): Conv3d(96, 96, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=96)
  (norm): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
  (pwconv1): Conv3d(96, 384, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (act): GELU(approximate=none)
  (pwconv2): Conv3d(384, 96, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (drop_path): Identity()
)

In [16]:
%time block(outs).shape

CPU times: user 22.2 s, sys: 7.85 s, total: 30.1 s
Wall time: 768 ms


torch.Size([1, 96, 48, 48, 48])

### ConvNextStage

In [17]:
#| export 
class ConvNextStage(nn.Module):
    def __init__(self, dims, layers, dp_rates=0, normalize=None): 
        fc.store_attr()
        super().__init__()
        if not isinstance(dp_rates, list): dp_rates = [x.item() for x in torch.linspace(0, dp_rates, layers)]  
        for i in range(self.layers): 
            setattr(self, f"layer{i}", ConvNextBlock(dims, drop_path=dp_rates[i]))
            if self.normalize is not None: setattr(self, f"norm{i}", self.normalize(dims))
            
    def forward(self, x): 
        for i in range(self.layers): x = getattr(self, f"layer{i}")(x)
        if self.normalize is not None: x = getattr(self, f"norm{i}")(x)
        return x 

In [18]:
stage = ConvNextStage(96, 2, dp_rates=0, normalize=LayerNorm3d)
stage

ConvNextStage(
  (layer0): ConvNextBlock(
    (dwconv): Conv3d(96, 96, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=96)
    (norm): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
    (pwconv1): Conv3d(96, 384, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (act): GELU(approximate=none)
    (pwconv2): Conv3d(384, 96, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (drop_path): Identity()
  )
  (norm0): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
  (layer1): ConvNextBlock(
    (dwconv): Conv3d(96, 96, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=96)
    (norm): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
    (pwconv1): Conv3d(96, 384, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (act): GELU(approximate=none)
    (pwconv2): Conv3d(384, 96, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (drop_path): Identity()
  )
  (norm1): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
)

In [19]:
%time stage(outs).shape

CPU times: user 49.8 s, sys: 16.2 s, total: 1min 5s
Wall time: 1.71 s


torch.Size([1, 96, 48, 48, 48])

### Adding ConvNext 

In [20]:
depths = [3, 3, 9, 3]
dp = [x.item() for x in torch.linspace(0, 0.1, sum(depths))]
for i in range(len(depths)):
    print(dp[sum(depths[:i]): sum(depths[:i+1])])

[0.0, 0.0058823530562222, 0.0117647061124444]
[0.01764705963432789, 0.0235294122248888, 0.029411764815449715]
[0.03529411926865578, 0.04117647185921669, 0.0470588244497776, 0.052941177040338516, 0.05882352963089943, 0.06470588594675064, 0.07058823853731155, 0.07647059112787247, 0.08235294371843338]
[0.0882352963089943, 0.0941176488995552, 0.10000000149011612]


In [21]:
#| export 
class ConvNext(nn.Module):
    def __init__(self, ic=1, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], c1_ks=(4, 4, 4), c1_s=(2, 4, 4),
                 drop_path_rate=0.):
        fc.store_attr()
        super().__init__()
        dp_rates = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(depths))]
        pad = 0 if c1_s[0] == c1_s[1] else 1 
        for i in range(len(self.depths)):  
            if i == 0:  
                setattr(self, f"base", conv3d(ic, dims[0], ks=c1_ks, stride=c1_s, norm=LayerNorm3d, padding=(pad, 0, 0)))
            else:
                setattr(self, f"downsample{i}", nn.Conv3d(dims[i-1], dims[i], kernel_size=2, stride=2, padding=(0, 0, 0)))
            
            dp_rates_layer = dp_rates[sum(depths[:i]): sum(depths[:i+1])]
            setattr(self, f"stage{i+1}", ConvNextStage(dims[i], layers=depths[i], dp_rates=dp_rates_layer))            
        
    
    def forward(self, x): 
        out = x 
        for i in range(len(self.dims)):
            if i==0: out = self.base(out)
            else: out = getattr(self, f"downsample{i}")(out)
            out = getattr(self, f"stage{i+1}")(out)
        return out

In [22]:
c10 = ConvNext(ic=1, depths=(1, 1, 1, 1), dims=[96, 192, 384, 768])
c10

ConvNext(
  (base): Sequential(
    (0): Conv3d(1, 96, kernel_size=(4, 4, 4), stride=(2, 4, 4), padding=(1, 0, 0), bias=False)
    (1): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
  )
  (stage1): ConvNextStage(
    (layer0): ConvNextBlock(
      (dwconv): Conv3d(96, 96, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=96)
      (norm): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
      (pwconv1): Conv3d(96, 384, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (act): GELU(approximate=none)
      (pwconv2): Conv3d(384, 96, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (drop_path): Identity()
    )
  )
  (downsample1): Conv3d(96, 192, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (stage2): ConvNextStage(
    (layer0): ConvNextBlock(
      (dwconv): Conv3d(192, 192, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), groups=192)
      (norm): LayerNorm3d((192,), eps=1e-06, elementwise_affine=True)
      (pwconv1): Conv3d(192, 768, kernel_size=(1,

In [23]:
%time c10(img).shape

CPU times: user 41.4 s, sys: 8.47 s, total: 49.9 s
Wall time: 1.12 s


torch.Size([1, 768, 6, 6, 6])

### Is item getter working in ths case?

In [24]:
import torchvision

In [25]:
returned_layers = [1, 2, 3, 4]
return_layers = {f"stage{k}": str(v) for v, k in enumerate(returned_layers)}

In [26]:
body = torchvision.models._utils.IntermediateLayerGetter(c10, return_layers=return_layers)

In [27]:
%time outs = body(img)

CPU times: user 38 s, sys: 6.44 s, total: 44.5 s
Wall time: 1.03 s


In [28]:
[(k, v.shape) for k, v in outs.items()]

[('0', torch.Size([1, 96, 48, 48, 48])),
 ('1', torch.Size([1, 192, 24, 24, 24])),
 ('2', torch.Size([1, 384, 12, 12, 12])),
 ('3', torch.Size([1, 768, 6, 6, 6]))]

## ConvNext 10

In [29]:
#| export 
def convnext10(ic, dims=[96, 192, 384, 768], c1_ks=(4, 4, 4), c1_s=(2, 4, 4), drop_path_rate=0.):
    c10 = ConvNext(ic=ic, depths=(1, 1, 1, 1), dims=dims, c1_ks=c1_ks, c1_s=c1_s, drop_path_rate=drop_path_rate)
    return c10

In [30]:
#| export 
def convnext18(ic, dims=[96, 192, 384, 768], c1_ks=(4, 4, 4), c1_s=(2, 4, 4), drop_path_rate=0.):
    c18 = ConvNext(ic=ic, depths=(2, 2, 2, 2), dims=dims, c1_ks=c1_ks, c1_s=c1_s, drop_path_rate=drop_path_rate)
    return c18

In [31]:
#| export 
def convnext50(ic, dims=[96, 192, 384, 768], c1_ks=(4, 4, 4), c1_s=(2, 4, 4), drop_path_rate=0.):
    c50 = ConvNext(ic=ic, depths=(3, 3, 9, 3), dims=dims, c1_ks=c1_ks, c1_s=c1_s, drop_path_rate=drop_path_rate)
    return c50

## with FPN

In [32]:
[c10.dims[-1]//8 * 2 ** (i - 1) for i in [1, 2, 3, 4]]

[96, 192, 384, 768]

In [33]:
#| export 
def convnext_fpn3d_feature_extractor(backbone, out_channels=256, returned_layers=[1, 2, 3], extra_blocks:bool=False):
    in_channels_stage2 = backbone.dims[-1] // 8
    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
    return_layers = {f"stage{k}": str(v) for v, k in enumerate(returned_layers)}
    return BackbonewithFPN3D(backbone, return_layers, in_channels_list, out_channels, extra_blocks)

In [38]:
fpn = convnext_fpn3d_feature_extractor(c10, extra_blocks=10)

In [40]:
fpn.body.base

Sequential(
  (0): Conv3d(1, 96, kernel_size=(4, 4, 4), stride=(2, 4, 4), padding=(1, 0, 0), bias=False)
  (1): LayerNorm3d((96,), eps=1e-06, elementwise_affine=True)
)

In [37]:
tuple(2 * s * 2 ** max([1, 2]) for s in fpn.body.base[0].stride)

(16, 32, 32)

In [None]:
%time out = fpn(img)

In [None]:
[(k, v.shape) for k, v in out.items()]

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()