In [None]:
import torch
import torch.nn as nn
import torchvision as tv
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.deform_conv import DeformConv2d

## Prerequisite of FPN

In [None]:
# define a backbone model
model = tv.models.resnet18()

In [None]:
# defining the layers to be used
layers = ["layer1", "layer2", "layer3", "layer4"]
dict_modules = {layer: str(idx) for idx, layer in enumerate(layers)}

In [None]:
# getting the features from the mentioned layers by intermediate layer getter
feature_extractor= IntermediateLayerGetter(model, dict_modules)

In [None]:
# defining a dummy input
x = torch.zeros((1, 3, 224, 224))

In [None]:
# outputs from the abovementioned intermediate layers
out = feature_extractor(x)

In [None]:
out

OrderedDict([('0',
              tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        ...,
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.]],
              
                       [[0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        ...,
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.]],
              
                       [[0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        [0., 0., 0.,  ..., 0., 0., 0.],
                        ...,
                      

## Defining the Feature Pyramid Network

In [None]:
out.keys()

odict_keys(['0', '1', '2', '3'])

In [None]:
in_channels = [o.shape[1] for _, o in out.items()] # getting the number of channels for the input to FPN
out_channels = 256

out_channs = out_channels // len(in_channels)

In [None]:
# defining the bottom up network
bottom_up_branches = nn.ModuleList(
    [
        nn.Sequential(
            DeformConv2d(channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        for _, channels in enumerate(in_channels)
    ]
)

In [None]:
# upsample branch
upsample_branch = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

In [None]:
# defining the top down network
top_down_branches = nn.ModuleList(
    [
        nn.Sequential(
            DeformConv2d(out_channels, out_channs, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channs),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2 ** idx, mode="bilinear", align_corners=True)

        )
        for idx, _ in enumerate(in_channels)
    ]
)

## Generating the features

In [None]:
# defining a dummy tensor in case the image is not given
dummy_in = torch.randn((1, 3, 224, 224))

# defining the wrapper for generating the features
def generate_features(image: torch.Tensor=dummy_in):

    features = feature_extractor(image)
    x: list[torch.Tensor] = [features[str(idx)] for idx in range(len(features))]

    assert len(x) == len(top_down_branches)

    _x: list[torch.Tensor] = [branch(t) for branch, t in zip(bottom_up_branches, x)]
    out: list[torch.Tensor] = [_x[-1]]

    for t in _x[:-1][::-1]: # reversing the list barring the last element
        out.append(upsample_branch(out[-1]) + t)

    out_ = [branch(t) for branch, t in zip(top_down_branches, out[::-1])]

    return torch.cat(out_, dim=1)

In [None]:
generate_features(dummy_in)