## Simple Autoencoder

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3,3),
                               stride=2, padding=(1,1))
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3),
                               stride=2, padding=(1,1))
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3),
                               stride=2, padding=(1,1))
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(in_features=2048, out_features=2)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        # print("After layer 1:\t", x.size())
        x = self.relu(self.conv2(x))
        # print("After layer 2:\t", x.size())
        x = self.relu(self.conv3(x))
        # print("After layer 3:\t", x.size())
        x = self.flatten(x)
        # print("After flatten:\t", x.size())
        x = self.linear1(x)
        return x
    


In [18]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(in_features=2, out_features=2048)
        self.convt1 = nn.ConvTranspose2d(in_channels=128, out_channels=128, 
                                         kernel_size=(3,3), stride=2, padding=(1,1),
                                         output_padding=(1,1))
        self.convt2 = nn.ConvTranspose2d(in_channels=128, out_channels=64,
                                         kernel_size=(3,3), stride=2, padding=(1,1),
                                         output_padding=(1,1))
        self.convt3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, 
                                         kernel_size=(3,3), stride=2, padding=(1,1),
                                         output_padding=(1,1))
        self.conv1 = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(3,3),
                               stride=1, padding="same")
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        
    def forward(self, x):
        x = self.linear1(x)
        # print("After linear1:\t", x.size())
        x = x.view(-1, 128, 4, 4)  # does the work of the reshape layer in tensorflow
        # print("After reshaping:", x.size())
        x = self.relu(self.convt1(x))
        # print("After convt1:\t", x.size())
        x = self.relu(self.convt2(x))
        # print("After convt2:\t", x.size())
        x = self.relu(self.convt3(x))
        # print("After convt3:\t", x.size())
        x = self.sigmoid(self.conv1(x))
        # we use sigmoid because the input images are normalized to between 0 
        # and 1 and we want the decoder to also reconstruct the images such the 
        # pixels are between 0 and 1
        # if we had normalized the images to between -1 and 1, then we would
        # have used tanh instead of sigmoid
        # print("After conv1:\t", x.size())
        return x
        

In [21]:
class AutoEncoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(AutoEncoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    