In [1]:
####

In [2]:
import torch
from torch import nn

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels, 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True ,
                 use_pool = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool
        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)
        if self.use_pool:
            self.maxpool = nn.MaxPool2d(kernel_size=2 , stride=2)

    def forward(self ,x ):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        if self.use_pool:
            x = self.maxpool(x)
        return x

In [5]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
conv = Conv(3 , 32 , use_pool=True).to(device)
z = conv(x)
z.shape

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


torch.Size([2, 32, 256, 256])

In [6]:
class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.convT1 = nn.ConvTranspose2d(in_channels , 
                                         out_channels , 
                                         kernel_size , 
                                         stride , 
                                         padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [7]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
convT = ConvT(3 , 32).to(device)
z = convT(x)
z.shape

torch.Size([2, 32, 1024, 1024])

In [8]:
class Linear(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 use_norm = True , 
                 use_activation = True):
        super(Linear , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.linear1 = nn.Linear(in_channels , 
                                 out_channels)

        if self.use_norm:
            self.norm = nn.BatchNorm1d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)
            
    def forward(self , x):
        x = self.linear1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [9]:
x = torch.randn(2 , 512).to(device)
linear = Linear(512 , 256).to(device)
z = linear(x)
z.shape

torch.Size([2, 256])

In [11]:
class M_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels_cls = 20 , 
                 out_channels_bbox = 4):
        super(M_Block , self).__init__()

        self.conv1 = Conv(in_channels , 
                          in_channels)
        self.conv_cls = Conv(in_channels , 
                             out_channels_cls ,
                             kernel_size = (1 , 1) ,
                             stride = (1 ,1) , 
                             padding = 0)
        self.conv_bbox = Conv(in_channels , 
                              out_channels_bbox , 
                              kernel_size = (1 ,1) , 
                              stride = (1 , 1) , 
                              padding = 0)
        
    def forward(self , x):
        x = self.conv1(x)
        x_cls = self.conv_cls(x)
        x_bbox = self.conv_bbox(x)
        return x_cls , x_bbox

In [12]:
x = torch.randn(2 , 16 , 128 , 128).to(device)
m = M_Block(16).to(device)
x_cls , z_bbox = m(x)
x_cls.shape , z_bbox.shape

In [13]:
config = [
          # out_channels , kernel_size , stride , padding , use_norm , use_activation , use_pool
          # S for save
          [16 , (3 , 3) , (1 , 1) , 1 , True , True , True] , # 16 x 256 x 256
          [32 , (3 , 3) , (1 , 1) , 1 , True , True , True] , # 32 x 128 x 128
          [64 , (3 , 3) , (1 , 1) , 1 , True , True , True] , # 64 x 64 x 64
          [128 , (3 , 3) , (1 , 1) , 1 , True , True , True] ,# 128 x 32 x 32
          [256 , (3 , 3) , (1 , 1) , 1 , True , True , True] ,# 256 x 16 x 16 

          [128 , (1 , 1) , (1 , 1) , 0 , True , True , False] ,# 128 x 16 x 16 
          128 , 
          'U' ,  # 128 x 32 x 32  
          [64 , (1 , 1) , (1 , 1) , 0 , True , True , False] ,  # 64 x 32 x 32
          64 , 
          'U' , # 64 x 64 x 64
          [32 , (1 , 1) , (1 , 1) , 0 , True , True , False] ,  # 32 x 64 x 64
          32 , 
          'U' , # 32 x 128 x 128
          [16 , (1 , 1) , (1 , 1) , 0 , True , True , False] ,  # 16 x 128 x 128
          16 , 
          'U' # 16 x 256 x 256
]

In [24]:
class FPN(nn.Module):
    def __init__(self , 
                 in_channels , 
                 config = config):
        super(FPN , self).__init__()

        self.layers = nn.ModuleList()

        for layer in config:
            if isinstance(layer , list):
                out_channels , kernel_size , stride , padding , use_norm , use_activation , use_pool = layer
                self.layers.append(Conv(
                    in_channels , 
                    out_channels , 
                    kernel_size , 
                    stride , 
                    padding , 
                    use_norm , 
                    use_activation , 
                    use_pool
                ))
                in_channels = out_channels
            elif isinstance(layer , str):
                if layer == 'U':
                    self.layers.append(nn.Upsample(scale_factor=2))
            elif isinstance(layer , int):
                in_channels_out = layer
                self.layers.append(M_Block(in_channels_out))

    def forward(self , x):
        for_concat = []
        output = []
        for i , layer in enumerate(self.layers):
            if isinstance(layer , Conv) or isinstance(layer , nn.Upsample):
                x = layer(x)
                if i != 0 and isinstance(layer , Conv):
                    for_concat.append(x)
            else:
                if isinstance(layer , M_Block):
                    output_ = self.layers[i](x)
                    output.append(output_)
                    last_ = for_concat.pop()
                    x = x + last_
        return output

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
fpn = FPN(3).to(device)
z = fpn(x)
len(z)