# This is a notebook

In [None]:
import torch
from torch import cuda
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [None]:
torch.manual_seed(22)
device = torch.device("cuda" if cuda.is_available() else "cpu")
print(device)

## VAE - Existing Work
First the VAE that was already developed was migrated from Tensorflow to Pytorch
before starting to work on the VSC

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        channels = 'placeholder'
        # Encoder
        # in_channel = same as image, filters=128, kernel=3, strides=1, padding=same, activation=relu
        # TODO Calculate padding size
        self.encoder_conv1 = nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=1, padding=)
        # ReLU
        self.encoder_maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder_conv2 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=)
        # ReLU
        self.encoder_maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder_conv3 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=)
        # ReLU
        self.encoder_maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.encoder_fc1 = nn.Linear(128, self.latent_dim)
        self.encoder_fc2 = nn.Linear(128, self.latent_dim)


    def encode(self, x):
        x = F.relu(self.encoder_conv1(x))
        x = self.encoder_maxpool1(x)
        x = F.relu(self.encoder_conv2(x))
        x = self.encoder_maxpool2(x)
        x = F.relu(self.encoder_conv3(x))
        x = self.encoder_maxpool3(x)

        mu = self.encoder_fc1(x)
        sigma = self.encoder_fc1(x)

        return mu, sigma

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD