In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.tv_tensors import Image
from torchvision.ops import DeformConv2d

from typing import Tuple

In [None]:
model = torchvision.models.resnet50()

In [None]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
img = torch.randn((1, 3, 640, 640))

In [None]:
x = img
for i, c in enumerate(model.children()):
    if isinstance(c, torch.nn.Linear):
        x = c(x.flatten())
    else:
        x = c(x)
    print(i, x.shape)

0 torch.Size([1, 64, 320, 320])
1 torch.Size([1, 64, 320, 320])
2 torch.Size([1, 64, 320, 320])
3 torch.Size([1, 64, 160, 160])
4 torch.Size([1, 256, 160, 160])
5 torch.Size([1, 512, 80, 80])
6 torch.Size([1, 1024, 40, 40])
7 torch.Size([1, 2048, 20, 20])
8 torch.Size([1, 2048, 1, 1])
9 torch.Size([1000])


In [None]:
backbone = nn.Sequential(*list(model.children())[:-2])

In [None]:
c2 = backbone[:5](img)
c2.shape

torch.Size([1, 256, 160, 160])

In [None]:
c3 = backbone[5](c2)
c3.shape

torch.Size([1, 512, 80, 80])

In [None]:
c4 = backbone[6](c3)
c4.shape

torch.Size([1, 1024, 40, 40])

In [None]:
c5 = backbone[7](c4)
c5.shape

torch.Size([1, 2048, 20, 20])

In [None]:
extra = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=3, stride=2, padding=1)
p6 = extra(c5)
p6.shape

torch.Size([1, 256, 10, 10])

In [None]:
# Convert all to 256 channels through lateral connections

In [None]:
lat_c2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1)
lat_c2(c2).shape

torch.Size([1, 256, 160, 160])

In [None]:
shapes = [256, 512, 1024, 2048]

In [None]:
for i in range(2, 6):
    exec(f"lat_conn{i} = nn.Conv2d(in_channels=shapes[i-2], out_channels=256, kernel_size=1, stride=1)")

In [None]:
m5 = lat_conn5(c5)
m5.shape

torch.Size([1, 256, 20, 20])

In [None]:
m4 = lat_conn4(c4)+F.interpolate(m5, scale_factor=2)
m4.shape

torch.Size([1, 256, 40, 40])

In [None]:
m3 = lat_conn3(c3)+F.interpolate(m4, scale_factor=2)
m3.shape

torch.Size([1, 256, 80, 80])

In [None]:
m2 = lat_conn2(c2)+F.interpolate(m3, scale_factor=2)
m2.shape

torch.Size([1, 256, 160, 160])

In [None]:
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)(m2).shape

torch.Size([1, 256, 160, 160])

In [None]:
class FPN(nn.Module):
    def __init__(self):
        super().__init__()

        resnet50 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(resnet50.children())[:-2])

        # Lateral Connections
        self.lat_conn2 = nn.Conv2d(256, 256, 1, stride=1)
        self.lat_conn3 = nn.Conv2d(512, 256, 1, stride=1)
        self.lat_conn4 = nn.Conv2d(1024, 256, 1, stride=1)
        self.lat_conn5 = nn.Conv2d(2048, 256, 1, stride=1)

        # Smoothing Conv layers
        self.p2_conv = nn.Conv2d(256, 256, 3, padding=1)
        self.p3_conv = nn.Conv2d(256, 256, 3, padding=1)
        self.p4_conv = nn.Conv2d(256, 256, 3, padding=1)
        self.p5_conv = nn.Conv2d(256, 256, 3, padding=1)

        self.p6_conv = nn.Conv2d(2048, 256, 3, stride=2, padding=1)
        nn.init.xavier_uniform_(self.p6_conv.weight)
        if self.p6_conv.bias is not None:
            nn.init.zeros_(self.p6_conv.bias)

    def forward(self, x: Image) -> Tuple[torch.Tensor]:
        """
        Forward pass to compute the FPN feature maps.

        Args:
            x (Image): Input image tensor of shape (B, C, H, W).

        Returns:
            Tuple[Tensor]: Multi-scale feature maps (P2, P3, P4, P5, P6).
        """
        # Bottom-Up pathway
        c2 = self.backbone[:5](x) # Output from layer 1
        c3 = self.backbone[5](c2) # Output from layer 2
        c4 = self.backbone[6](c3) # Output from layer 3
        c5 = self.backbone[7](c4) # Output from layer 4

        # Top-Down pathway
        m5 = self.lat_conn5(c5)
        m4 = self.lat_conn4(c4)+F.interpolate(m5, scale_factor=2, mode="nearest")
        m3 = self.lat_conn3(c3)+F.interpolate(m4, scale_factor=2, mode="nearest")
        m2 = self.lat_conn2(c2)+F.interpolate(m3, scale_factor=2, mode="nearest")

        # Smoothing
        p2 = self.p2_conv(m2)
        p3 = self.p3_conv(m3)
        p4 = self.p4_conv(m4)
        p5 = self.p5_conv(m5)
        p6 = self.p6_conv(c5)

        return p2, p3, p4, p5, p6

