# Import necessary packages

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
import os
import copy
from IPython.display import clear_output
from torchsummary import summary
from tensorboardX import SummaryWriter

In [5]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        resnet_layers = list(models.resnet50(pretrained=True).children())
#         print(resnet_layers)
        self.resnet_1 = nn.Sequential(*resnet_layers[:-5])
        self.resnet_2 = nn.Sequential(*resnet_layers[-5])
        self.resnet_3 = nn.Sequential(*resnet_layers[-4])
        self.resnet_4 = nn.Sequential(*resnet_layers[-3])
        
        self.conv11_resnet_1 = nn.Conv2d(in_channels=256,out_channels=256, kernel_size=(1,1), stride=(1,1))
        self.conv11_resnet_2 = nn.Conv2d(in_channels=512,out_channels=256, kernel_size=(1,1), stride=(1,1))
        self.conv11_resnet_3 = nn.Conv2d(in_channels=1024,out_channels=256, kernel_size=(1,1), stride=(1,1))
        self.conv11_resnet_4 = nn.Conv2d(in_channels=2048,out_channels=256, kernel_size=(1,1), stride=(1,1))
        
        self.scaling_1 = nn.UpsamplingNearest2d(scale_factor=2)
        self.scaling_2 = nn.UpsamplingNearest2d(scale_factor=2)
        self.scaling_3 = nn.UpsamplingNearest2d(scale_factor=2)
        
        self.conv33_td1_1 = nn.Conv2d(in_channels=256, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td1_2 = nn.Conv2d(in_channels=128, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td2_1 = nn.Conv2d(in_channels=256, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td2_2 = nn.Conv2d(in_channels=128, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td3_1 = nn.Conv2d(in_channels=256, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td3_2 = nn.Conv2d(in_channels=128, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td4_1 = nn.Conv2d(in_channels=256, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td4_2 = nn.Conv2d(in_channels=128, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        
        self.scaling_P1 = nn.UpsamplingNearest2d(scale_factor=8)
        self.scaling_P2 = nn.UpsamplingNearest2d(scale_factor=4)
        self.scaling_P3 = nn.UpsamplingNearest2d(scale_factor=2)
        
        self.conv3_final = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), stride=(1,1))
        self.batchnorm_final = nn.BatchNorm2d(512)
        self.activation = nn.ReLU()
        self.conv1_final = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=(1,1), stride=(1,1))
        
        
#         self.scaling = nn.ConvTranspose2d(in_channels=256, out_channels=256,kernel_size=(2,2), stride=2)
        
    def forward(self, x):
        x_bu1 = self.resnet_1(x)
        x_bu2 = self.resnet_2(x_bu1)
        x_bu3 = self.resnet_3(x_bu2)
        x_bu4 = self.resnet_4(x_bu3)
#         print("Bottom-up")
#         print(x_bu1.shape)
#         print(x_bu2.shape)
#         print(x_bu3.shape)
#         print(x_bu4.shape)
        
        
        x_td1 = self.conv11_resnet_4(x_bu4)
        P_1 = self.conv33_td1_1(x_td1)
        P_1 = self.conv33_td1_2(P_1)
        
        x_td2_1 = self.scaling_1(x_td1)
        x_td2_2 = self.conv11_resnet_3(x_bu3)
        x_td2 = x_td2_1+x_td2_2
        P_2 = self.conv33_td2_1(x_td2)
        P_2 = self.conv33_td2_2(P_2)
        
        x_td3_1 = self.scaling_2(x_td2)
        x_td3_2 = self.conv11_resnet_2(x_bu2)
        x_td3 = x_td3_1+x_td3_2
        P_3 = self.conv33_td3_1(x_td3)
        P_3 = self.conv33_td3_2(P_3)
        
        x_td4_1 = self.scaling_3(x_td3)
        x_td4_2 = self.conv11_resnet_1(x_bu1)
        x_td4 = x_td4_1+x_td4_2
        P_4 = self.conv33_td4_1(x_td4)
        P_4 = self.conv33_td4_2(P_4)
        
        
        P_1 = self.scaling_P1(P_1)
        P_2 = self.scaling_P2(P_2)
        P_3 = self.scaling_P3(P_3)
        
        P_concat = torch.cat((P_1, P_2, P_3, P_4), dim=1)
        F = P_concat
#         print("P_concat")
#         print(P_concat.shape)
        
        F = self.conv3_final(F)
        F = self.batchnorm_final(F)
        F = self.activation(F)
        
        F = self.conv1_final(F)
        F = nn.UpsamplingBilinear2d(size = (int(x.size()[-2]),
                                          int(x.size()[-1])))(F)
        F = nn.Sigmoid()(F)
#         print("Top-down")
#         print(x_td4.shape)
#         print(x_td3.shape)
#         print(x_td2.shape)
#         print(x_td1.shape)
        
#         print("P")
#         print(P_1.shape)
#         print(P_2.shape)
#         print(P_3.shape)
#         print(P_4.shape)
#         print(P_concat.shape)
        
#         print("F")
#         print(x.shape)
#         print(F.shape)
        
        
        return F
with SummaryWriter("runs/MyModel") as w:
    w.add_graph(MyModel(),torch.zeros(1,3,224,224), False)
summary(MyModel(), (3,224,224))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [None]:
model = MyModel()
out=model(torch.zeros((10,3,300,300)))
print(out.shape)

In [None]:
out.shape