In [1]:
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision.models import ResNet50_Weights

In [2]:
class BackBone(nn.Module):
    def __init__(self): # TODO later add config parameter to choose differente architectures
        super().__init__()
        self.model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
    
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        # we keep intermediates outputs for the FPN to do his job
        x1 = self.model.layer1(x)
        x2 = self.model.layer2(x1)
        x3 = self.model.layer3(x2)
        x4 = self.model.layer4(x3)

        return x1, x2, x3, x4 

##### We can see that the output shape is (1000, 1) due to the last fully connected layer. we dont want that as we just want to use the extracting feature power of resnet so lets modify our forward function

In [3]:
from torchsummary import summary
m = BackBone()

# for name, module in m.named_modules():
#     print(f'|{name}|')
    # print(module)
# input_size = (3, 256, 256)
# summary(m, input_size)
t = torch.rand(1, 3, 256, 256)
print(t.shape)
with torch.no_grad():
    x1, x2, x3, x4 = m(t)

print('x1.shape:', x1.shape, 'x2.shape:', x2.shape, 'x3.shape:', x3.shape, 'x4.shape:', x4.shape)

torch.Size([1, 3, 256, 256])
x1.shape: torch.Size([1, 256, 64, 64]) x2.shape: torch.Size([1, 512, 32, 32]) x3.shape: torch.Size([1, 1024, 16, 16]) x4.shape: torch.Size([1, 2048, 8, 8])


##### Good, we can see our output shape is now 8x8 x2048 channels, this is what we want as the furthers models are convolutional too

In [4]:
# tests

lst = [(2048 // (2 ** c), 512) for c in range(4)]

print(lst)

[(2048, 512), (1024, 512), (512, 512), (256, 512)]


In [5]:
Norm = nn.BatchNorm2d


class Conv1x1(nn.Module):
    def __init__(self, num_in, num_out):
        super().__init__()
        self.conv = nn.Conv2d(num_in, num_out, kernel_size=1, bias=False)
        self.norm = Norm(num_out)
        self.active = nn.ReLU(True)
        self.block = nn.Sequential(self.conv, self.norm, self.active)

    def forward(self, x):
        return self.block(x)


class Conv3x3(nn.Module):
    def __init__(self, num_in, num_out):
        super().__init__()
        self.conv = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1,
                              bias=False)
        self.norm = Norm(num_out)
        self.active = nn.ReLU(True)
        self.block = nn.Sequential(self.conv, self.norm, self.active)

    def forward(self, x):
        return self.block(x)


class FPN(nn.Module):
    def __init__(self, inplanes = 2048, outplanes = 256):
        super(FPN, self).__init__()

        self.laterals = nn.Sequential(*[Conv1x1(inplanes // (2 ** c), outplanes) for c in range(4)])
        self.smooths = nn.Sequential(*[Conv3x3(outplanes * c, outplanes * c) for c in range(1, 5)])
        self.pooling = nn.MaxPool2d(2)

        self.out_channels = outplanes * 4 # because our top-down pathway is composed of 4 layers

    def forward(self, features):
        laterals = [lateral(features[f]) for f, lateral in enumerate(self.laterals)]

        map4 = laterals[0]

        map3 = laterals[1] + nn.functional.interpolate(map4, scale_factor=2, mode="nearest")
        map2 = laterals[2] + nn.functional.interpolate(map3, scale_factor=2, mode="nearest")
        map1 = laterals[3] + nn.functional.interpolate(map2, scale_factor=2, mode="nearest")

        map1 = self.smooths[0](map1)
        map2 = self.smooths[1](torch.cat([map2, self.pooling(map1)], dim=1))
        map3 = self.smooths[2](torch.cat([map3, self.pooling(map2)], dim=1))
        map4 = self.smooths[3](torch.cat([map4, self.pooling(map3)], dim=1))

        return map4

In [6]:
fpn_test = FPN(2048)

out = fpn_test([x1, x2, x3, x4][::-1])

print('out shape after FPN forward', out.shape)

out shape after FPN forward torch.Size([1, 1024, 8, 8])


In [7]:


class Decoder(nn.Module):
    def __init__(self, inplanes, bn_momentum=0.1):
        super(Decoder, self).__init__()
        self.bn_momentum = bn_momentum
        # backbone output: [b, 2048, _h, _w]
        self.inplanes = inplanes
        self.deconv_with_bias = False
        self.deconv_layers = self._make_deconv_layer(
            num_layers=5,
            num_filters=[256, 256, 256, 256, 256],
            num_kernels=[4, 4, 4, 4, 4],
        )

    def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
        layers = []
        for i in range(num_layers):
            kernel = num_kernels[i]
            padding = 0 if kernel == 2 else 1
            output_padding = 1 if kernel == 3 else 0
            planes = num_filters[i]
            layers.append(
                nn.ConvTranspose2d(
                    in_channels=self.inplanes,
                    out_channels=planes,
                    kernel_size=kernel,
                    stride=2,
                    padding=padding,
                    output_padding=output_padding,
                    bias=self.deconv_with_bias))
            layers.append(nn.BatchNorm2d(planes, momentum=self.bn_momentum))
            layers.append(nn.ReLU(inplace=True))
            self.inplanes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.deconv_layers(x)

In [8]:
decoder_test = Decoder(inplanes=out.shape[1])

out = decoder_test(out)

print('out shape after decoder forward', out.shape)

out shape after decoder forward torch.Size([1, 256, 256, 256])


In [9]:
class Heads(nn.Module):
    def __init__(self, nclasses=53, in_channels = 256):
        super(Heads, self).__init__()

        self.nclasses = nclasses

        self.heat_maps = nn.Sequential(
                                            nn.Conv2d(in_channels, out_channels = 64, kernel_size = 3, stride=2, padding=1, bias=True),
                                            nn.ReLU(inplace=True),
                                            nn.Conv2d(64, self.nclasses, kernel_size = 1, stride=2, padding=0, bias=True)
                                        )
        
        self.offset_maps = nn.Sequential(
                                            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1),
                                            nn.ReLU(inplace=True),
                                            nn.Conv2d(in_channels, 2, kernel_size=1, stride=2, padding=0),
                                        )
        
        self.size_maps = nn.Sequential(
                                            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=1),
                                            nn.ReLU(inplace=True),
                                            nn.Conv2d(in_channels, 2, kernel_size=1, stride=2, padding=0),
                                        )

    def forward(self, x):
        heat = self.heat_maps(x)
        offset = self.offset_maps(x)
        size = self.size_maps(x)

        return heat, offset, size
    

In [10]:
heads_test = Heads()

heat, offset, size = heads_test(out)

print('predicted heatmaps shape:', heat.shape)
print('predicted offset shape:', offset.shape)
print('predicted size shape:', size.shape)

predicted heatmaps shape: torch.Size([1, 53, 64, 64])
predicted offset shape: torch.Size([1, 2, 64, 64])
predicted size shape: torch.Size([1, 2, 64, 64])
