<a href="https://colab.research.google.com/github/apophis30/Feature-Pyramid-NN/blob/main/FeaturePyramidNetwork.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import os
from typing import Union, Tuple, List, NamedTuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import nn, Tensor
import torchvision.models

class ResNet50:
    class ConvLayers(NamedTuple):
        conv1: nn.Module
        conv2: nn.Module
        conv3: nn.Module
        conv4: nn.Module
        conv5: nn.Module

    class LateralLayers(NamedTuple):
        lateral_c2: nn.Module
        lateral_c3: nn.Module
        lateral_c4: nn.Module
        lateral_c5: nn.Module

    class DealiasingLayers(NamedTuple):
        dealiasing_p2: nn.Module
        dealiasing_p3: nn.Module
        dealiasing_p4: nn.Module

    def __init__(self, pretrained: bool):
        self._pretrained = pretrained

    def features(self) -> Tuple[ConvLayers, LateralLayers, DealiasingLayers, int]:
        resnet50 = torchvision.models.resnet50(pretrained=self._pretrained)
        children = list(resnet50.children())

        conv1 = nn.Sequential(*children[:4])
        conv2 = children[4]
        conv3 = children[5]
        conv4 = children[6]
        conv5 = children[7]

        num_features_out = 256

        lateral_c2 = nn.Conv2d(in_channels=256, out_channels=num_features_out, kernel_size=1)
        lateral_c3 = nn.Conv2d(in_channels=512, out_channels=num_features_out, kernel_size=1)
        lateral_c4 = nn.Conv2d(in_channels=1024, out_channels=num_features_out, kernel_size=1)
        lateral_c5 = nn.Conv2d(in_channels=2048, out_channels=num_features_out, kernel_size=1)

        dealiasing_p2 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1)
        dealiasing_p3 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1)
        dealiasing_p4 = nn.Conv2d(in_channels=num_features_out, out_channels=num_features_out, kernel_size=3, padding=1)

        for module in [conv1, conv2]:
            for parameter in module.parameters():
                parameter.requires_grad = False

        conv_layers = self.ConvLayers(conv1, conv2, conv3, conv4, conv5)
        lateral_layers = self.LateralLayers(lateral_c2, lateral_c3, lateral_c4, lateral_c5)
        dealiasing_layers = self.DealiasingLayers(dealiasing_p2, dealiasing_p3, dealiasing_p4)

        return conv_layers, lateral_layers, dealiasing_layers, num_features_out

In [13]:
class Model(nn.Module):

    class ForwardInput(object):
        class Train(NamedTuple):
            image: Tensor


        class Eval(NamedTuple):
            image: Tensor

    class ForwardOutput(object):
        class Train(NamedTuple):
            output: Tensor

        class Eval(NamedTuple):
            output: Tensor

    def __init__(self, backbone, num_classes: int):
        super().__init__()

        resnet = ResNet50(pretrained=True)
        conv_layers, lateral_layers, dealiasing_layers, num_features_out = resnet.features()
        self.conv1, self.conv2, self.conv3, self.conv4, self.conv5 = conv_layers
        self.lateral_c2, self.lateral_c3, self.lateral_c4, self.lateral_c5 = lateral_layers
        self.dealiasing_p2, self.dealiasing_p3, self.dealiasing_p4 = dealiasing_layers

        self._bn_modules = [it for it in self.conv1.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.conv2.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.conv3.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.conv4.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.conv5.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.lateral_c2.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.lateral_c3.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.lateral_c4.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.lateral_c5.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.dealiasing_p2.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.dealiasing_p3.modules() if isinstance(it, nn.BatchNorm2d)] + \
                           [it for it in self.dealiasing_p4.modules() if isinstance(it, nn.BatchNorm2d)]

        self.num_classes = num_classes

    def forward(self, forward_input: Union[ForwardInput.Train, ForwardInput.Eval]) -> Union[ForwardOutput.Train, ForwardOutput.Eval]:
        # freeze batch normalization modules for each forwarding process just in case model was switched to `train` at any time
        for bn_module in self._bn_modules:
            bn_module.eval()
            for parameter in bn_module.parameters():
                parameter.requires_grad = False

        image = forward_input.image.unsqueeze(dim=0)
        image_height, image_width = image.shape[2], image.shape[3]

        # Bottom-up pathway
        c1 = self.conv1(image)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        c4 = self.conv4(c3)
        c5 = self.conv5(c4)

        # Top-down pathway and lateral connections
        p5 = self.lateral_c5(c5)
        p4 = self.lateral_c4(c4) + F.interpolate(input=p5, size=(c4.shape[2], c4.shape[3]), mode='nearest')
        p3 = self.lateral_c3(c3) + F.interpolate(input=p4, size=(c3.shape[2], c3.shape[3]), mode='nearest')
        p2 = self.lateral_c2(c2) + F.interpolate(input=p3, size=(c2.shape[2], c2.shape[3]), mode='nearest')

        # Reduce the aliasing effect
        p4 = self.dealiasing_p4(p4)
        p3 = self.dealiasing_p3(p3)
        p2 = self.dealiasing_p2(p2)

        p6 = F.max_pool2d(input=p5, kernel_size=1, stride=2)

        return p4, p3, p2, p6


In [14]:
import torch
from torchvision.transforms import ToTensor
from PIL import Image

# Create an instance of the Model
model = Model(backbone=ResNet50(pretrained=True), num_classes=10)

# Load an image for testing
image = Image.open('/content/c98f79f71.png')
image = ToTensor()(image)  # Convert the image to a tensor

# Create a ForwardInput instance for testing
forward_input = Model.ForwardInput.Eval(image=image)

# Set the model to evaluation mode
model.eval()

# Perform the forward pass
output = model(forward_input)

# Print the shapes of the intermediate feature maps
p4, p3, p2, p6 = output
print("p4 shape:", p4.shape)
print("p3 shape:", p3.shape)
print("p2 shape:", p2.shape)
print("p6 shape:", p6.shape)


p4 shape: torch.Size([1, 256, 32, 32])
p3 shape: torch.Size([1, 256, 64, 64])
p2 shape: torch.Size([1, 256, 128, 128])
p6 shape: torch.Size([1, 256, 8, 8])
