In [None]:
# CONSTANTS

# About VAE's
VAE's, Variational Auto-Encoders work in encoder, $e$, decoder, $d$ pairs. The goals of these 2 components is to
- Maximize the amount of information held while encoding, $e^*$
- Minimize the amount of information lost while decoding, $d^*$

We express this function as
$$ e^*, d^* = \arg \min \epsilon\left(x, d(e(x))\right)$$
Where $d(e(x))$ is the reconstructed output

## Autoencoder

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [None]:
def get_device():
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    return device

mnist_loader = torch.utils.data.DataLoader(datasets.MNIST(root='./data', train=True, download=True, transform=trans), shuffle=True)

In [12]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(LinearEncoder, self).__init__()
        encoder_layers = [F.relu(nn.Linear(in_features=28*28, out_features=64)), 
                          F.relu(nn.Linear(in_features=64, out_features=12)), 
                          F.relu(nn.Linear(in_features=12, out_features=3))]
        self.encoder = nn.Sequential()
        
        decoder_layers = list(reversed(encoder_layers))
        self.decoder = nn.Sequential()
    
    def forward(self, x):
        return self.decoder(self.encoder(x))

In [None]:
def train():
    # architecture
    net = Autoencoder()
    
    # loss and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

    # training loop
    train_loss = []
    for epoch in range(NUM_EPOCHS):
        running_loss = 0.0
        for data in trainloader:
            img, _ = data
            img = img.to(device)
            img = img.view(img.size(0), -1)
            optimizer.zero_grad()
            outputs = net(img)
            loss = criterion(outputs, img)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        loss = running_loss / len(trainloader)
        train_loss.append(loss)
        print('Epoch {} of {}, Train Loss: {:.3f}'.format(
            epoch+1, NUM_EPOCHS, loss))
        if epoch % 5 == 0:
            save_decoded_image(outputs.cpu().data, epoch)
    return train_loss