# This is a notebook

In [1]:
import torch
from torch import cuda
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils import data
from torchvision.utils import save_image

In [2]:
data_dir = '../../../Data/'
batch_size = 64

torch.manual_seed(22)
device = torch.device("cuda" if cuda.is_available() else "cpu")
print(device)


cuda


In [3]:
dataset = datasets.ImageFolder(data_dir, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]))

loader = data.DataLoader(dataset,
                         batch_size=batch_size,
                         shuffle=True,
                         pin_memory=cuda.is_available(),
                         num_workers= 1 if cuda.is_available() else 4)

## VAE - Existing Work
First the VAE that was already developed was migrated from Tensorflow to Pytorch
before starting to work on the VSC

In [19]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        channels = 'placeholder'
        # Encoder
        self.encoder_conv1 = self.getConvolutionLayer(3, 128)
        self.encoder_conv2 = self.getConvolutionLayer(128, 64)
        self.encoder_conv3 = self.getConvolutionLayer(64, 32)
        
        self.flatten = nn.Flatten()

        self.encoder_fc1 = nn.Linear(4608, self.latent_dim)
        self.encoder_fc2 = nn.Linear(4608, self.latent_dim)
        
        # Decoder
        self.decoder_fc1 = nn.Sequential(
            nn.Linear(self.latent_dim, 4608),
            nn.ReLU()
        )
        # Reshape to 32x12x12
        self.decoder_upsampler1 = nn.Upsample(scale_factor=(2,2) ,mode='nearest')
        
        self.decoder_deconv1 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=(2,2) ,mode='nearest')
        )
        # 48x48x64
        self.decoder_deconv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=(2,2) ,mode='nearest')
        )

        self.decoder_conv1 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1)
        # 96x96x128
        

    def getConvolutionLayer(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )


    def encode(self, x):
        print(x.shape)
        x = self.encoder_conv1(x)
        print(x.shape)
        x = self.encoder_conv2(x)
        print(x.shape)
        x = self.encoder_conv3(x)
        print(x.shape)
        x = self.flatten(x)
        print(x.shape)
        mu = self.encoder_fc1(x)
        sigma = self.encoder_fc2(x)

        return mu, sigma

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        # Keeps shape, samples from normal dist with mean 0 and variance 1
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        print(z.shape)
        z = self.decoder_fc1(z)
        print(z.shape)
        z = self.decoder_upsampler1(z.view(-1, 32, 12, 12))
        print(z.shape)
        z = self.decoder_deconv1(z)
        print(z.shape)
        z = self.decoder_deconv2(z)
        print(z.shape)
        recon = self.decoder_conv1(z)
        print(recon.shape)
        
        return recon

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [20]:
vae = VAE(256).to(device)
vae

VAE(
  (encoder_conv1): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encoder_conv2): Sequential(
    (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (encoder_conv3): Sequential(
    (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (flatten): Flatten()
  (encoder_fc1): Linear(in_features=4608, out_features=256, bias=True)
  (encoder_fc2): Linear(in_features=4608, out_features=256, bias=True)
  (decoder_fc1): Sequential(
    (0): Linear(in_features=256, out_features=4608, bias=True)
    (1): ReLU()
  )
  (decoder_upsampler1): Upsample(scale_factor=(2.0, 2.0), mode=nearest)
  (decoder_deco

In [21]:
vae.eval()
images, _ = next(iter(loader))
images = images.to(device)
images[0].shape

torch.Size([3, 96, 96])

In [22]:
print('*** Encoder ***')
mu, logvar = vae.encode(images)
print('*** Decoder ***')
z = vae.reparameterize(mu, logvar)
recon = vae.decode(z)

*** Encoder ***
torch.Size([64, 3, 96, 96])
torch.Size([64, 128, 48, 48])
torch.Size([64, 64, 24, 24])
torch.Size([64, 32, 12, 12])
torch.Size([64, 4608])
*** Decoder ***
torch.Size([64, 256])
torch.Size([64, 4608])
torch.Size([64, 32, 24, 24])
torch.Size([64, 64, 48, 48])
torch.Size([64, 128, 96, 96])
torch.Size([64, 3, 96, 96])


In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD