In [None]:
#| default_exp networks/res_se_net

## Blogs 
- [About ResNet blog](https://medium.com/@14prakash/understanding-and-implementing-architectures-of-resnet-and-resnext-for-state-of-the-art-image-cf51669e1624)
- [ResNet paper](https://arxiv.org/pdf/1512.03385.pdf)
- [fastai implementation](https://github.com/fastai/course22p2/blob/master/nbs/13_resnet.ipynb)

In [None]:
#| export 
import torch
import torch.nn as nn
import fastcore.all as fc

from functools import partial
from typing import List

from voxdet.activations import GeneralRelu

In [None]:
backbone = dict(
  spatial_dims = 3,
  conv1_t_stride = [2, 2, 1],
  pretrained = False,
  progress = False,
  n_input_channels = 1,
  conv1_t_size = [7, 7, 7],
  no_max_pool= False
)

In [None]:
import monai.networks.nets.resnet as monai_res

In [None]:
m10 = monai_res.resnet10(**backbone)#(torch.ones((1, 1, 3, 224, 224))).shape
m50 = monai_res.resnet50(**backbone)

In [None]:
m10.in_planes

512

In [None]:
n = 0
for name, params in m10.named_parameters(): n+=params.numel()
print(n)

14561616


In [None]:
def count_params(layer):
    n=0
    for name, params in layer.named_parameters(): n+=params.numel()
    return n

In [None]:
count_params(m50), count_params(m10)

(46978512, 14561616)

In [None]:
%%time
out = m10.relu(m10.bn1(m10.conv1(torch.ones((1, 1, 192, 192, 96)))))
out.shape

CPU times: user 2.67 s, sys: 385 ms, total: 3.06 s
Wall time: 378 ms


torch.Size([1, 64, 96, 96, 96])

In [None]:
out = m10.maxpool(out)
out.shape

torch.Size([1, 64, 48, 48, 48])

In [None]:
m10.layer2(out).shape

torch.Size([1, 128, 24, 24, 24])

In [None]:
#| export 
act_gr = partial(GeneralRelu, leak=0.1, sub=0.4)

## Conv3D block  - Conv-act-norm: base

In [None]:
#| export 
def conv3d(ni, nf, ks=3, stride=2, act=None, norm=None, bias=False, padding=None, dilation=1):
    pad = (ks//2 if isinstance(ks, int) else tuple(k // 2 for k in ks)) if padding is None else padding
    res = nn.Conv3d(ni, nf, stride=stride, kernel_size=ks, padding=pad, bias=bias, dilation=dilation)
    act = nn.Identity() if act is None else act()
    norm = nn.Identity() if norm is None else norm(nf)
    return nn.Sequential(*[i for i in [res, norm, act] if not isinstance(i, nn.Identity)])

> In our case we should use `3x3x3` as nodules can be of 3mm. Aggregating ks=7 is too much loss of info? From below, since the dimesion is not changing much, this might actually help us? we need to experiement and check.

In [None]:
base = conv3d(1, 64, (7, 7, 7), (1, 2, 2), act=act_gr, norm=nn.BatchNorm3d)
baseks3 = conv3d(1, 64, (3, 3, 3), (1, 2, 2), act=act_gr, norm=nn.BatchNorm3d)
base

Sequential(
  (0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
  (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
)

In [None]:
%%time
img = torch.zeros((1, 1, 96, 192, 192))
base_out = base(img)
base_out.shape

CPU times: user 2.93 s, sys: 496 ms, total: 3.43 s
Wall time: 276 ms


torch.Size([1, 64, 96, 96, 96])

In [None]:
baseks3_out = baseks3(img)
baseks3_out.shape

torch.Size([1, 64, 96, 96, 96])

## basic Block 

> we have `[[3×3,64], [3×3,64]]` blocks. `x2` for `resnet18` and `x3` for `resnet34` and `x1` for `resnet10`

> we will not have activation at the end. this is performed post `skip-connection` addition in the network. 

> `ks=3`as mentioned and that is why we kept it as a default option.

> `stride` is 2 for the 2nd conv if u want to decrease the size of the feature map. 

In [None]:
#| export 
def _conv3d_block(ni, nf, stride, act=act_gr, norm=None, ks=3):
    return nn.Sequential(conv3d(ni, nf, stride=stride, act=act, norm=norm, ks=ks),
                         conv3d(nf, nf, stride=1, act=None, norm=norm, ks=ks))

In [None]:
m10.layer1

Sequential(
  (0): ResNetBlock(
    (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [None]:
block1 = _conv3d_block(64, 64, 1, norm=nn.BatchNorm3d)
block1

Sequential(
  (0): Sequential(
    (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
  )
  (1): Sequential(
    (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [None]:
%time block1(baseks3_out).shape

CPU times: user 9 s, sys: 2.37 s, total: 11.4 s
Wall time: 952 ms


torch.Size([1, 64, 96, 96, 96])

In [None]:
%time block1(base_out).shape

CPU times: user 8.01 s, sys: 1.42 s, total: 9.42 s
Wall time: 793 ms


torch.Size([1, 64, 96, 96, 96])

In [None]:
%time m10.layer1(baseks3_out).shape

CPU times: user 8.11 s, sys: 1.33 s, total: 9.44 s
Wall time: 780 ms


torch.Size([1, 64, 96, 96, 96])

## BottleNeck 

> `[1×1,64 3×3,64] [1x1,256]`.

> we have stride int he mid conv layer, >1 is used if u want to reduce the size of the fe map. ks is also by default 3 here.  

> we expand `nf` by exp (default=4) in the final layer. 


In [None]:
#| export 
def _conv3d_bottleneck(ni, nf, stride, exp=4, act=act_gr, norm=None, ks=3):
    return nn.Sequential(conv3d(ni, nf, stride=1, act=act, norm=norm, ks=1),
                         conv3d(nf, nf, stride=stride, act=act, norm=norm, ks=ks),
                         conv3d(nf, nf*exp, stride=1, act=None, norm=norm, ks=1))

In [None]:
getattr(m50.layer1, "0")

ResNetBottleneck(
  (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [None]:
bottleneck = _conv3d_bottleneck(64, 64, 1, 4, norm=nn.BatchNorm3d)
bottleneck

Sequential(
  (0): Sequential(
    (0): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
  )
  (1): Sequential(
    (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
  )
  (2): Sequential(
    (0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [None]:
bottleneck(baseks3_out).shape, bottleneck(base_out).shape, m50.layer1(baseks3_out).shape

(torch.Size([1, 256, 96, 96, 96]),
 torch.Size([1, 256, 96, 96, 96]),
 torch.Size([1, 256, 96, 96, 96]))

In [None]:
%time bottleneck(baseks3_out).shape

CPU times: user 10.7 s, sys: 4.42 s, total: 15.1 s
Wall time: 1.5 s


torch.Size([1, 256, 96, 96, 96])

In [None]:
%time m50.layer1(baseks3_out).shape

CPU times: user 34.2 s, sys: 13.4 s, total: 47.6 s
Wall time: 4.41 s


torch.Size([1, 256, 96, 96, 96])

> the time difference is because the `m50` reblock does downsampling, addition and then `relu`. we will see how to add this next. 

## ResBlock 

> we will intitialize a `BasicBlock` or `BottleneckBlock`. Add `downsampling` layer if we want to reduce the size of the feature map.  

In [None]:
#| export 
class ResBlock(nn.Module):
    def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=nn.BatchNorm3d, block_type="basic"):
        super().__init__()
        fc.store_attr()
        if self.block_type not in ["basic", "bottleneck"]: raise NotImplementedError(f"block_type: {self.block_type} missing")
        exp = 4 if block_type!="basic" else 1
        self.convs = _conv3d_block(ni, nf, stride, act=act, ks=ks, norm=norm) if block_type=="basic" else \
        _conv3d_bottleneck(ni, nf, stride, exp=4, act=act, norm=norm, ks=ks)
        self.downsample = fc.noop if ni==nf*exp else conv3d(ni, nf*exp, ks=1, stride=stride, norm=norm, act=None, bias=True)
        #self.pool = fc.noop #if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
        self.act = act()

    def forward(self, x): return self.act(self.convs(x) + self.downsample(x)) #self.idconv(self.pool(x)))

In [None]:
basic = ResBlock(64, 64, ks=3, block_type="basic")
bottleneck = ResBlock(64, 64, ks=3, block_type="bottleneck")

In [None]:
basic

ResBlock(
  (convs): Sequential(
    (0): Sequential(
      (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
    )
    (1): Sequential(
      (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (act): GeneralRelu: leak:0.1-sub:0.4-maxv:None
)

In [None]:
m10.layer1

Sequential(
  (0): ResNetBlock(
    (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [None]:
bottleneck

ResBlock(
  (convs): Sequential(
    (0): Sequential(
      (0): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
    )
    (1): Sequential(
      (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
    )
    (2): Sequential(
      (0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (downsample): Sequential(
    (0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (act): GeneralRelu: leak:0.1-sub:0.4-maxv:None
)

In [None]:
%time bottleneck(baseks3_out).shape

CPU times: user 14.5 s, sys: 7.71 s, total: 22.2 s
Wall time: 2.28 s


torch.Size([1, 256, 96, 96, 96])

In [None]:
getattr(m50.layer1, "0")

ResNetBottleneck(
  (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
  (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

## ResStage
- [bottleneck] and [basic] blocks are  multiplied several times in each stage. In Res18, stage1 we have `basic` block used 2 times. In Res50, stage1 we have `Bottleneck` block used 3 times.
- In each stage, there are two things happening wrt to `Basic` or `Bottleneck`


## `Basic`: 
either we reduce the size by 2 using `downsample` or not. for resnet10-34 this happens in 2nd layer onwards. 
#### Res10
- layer1: basic [ks=3, stride=1, 1] [ni=64, nf=64]
- layer2: basic [ks=3, stride=2, 1] [ni=64, nf=128] + downsample 
- layer3: basic [ks=3, stride=2, 1] [ni=128, nf=256] + downsample 
- layer4: basic [ks=3, stride=2, 1] [ni=256, nf=512] + downsample 


#### Res50
first block in every stage has downsample layer. 
- layer1: 3x: [64-64-256], ks=[1, 3, 1] stride = [1, 1, 1], downsampling stride = 1
- layer2: 4x: [256-128-512], ks=[1, 3, 1] stride = [1, 2, 1], ds=2
- layer3: 6x: [512-256-1024], ks=[1, 3, 1] stride = [1, 2, 1], ds=2
- layer4: 3x: [1024-512-2048], ks=[1, 3, 1] stride = [1, 2, 1], ds=2

In [None]:
#| export 
class ResStage(nn.Module):
    def __init__(self, ni: int, ip: int, nf: int, layers: int, stride=1, ks=3, act=act_gr, norm=nn.BatchNorm3d, block_type="basic"):
        super().__init__()
        fc.store_attr()
        self.block0 = ResBlock(ni, ip, stride, ks, act, norm, block_type=block_type)
        for i in range(1, layers): 
            setattr(self, f"block{i}", ResBlock(nf, ip, 1, ks, act, norm, block_type=block_type))
    
    def forward(self, x): 
        for i in range(self.layers): x = getattr(self, f"block{i}")(x)
        return x 

In [None]:
layer1 = ResStage(64, 64, 64, layers=1, stride=1, ks=3)
layer2 = ResStage(64, 128, 128, layers=1, stride=2, ks=3)

In [None]:
count_params(layer1), count_params(layer2), count_params(m10.layer1), count_params(m10.layer2)

(221440, 672640, 221440, 672640)

In [None]:
layer2

ResStage(
  (block0): ResBlock(
    (convs): Sequential(
      (0): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
      )
      (1): Sequential(
        (0): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (downsample): Sequential(
      (0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(2, 2, 2))
      (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (act): GeneralRelu: leak:0.1-sub:0.4-maxv:None
  )
)

In [None]:
m10.layer2

Sequential(
  (0): ResNetBlock(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (downsample): Sequential(
      (0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(2, 2, 2))
      (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)

In [None]:
layer1 = ResStage(64, 64, 256, layers=3, stride=1, ks=3, block_type="bottleneck")
layer2 = ResStage(256, 128, 512, layers=4, stride=2, ks=3, block_type="bottleneck")

In [None]:
count_params(layer1), count_params(layer2), count_params(m50.layer1), count_params(m50.layer2)

(437248, 2399744, 437248, 2399744)

In [None]:
%time layer1(base_out).shape

CPU times: user 38.3 s, sys: 18.1 s, total: 56.4 s
Wall time: 5.73 s


torch.Size([1, 256, 96, 96, 96])

In [None]:
%time m50.layer1(base_out).shape

CPU times: user 33.5 s, sys: 14.2 s, total: 47.7 s
Wall time: 4.62 s


torch.Size([1, 256, 96, 96, 96])

In [None]:
m50.layer1

Sequential(
  (0): ResNetBottleneck(
    (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (bn3): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): ResNetBottleneck(
    (conv1): Conv3d(256, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

## ResNet 

- After every stage we have a reduction in size by 2. but in monai the size remains the same in first stage and reduces by half from stage2 onwards. 
- base consists fo conv3d layer. we apply max_pool as optional after this. 
- stage consists of `[layers]`
- if it is a basic block ic and ip are same for each layer. For Bottleneck block ic=ip for 1st stage and ic=ip/2 from next stage onwards. nf=ip x 4 for all the stages. 

In [None]:
#| export 
def kaiming_init_weights(m):
    if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d)): nn.init.kaiming_normal_(m.weight)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

In [None]:
#| export 
class ResNet(nn.Module):
    def __init__(self, \
                 ic: int, \
                 ip: List[int], \
                 layers: List[int], \
                 c1_ks = [7, 7, 7], \
                 c1_stride = [2, 2, 1], \
                 base_pool=False, \
                 norm=nn.BatchNorm3d, \
                 act=act_gr, \
                 block_type="basic", \
                 init_type="kaiming", \
                 dilated_conv_last_layer=False):
        """ic is input channels, c1_ks and c1_stride are conv1 kernel size and stride respectively\
        layers: How 
        block_planes: 
        """
        super().__init__()
        fc.store_attr()
        self.base = conv3d(self.ic, ip[0], ks=c1_ks, stride=c1_stride, act=act_gr, norm=nn.BatchNorm3d)
        if self.base_pool: self.pool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        _num = 0.5 if block_type=="basic" else 2
        #ks = 3 if block_type=="basic" else 1
        for n, (layer, _ip) in enumerate(zip(self.layers, self.ip)):
            if block_type =="basic": ni, nf = _ip if n==0 else int(_ip/2), _ip if n==0 else _ip*2
            else: ni, nf = _ip if n==0 else _ip*2, _ip*4 
#             ni = ip[n] if n==0 else int(ip[n]*_num)
#             nf = int(ni*_num*2) if n==0 else int(ni*_num)
            stride = 1 if n==0 else 2
            setattr(self, f"layer{n+1}", ResStage(ni, _ip, nf, layers=layer, stride=stride, \
                                                  ks=3, norm=norm, act=act, block_type=block_type))
        
        if self.dilated_conv_last_layer: self.dil_conv = conv3d(ip[-1], ip[-1], ks=3, stride=c1_stride, act=act_gr, norm=nn.BatchNorm3d, padding=2, dilation=2)
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        if init_type == "kaiming": self.apply(kaiming_init_weights)
        else: raise NotImplementedError("Only kaiming implmented")
    
    def forward(self, x):
        out = self.base(x)
        if self.base_pool: out = self.pool(out)
        for i in range(4): out = getattr(self, f"layer{i+1}")(out)
        if self.dilated_conv_last_layer: out = self.dil_conv(out)
        out = self.avg_pool(out)
        return out        

In [None]:
import torch
import torch.nn as nn

input_tensor = torch.randn(1, 2, 96, 192, 192)

# Define the Conv3d layer
in_channels = 2
out_channels = 16

In [None]:
conv3d_layer = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=1)
dil_conv3d_layer = nn.Conv3d(in_channels, in_channels, kernel_size=3, stride=(1, 2, 2), padding=2, dilation=2)

In [None]:
output_tensor = conv3d_layer(input_tensor)
dil_output_tensor = dil_conv3d_layer(input_tensor)

In [None]:
dil_output_tensor.shape

torch.Size([1, 2, 96, 96, 96])

In [None]:
k10 = ResNet(ic=1, ip=[64, 128, 256, 512], layers=[1, 1, 1, 1], block_type="basic")
k50 = ResNet(ic=1, ip=[64, 128, 256, 512], layers=[3, 4, 6, 3], block_type="bottleneck")

In [None]:
#k10.apply(lambda m: print(type(m).__name__));

In [None]:
count_params(k50), count_params(m50), count_params(k10), count_params(m10) #extra for FC layer

(46158912, 46978512, 14356416, 14561616)

In [None]:
for r in [1, 2, 3, 4]:
    print(count_params(getattr(k50, f"layer{r}")), count_params(getattr(m50, f"layer{r}")))
    print(count_params(getattr(k10, f"layer{r}")), count_params(getattr(m10, f"layer{r}")))

437248 437248
221440 221440
2399744 2399744
672640 672640
14177280 14177280
2688768 2688768
29122560 29122560
10751488 10751488


In [None]:
for r in [1, 2, 3, 4]:
    fc.equals(count_params(getattr(k50, f"layer{r}")), count_params(getattr(m50, f"layer{r}")))
    fc.equals(count_params(getattr(k10, f"layer{r}")), count_params(getattr(m10, f"layer{r}")))

### ResNet10

In [None]:
#| export 
def resnet10(ic, c1_ks, c1_stride, base_pool=True, norm=nn.BatchNorm3d, act=act_gr, dilated_conv_last_layer=False):
    return ResNet(ic=ic, ip=[64, 128, 256, 512], layers=[1, 1, 1, 1], block_type="basic", \
                 c1_ks=c1_ks, c1_stride=c1_stride, norm=norm, act=act, base_pool=base_pool, dilated_conv_last_layer=dilated_conv_last_layer)

In [None]:
r10 = resnet10(1, [7, 7, 7], [1, 2, 2], dilated_conv_last_layer=True)
r10.dil_conv

Sequential(
  (0): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(2, 2, 2), dilation=(2, 2, 2), bias=False)
  (1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
)

In [None]:
x = r10(torch.zeros((1, 1, 96, 192, 192)))
[i.shape for i in x]

[torch.Size([512, 1, 1, 1])]

In [None]:
r10 = resnet10(1, [7, 7, 7], [1, 2, 2])
m10 = monai_res.resnet10(**backbone)
for r in [1, 2, 3, 4]: fc.equals(count_params(getattr(r10, f"layer{r}")), count_params(getattr(m10, f"layer{r}")))

### ResNet18

In [None]:
#| export 
def resnet18(ic, c1_ks, c1_stride, base_pool=True, norm=nn.BatchNorm3d, act=act_gr):
    return ResNet(ic=ic, ip=[64, 128, 256, 512], layers=[2, 2, 2, 2], block_type="basic", \
                 c1_ks=c1_ks, c1_stride=c1_stride, norm=norm, act=act, base_pool=base_pool)

In [None]:
r18 = resnet18(1, [7, 7, 7], [1, 2, 2])
m18 = monai_res.resnet18(**backbone)
for r in [1, 2, 3, 4]: fc.equals(count_params(getattr(r18, f"layer{r}")), count_params(getattr(m18, f"layer{r}")))

### ResNet50

In [None]:
#| export 
def resnet50(ic, c1_ks, c1_stride, base_pool=True, norm=nn.BatchNorm3d, act=act_gr):
    return ResNet(ic=ic, ip=[64, 128, 256, 512], layers=[3, 4, 6, 3], block_type="bottleneck", \
                 c1_ks=c1_ks, c1_stride=c1_stride, norm=norm, act=act, base_pool=base_pool)

In [None]:
r50 = resnet50(1, [7, 7, 7], [1, 2, 2])
m50 = monai_res.resnet50(**backbone)
for r in [1, 2, 3, 4]: fc.equals(count_params(getattr(r50, f"layer{r}")), count_params(getattr(m50, f"layer{r}")))

In [None]:
# #| export 
# def resnet34(ic, c1_ks, c1_stride, max_pool=True, norm=nn.BatchNorm3d, act=act_gr):
#     return ResNet(ic=ic, ip=[64, 128, 256, 512], layers=[3, 4, 6, 3], block_type="basic", \
#                  c1_ks=c1_ks, c1_stride=c1_stride, norm=norm, act=act, max_pool=max_pool)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()