In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim

import torchvision.transforms as transforms
import torchvision.models as models

from PIL import Image
import matplotlib.pyplot as plt

import copy

In [None]:
class GeneratorAE(nn.Module):

    def __init__(self):
        super(GeneratorAE,self).__init__()
        
        #Convolution 1
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=16, kernel_size=4,stride=1, padding=0)
        nn.init.xavier_uniform(self.conv1.weight) #Xaviers Initialisation
        self.act1 = nn.ReLU()

        #Max Pool 1
        self.maxpool1 = nn.MaxPool2d(kernel_size=2,return_indices=True)

        #Convolution 2
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5)
        nn.init.xavier_uniform(self.conv2.weight)
        self.act2 = nn.ReLU()

        #Max Pool 2
        self.maxpool2 = nn.MaxPool2d(kernel_size=2,return_indices=True)

        #Convolution 3
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        nn.init.xavier_uniform(self.conv3.weight)
        self.act3 = nn.ReLU()

        #Deconvolution 1
        self.deconv1 = nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=3)
        nn.init.xavier_uniform(self.deconv1.weight)
        self.act4 = nn.ReLU()

        #Max UnPool 1
        self.maxunpool1 = nn.MaxUnpool2d(kernel_size=2)

        #Deconvolution 2
        self.deconv2 = nn.ConvTranspose2d(in_channels=32,out_channels=16,kernel_size=5)
        nn.init.xavier_uniform(self.deconv2.weight)
        self.act5 = nn.ReLU()

        #Max UnPool 2
        self.maxunpool2 = nn.MaxUnpool2d(kernel_size=2)

        #Deconvolution 3
        self.deconv3 = nn.ConvTranspose2d(in_channels=16,out_channels=3,kernel_size=4)
        nn.init.xavier_uniform(self.deconv3.weight)
        self.act6 = nn.ReLU()
        
    def forward(self,x):        
        out = self.conv1(x)
        out = self.act1(out)
        size1 = out.size()
        out, indices1 = self.maxpool1(out)
        out = self.conv2(out)
        out = self.act2(out)
        size2 = out.size()
        out, indices2 = self.maxpool2(out)
        out = self.conv3(out)
        out = self.act3(out)

        out = self.deconv1(out)
        out = self.act4(out)
        out = self.maxunpool1(out,indices2,size2)
        out = self.deconv2(out)
        out = self.act5(out)
        out = self.maxunpool2(out,indices1,size1)
        out = self.deconv3(out)
        out = self.act6(out)
        return(out)