In [1]:
####

In [2]:
import torch
from torch import nn

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def Reverse(lst):
    return [ele for ele in reversed(lst)]

In [5]:

class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels ,
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_pool = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool

        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)
        if self.use_pool:
            self.max_pool = nn.MaxPool2d(kernel_size=2 , stride=2)
        
    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_pool:
            x = self.max_pool(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [6]:

class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.convT = nn.ConvTranspose2d(in_channels , 
                                        out_channels , 
                                        kernel_size , 
                                        stride ,
                                        padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [7]:
class Resnet_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 downsample = False):
        super(Resnet_Block , self).__init__()

        self.downsample = downsample

        if self.downsample:
            self.conv1 = Conv(in_channels , 
                        in_channels , 
                        kernel_size=(2 , 2) , 
                        stride=(2 , 2) ,
                        padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                            out_channels ,
                            kernel_size = (2 ,2) , 
                            stride = (2 , 2) , 
                            padding = 0)
        else:    
            self.conv1 = Conv(in_channels , 
                            in_channels , 
                            kernel_size=(1 , 1) , 
                            stride=(1 , 1) ,
                            padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                              out_channels ,
                              kernel_size = (1 , 1) , 
                              stride = (1 ,1) , 
                              padding = 0)
            
        self.conv2 = Conv(in_channels , 
                          in_channels)
        
        self.conv3 = Conv(in_channels , 
                          out_channels , 
                          kernel_size = (1 , 1) , 
                          stride = (1 , 1) , 
                          padding = 0)
        

        
    def forward(self , x): 
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x_ = self.conv_skip(x_)
        x += x_
        return x

In [8]:

class Resnet(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(Resnet , self).__init__()

        self.conv1 = Conv(in_channels , 128 , kernel_size=(7 , 7) , stride=(2 , 2) , padding=3)

        self.conv2 = self._make_repeated_blocks(128 , 256 , 3 , downsample = False)
        self.conv3 = self._make_repeated_blocks(256 , 512 , 8)
        self.conv4 = self._make_repeated_blocks(512 , 1024 , 36)
        self.conv5 = self._make_repeated_blocks(1024 , 2048 , 3)
        #self.linear = Linear(2048 , out_channels)

    def _make_repeated_blocks(self , in_channels , out_channels , repeats , downsample = True):
        layers = []
        for i in range(repeats):
            if i == 0 and downsample == True:
                layers.append(Resnet_Block(in_channels , out_channels , downsample=downsample))
            elif i == 0:
                layers.append(Resnet_Block(in_channels , out_channels))
            else:
                layers.append(Resnet_Block(out_channels , out_channels))
        return nn.Sequential(*layers)

    def forward(self , x):
        x_0 = self.conv1(x)
        x_0_ = torch.max_pool2d(x_0 , kernel_size = (2 , 2) , stride = (2 , 2))
        x_1 = self.conv2(x_0_)
        x_2 = self.conv3(x_1)
        x_3 = self.conv4(x_2)
        x_4 = self.conv5(x_3)
        #x_out = [x_0 , x_1 , x_2 , x_3]
        x_out = [x_4 , x_3 , x_2 , x_1 , x_0]
        return x_out

In [None]:
resnet = Resnet(3 , 1000).to(device)
x = torch.randn(2 , 3 , 224 , 224).to(device)
z = resnet(x)
print(z[0].shape , z[1].shape , z[2].shape , z[3].shape , z[4].shape)

In [10]:
class BiFPN_Layer(nn.Module):
    def __init__(self , 
                 in_channels_list = [2048 , 1024 , 512 , 256 , 128] , 
                 out_channels = [128 , 256 , 512 , 1024 , 2048]):
        super(BiFPN_Layer , self).__init__()

        self.top_down = nn.ModuleList()
        self.bottom_up = nn.ModuleList()

        j = 0
        for i , channels in enumerate(in_channels_list):
            if i == 0 or i == len(in_channels_list)-1:
                if i == 0:
                    self.top_down.append(ConvT(channels , in_channels_list[i+1]))
            else :
                self.top_down.append(ConvT(channels * 2 , in_channels_list[i+1]))
        
        reversed_in_channels_list = Reverse(in_channels_list)

        for i , channels in enumerate(reversed_in_channels_list):
            if i == 0 or i == len(reversed_in_channels_list) -1 :
                if i == 0:
                    self.bottom_up.append(Conv(channels*2 , out_channels[i] , stride=(2 , 2)))
                else :
                    self.bottom_up.append(Conv(channels + out_channels[i-1] , out_channels[i] , stride=(2 , 2)))
            else :
                self.bottom_up.append(Conv(channels * 2 + out_channels[i-1] , out_channels[i] , stride=(2 , 2)))
    
    def forward(self , x):
        x_1 = []
        x_out = []
        j = 0
        for i , x_ in enumerate(x):
          
            if i == 0 or i == len(x)-1:
                if i == 0:
                    x_1.append(self.top_down[j](x_))
                    j += 1
                elif i == len(x) - 1:
                    lamp = 0
            else :
                temp = torch.cat([x[i] , x_1[-1]] , dim=1)
                x_1.append(self.top_down[j](temp))
                j+=1
        x_1_reversed = Reverse(x_1)
        x_reversed = Reverse(x)
        j = 0
        for i in range(len(x)):
            if i == 0 or i == len(x) - 1:
                if i == 0:
                    temp = torch.cat([x_reversed[i] , x_1_reversed[i]] , dim=1)
                    x_out.append(self.bottom_up[j](temp))
                    j += 1
                else :
                    temp = torch.cat([x_reversed[-1] , x_out[-1]] , dim=1)
                    x_out.append(self.bottom_up[j](temp))
            else :
                temp = torch.cat([x_reversed[i] , x_1_reversed[i] , x_out[-1]] , dim=1)
                x_out.append(self.bottom_up[j](temp))
                j += 1
        return x_out

In [None]:
bifpn = BiFPN_Layer().to(device)
y = bifpn(z)

In [12]:
class Model(nn.Module):
    def __init__(self , 
                 in_channels):
        super(Model , self).__init__()

        self.resnet = Resnet(in_channels , 1000)
        self.bifpn = BiFPN_Layer()
        self.bifpn1 = BiFPN_Layer(in_channels_list=[2048 , 1024 , 512 , 256 , 128] , 
                                  out_channels=[128 , 128 , 128 , 128 , 128])

    def forward(self  ,x):
        x = self.resnet(x)
        x = self.bifpn(x)
        x = Reverse(x)
        x = self.bifpn1(x)
        return x

In [None]:
model = Model(3).to(device)
x = torch.randn(2 , 3 , 448 , 448).to(device)
z = model(x)

In [14]:
for z_ in z:
    print(z_.shape)

torch.Size([2, 128, 56, 56])
torch.Size([2, 128, 28, 28])
torch.Size([2, 128, 14, 14])
torch.Size([2, 128, 7, 7])
torch.Size([2, 128, 4, 4])
