# Import necessary packages

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
import os
import copy
from IPython.display import clear_output
from torchsummary import summary
from tensorboardX import SummaryWriter
import torch.nn.functional as F

In [2]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        resnet_layers = list(models.resnet50(pretrained=True).children())
#         print(resnet_layers)
        self.resnet_1 = nn.Sequential(*resnet_layers[:-5])
        self.resnet_2 = nn.Sequential(*resnet_layers[-5])
        self.resnet_3 = nn.Sequential(*resnet_layers[-4])
        self.resnet_4 = nn.Sequential(*resnet_layers[-3])
        
        self.conv11_resnet_1 = nn.Conv2d(in_channels=256,out_channels=256, kernel_size=(1,1), stride=(1,1))
        self.conv11_resnet_2 = nn.Conv2d(in_channels=512,out_channels=256, kernel_size=(1,1), stride=(1,1))
        self.conv11_resnet_3 = nn.Conv2d(in_channels=1024,out_channels=256, kernel_size=(1,1), stride=(1,1))
        self.conv11_resnet_4 = nn.Conv2d(in_channels=2048,out_channels=256, kernel_size=(1,1), stride=(1,1))
        
        self.scaling_1 = nn.UpsamplingNearest2d(scale_factor=2)
        self.scaling_2 = nn.UpsamplingNearest2d(scale_factor=2)
        self.scaling_3 = nn.UpsamplingNearest2d(scale_factor=2)
        
        self.conv33_td1_1 = nn.Conv2d(in_channels=256, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td1_2 = nn.Conv2d(in_channels=128, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td2_1 = nn.Conv2d(in_channels=256, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td2_2 = nn.Conv2d(in_channels=128, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td3_1 = nn.Conv2d(in_channels=256, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td3_2 = nn.Conv2d(in_channels=128, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td4_1 = nn.Conv2d(in_channels=256, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        self.conv33_td4_2 = nn.Conv2d(in_channels=128, out_channels=128,kernel_size=(3,3),stride=(1,1), padding=(1,1))
        
        self.scaling_P1 = nn.UpsamplingNearest2d(scale_factor=8)
        self.scaling_P2 = nn.UpsamplingNearest2d(scale_factor=4)
        self.scaling_P3 = nn.UpsamplingNearest2d(scale_factor=2)
        
        self.conv3_final = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3,3), stride=(1,1))
        self.batchnorm_final = nn.BatchNorm2d(512)
        self.activation = nn.ReLU()
        self.conv1_final = nn.Conv2d(in_channels=512, out_channels=2, kernel_size=(1,1), stride=(1,1))
        
        
#         self.scaling = nn.ConvTranspose2d(in_channels=256, out_channels=256,kernel_size=(2,2), stride=2)
        
    def forward(self, x):
        x_bu1 = self.resnet_1(x)
        x_bu2 = self.resnet_2(x_bu1)
        x_bu3 = self.resnet_3(x_bu2)
        x_bu4 = self.resnet_4(x_bu3)
#         print("Bottom-up")
#         print(x_bu1.shape)
#         print(x_bu2.shape)
#         print(x_bu3.shape)
#         print(x_bu4.shape)
        
        
        x_td1 = self.conv11_resnet_4(x_bu4)
        P_1 = self.conv33_td1_1(x_td1)
        P_1 = self.conv33_td1_2(P_1)
        
        x_td2_1 = self.scaling_1(x_td1)
        x_td2_2 = self.conv11_resnet_3(x_bu3)
        x_td2 = x_td2_1+x_td2_2
        P_2 = self.conv33_td2_1(x_td2)
        P_2 = self.conv33_td2_2(P_2)
        
        x_td3_1 = self.scaling_2(x_td2)
        x_td3_2 = self.conv11_resnet_2(x_bu2)
        x_td3 = x_td3_1+x_td3_2
        P_3 = self.conv33_td3_1(x_td3)
        P_3 = self.conv33_td3_2(P_3)
        
        x_td4_1 = self.scaling_3(x_td3)
        x_td4_2 = self.conv11_resnet_1(x_bu1)
        x_td4 = x_td4_1+x_td4_2
        P_4 = self.conv33_td4_1(x_td4)
        P_4 = self.conv33_td4_2(P_4)
        
        
        P_1 = self.scaling_P1(P_1)
        P_2 = self.scaling_P2(P_2)
        P_3 = self.scaling_P3(P_3)
        
        P_concat = torch.cat((P_1, P_2, P_3, P_4), dim=1)
        F = P_concat
#         print("P_concat")
#         print(P_concat.shape)
        
        F = self.conv3_final(F)
        F = self.batchnorm_final(F)
        F = self.activation(F)
        
        F = self.conv1_final(F)
        F = nn.UpsamplingBilinear2d(size = (int(x.size()[-2]),
                                          int(x.size()[-1])))(F)
#         F = nn.Softmax2d()(F)
#         print("Top-down")
#         print(x_td4.shape)
#         print(x_td3.shape)
#         print(x_td2.shape)
#         print(x_td1.shape)
        
#         print("P")
#         print(P_1.shape)
#         print(P_2.shape)
#         print(P_3.shape)
#         print(P_4.shape)
#         print(P_concat.shape)
        
#         print("F")
#         print(x.shape)
#         print(F.shape)
        
        
        return F
with SummaryWriter("runs/MyModel") as w:
    w.add_graph(MyModel(),torch.zeros(1,3,224,224), False)
summary(MyModel(), (3,224,224))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [3]:
model = MyModel()
out=model(torch.randn((2,3,224,224)))
print(out.shape)
print(out[1,0,1,1], out[1,1,1,1])

torch.Size([2, 2, 224, 224])
tensor(-0.1556, grad_fn=<SelectBackward>) tensor(-0.1465, grad_fn=<SelectBackward>)


In [4]:
for param in model.parameters():
    print(param.requires_grad)

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [5]:
out.shape

torch.Size([2, 2, 224, 224])

In [6]:
def crossentropy_loss(input_tensor, target=None, weight=None ):
    n,c,h,w = input_tensor.size()
    
    
#     target = target.repeat(1,3,1,1)
#     print(target.size())
#     target_mask = target > 0
#     input_tensor_masked = input_tensor[target_mask]
#     target_masked = target[target_mask]
#     print(input_tensor_masked.size(), target_masked.size())

    input_tensor = input_tensor.transpose(2,1).transpose(2,3).contiguous()
# #     print(input_tensor,target.view(n,h,w,c))
#     print(c)
    input_tensor = input_tensor[target.view(n,h,w,1).repeat(1,1,1,c)>=0].view(-1,c)
    target_mask = target >= 0
    target = target[target_mask]
#     print(input_tensor.size(), input_tensor)
#     print(target.size(), target)
    loss = F.cross_entropy(input_tensor, target, weight=weight, size_average=False)
#     print(loss)
    return loss


In [7]:
input_image = torch.randn((1,3,224,224))*255
# print(input_image)
label = input_image[:,0,:,:]>0.5
label = label.long()
output=model(input_image)

crossentropy_loss(output, label)



tensor(35163.7812, grad_fn=<NllLossBackward>)

In [8]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters())
summary = SummaryWriter("runs/Plots")
import matplotlib.pyplot as plt

input_image = plt.imread("./index.png")
# print(input_image.shape)
input_image = torch.FloatTensor(input_image).transpose(1,2).transpose(0,1)
# print(input_image.size())
# print(input_image)
input_image = input_image.repeat(1,1,1,1)[:,:,:224,:224]
label = input_image[:,0,:,:]>=0.5
label = label.long()
# print(label)
# print(label*255)

# input_image = torch.randn((1,3,224,224))*255



print(label.size())
summary.add_images("original_image",input_image)
summary.add_image("label_image", label.float())



for epoch in range(500):
    optimizer.zero_grad()
    # print(input_image)
    
    output=model(input_image)
    loss = crossentropy_loss(output, label)
    loss.backward()
    optimizer.step()
    print(epoch,loss)
    print(output.shape)
    print(nn.Softmax2d()(output)[0,0].shape)
#     print(output)
    summary.add_scalar("Loss",loss,epoch)
    summary.add_image("output_images_region1", nn.Softmax2d()(output)[0,0].view(1,224,224), epoch)
    summary.add_image("output_images_region2", nn.Softmax2d()(output)[0,1].view(1,224,224), epoch)
    
    
    

torch.Size([1, 224, 224])
0 tensor(34265.8477, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
1 tensor(30528.4277, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
2 tensor(67677.6484, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
3 tensor(32483.7109, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
4 tensor(24791.6016, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
5 tensor(23256.0176, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
6 tensor(20329.7168, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
7 tensor(17884.5781, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
8 tensor(16936.2539, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
9 tensor(16223.1729, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224

82 tensor(1803.2679, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
83 tensor(1785.0641, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
84 tensor(1767.2480, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
85 tensor(1749.8171, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
86 tensor(1732.8342, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
87 tensor(1716.5768, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
88 tensor(1700.2693, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
89 tensor(1684.5549, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
90 tensor(1668.9146, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
91 tensor(1653.8530, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])


164 tensor(1087.8806, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
165 tensor(1082.9126, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
166 tensor(1078.4508, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
167 tensor(1073.5222, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
168 tensor(1068.7584, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
169 tensor(1064.1161, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
170 tensor(1059.6991, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
171 tensor(1056.9243, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
172 tensor(1056.9471, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
173 tensor(1065.5975, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([2

246 tensor(893.7197, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
247 tensor(858.3022, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
248 tensor(868.2557, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
249 tensor(855.5798, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
250 tensor(847.4069, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
251 tensor(849.8184, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
252 tensor(834.3759, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
253 tensor(838.5348, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
254 tensor(828.3619, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
255 tensor(824.5114, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])


328 tensor(707.5904, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
329 tensor(694.2781, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
330 tensor(697.6096, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
331 tensor(690.8198, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
332 tensor(687.9754, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
333 tensor(686.9012, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
334 tensor(680.4639, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
335 tensor(681.5935, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
336 tensor(675.1682, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
337 tensor(675.3313, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])


410 tensor(548.6523, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
411 tensor(547.3492, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
412 tensor(546.0530, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
413 tensor(544.7650, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
414 tensor(543.4839, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
415 tensor(542.2116, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
416 tensor(540.9538, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
417 tensor(539.7169, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
418 tensor(538.5327, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
419 tensor(537.4626, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])


492 tensor(491.7773, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
493 tensor(490.4610, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
494 tensor(489.2342, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
495 tensor(487.9112, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
496 tensor(486.7084, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
497 tensor(485.4657, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
498 tensor(484.2604, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])
499 tensor(483.0560, grad_fn=<NllLossBackward>)
torch.Size([1, 2, 224, 224])
torch.Size([224, 224])


In [9]:
def test(imgdir):
    input_image = plt.imread(imgdir)
    input_image = torch.FloatTensor(input_image).transpose(1,2).transpose(0,1)
    input_image = input_image.repeat(1,1,1,1)[:,:,:224*3,:224*3]
    print(input_image.size())
    output=model(input_image)
    summary.add_image("test_image", input_image[0,0])
    summary.add_image("test_output_images_region1", nn.Softmax2d()(output)[0,0].view(1,224*3,224*3))
    summary.add_image("test_output_images_region2", nn.Softmax2d()(output)[0,1].view(1,224*3,224*3))

In [10]:
test("./index11.jpeg")

torch.Size([1, 3, 672, 672])


AssertionError: size of input tensor and input format are different