In [1]:
import torch
import numpy as np

from itertools import chain

In [2]:
X_t = torch.zeros([1, 3, 256, 256])

In [3]:
class ConvUnit(torch.nn.Module):
    def __init__(self, channels):
        super(ConvUnit, self).__init__()
        input_channels, out_channels = channels
        self.conv = torch.nn.Conv2d(input_channels, out_channels, (4, 4), 2, 1)
        self.batch_norm = torch.nn.BatchNorm2d(out_channels)
        self.relu_op = torch.nn.LeakyReLU()
        
    def forward(self, input_x):      
        conv_out = self.conv(input_x)
        batch_norm_out = self.batch_norm(conv_out)
        output = self.relu_op(batch_norm_out)
        
        return output

In [4]:
class ConvTransposeUnit(torch.nn.Module):
    def __init__(self, channels):
        super(ConvTransposeUnit, self).__init__()
        input_channels, out_channels = channels
        self.conv = torch.nn.ConvTranspose2d(input_channels, out_channels, (4, 4), 2, 1)
        self.batch_norm = torch.nn.BatchNorm2d(out_channels)
        self.relu_op = torch.nn.LeakyReLU()
        
    def forward(self, input_x):      
        conv_out = self.conv(input_x)
        batch_norm_out = self.batch_norm(conv_out)
        output = self.relu_op(batch_norm_out)
        
        return output

def fw_pass():

    input_channels = [3, 32, 64, 128, 256, 512, 1024]
    output_channels = [32, 64, 128, 256, 512, 1024, 1024]

    channels_generator = chain(zip(input_channels, output_channels))

    conv_units_fw = [*map(lambda channels: ConvUnit(channels), channels_generator)]

    intermidiate_outputs = []
    input_x = X_t

    for unit in conv_units_fw:
        out = unit(input_x)
        intermidiate_outputs += [out]
        input_x = out
        
    return intermidiate_outputs

intermidiate_outputs = fw_pass()
def bw_pass(intermidiate_outputs):
    input_channels = [3, 32, 64, 128, 256, 512, 1024]
    output_channels = [32, 64, 128, 256, 512, 1024, 1024]

    output_channels.reverse()
    input_channels.reverse()

    rev_channels_generator = chain(zip(output_channels[:-1], input_channels[:-1]))

    conv_units_bw = [*map(lambda channels: ConvTransposeUnit(channels), rev_channels_generator)]

    z_att_0 = intermidiate_outputs.pop()
    intermidiate_outputs.reverse()

    att_features = [z_att_0]
    bw_outputs = []

    z_att_k = z_att_0
    for i, unit in enumerate(conv_units_bw):

        z_att_k = unit(z_att_k)
        inter_out_k = intermidiate_outputs[i]
        att_features += [torch.cat([z_att_k, inter_out_k], dim=1)]

    att_features += [torch.nn.UpsamplingBilinear2d(scale_factor=2)(att_features[-1])]
    
    return att_features

In [14]:
class MultiLevelAttEncoder(torch.nn.Module):
    def __init__(self, input_channels, output_channels):
        super(MultiLevelAttEncoder, self).__init__()
        self.fw_channels_generator = chain(zip(input_channels, output_channels))
        
        bw_input_channels = input_channels.copy()
        bw_output_channels = output_channels.copy()
        bw_input_channels.reverse()
        bw_output_channels.reverse()
        
        self.bw_channels_generator = chain(zip(bw_output_channels[:-1], bw_input_channels[:-1]))
        
        

    def fw_pass(self, input_x):
        #input_x = X_t [batch_size, 3, 256, 256]
        
        conv_units_fw = [*map(lambda channels: ConvUnit(channels), self.fw_channels_generator)]        
        intermidiate_outputs = []
#         input_x = X_t
        for unit in conv_units_fw:
            out = unit(input_x)
            intermidiate_outputs += [out]
            input_x = out
        return intermidiate_outputs
    
    def bw_pass(self, intermidiate_outputs):
        conv_units_bw = [*map(lambda channels: ConvTransposeUnit(channels), self.bw_channels_generator)]
        z_att_0 = intermidiate_outputs.pop()
        intermidiate_outputs.reverse()
        
        att_features = [z_att_0]
        
        z_att_k = z_att_0
        for i, unit in enumerate(conv_units_bw):

            z_att_k = unit(z_att_k)
            inter_out_k = intermidiate_outputs[i]
            att_features += [torch.cat([z_att_k, inter_out_k], dim=1)]

        att_features += [torch.nn.UpsamplingBilinear2d(scale_factor=2)(att_features[-1])]
        
        return att_features
    
    def forward(self, input_x):
        intermidiate_outputs = self.fw_pass(input_x)
        att_features = self.bw_pass(intermidiate_outputs)
        
        return att_features

In [15]:
input_channels = [3, 32, 64, 128, 256, 512, 1024]
output_channels = [32, 64, 128, 256, 512, 1024, 1024]

att_enc = MultiLevelAttEncoder(input_channels, output_channels)

In [16]:
att_features = att_enc(X_t)

In [None]:
[*map(lambda x: x.shape, intermidiate_outputs)]

In [18]:
[*map(lambda x: x.shape, att_features)]

[torch.Size([1, 1024, 2, 2]),
 torch.Size([1, 2048, 4, 4]),
 torch.Size([1, 1024, 8, 8]),
 torch.Size([1, 512, 16, 16]),
 torch.Size([1, 256, 32, 32]),
 torch.Size([1, 128, 64, 64]),
 torch.Size([1, 64, 128, 128]),
 torch.Size([1, 64, 256, 256])]

In [None]:
conv_units_bw