In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Dict, List, Tuple, Optional, Union
from torch import Tensor

In [14]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return

In [15]:
class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            # 在官方的实现中，其实是3x3的kernel并不是5x5，这里我也懒得改了，具体可以参考下面的issue
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)

In [20]:
class ClassificationHead(nn.Module):
    """
    A classification head for use in RetinaNet.
    Args:
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
        num_classes (int): number of classes to be predicted
    """

    def __init__(self, in_channels, num_anchors, num_classes, prior_probability=0.01):
        super(ClassificationHead, self).__init__()

        # class subnet是由四个3x3的卷积层(激活函数为ReLU) + 一个3x3的卷积层(分类器)
        conv = []
        for _ in range(4):
            conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
            conv.append(nn.ReLU(inplace=True))
        self.conv = nn.Sequential(*conv)

        self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)

        self.num_classes = num_classes
        self.num_anchors = num_anchors

    def forward(self, x: Tensor) -> Tensor:
        all_cls_logits = []

        # 遍历每个预测特征层
        for features in x:
            cls_logits = self.conv(features)
            cls_logits = self.cls_logits(cls_logits)

            # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
            N, _, H, W = cls_logits.shape
            cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
            # [N, A, K, H, W] -> [N, H, W, A, K]
            cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
            # [N, H, W, A, K] -> [N, HWA, K]
            cls_logits = cls_logits.reshape(N, -1, self.num_classes)

            all_cls_logits.append(cls_logits)

        return torch.cat(all_cls_logits, dim=1)

In [23]:

class RegressionHead(nn.Module):
    """
    A regression head for use in RetinaNet.
    Args:
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
    """

    def __init__(self, in_channels, num_anchors):
        super(RegressionHead, self).__init__()

        # box subnet是由四个3x3的卷积层(激活函数为ReLU) + 一个3x3的卷积层(边界框回归器)
        conv = []
        for _ in range(4):
            conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1))
            conv.append(nn.ReLU(inplace=True))
        self.conv = nn.Sequential(*conv)

        self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
        self.iou_aw = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)

    def forward(self, x: List[Tensor]) -> Tensor:
        all_bbox_regression = []
        all_iou_aware = []

        # 遍历每个预测特征层
        for features in x:
            bbox_regression = self.conv(features)
            bbox_regression = self.bbox_reg(bbox_regression)
            iou_aware = self.iou_aw(bbox_regression)

            # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
            N, _, H, W = bbox_regression.shape
            # [N, 4 * A, H, W] -> [N, A, 4, H, W]
            bbox_regression = bbox_regression.view(N, -1, 4, H, W)
            iou_aware = iou_aware.view(N, -1, 4, H, W)
            # [N, A, 4, H, W] -> [N, H, W, A, 4]
            bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
            iou_aware = iou_aware.permute(0, 3, 4, 1, 2)
            # [N, H, W, A, 4] -> [N, HWA, 4]
            bbox_regression = bbox_regression.reshape(N, -1, 4)
            iou_aware = iou_aware.reshape(N, -1, 4)

            all_bbox_regression.append(bbox_regression)
            all_iou_aware.append(iou_aware)

        return torch.cat(all_bbox_regression, dim=1), torch.cat(all_iou_aware, dim=1)

In [25]:
class RetinaNetHead(nn.Module):
    """
    A regression and classification head for use in RetinaNet.
    Args:
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
        num_classes (int): number of classes to be predicted
    """

    def __init__(self, in_channels, num_anchors, num_classes):
        super(RetinaNetHead, self).__init__()
        self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
        self.regression_head = RegressionHead(in_channels, num_anchors)

    def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
        return {
            "cls_logits": self.classification_head(x),
            "bbox_regression": self.regression_head(x)
        }