In [None]:
import torch
from torchvision import models
import torch.nn as nn

In [None]:
class LocationAwareConv2d(torch.nn.Conv2d):
    
    def __init__(self,locationAware,gradient,w,h,in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias)
        
        if locationAware:
            self.locationBias=torch.nn.Parameter(torch.zeros(w,h,3))
            self.locationEncode=torch.autograd.Variable(torch.ones(w,h,3))
            
            if gradient:
                for i in range(w):
                    self.locationEncode[i,:,1]=self.locationEncode[:,i,0]=i/float(w-1)
        
        self.up=torch.nn.Upsample(size=(w,h), mode='bilinear', align_corners=False)
        self.w=w
        self.h=h
        self.locationAware=locationAware
        
    def forward(self,inputs):
        
        if self.locationAware:
            if self.locationBias.device != inputs.device:
                self.locationBias=self.locationBias.to(inputs.get_device())
                
            if self.locationEncode.device != inputs.device:
                self.locationEncode=self.locationEncode.to(inputs.get_device())
                
            b=self.locationBias*self.locationEncode
            
        convRes=super().forward(inputs)
        
        if convRes.shape[2]!=self.w and convRes.shape[3]!=self.h:
            convRes=self.up(convRes)
            
        if self.locationAware:
            return convRes+b[:,:,0]+b[:,:,1]+b[:,:,2]
        
        else:
            return convRes

In [None]:
"""
NimbRoNet2 Model Class
"""

class NimbRoNet2(nn.Module):
    def __init__(self):
        super(NimbRoNet2, self).__init__()
        model = models.resnet18(pretrained=True)
        
        """
        Encoder Block
        """
        self.e_block1 = nn.Sequential(*list(model.children())[0:5])
        self.conv_1_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1)
        
        self.e_block2 = nn.Sequential(*list(model.children())[5:6])
        self.conv_1_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1)
        
        self.e_block3 = nn.Sequential(*list(model.children())[6:7])
        self.conv_1_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1)
        
        self.e_block4 = nn.Sequential(*list(model.children())[7:-2])
        """
        Decoder Block
        """
        self.d_block1 = nn.Sequential(
                        nn.ReLU(),
                        nn.ConvTranspose2d(in_channels = 512, out_channels=256, kernel_size=2, stride=2, padding=0, output_padding=0))
        
        self.d_block2 = nn.Sequential(
                        nn.ReLU(),
                        nn.BatchNorm2d(512),
                        nn.ConvTranspose2d(in_channels = 512, out_channels=256, kernel_size=2, stride=2, padding=0, output_padding=0))
        
        self.d_block3 = nn.Sequential(
                        nn.ReLU(),
                        nn.BatchNorm2d(512),
                        nn.ConvTranspose2d(in_channels = 512, out_channels=128, kernel_size=2, stride=2, padding=0, output_padding=0))
        
        self.d_block4 = nn.Sequential(
                        nn.ReLU(),
                        nn.BatchNorm2d(256),
                        nn.ConvTranspose2d(in_channels = 256,out_channels=3, kernel_size=1, stride=1, padding=0, output_padding=0))
#                       LocationAwareConv2d())
        """
        Location dependent convolution
        """
        
        

    def forward(self, input):
        """
        encoder
        """
        #print("Input : ",input.shape)
        
        output = self.e_block1(input)
        #print("e_block1 : ",output.shape)
        
        res_1 = self.conv_1_1(output)
        #print("res_1 : ",res_1.shape)
        
        output = self.e_block2(output)
        #print("e_block2 : ",output.shape)
        
        res_2 = self.conv_1_2(output)
        #print("res_2 : ",res_2.shape)
        
        output = self.e_block3(output)
        #print("e_block3 : ",output.shape)
        
        res_3 = self.conv_1_3(output)
        #print("res_3 : ",res_3.shape)
        
        output = self.e_block4(output)
        #print("e_block4 : ",output.shape)
        """
        decoder
        """
        output = self.d_block1(output)
        #print("d_block1 : ",output.shape)
        
        output = torch.cat((output, res_3), 1)
        #print("d_block1 + res3 : ",output.shape)
        
        output = self.d_block2(output)
        #print("d_block2 : ",output.shape)
        
        output = torch.cat((output, res_2), 1)
        #print("d_block2 + res2 : ",output.shape)        
        
        output = self.d_block3(output)
        #print("d_block3 : ",output.shape)
        
        output = torch.cat((output, res_1), 1)
        print("d_block3 + res1 : ",output.shape)
        
        output = self.d_block4(output)
        #print("d_block4 : ",output.shape)
        """
        Location dependent convolution
        """
        
        return output
        