In [1]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchsummary import summary

In [54]:
class Up_FarSeg(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Up_FarSeg, self).__init__()
        self.up_block = nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'nearest'),
                                          nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
                                          nn.BatchNorm2d(out_channels),
                                          nn.ReLU(inplace = True))
        
    def forward(self, x):
        return self.up_block(x)

In [55]:
class Trans_FarSeg(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Trans_FarSeg, self).__init__()
        self.trans_block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size = 1),
                                         nn.BatchNorm2d(out_channels),
                                         nn.ReLU(inplace = True))
        
    def forward(self, x):
        return self.trans_block(x)

In [122]:
class FarSeg(nn.Module):
    def __init__(self, in_channels = 3, n_classes = 1):
        super(FarSeg, self).__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        
        resnet = list(torchvision.models.resnet34(pretrained=True).children())
        self.c2 = nn.Sequential(*resnet[0:5])
        self.c3 = resnet[5]
        self.c4 = resnet[6]
        self.c5 = resnet[7]
        
        self.p5_up = Up_FarSeg(512 // 1, 512 // 2)
        self.p4_up = Up_FarSeg(512 // 2, 512 // 4)
        self.p3_up = Up_FarSeg(512 // 4, 512 // 8)
        self.p2_up = Up_FarSeg(512 // 8, 512 // 16)
        
        self.p4_trans = Trans_FarSeg(512 // 2, 512 // 2)
        self.p3_trans = Trans_FarSeg(512 // 4, 512 // 4)
        self.p2_trans = Trans_FarSeg(512 // 8, 512 // 8)
        
    def forward(self, x):
        x_c2 = self.c2(x)
        x_c3 = self.c3(x_c2)
        x_c4 = self.c4(x_c3)
        x_c5 = self.c5(x_c4)
        
        print("x_c5", x_c5.shape)
        x_p5 = self.p5_up(x_c5)
        print('x_p5', x_p5.shape)
        x_p4 = torch.add(x_p5, self.p4_trans(x_c4))
        print('x_p4', x_p4.shape)
        print(x_c3.shape, self.p3_trans(x_c3).shape)
#         x_p3 = torch.add(x_p4, self.p3_trans(x_c3))
#         x_p2 = torch.add(x_p3, self.p2_trans(x_c2))
        
        return x_p4

In [123]:
# net = DecoderBlock_FarSeg(512 // 1, 512 // 2).cuda()
# summary(net, (512, 8, 8))

In [124]:
net = FarSeg().cuda()
summary(net, (3, 256, 256))

x_c5 torch.Size([2, 512, 8, 8])
x_p5 torch.Size([2, 256, 16, 16])
x_p4 torch.Size([2, 256, 16, 16])
torch.Size([2, 128, 32, 32]) torch.Size([2, 128, 32, 32])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           9,408
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
         MaxPool2d-4           [-1, 64, 64, 64]               0
            Conv2d-5           [-1, 64, 64, 64]          36,864
       BatchNorm2d-6           [-1, 64, 64, 64]             128
              ReLU-7           [-1, 64, 64, 64]               0
            Conv2d-8           [-1, 64, 64, 64]          36,864
       BatchNorm2d-9           [-1, 64, 64, 64]             128
             ReLU-10           [-1, 64, 64, 64]               0
       BasicBlock-11           [-1, 64, 64, 64]               0
         