In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_
import os
import numpy as np
import torchvision.transforms as transforms
import torch.optim as optim
from PIL import Image
from torch.autograd import Variable
import ipdb

In [35]:
class BottleNet(nn.Module):
    def __init__(self, inplane, plane, stride=1, downsample=None):
        exposion = 4
        
        super(BottleNet, self).__init__()
        self.conv1 = nn.Conv2d(inplane, plane, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(plane)
        self.conv2 = nn.Conv2d(plane, plane, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(plane)
        self.conv3 = nn.Conv2d(plane, exposion*plane, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(exposion*plane)
        self.relu = nn.ReLU(inplace=True)
        
        self.downsample = downsample
        
    def forward(self, x):
        resduial = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        
        
        if self.downsample:
            resduial = self.downsample(resduial)
        
        output = x + resduial
        output = self.relu(output)
        return output

class ResNet(nn.Module):
    def __init__(self, res=(64,128,256,512), blocks=(3,4,6,3), out_channel=256):
        super(ResNet, self).__init__()
        self.out_channel = out_channel
        self.inplane = res[0]
        self.pre_conv = nn.Sequential(nn.Conv2d(3, res[0], kernel_size=7, stride=2, padding=3, bias=False),
                                   nn.BatchNorm2d(res[0]),
                                   nn.ReLU(inplace=True),
                                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        self.block1 = self._make_layer(res[0], blocks[0])
        self.block2 = self._make_layer(res[1], blocks[1],2)
        self.block3 = self._make_layer(res[2], blocks[2],2)
        self.block4 = self._make_layer(res[3], blocks[3],2)
        
        self.top = nn.Conv2d(2048, out_channel, kernel_size=1)
        self.fpn3 = nn.Conv2d(1024, out_channel, kernel_size=1)
        self.fpn2 = nn.Conv2d(512, out_channel, kernel_size=1)
        self.fpn1 = nn.Conv2d(256, out_channel, kernel_size=1)
        self.smooth = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1)
        
        self.output_layers = []
        
    
        
    def forward(self,x):
        x = self.pre_conv(x)
        x = self.block1(x)
        f1 = self.fpn1(x)
        
        x = self.block2(x)
        f2 = self.fpn2(x)
        
        x = self.block3(x)
        f3 = self.fpn3(x)
        
        x = self.block4(x)
        x = self.top(x)
        
        o3 = self.fpn_cont(x, f3)
        o2 = self.fpn_cont(o3, f2)
        o1 = self.fpn_cont(o2, f1)
        
        x = self.smooth(x)
        o3 = self.smooth(o3)
        o2 = self.smooth(o2)
        o1 = self.smooth(o1)
        self.output_layers += [x, o3, o2, o1]
        
        
        return self.output_layers
        
    
    def fpn_cont(self, up, down):
        up = F.interpolate(up, scale_factor=2, mode='bilinear')
        return up + down
    
    def _make_layer(self, plane, blocks, stride=1):
        exposion = 4
        layers = []
        if stride != 1 or self.inplane != plane*exposion:
            downsample = nn.Sequential(nn.Conv2d(self.inplane, plane*exposion, kernel_size=1, stride=stride, bias=False),
                                       nn.BatchNorm2d(plane*exposion),
                                      nn.ReLU(inplace=True))
        layers.append(BottleNet(self.inplane, plane, stride, downsample))
        self.inplane = plane*exposion
        for i in range(blocks):
            layers.append(BottleNet(self.inplane, plane))
        return nn.Sequential(*layers)

In [37]:
res = ResNet()
x = torch.rand(3,224,224).unsqueeze(0)
output = res(x)
print(len(output))

4


In [38]:
for i in output:
    print(i.shape)

torch.Size([1, 256, 7, 7])
torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 28, 28])
torch.Size([1, 256, 56, 56])
