<a href="https://colab.research.google.com/github/ArghyaPal/Zero-shot-task-transfer/blob/master/Encoder_Decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
'''
A slight change from our paper description. In the paper we used ResNet-50 as Encoder.
However, we later changed it to UNet as we find it more concise

'''

import torch
from torch.autograd import Variable
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models,transforms
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import os
from torch.autograd import Function
from torch.autograd import Variable
from collections import OrderedDict
import math
import torchvision.models as models
import random
from resnet import *

zsize = 100
batch_size = 50
iterations =  100
learningRate= 0.001


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

##############################
#           U-NET
##############################
class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 3, 1, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        #print(x.size())
        #print(skip_input.size())
        x = torch.cat((x, skip_input), 1)

        return x
        
class Encoder(nn.Module):

    def __init__(self, block, layers, num_classes=23):
        self.inplanes = 64
        super(Encoder, self).__init__()
        self.down1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.down2 = self._make_layer(block, 64, layers[0])
        self.down3 = self._make_layer(block, 128, layers[1], stride=2)
        self.down4 = self._make_layer(block, 256, layers[2], stride=2)
        self.down5 = self._make_layer(block, 512, layers[3], stride=2)
      
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
    	d1 = self.down1(x)
    	d2 = self.down2(d1)
    	d3 = self.down3(d2)
    	d4 = self.down4(d3) 
    	d5 = self.down5(d4)
    	return d1,d2,d3,d4,d5

encoder = Encoder(Bottleneck, [3, 4, 6, 3])


class Decoder(nn.Module):
        def __init__(self):
                super(Decoder, self).__init__()
                self.up1 = nn.ConvTranspose2d(2048, 512, 3, 1, 1, bias=False)
                self.up2 = UNetUp(512, 512, dropout=0.5)
                self.up3 = nn.ConvTranspose2d(2560, 256, 3, 2, 1, bias=False)
                self.up4 = UNetUp(256, 256, dropout=0.5)
                self.up5 = nn.ConvTranspose2d(1280, 128, 3, 2, 1, bias=False)
                self.up6 = UNetUp(128, 128, dropout=0.5)
                self.up7 = nn.ConvTranspose2d(640, 64, 4, 2, 1, bias=False)
                self.up8 = UNetUp(64, 64, dropout=0.5)
                self.up9 = UNetUp(320, 3, dropout=0.5)
                self.up10 = nn.ConvTranspose2d(67, 3, 4, 2, 1, bias=False)
                self.up11 = UNetUp(3, 3, dropout=0.5)
                self.up12 = nn.ConvTranspose2d(6, 1, 1, bias=False) 
                # pleae change the number of output channels according to your task
                
        def forward(self, d1, d2, d3, d4, d5):
                u1 = self.up1(d5)    
                u2 = self.up2(u1,d5)
                u3 = self.up3(u2)
                u4 = self.up4(u3, d4)
                u5 = self.up5(u4)
                u6 = self.up6(u5, d3)
                u7 = self.up7(u6)
                u8 = self.up8(u7, d2)
                u9 = self.up9(u8, d1)
                u10 = self.up10(u9)
                return u10

decoder = Decoder()


#########################################################
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self,x):
        d1, d2, d3, d4, d5 = self.encoder(x)
        x = self.decoder(d1, d2, d3, d4, d5)
        return x