In [2]:
import torch
from torch import nn
import torch.nn.functional as F

In [16]:
# ResNet
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_planes, planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(True),
            nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(True),
            nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False),
            nn.BatchNorm2d(self.expansion * planes)
        )
        self.relu = nn.ReLU(True)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        out = self.bottleneck(x)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

In [33]:
class FPN(nn.Module):
        def __init__(self, layers):
            super(FPN, self).__init__()
            self.inplanes = 64
            # c1
            self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(True)
            self.maxpooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            # 至上而下 c2, c3, c4, c5
            self.layer1 = self._make_layer(64, layers[0])
            self.layer2 = self._make_layer(128, layers[1], 2)
            self.layer3 = self._make_layer(256, layers[2], 2)
            self.layer4 = self._make_layer(512, layers[3], 2)
            # p5
            self.toplayer = nn.Conv2d(2048, 256, 1, 1, 0)
            # 横向连接, 保证通道数一样
            self.latlayer1 = nn.Conv2d(1024, 256, 1, 1, 0)
            self.latlayer2 = nn.Conv2d(512, 256, 1, 1, 0)
            self.latlayer3 = nn.Conv2d(256, 256, 1, 1, 0)
            # 3x3 卷积融合特征
            self.smooth = nn.Conv2d(256, 256, 3, 1, 1)
            
        def _make_layer(self, planes, blocks, strides=1):
            dowmsample = None
            if strides != 1 or self.inplanes != Bottleneck.expansion * planes:
                dowmsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, Bottleneck.expansion * planes, 1, strides, bias=False),
                    nn.BatchNorm2d(Bottleneck.expansion * planes)
                )
            layers = []
            layers.append(Bottleneck(self.inplanes, planes, strides, dowmsample))
            self.inplanes = Bottleneck.expansion * planes
            for i in range(1, blocks):
                layers.append(Bottleneck(self.inplanes, planes))
            return nn.Sequential(*layers)
        
        def _upsample_add(self, x, y):
            _, _, H, W = y.shape
            return F.upsample(x, size=(H, W), mode='bilinear') + y
        
        def forward(self, x):
            # 自下而上
            c1 = self.maxpooling(self.relu(self.bn1(self.conv1(x))))
            c2 = self.layer1(c1)
            c3 = self.layer2(c2)
            c4 = self.layer3(c3)
            c5 = self.layer4(c4)
            # 至上而下
            p5 = self.toplayer(c5)
            p4 = self._upsample_add(p5, self.latlayer1(c4))
            p3 = self._upsample_add(p4, self.latlayer2(c3))
            p2 = self._upsample_add(p3, self.latlayer3(c2))
            # 卷积融合, 平滑处理
            p4 = self.smooth(p4)
            p3 = self.smooth(p3)
            p2 = self.smooth(p2)
            return p2, p3, p4, p5

In [34]:
fpn = FPN([3, 4, 6, 3])

In [35]:
fpn.conv1

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [36]:
fpn.layer1

Sequential(
  (0): Bottleneck(
    (bottleneck): Sequential(
      (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
    (downsample): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (bottleneck): Sequential(
      (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
   

In [37]:
fpn.toplayer

Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))

In [38]:
data = torch.randn(1, 3, 224, 224)

In [40]:
out = fpn(data)

In [43]:
# p2
out[0].shape

torch.Size([1, 256, 56, 56])

In [44]:
# p3
out[1].shape

torch.Size([1, 256, 28, 28])

In [45]:
# p4
out[2].shape

torch.Size([1, 256, 14, 14])

In [46]:
# p5
out[3].shape

torch.Size([1, 256, 7, 7])