In [1]:
####

In [2]:
import torch
from torch import nn

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def Reverse(lst):
    return [ele for ele in reversed(lst)]

In [5]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels ,
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 use_pool = False):
        super(Conv , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation
        self.use_pool = use_pool

        self.conv1 = nn.Conv2d(in_channels , 
                               out_channels , 
                               kernel_size , 
                               stride , 
                               padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)
        if self.use_pool:
            self.max_pool = nn.MaxPool2d(kernel_size=2 , stride=2)
        
    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_pool:
            x = self.max_pool(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [None]:
x = torch.randn(2 , 3 , 512,  512).to(device)
conv = Conv(3 , 32).to(device)
z = conv(x)
z.shape

In [7]:
class ConvT(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (2 , 2) , 
                 stride = (2 , 2) , 
                 padding = 0 , 
                 use_norm = True , 
                 use_activation = True):
        super(ConvT , self).__init__()

        self.use_norm = use_norm
        self.use_activation = use_activation

        self.convT = nn.ConvTranspose2d(in_channels , 
                                        out_channels , 
                                        kernel_size , 
                                        stride ,
                                        padding)
        if self.use_norm:
            self.norm = nn.InstanceNorm2d(out_channels)
        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.convT(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [None]:
x = torch.randn(2 , 3 , 512 , 512).to(device)
convT = ConvT(3 , 32).to(device)
z = convT(x)
z.shape

In [9]:
class Resnet_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 downsample = False):
        super(Resnet_Block , self).__init__()

        self.downsample = downsample

        if self.downsample:
            self.conv1 = Conv(in_channels , 
                        in_channels , 
                        kernel_size=(2 , 2) , 
                        stride=(2 , 2) ,
                        padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                            out_channels ,
                            kernel_size = (2 ,2) , 
                            stride = (2 , 2) , 
                            padding = 0)
        else:    
            self.conv1 = Conv(in_channels , 
                            in_channels , 
                            kernel_size=(1 , 1) , 
                            stride=(1 , 1) ,
                            padding = 0)
            
            self.conv_skip = Conv(in_channels ,
                              out_channels ,
                              kernel_size = (1 , 1) , 
                              stride = (1 ,1) , 
                              padding = 0)
            
        self.conv2 = Conv(in_channels , 
                          in_channels)
        
        self.conv3 = Conv(in_channels , 
                          out_channels , 
                          kernel_size = (1 , 1) , 
                          stride = (1 , 1) , 
                          padding = 0)
        

        
    def forward(self , x): 
        x_ = x.clone()
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x_ = self.conv_skip(x_)
        x += x_
        return x

In [10]:
class Resnet(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels):
        super(Resnet , self).__init__()

        self.conv1 = Conv(in_channels , 64 , kernel_size=(7 , 7) , stride=(2 , 2) , padding=3)

        self.conv2 = self._make_repeated_blocks(64 , 256 , 3 , downsample = False)
        self.conv3 = self._make_repeated_blocks(256 , 512 , 8)
        self.conv4 = self._make_repeated_blocks(512 , 1024 , 36)
        self.conv5 = self._make_repeated_blocks(1024 , 2048 , 3)
        #self.linear = Linear(2048 , out_channels)

    def _make_repeated_blocks(self , in_channels , out_channels , repeats , downsample = True):
        layers = []
        for i in range(repeats):
            if i == 0 and downsample == True:
                layers.append(Resnet_Block(in_channels , out_channels , downsample=downsample))
            elif i == 0:
                layers.append(Resnet_Block(in_channels , out_channels))
            else:
                layers.append(Resnet_Block(out_channels , out_channels))
        return nn.Sequential(*layers)

    def forward(self , x):
        x = self.conv1(x)
        x = torch.max_pool2d(x , kernel_size = (2 , 2) , stride = (2 , 2))
        x_0 = self.conv2(x)
        x_1 = self.conv3(x_0)
        x_2 = self.conv4(x_1)
        x_3 = self.conv5(x_2)
        #x_out = [x_0 , x_1 , x_2 , x_3]
        x_out = [x_3 , x_2 , x_1 , x_0]
        return x_out

In [None]:
resnet = Resnet(3 , 1000).to(device)
x = torch.randn(2 , 3 , 224 , 224).to(device)
z = resnet(x)
print(z[0].shape , z[1].shape , z[2].shape , z[3].shape)

In [17]:
class PAN_Net(nn.Module):
    def __init__(self , 
                 in_channels = [2048 , 1024 , 512 , 256] ,
                 out_channels = [128 , 256 , 512 , 1024]):
        super(PAN_Net , self).__init__()

        self.top_down = nn.ModuleList()
        self.bottom_up = nn.ModuleList()
        self.resnet = Resnet(3 , 1000)
        
        for channel in in_channels:
            out_channel = channel // 2
            self.top_down.append(ConvT(channel , out_channel))
        
        for channel in out_channels:
            out_channel = channel * 2
            self.bottom_up.append(Conv(channel , out_channel , use_pool=True))

    def forward(self , x):
        x = self.resnet(x)
        p = []
        N = []
        x0 , x1 , x2 , x3 = x
        for i , layer in enumerate(self.top_down):
            p.append(layer(x[i]))
            if i !=0 and i!= len(self.top_down)-1:
                p[i] = p[i] + x[i+1]

        p_ = Reverse(p)
        for i , layer in enumerate(self.bottom_up):
            N.append(layer(p_[i]))
            if i != 0 and i!= len(self.top_down)-1:
                N[i] = N[i] + p_[i+1]

        return N

In [18]:
def test():
    pan_net = PAN_Net().to(device)
    x = torch.randn(2 , 3 , 224 , 224).to(device)
    z = pan_net(x)
    for i in z:
        print(i.shape)

In [None]:
test()