In [None]:
fpn = FPN()

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 85.0MB/s]


In [None]:
maps = fpn(img)

In [None]:
for m in maps:
    print(m.shape)

torch.Size([1, 256, 160, 160])
torch.Size([1, 256, 80, 80])
torch.Size([1, 256, 40, 40])
torch.Size([1, 256, 20, 20])
torch.Size([1, 256, 10, 10])


In [None]:
torch.tensor([1, 2]) == torch.Tensor([1, 2])

tensor([True, True])

In [None]:
from torchvision.transforms import v2
type(v2.ToImage()(img))

torchvision.tv_tensors._image.Image

#Anchor Boxes

In [None]:
for m in maps:
    print(m.shape)

torch.Size([1, 256, 160, 160])
torch.Size([1, 256, 80, 80])
torch.Size([1, 256, 40, 40])
torch.Size([1, 256, 20, 20])
torch.Size([1, 256, 10, 10])


In [None]:
step = 2**(1/3)
step

1.2599210498948732

In [None]:
strides = torch.tensor([4, 8, 16, 32, 64])

In [None]:
size1 = strides*4
size1

tensor([ 16,  32,  64, 128, 256])

In [None]:
size2 = size1 * step
size2

tensor([ 20.1587,  40.3175,  80.6349, 161.2699, 322.5398])

In [None]:
size3 = size2 * step
size3

tensor([ 25.3984,  50.7968, 101.5937, 203.1873, 406.3747])

In [None]:
anchor_sizes = torch.stack([size1, size2, size3]).T/640
anchor_sizes

tensor([[0.0250, 0.0315, 0.0397],
        [0.0500, 0.0630, 0.0794],
        [0.1000, 0.1260, 0.1587],
        [0.2000, 0.2520, 0.3175],
        [0.4000, 0.5040, 0.6350]])

In [None]:
map_sizes = torch.tensor([m.shape[-1] for m in maps])
map_sizes

tensor([160,  80,  40,  20,  10])

In [None]:
n_dbox = 3

In [None]:
landmark_template = torch.tensor([
    [0.3, 0.4],  # Left eye
    [0.7, 0.4],  # Right eye
    [0.5, 0.55], # Nose
    [0.35, 0.75], # Left mouth corner
    [0.65, 0.75]  # Right mouth corner
])

In [None]:
(landmark_template * anchor_sizes[-1][1])[:, 0]

tensor([0.1512, 0.3528, 0.2520, 0.1764, 0.3276])

In [None]:
def create_anchors(map_sizes, anchor_sizes, if_landmarks=False):
    anchors_per_map = []
    landmarks_per_map = []
    for size, anchor_size in zip(map_sizes, anchor_sizes):
        anchors = []
        landmarks = []
        for i in range(size):
            for j in range(size):
                center_x = (i+0.5)/size
                center_y = (j+0.5)/size
                for sz in anchor_size:
                    anchors.append([center_x, center_y, sz, sz])

                    if if_landmarks:
                        scaled_landmarks = landmark_template * sz
                        scaled_landmarks[:, 0] += center_x
                        scaled_landmarks[:, 1] += center_y
                        landmarks.append(scaled_landmarks.tolist())

        anchors_per_map.append(torch.tensor(anchors, dtype=torch.float32))
        if if_landmarks:
            landmarks_per_map.append(torch.tensor(landmarks, dtype=torch.float32))

    if if_landmarks:
        return (anchors_per_map, landmarks_per_map)
    else:
        return anchors_per_map

