# Variational Autoencoder

## Data preparation

In [None]:
from dataclasses import dataclass
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import v2
from tqdm.auto import tqdm

torch.manual_seed(123)

batch_size = 512
learning_rate = 5*1e-3
num_epochs = 50
latent_dim = 2
hidden_dim = [512, 256]

transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Lambda(lambda x: x.view(-1)),
])

train_data = datasets.FashionMNIST('FashionMNIST_data/', download=False, train=True, transform=transform)
test_data = datasets.FashionMNIST('FashionMNIST_data/', download=False, train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

for X, _ in train_loader:
    input_dim = X.shape[1]
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    break

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Autoencoder definition

In [2]:
@dataclass
class AEOutput:
    """
    Dataclass for AE output.
    """
    z_proj: torch.Tensor    
    x_recon: torch.Tensor    
    loss: torch.Tensor

class AE(nn.Module):
    """
    Autoencoder (VAE) class.
    
    Args:
        input_dim (int): Dimensionality of the input data.
        hidden_dim (int): Dimensionality of the hidden layers.
        latent_dim (int): Dimensionality of the latent space.
    """
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(AE, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        encoder_layers_list = []
        if type(hidden_dim) != list:
            encoder_layers_list.append(nn.Linear(input_dim, hidden_dim))
            encoder_layers_list.append(nn.LeakyReLU())
            encoder_layers_list.append(nn.Linear(hidden_dim, 2 * latent_dim))
        else:
            self.hidden_dim = hidden_dim.copy()
            encoder_layers_list.append(nn.Linear(input_dim, hidden_dim[0]))
            encoder_layers_list.append(nn.LeakyReLU())
            for i in range(len(hidden_dim[1:])):
                encoder_layers_list.append(nn.Linear(hidden_dim[i], hidden_dim[i+1]))
                encoder_layers_list.append(nn.LeakyReLU())
            encoder_layers_list.append(nn.Linear(hidden_dim[-1], latent_dim))
        
        self.encoder = nn.Sequential(*encoder_layers_list)

        decoder_layers_list = []
        if type(hidden_dim) != list:
            decoder_layers_list.append(nn.Linear(latent_dim, hidden_dim))
            decoder_layers_list.append(nn.LeakyReLU())
            decoder_layers_list.append(nn.Linear(hidden_dim, input_dim))
            decoder_layers_list.append(nn.Sigmoid())
        else:
            hidden_dim_decoder = hidden_dim[::-1]
            decoder_layers_list.append(nn.Linear(latent_dim, hidden_dim_decoder[0]))
            decoder_layers_list.append(nn.LeakyReLU())
            for i in range(len(hidden_dim_decoder[1:])):
                decoder_layers_list.append(nn.Linear(hidden_dim_decoder[i], hidden_dim_decoder[i+1]))
                decoder_layers_list.append(nn.LeakyReLU())
            decoder_layers_list.append(nn.Linear(hidden_dim_decoder[-1], input_dim))
            decoder_layers_list.append(nn.Sigmoid())

        self.decoder = nn.Sequential(*decoder_layers_list)

        print(f"Encoder: {encoder_layers_list}")
        print(f"Decoder: {decoder_layers_list}")

    def encode(self, x):        
        return self.encoder(x)
        
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x, compute_loss: bool = True):
        z_proj = self.encode(x)
        recon_x = self.decode(z_proj)
        
        if not compute_loss:
            return AEOutput(
                z_proj=z_proj,
                x_recon=recon_x,
                loss=None,
            )
        
        loss = F.binary_cross_entropy(recon_x, x, reduction='none').sum(-1).mean()
        
        return AEOutput(
                z_proj=z_proj,
                x_recon=recon_x,
                loss=loss,
            )

## Variational Autoencoder definition

In [None]:
@dataclass
class VAEOutput:
    """
    Dataclass for VAE output.
    
    Attributes:
        z_dist (torch.distributions.Distribution): The distribution of the latent variable z.
        z_sample (torch.Tensor): The sampled value of the latent variable z.
        x_recon (torch.Tensor): The reconstructed output from the VAE.
        loss (torch.Tensor): The overall loss of the VAE.
        loss_recon (torch.Tensor): The reconstruction loss component of the VAE loss.
        loss_kl (torch.Tensor): The KL divergence component of the VAE loss.
    """
    z_dist: torch.distributions.Distribution
    z_sample: torch.Tensor
    x_recon: torch.Tensor
    
    loss: torch.Tensor
    loss_recon: torch.Tensor
    loss_kl: torch.Tensor

class VAE(nn.Module):
    """
    Variational Autoencoder (VAE) class.
    
    Args:
        input_dim (int): Dimensionality of the input data.
        hidden_dim (int): Dimensionality of the hidden layers.
        latent_dim (int): Dimensionality of the latent space.
    """
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        encoder_layers_list = []
        if type(hidden_dim) != list:
            encoder_layers_list.append(nn.Linear(input_dim, hidden_dim))
            encoder_layers_list.append(nn.LeakyReLU())
            encoder_layers_list.append(nn.Linear(hidden_dim, 2 * latent_dim))
        else:
            self.hidden_dim = hidden_dim.copy()
            encoder_layers_list.append(nn.Linear(input_dim, hidden_dim[0]))
            encoder_layers_list.append(nn.LeakyReLU())
            for i in range(len(hidden_dim[1:])):
                encoder_layers_list.append(nn.Linear(hidden_dim[i], hidden_dim[i+1]))
                encoder_layers_list.append(nn.LeakyReLU())
            encoder_layers_list.append(nn.Linear(hidden_dim[-1], 2 * latent_dim))
        
        self.encoder = nn.Sequential(*encoder_layers_list)
        self.softplus = nn.Softplus()

        decoder_layers_list = []
        if type(hidden_dim) != list:
            decoder_layers_list.append(nn.Linear(latent_dim, hidden_dim))
            decoder_layers_list.append(nn.LeakyReLU())
            decoder_layers_list.append(nn.Linear(hidden_dim, input_dim))
            decoder_layers_list.append(nn.Sigmoid())
        else:
            hidden_dim_decoder = hidden_dim[::-1]
            decoder_layers_list.append(nn.Linear(latent_dim, hidden_dim_decoder[0]))
            decoder_layers_list.append(nn.LeakyReLU())
            for i in range(len(hidden_dim_decoder[1:])):
                decoder_layers_list.append(nn.Linear(hidden_dim_decoder[i], hidden_dim_decoder[i+1]))
                decoder_layers_list.append(nn.LeakyReLU())
            decoder_layers_list.append(nn.Linear(hidden_dim_decoder[-1], input_dim))
            decoder_layers_list.append(nn.Sigmoid())

        self.decoder = nn.Sequential(*decoder_layers_list)

        print(f"Encoder: {encoder_layers_list}")
        print(f"Decoder: {decoder_layers_list}")

    def encode(self, x, eps: float = 1e-8):
        """
        Encodes the input data into the latent space.
        
        Args:
            x (torch.Tensor): Input data.
            eps (float): Small value to avoid numerical instability.
        
        Returns:
            torch.distributions.MultivariateNormal: Normal distribution of the encoded data.
        """

        """
            COMPLETE AQUI!
            
            Codifique a entrada em uma distribuição variacional Gaussiana.
            Lembre que o codificador retorna as médias e os logs das variâncias.
            Você pode usar eps para evitar uma variância excessivamente pequena.
            Você deve retornar um objeto torch.distributions.MultivariateNormal.
            
            COMPLETE AQUI!
        """
        
    def reparameterize(self, dist):
        """
        Reparameterizes the encoded data to sample from the latent space.
        
        Args:
            dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.
        Returns:
            torch.Tensor: Sampled data from the latent space.
        """
        return dist.rsample()
    
    def decode(self, z):
        """
        Decodes the data from the latent space to the original input space.
        
        Args:
            z (torch.Tensor): Data in the latent space.
        
        Returns:
            torch.Tensor: Reconstructed data in the original input space.
        """
        return self.decoder(z)
    
    def forward(self, x, compute_loss: bool = True):
        """
        Performs a forward pass of the VAE.
        
        Args:
            x (torch.Tensor): Input data.
            compute_loss (bool): Whether to compute the loss or not.
        
        Returns:
            VAEOutput: VAE output dataclass.
        """
        
        dist = self.encode(x)
        z = self.reparameterize(dist)
        recon_x = self.decode(z)
        
        if not compute_loss:
            return VAEOutput(
                z_dist=dist,
                z_sample=z,
                x_recon=recon_x,
                loss=None,
                loss_recon=None,
                loss_kl=None,
            )
        
        # compute loss terms 
        loss_recon = F.binary_cross_entropy(recon_x, x, reduction='none').sum(-1).mean()
        std_normal = torch.distributions.MultivariateNormal(
            torch.zeros_like(z, device=z.device),
            scale_tril=torch.eye(z.shape[-1], device=z.device).unsqueeze(0).expand(z.shape[0], -1, -1),
        )
        loss_kl = torch.distributions.kl.kl_divergence(dist, std_normal).mean()
                
        loss = loss_recon + loss_kl
        
        return VAEOutput(
            z_dist=dist,
            z_sample=z,
            x_recon=recon_x,
            loss=loss,
            loss_recon=loss_recon,
            loss_kl=loss_kl,
        )

## Train and Test functions

In [4]:
def train(model, dataloader, optimizer, prev_updates):
    """
    Trains the model on the given data.
    
    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        loss_fn: The loss function.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode
    
    loss_history = []
    for batch_idx, (data, _) in enumerate(tqdm(dataloader, desc='Training')):
        n_upd = prev_updates + batch_idx
        
        data = data.to(device)
       
        optimizer.zero_grad()  # Zero the gradients
        
        output = model(data)  # Forward pass
        
        loss = output.loss
        loss.backward()

        optimizer.step()  # Update the model parameters
        
        loss_history.append(loss.item())

    print(f'Step {n_upd:,} (N samples: {n_upd*batch_size:,}), Train set loss: {np.mean(loss_history):.4f}') 

    return prev_updates + len(dataloader), loss_history

def test(model, dataloader):
    """
    Tests the model on the given data.
    
    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
    """
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    
    with torch.no_grad():
        for data, _ in tqdm(dataloader, desc='Testing'):
            data = data.to(device)
            data = data.view(data.size(0), -1)  # Flatten the data
            
            output = model(data, compute_loss=True)  # Forward pass
            
            test_loss += output.loss.item()
            
    test_loss /= len(dataloader)
    print(f'====> Test set loss: {test_loss:.4f}')

    return test_loss

## Train the Autoencoder

In [None]:
print("Training Autoencoder...")
model_AE = AE(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
print(model_AE)
optimizer = torch.optim.AdamW(model_AE.parameters(), lr=learning_rate)

train_loss_history_AE = []
test_loss_history_AE = []
prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    prev_updates, train_loss = train(model_AE, train_loader, optimizer, prev_updates)
    test_loss = test(model_AE, test_loader)
    train_loss_history_AE.append(np.mean(train_loss))
    test_loss_history_AE.append(test_loss)

In [None]:
plt.plot(range(1,len(train_loss_history_AE)+1), train_loss_history_AE, label="Train loss (epoch average)")
plt.plot(range(1,len(train_loss_history_AE)+1), test_loss_history_AE, label="Test loss")
plt.legend()
plt.xlabel("Epochs")
plt.title("Autoencoder")
plt.show()

## Train the Variational Autoencoder

In [None]:
print("Training Variational Autoencoder...")
model_VAE = VAE(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
print(model_VAE)
optimizer = torch.optim.AdamW(model_VAE.parameters(), lr=learning_rate)

train_loss_history_VAE = []
test_loss_history_VAE = []
prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    prev_updates, train_loss = train(model_VAE, train_loader, optimizer, prev_updates)
    test_loss = test(model_VAE, test_loader)
    train_loss_history_VAE.append(np.mean(train_loss))
    test_loss_history_VAE.append(test_loss)

In [None]:
plt.plot(range(1,len(train_loss_history_VAE)+1), train_loss_history_VAE, label="Train loss (epoch average)")
plt.plot(range(1,len(train_loss_history_VAE)+1), test_loss_history_VAE, label="Test loss")
plt.legend()
plt.xlabel("Epochs")
plt.title("Variational Autoencoder")
plt.show()

## Plot functions

In [None]:
def plot_latent_space(model):
    model.eval()
    z_all = []
    y_all = []
    with torch.no_grad():
        for data, target in tqdm(train_loader, desc='Encoding'):
            """
            COMPLETE AQUI!
            
            Projete os dados de treinamento.
            
            COMPLETE AQUI!
            """
            y_all.append(target.numpy())

    z_all = np.concatenate(z_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)

    plt.figure(figsize=(10, 10))
    plt.scatter(z_all[:, 0], z_all[:, 1], c=y_all, cmap='tab10')
    plt.xlim(-6,6)
    plt.ylim(-6,6)
    plt.colorbar()
    plt.title(f'Latent projection')
    plt.show()

def plot_random_samples(model):
    """
    COMPLETE AQUI!
    
    Gere 100 novas amostras na variável 'samples'
    
    COMPLETE AQUI!
    """
    
    fig, ax = plt.subplots(10, 10, figsize=(10, 10))
    for i in range(10):
        for j in range(10):
            ax[i, j].imshow(samples[i*10+j].view(28, 28).cpu().detach().numpy(), cmap='gray')
            ax[i, j].axis('off')
    plt.show()

## Show plots

In [None]:
print("Autoencoder:")
plot_latent_space(model_AE)
print("Variational Autoencoder:")
plot_latent_space(model_VAE)

In [None]:
print("Autoencoder:")
plot_random_samples(model_AE)
print("Variational Autoencoder:")
plot_random_samples(model_VAE)