# Library


In [1]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import models

# RetinaNet definition


### FPN Backbone


In [None]:
def resnes_backbone():
    resnet = models.resnet50(pretrained=True)
    modules = list(resnet.children())[:-2]
    backbone = nn.Sequential(*modules)
    return backbone


class FPN(nn.Modules):
    def __init__(self, backbone):
        self.backbone = backbone

        # Lateral layers
        self.latLayer1 = nn.Conv2D(2048, 256, kernel_size=1)
        self.latLayer2 = nn.Conv2D(1024, 256, kernel_size=1)
        self.latLayer3 = nn.Conv2D(512, 256, kernel_size=1)

        # Final layers
        self.conv1 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2D(256, 256, kernel_size=3, padding=1)

    def forward(self, x):
        c2, c3, c4, c5 = self.backbone(x)

        # Top-down pathway
        p5 = self.latLayer1(c5)
        p4 = self.latLayer2(c4) + nn.Upsample(p5, scale_factor=2)
        p3 = self.latLayer3(c3) + nn.Upsample(p4, scale_factor=2)

        # Final convolutions
        p5 = self.conv1(p5)
        p4 = self.conv2(p4)
        p3 = self.conv3(p3)

        return p3, p4, p5

### Classification and Box Regression Subnetwork


In [None]:
class ClassificationSubnet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()
        self.conv5 = nn.Conv2D(256, num_classes, kernel_size=3, padding=1)
        self.act2 = nn.Sigmoid()

    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)
        out = self.conv2(out)
        out = self.act1(out)
        out = self.conv3(out)
        out = self.act1(out)
        out = self.conv4(out)
        out = self.act1(out)
        out = self.conv5(out)
        out = self.act2(out)
        return out


class RegressionSubnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2D(256, 256, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()
        self.conv5 = nn.Conv2D(256, 4, kernel_size=3, padding=1)
        self.act2 = nn.Sigmoid()

    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)
        out = self.conv2(out)
        out = self.act1(out)
        out = self.conv3(out)
        out = self.act1(out)
        out = self.conv4(out)
        out = self.act1(out)
        out = self.conv5(out)
        out = self.act2(out)
        return out

### RetinaNet


In [None]:
class RetinaNet(nn.Module):
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        self.fpn = FPN(backbone)
        self.classification_subnet = ClassificationSubnet(num_classes)
        self.regression_subnet = RegressionSubnet()

    def forward(self, x):
        p3, p4, p5 = self.fpn(x)
        cls_p3 = self.classification_subnet(p3)
        cls_p4 = self.classification_subnet(p4)
        cls_p5 = self.classification_subnet(p5)

        reg_p3 = self.regression_subnet(p3)
        reg_p4 = self.regression_subnet(p4)
        reg_p5 = self.regression_subnet(p5)

        return [cls_p3, cls_p4, cls_p5], [reg_p3, reg_p4, reg_p5]

### Focal Loss


In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        probs = torch.sigmoid(inputs)

        for i in range(len(probs)):
            if targets[i] != 1:
                probs[i] = 1 - probs[i]

        loss = -self.alpha * (1 - probs) ** self.gamma * torch.log(probs)

        return loss

# Training
