In [1]:

import torch
import torch.nn as nn
from torch import Tensor
from collections import OrderedDict
from typing import Dict, List, Optional, Callable
import timm

import torchvision
from torchvision.models._utils import IntermediateLayerGetter
from torchvision import models
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock, LastLevelP6P7
from torchvision.ops.misc import FrozenBatchNorm2d

from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models.detection.retinanet import RetinaNet


In [2]:
# backbone = models.__dict__['resnet18']
backbone = torchvision.models.resnet18(pretrained=False)

In [3]:
returned_layers = [2,3,4]
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
# #in_channels_stage2 = backbone.inplanes // 8
# in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
in_channels_list = [128, 256, 512]
out_channels = 256
extra_blocks = LastLevelP6P7(256,256)
body = IntermediateLayerGetter(backbone, return_layers=return_layers)
fpn = FeaturePyramidNetwork(
    in_channels_list=in_channels_list,
    out_channels=out_channels,
    extra_blocks=extra_blocks,
)


In [4]:

out = body(torch.randn(2,3,512,512))
out.keys()

odict_keys(['0', '1', '2'])

In [5]:
type(out)

collections.OrderedDict

In [6]:
out['0'].shape

torch.Size([2, 128, 64, 64])

In [7]:
out['1'].shape

torch.Size([2, 256, 32, 32])

In [8]:
out['2'].shape

torch.Size([2, 512, 16, 16])

In [9]:
fpn_out = fpn(out)

In [10]:
fpn_out.keys()

odict_keys(['0', '1', '2', 'p6', 'p7'])

In [11]:
fpn_out['0'].shape

torch.Size([2, 256, 64, 64])

In [12]:
fpn_out['1'].shape

torch.Size([2, 256, 32, 32])

In [13]:
fpn_out['2'].shape

torch.Size([2, 256, 16, 16])

In [14]:
fpn_out['p6'].shape

torch.Size([2, 256, 8, 8])

In [15]:
fpn_out['p7'].shape

torch.Size([2, 256, 4, 4])

In [16]:
# backbone = torchvision.models.resnet18(pretrained=False)
# backbone = timm.create_model('resnet18', features_only=True)

In [17]:
returned_layers = [2,3,4]
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
# #in_channels_stage2 = backbone.inplanes // 8
# in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
in_channels_list = [128, 256, 512]
out_channels = 256
extra_blocks = LastLevelP6P7(256,256)
body = timm.create_model('resnet18', features_only=True)
fpn = FeaturePyramidNetwork(
    in_channels_list=in_channels_list,
    out_channels=out_channels,
    extra_blocks=extra_blocks,
)


In [18]:

o = body(torch.randn(2,3,512,512))
out = OrderedDict()
out['0'] = o[-3]
out['1'] = o[-2]
out['2'] = o[-1]

In [19]:
out.keys()

odict_keys(['0', '1', '2'])

In [20]:
fpn_out = fpn(out)

In [21]:
fpn_out.keys()

odict_keys(['0', '1', '2', 'p6', 'p7'])

In [22]:
fpn_out['0'].shape

torch.Size([2, 256, 64, 64])

In [23]:
fpn_out['2'].shape

torch.Size([2, 256, 16, 16])

In [24]:
fpn_out['p6'].shape

torch.Size([2, 256, 8, 8])

In [25]:
def resnet_fpn_backbone(
    backbone_name: str,
    pretrained: bool,
    norm_layer: Callable[..., nn.Module] =  FrozenBatchNorm2d, 
    trainable_layers: int = 3,
    returned_layers: Optional[List[int]] = None,
    extra_blocks: Optional[ExtraFPNBlock] = None,
) -> BackboneWithFPN:

    # select layers that wont be frozen
    backbone = torchvision.models.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
    assert 0 <= trainable_layers <= 5
    layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
    if trainable_layers == 5:
        layers_to_train.append("bn1")
    for name, parameter in backbone.named_parameters():
        if all([not name.startswith(layer) for layer in layers_to_train]):
            parameter.requires_grad_(False)

    if extra_blocks is None:
        extra_blocks = LastLevelMaxPool()

    if returned_layers is None:
        returned_layers = [1, 2, 3, 4]
    assert min(returned_layers) > 0 and max(returned_layers) < 5
    return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}

    in_channels_stage2 = backbone.inplanes // 8
    in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
    out_channels = 256
    return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)



In [26]:
backbone = resnet_fpn_backbone('resnet18', pretrained=True, trainable_layers=3)

In [27]:
model = RetinaNet(backbone,
                  num_classes=1,
                  anchor_generator=None)

In [28]:
model.eval()
model(torch.randn(2,3,500,500))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>),
  'scores': tensor([], grad_fn=<IndexBackward0>),
  'labels': tensor([], dtype=torch.int64)},
 {'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>),
  'scores': tensor([], grad_fn=<IndexBackward0>),
  'labels': tensor([], dtype=torch.int64)}]

In [39]:
class BackboneWithFPN(nn.Module):
    def __init__(
        self,
        backbone: str,
        in_channels_list: List[int] = [128,256,512],
        out_channels: int = 256,
        extra_blocks: Optional[ExtraFPNBlock] = None,
    ) -> None:
        super(BackboneWithFPN, self).__init__()

        if extra_blocks is None:
            extra_blocks = LastLevelMaxPool()
        self.in_channels_list = in_channels_list
        self.out_channels = out_channels
        self.extra_block = LastLevelP6P7(out_channels,out_channels)

        self.body = timm.create_model(backbone, features_only=True)
        self.fpn = FeaturePyramidNetwork(
            in_channels_list=self.in_channels_list,
            out_channels=self.out_channels,
            extra_blocks=self.extra_block,
        )
        self.out_channels = out_channels

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        out = OrderedDict()
        x = self.body(x)
        out['0'] = x[-3]
        out['1'] = x[-2]
        out['2'] = x[-1]        
        out = self.fpn(out)
        print(out.keys())
        return out

In [40]:
backbone = BackboneWithFPN('resnet18')

In [41]:
model = RetinaNet(backbone,
                  num_classes=1,
                  anchor_generator=None)

model.eval()
model(torch.randn(2,3,500,500))

odict_keys(['0', '1', '2', 'p6', 'p7'])


[{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>),
  'scores': tensor([], grad_fn=<IndexBackward0>),
  'labels': tensor([], dtype=torch.int64)},
 {'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward0>),
  'scores': tensor([], grad_fn=<IndexBackward0>),
  'labels': tensor([], dtype=torch.int64)}]