<a href="https://colab.research.google.com/github/AlbertFarkhutdinov/ml_lessons/blob/main/vae_by_mipt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Variational Autoencoder

**Setup**

In [8]:
import os

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import transforms
from torchvision.datasets import MNIST

from tqdm.notebook import tqdm, trange

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%matplotlib inline

plt.style.use('seaborn')

In [4]:
DATA_DIR = os.path.join(os.getcwd(), 'drive', 'My Drive', 'Colab Notebooks', 'data')
DATA_DIR

'/content/drive/My Drive/Colab Notebooks/data'

**The device on which a torch.Tensor will be allocated.**

In [5]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cpu')

**Tensor transformations**

In [6]:
def rescale(image):
  image = image - image.min()
  image = image / image.max()
  return image

In [11]:
mnist_transformations = transforms.Compose([
    transforms.ToTensor(),
    rescale,
])

In [12]:
BATCH_SIZE = 256
NUM_DATALOADER_WORKERS = 1

In [16]:
train_loader = torch.utils.data.DataLoader(
    dataset=MNIST(
        root=DATA_DIR,
        train=True,
        transform=mnist_transformations,
        download=True,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_DATALOADER_WORKERS,
)

In [None]:
test_loader = torch.utils.data.DataLoader(
    dataset=MNIST(
        root=DATA_DIR,
        train=False,
        transform=mnist_transformations,
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_DATALOADER_WORKERS,
)

In [None]:
class Encoder

In [17]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, intermediate_dims, latent_dim, input_shape):
        super().__init__()
        self.register_buffer('_initial_mu', torch.zeros((latent_dim)))
        self.register_buffer('_initial_sigma', torch.ones((latent_dim)))

        self.latent_distribution = torch.distributions.normal.Normal(
            loc=self._initial_mu,
            scale=self._initial_sigma
        )
        input_dim = np.prod(input_shape)
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, intermediate_dims[0]),
            nn.ReLU(),
            nn.BatchNorm1d(intermediate_dims[0]),
            nn.Dropout(0.3),
            nn.Linear(intermediate_dims[0], intermediate_dims[1]),
            nn.ReLU(),
            nn.BatchNorm1d(intermediate_dims[1]),
            nn.Dropout(0.3)
        )
        
        self.mu_repr = nn.Linear(intermediate_dims[1], latent_dim)
        self.log_sigma_repr = nn.Linear(intermediate_dims[1], latent_dim)  
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, intermediate_dims[1]),
            nn.LeakyReLU(),
            nn.BatchNorm1d(intermediate_dims[1]),
            nn.Dropout(0.3),
            nn.Linear(intermediate_dims[1], intermediate_dims[0]),
            nn.LeakyReLU(),
            nn.BatchNorm1d(intermediate_dims[0]),
            nn.Dropout(0.3),
            nn.Linear(intermediate_dims[0], input_dim),
            nn.Sigmoid(),
            RestoreShape(input_shape)
        )
    
    def _encode(self, x):
        latent_repr = self.encoder(x)
        mu_values = self.mu_repr(latent_repr)
        log_sigma_values = self.log_sigma_repr(latent_repr)
        return mu_values, log_sigma_values, latent_repr
    
    def _reparametrize(self, sample, mu_values, log_sigma_values):
        latent_sample = torch.exp(log_sigma_values) * sample + mu_values
        return latent_sample

    def forward(self, x, raw_sample=None):
        mu_values, log_sigma_values, latent_repr = self._encode(x)

        if raw_sample is None:
            raw_sample = torch.randn_like(mu_values)

        latent_sample = self._reparametrize(raw_sample, mu_values, log_sigma_values)
        
        reconstructed_repr = self.decoder(latent_sample)
        
        return reconstructed_repr, latent_sample, mu_values, log_sigma_values