In [None]:
#| default_exp networks/fpn

In [1]:
#| export 
import torch.nn as nn
import torchvision
import fastcore.all as fc
from typing import Dict, List
from monai.networks.blocks.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

In [2]:
import torch
from voxdet.networks.res_se_net import resnet10, conv3d

In [3]:
backbone = resnet10(1, (7, 7, 7), (1, 2, 2), base_pool=False)

In [4]:
returned_layers = [1, 2, 3, 4]
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
return_layers

{'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}

In [5]:
in_channels_list = [(512//8) * 2 ** (i - 1) for i in returned_layers]
in_channels_list

[64, 128, 256, 512]

In [6]:
body = torchvision.models._utils.IntermediateLayerGetter(backbone, return_layers=return_layers)

In [7]:
body

IntermediateLayerGetter(
  (base): Sequential(
    (0): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
  )
  (layer1): ResStage(
    (block0): ResBlock(
      (convs): Sequential(
        (0): Sequential(
          (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): GeneralRelu: leak:0.1-sub:0.4-maxv:None
        )
        (1): Sequential(
          (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (act): GeneralRelu: leak:0.1-sub:0.4-maxv:None
    )
  )
  (layer2): ResStage(
    (block0): ResBlock(
      (

In [8]:
img = torch.ones((1, 1, 96, 192, 192))
img.shape

torch.Size([1, 1, 96, 192, 192])

In [9]:
%%time
outs = body(img)
[(k, v.shape) for k, v in outs.items()]

CPU times: user 53.3 s, sys: 29 s, total: 1min 22s
Wall time: 1.54 s


[('0', torch.Size([1, 64, 96, 96, 96])),
 ('1', torch.Size([1, 128, 48, 48, 48])),
 ('2', torch.Size([1, 256, 24, 24, 24])),
 ('3', torch.Size([1, 512, 12, 12, 12]))]

### FeaturePyramidNetwork

- One pyramid level for each stage - C2, C3, C4, C5. In this case we are taking C1 and C2 . this has stride of (1, 2, 2) and (2, 4, 4). 
- Take C2 and upsample by a factor of 2 (using nearest neighbor upsampling for simplicity)
- take C1 -  1x1 conv to reduce the channel dimension
- Add 1 and 2 element wise 
- 3x3 on each merged map to generate the final feature map 
- we get P2 

similarly we do for other layers too 

- set d=256 
- All levels of the pyramid use shared classifiers/regressors. 

In [10]:
#| export 
class BackbonewithFPN3D(nn.Module):
    def __init__(self, backbone, return_layers: Dict[str, str], in_channels_list: List[int],\
                 out_channels: int, extra_blocks: bool=False):
        super().__init__()
        fc.store_attr(names=["return_layers", "in_channels_list", "out_channels"])
        self.body = torchvision.models._utils.IntermediateLayerGetter(backbone, return_layers=return_layers)
        self.fpn = FeaturePyramidNetwork( 
            spatial_dims=3, 
            in_channels_list=in_channels_list,
            out_channels=out_channels,
            extra_blocks=LastLevelMaxPool(3) if extra_blocks else None,
        )
        
    def forward(self, x):
        x = self.body(x)
        y = self.fpn(x)
        return y

In [13]:
#| export 
def resnet_fpn3d_feature_extractor(backbone, out_channels=256, returned_layers=[1, 2, 3], extra_blocks:bool=False):
    in_channels_stage2 = backbone.ip[-1] // 8
    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
    return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
    return BackbonewithFPN3D(backbone, return_layers, in_channels_list, out_channels, extra_blocks)

In [14]:
network = resnet_fpn3d_feature_extractor(backbone, extra_blocks=True)

In [15]:
%%time
outs = network(img)

CPU times: user 1min 46s, sys: 1min 34s, total: 3min 20s
Wall time: 4.19 s


In [16]:
[(k, v.shape) for k, v in outs.items()]

[('0', torch.Size([1, 256, 96, 96, 96])),
 ('1', torch.Size([1, 256, 48, 48, 48])),
 ('2', torch.Size([1, 256, 24, 24, 24])),
 ('pool', torch.Size([1, 256, 12, 12, 12]))]

In [17]:
def count_params(layer):
    n=0
    for name, params in layer.named_parameters(): n+=params.numel()
    return n

In [18]:
count_params(network)

9029568

## Comparing it with monai implementations

In [19]:
model_cfg = dict(
  spatial_dims = 3,
  pretrained_backbone = False,
  trainable_backbone_layers = None, 
  returned_layers = [1, 2, 3],
)

In [20]:
from monai.apps.detection.networks.retinanet_network import resnet_fpn_feature_extractor as rffe
backbone.in_planes = 512
feature_extractor = rffe(backbone=backbone, **model_cfg)

In [21]:
count_params(feature_extractor)

9029568

In [22]:
%%time
outs = feature_extractor(img)

CPU times: user 2min 7s, sys: 1min 17s, total: 3min 24s
Wall time: 4.48 s


In [23]:
[(k, v.shape) for k, v in outs.items()]

[('0', torch.Size([1, 256, 96, 96, 96])),
 ('1', torch.Size([1, 256, 48, 48, 48])),
 ('2', torch.Size([1, 256, 24, 24, 24])),
 ('pool', torch.Size([1, 256, 12, 12, 12]))]

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()