In [2]:
from model import *
from fastai.vision import *
from fastai.vision.models.unet import _get_sfs_idxs, model_sizes, hook_outputs


In [3]:
resnet = ResNet("resnet101", stage5=True)

In [4]:
torch.cuda.is_available()

True

In [5]:
C1, C2, C3, C4, C5 = resnet.stages()

In [6]:
fpn = FPN(C1, C2, C3, C4, C5, out_channels=256)

In [7]:
x = torch.randn(5,3,224,224)

In [8]:
res = fpn(x)



In [9]:
[print(output.shape) for output in res]

torch.Size([5, 256, 56, 56])
torch.Size([5, 256, 28, 28])
torch.Size([5, 256, 14, 14])
torch.Size([5, 256, 7, 7])
torch.Size([5, 256, 4, 4])


[None, None, None, None, None]

In [10]:
class FPN2(nn.Module):
    def __init__(self, C1, C2, C3, C4, C5, out_channels):
        super(FPN2, self).__init__()
        self.out_channels = out_channels
        self.C1 = C1
        self.C2 = C2
        self.C3 = C3
        self.C4 = C4
        self.C5 = C5
        self.P6 = nn.MaxPool2d(kernel_size=1, stride=2)
        self.P5_conv1 = nn.Conv2d(2048, self.out_channels, kernel_size=1, stride=1)
        self.P5_conv2 = nn.Sequential(
            SamePad2d(kernel_size=3, stride=1),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, stride=1),
        )
        self.P4_conv1 =  nn.Conv2d(1024, self.out_channels, kernel_size=1, stride=1)
        self.P4_conv2 = nn.Sequential(
            SamePad2d(kernel_size=3, stride=1),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, stride=1),
        )
        self.P3_conv1 = nn.Conv2d(512, self.out_channels, kernel_size=1, stride=1)
        self.P3_conv2 = nn.Sequential(
            SamePad2d(kernel_size=3, stride=1),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, stride=1),
        )
        self.P2_conv1 = nn.Conv2d(256, self.out_channels, kernel_size=1, stride=1)
        self.P2_conv2 = nn.Sequential(
            SamePad2d(kernel_size=3, stride=1),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, stride=1),
        )

    def forward(self, x):
        x = self.C1(x)
        x = self.C2(x)
        c2_out = x
        x = self.C3(x)
        c3_out = x
        x = self.C4(x)
        c4_out = x
        x = self.C5(x)
        p5_out = self.P5_conv1(x)
        p4_out = self.P4_conv1(c4_out) + F.upsample(p5_out, scale_factor=2)
        p3_out = self.P3_conv1(c3_out) + F.upsample(p4_out, scale_factor=2)
        p2_out = self.P2_conv1(c2_out) + F.upsample(p3_out, scale_factor=2)

        p5_out = self.P5_conv2(p5_out)
        p4_out = self.P4_conv2(p4_out)
        p3_out = self.P3_conv2(p3_out)
        p2_out = self.P2_conv2(p2_out)

        # P6 is used for the 5th anchor scale in RPN. Generated by
        # subsampling from P5 with stride of 2.
        p6_out = self.P6(p5_out)

        return [p2_out, p3_out, p4_out, p5_out, p6_out]

In [11]:
fpn2 = FPN2(C1, C2, C3, C4, C5, out_channels=256)

In [12]:
res2 = fpn2(x)
[print(output.shape) for output in res2]



torch.Size([5, 256, 56, 56])
torch.Size([5, 256, 28, 28])
torch.Size([5, 256, 14, 14])
torch.Size([5, 256, 7, 7])
torch.Size([5, 256, 4, 4])


[None, None, None, None, None]

In [15]:
class LateralUpsampleMerge(nn.Module):
    "Merge the features coming from the downsample path (in `hook`) with the upsample path."
    def __init__(self, ch, ch_lat, hook):
        super().__init__()
        self.hook = hook
        self.Px_conv1 = conv2d(ch_lat, ch, ks=1, bias=True)
#         self.Px_conv2 = conv2d(ch, cs, ks=3, padding=0, bias=True)
            
    def forward(self, x):
        #Run a 1x1conv on the features from the downsampling path, upsample the output from P(x-1)
        res = self.Px_conv1(self.hook.stored) + nn.functional.interpolate(x, scale_factor=2)
        return res

In [16]:
class FPN3(nn.Module):
    """
    Creates upsampling path by hooking activations from the downsampling path. Tested on ResNet50.
    
    INPUT:
        encoder: Default ResNet50
        chs: Number of intermediate channels to use between convolutions
    
    RETURNS:
        [p2,p3,p4,p5,p6]: [Tensor]
    """
    def __init__(self, encoder:nn.Module, chs:int):
    
        super().__init__()
        
        #This runs dummy data through the encoder to get the right channel numbers after each layer C1 through C5
        imsize = (256,256)
        sfs_szs = model_sizes(encoder, size=imsize)
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        
        #Attaching hooks to the relevant layers C2 to C5 so we can get their activations during the
        #upsampling path
        self.encoder = encoder
        
        #The link between C5 and P5
        #TODO: will a stride 2 conv be better?
        self.c5_p5 = nn.Sequential(
            conv2d(sfs_szs[-1][1], chs, ks=1, bias=True),
        )
        
        self.p5_conv2 = conv2d(chs,chs,ks=3,bias=True)
        
        #Link between P5 and P6
        self.p5_p6 = nn.MaxPool2d(kernel_size=1, stride=2)
        
        #These are the idxs of C4, C3, and C2 respectively
        idx  = list(reversed(sfs_idxs[-2:-5:-1]))
        self.sfs = hook_outputs([encoder[i] for i in idx])
        
        #This handles the mapping from P5 -> P4 -> P3 -> P2
        self.merges = nn.ModuleList([LateralUpsampleMerge(chs, sfs_szs[idx][1], hook) 
                                     for idx,hook in zip(idx, self.sfs)])
        
        #One final conv to smoothen things out after the merge
        self.final_convs = [conv2d(chs, chs, ks=3, stride=1, bias=True) for _ in idx+[1]]
           
    def forward(self, x):
        c5 = self.encoder(x)
        p_states = [self.c5_p5(c5.clone())]
        #Mapping P5 through P2 one by one
        for merge in self.merges: p_states = [merge(p_states[0])] + p_states
         
        #Extra convs after the lateral upsampling
        for i, conv in enumerate(self.final_convs):
            p_states[i] = conv(p_states[i])
            
        #Doing P6 at the end
        p6 = self.p5_p6(p_states[-1])
        p_states += [p6]
        return p_states
    
    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()

In [17]:
fpn3 = FPN3(create_body(models.resnet50),256)


In [20]:
fpn3.encoder

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1

In [19]:
res3 = fpn3(x)
[print(output.shape) for output in res3]

torch.Size([5, 256, 56, 56])
torch.Size([5, 256, 28, 28])
torch.Size([5, 256, 14, 14])
torch.Size([5, 256, 7, 7])
torch.Size([5, 256, 4, 4])


[None, None, None, None, None]