In [None]:
anchors_per_map, landmarks_per_map = create_anchors(map_sizes, anchor_sizes, if_landmarks=True)

In [None]:
for i in range(5):
    print(anchors_per_map[i].shape)

torch.Size([76800, 4])
torch.Size([19200, 4])
torch.Size([4800, 4])
torch.Size([1200, 4])
torch.Size([300, 4])


In [None]:
for i in range(5):
    print(landmarks_per_map[i].shape)

torch.Size([76800, 5, 2])
torch.Size([19200, 5, 2])
torch.Size([4800, 5, 2])
torch.Size([1200, 5, 2])
torch.Size([300, 5, 2])


In [None]:
torch.tensor(map_sizes)**2*3

tensor([76800, 19200,  4800,  1200,   300])

In [None]:
dboxes_per_map[3][16]

tensor([0.0250, 0.2750, 0.2520, 0.2520])

In [None]:
all_anchors = torch.cat(anchors_per_map)
all_anchors.shape

torch.Size([102300, 4])

In [None]:
all_landmarks = torch.stack([landmark_template]*all_anchors.shape[0])
all_landmarks.shape

torch.Size([102300, 5, 2])

In [None]:
all_landmarks = all_anchors[:, 2:].unsqueeze(1) * all_landmarks

In [None]:
all_landmarks = all_anchors[:, :2].unsqueeze(1) + all_landmarks

In [None]:
def create_anchors(map_sizes, anchor_sizes, if_landmarks=False):
    anchors = []
    for size, anchor_size in zip(map_sizes, anchor_sizes):
        for i in range(size):
            for j in range(size):
                center_x = (i+0.5)/size
                center_y = (j+0.5)/size
                for sz in anchor_size:
                    anchors.append([center_x, center_y, sz, sz])

    anchors = torch.tensor(anchors)

    if not if_landmarks:
        return anchors

    n_anchors = anchors.shape[0]
    landmarks = landmark_template.expand(n_anchors, -1, -1)
    landmarks = anchors[:, 2:].unsqueeze(1) * landmarks
    landmarks = anchors[:, :2].unsqueeze(1) + landmarks

    return anchors, landmarks.reshape(n_anchors, 10)


In [None]:
all_anchors, all_landmarks = create_anchors(map_sizes, anchor_sizes, if_landmarks=True)

In [None]:
all_anchors.shape, all_landmarks.shape

(torch.Size([102300, 4]), torch.Size([102300, 10]))

#Take 2

In [None]:
# Make DCN class
# Replace 3x3 Conv2d in FPN with DCN
# Make Context Module class

In [None]:
# 1. DCN

In [None]:
for m in maps:
    print(m.shape)

torch.Size([1, 256, 160, 160])
torch.Size([1, 256, 80, 80])
torch.Size([1, 256, 40, 40])
torch.Size([1, 256, 20, 20])
torch.Size([1, 256, 10, 10])


In [None]:
offsets_conv = nn.Conv2d(256, 18, 3, padding=1)
offsets_conv(maps[0]).shape

torch.Size([1, 18, 160, 160])

In [None]:
dcn = DeformConv2d(256, 256, 3, padding=1)
dcn(maps[0], offsets_conv(maps[0])).shape

torch.Size([1, 256, 160, 160])

In [None]:
class DCNBlock(nn.Module):
    def __init__(in_channels, out_channels, kernel_size, padding):
        super().__init__()
        o = kernel_size
        self.offsets_conv = nn.Conv2d(in_channels, kernel_size, kernel_size, padding=padding)
        self.dcn = DeformConv2d(in_channels, out_channels, kernel_size, padding=padding)