In [1]:
import torch
import torchvision.datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import nn

from torchsummary import summary

import matplotlib.pyplot as plt

## Datasets

In [2]:
mnist_dataset = torchvision.datasets.MNIST(
    root="Mnist_dataset",
    download=True,
    train=True,
    transform=ToTensor()
)

In [3]:
train,val,test = torch.utils.data.random_split(mnist_dataset,(0.8,0.1,0.1))

In [4]:
batch_size = 500
train_loader = DataLoader(
    dataset=train,
    shuffle=True,
    batch_size=batch_size
)

In [5]:
val_loader = DataLoader(
    dataset=val,
    shuffle=True,
    batch_size=batch_size
)

## VAE Encoder
Similar to standard Autoencoder, but with two independent lineal layer to represent mean and logvar


In [50]:
def Conv_block(in_channels,out_channels, kernel_size = 4, stride = 2, padding = 1):
    return nn.Sequential(
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )

class Encoder(nn.Module):
    def __init__(self, in_channels:int , latent_dims:int):
        super().__init__()
        self.convolutions = nn.Sequential(
            Conv_block(in_channels,128),
            Conv_block(128,256),
            Conv_block(256,512),
            Conv_block(512,1024)
        )

        self.mu = nn.Linear(1024,latent_dims)
        self.logvar = nn.Linear(1024,latent_dims)

    def forward(self, x:torch.Tensor):
        bs = x.shape[0]
        # print(x.shape)

        x = self.convolutions(x)
        x = x.reshape(bs,-1)
        # x = x.flatten(start_dim=1)
        mu = self.mu(x)
        logvar = self.logvar(x)

        return mu,logvar
        # return x

In [None]:
# VAE_encoder = Encoder(in_channels=1,latent_dims=2).cuda()

# # VAE_encoder(a.cuda()).shape

# summary(VAE_encoder, (1,28,28), batch_size=100)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [100, 128, 14, 14]           2,176
       BatchNorm2d-2         [100, 128, 14, 14]             256
              ReLU-3         [100, 128, 14, 14]               0
            Conv2d-4           [100, 256, 7, 7]         524,544
       BatchNorm2d-5           [100, 256, 7, 7]             512
              ReLU-6           [100, 256, 7, 7]               0
            Conv2d-7           [100, 512, 3, 3]       2,097,664
       BatchNorm2d-8           [100, 512, 3, 3]           1,024
              ReLU-9           [100, 512, 3, 3]               0
           Conv2d-10          [100, 1024, 1, 1]       8,389,632
      BatchNorm2d-11          [100, 1024, 1, 1]           2,048
             ReLU-12          [100, 1024, 1, 1]               0
           Linear-13                   [100, 2]           2,050
           Linear-14                   

## Decoder
Similar to standard decoder, but take a data from distribution defined by encoder.
Actually, we don't need to make non changes. 

In [47]:
def conv_transpose_block(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=2,
        padding=1,
        output_padding=0,
        with_act=True
):
    modules = [
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding
            )   
    ]

    if with_act:
        modules.append(nn.BatchNorm2d(out_channels))
        modules.append(nn.ReLU())

    return nn.Sequential(*modules)


class Decoder(nn.Module):
    def __init__(self,out_channels:int , latent_dims:int):
        super().__init__()

        self.lineal = nn.Linear(latent_dims,1024*4*4)
        self.t_conv = nn.Sequential(
            conv_transpose_block(1024,512),
            conv_transpose_block(512,256,output_padding=1),
            conv_transpose_block(256,out_channels,output_padding=1,with_act=False)
        )

    def forward(self, x:torch.Tensor):
        bs = x.shape[0]

        x = self.lineal(x)
        x = x.reshape((bs,1024,4,4))
        x = self.t_conv(x)
        
        return x

## Variational Autoencoder

In [48]:
class VAE_AutoEncoder(nn.Module):
    def __init__(self, in_channels, latent_dims:int):
        super().__init__()
        self.encoder = Encoder(in_channels,latent_dims)
        self.decoder = Decoder(in_channels,latent_dims)

    def encode(self,x):
        return self.encoder(x)

    def decode(self,z):
        return self.decoder(z)


    def forward(self,x):
        mu,logvar = self.encode(x)
        
        #Take a sample from distribution
        std = torch.exp(0.5*logvar) # Compute standard desviation
        z = self.sample(mu,std)
        
        reconstructed = self.decode(z)

        # Return reconstructed image and, mean and logvar to calculate losses
        return reconstructed, mu, logvar



    def sample(self,mu,std):
        standard_sample = torch.randn_like(std)

        return mu + (standard_sample * std) # Take a sample of standarized distribution and transforming it to mean and desviation given
                                            # This method is used to simplify gradient calculus

In [51]:
VAE_model = VAE_AutoEncoder(in_channels=1,latent_dims=2).cuda()
summary(VAE_model,(1,28,28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 14, 14]           2,176
       BatchNorm2d-2          [-1, 128, 14, 14]             256
              ReLU-3          [-1, 128, 14, 14]               0
            Conv2d-4            [-1, 256, 7, 7]         524,544
       BatchNorm2d-5            [-1, 256, 7, 7]             512
              ReLU-6            [-1, 256, 7, 7]               0
            Conv2d-7            [-1, 512, 3, 3]       2,097,664
       BatchNorm2d-8            [-1, 512, 3, 3]           1,024
              ReLU-9            [-1, 512, 3, 3]               0
           Conv2d-10           [-1, 1024, 1, 1]       8,389,632
      BatchNorm2d-11           [-1, 1024, 1, 1]           2,048
             ReLU-12           [-1, 1024, 1, 1]               0
           Linear-13                    [-1, 2]           2,050
           Linear-14                   

## Training VAE
We create a new loss from combitanion of two losses:
* The old one, from recreated images
* A new from logvar: *Kullback–Leibler divergence* (KLD) is a functions that compares a distribution from another one, in this case, the standard

This combination is a sum of boths

In some cases, we can put more weigth on one of the lossesx, to balance the reconstruction of gaussian distribution.


In [69]:
import torch.nn.functional as F
def vae_loss(batch, reconstructed_image, mu, logvar):
    bs = batch.shape[0]


    reconstruction_loss = F.mse_loss(
        reconstructed_image.reshape(bs, -1),
        batch.reshape(bs,-1),
        reduction="none"
    ).sum(dim=-1)

    KL_loss = -0.5* torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)

    loss = (reconstruction_loss + KL_loss).mean(dim=0)

    return (loss, reconstruction_loss, KL_loss)


In [80]:
def train_loop(loader:torch.utils.data.DataLoader, model : nn.Module, optimizer:torch.optim, loss_fn:callable):
    nlotes = len(loader)
    # train_size = len(loader.dataset)

    model.train()   #Preparo el modelo para el entrenamiento

    losses = {
        "loss": [],
        "reconstruction_loss": [],
        "KL_loss": []
    }

    train_losses = 0

    # losses_list = []

    for nlote,(x,_) in enumerate(loader):
        x = x.cuda()

        # Forward Pass
        reconstructed ,mu ,logvar = model(x)

        # Backpropagation
        loss,reconstruction_loss,KL_loss = loss_fn(x, reconstructed, mu, logvar)    #Calculo de loss
        
        loss.backward()             #Calculo de gradiente
        
        # Save Losses
        losses["loss"].append(
            loss.item())
        losses["reconstruction_loss"].append(
            reconstruction_loss.mean().item()
        )
        losses["KL_loss"].append(
            KL_loss.mean().item()
        )

        optimizer.step()            #Actualización de parámetros
        optimizer.zero_grad()       #Limpieza del optimizador

        #Guardamos algunas caractgerísticas para plotear al final
        train_losses += loss.item()

        #Muestra del proceso
        if nlote % 10 == 0:
            print("Nº de lote:\t",nlote)
            print("Loss:\t\t\t",loss.item())
            print("Reconstruction_loss:\t",reconstruction_loss.mean().item())
            print("KL_loss:\t\t",KL_loss.mean().item())
            print()

    train_losses /= nlotes
    print()
    print("\tAccuracy/Loss Promedio")
    # print(f"\t\tEntrenamiento: {(100*train_accuracy):>0.1f}% / {train_losses:>8f}")
    print(f"\t\tEntrenamiento: {train_losses:>8f}")

    return losses

In [110]:
def val_loop(loader:torch.utils.data.DataLoader, model:nn.Module, loss_fn:callable):
    
    # val_size = len(loader.dataset)
    nlotes = len(loader)

    model.eval()         #Preparo el modelo para inferencia

    val_losses = 0
    losses_list = []

    with torch.no_grad():       #Calcelo el calculo del gradiente
        for x,_ in loader:
            
            x = x.cuda()
            with torch.no_grad():
                recosntructed, mu, logvar = model(x)       # Inferencia
            
            loss,_,_ = loss_fn(x, recosntructed, mu, logvar)

            val_losses += loss.item()
            losses_list.append(loss.item())

    val_losses /= nlotes

    print(f"\t\t Validación: {val_losses:>8f}")

    return losses_list


In [63]:
epochs = 10
lr = 0.0001

optimizer = torch.optim.AdamW(VAE_model.parameters(),lr=lr,eps=1e-5)

In [112]:
training_losses = {
        "loss": [],
        "reconstruction_loss": [],
        "KL_loss": []
    }
val_losses = []


for epoch in range(epochs):
    print(f"Itenración: {(epoch + 1)} / {epochs} -----------------------------")
    
    #Train
    output_output = train_loop(train_loader,VAE_model,optimizer,vae_loss)
    training_losses["loss"] += output_output["loss"]
    
    #Validation
    val_output = val_loop(train_loader,VAE_model,vae_loss)
    val_losses += val_output

print("Finalizado entrenamiento del modelo!")
# output_losses = train_loop(train_loader,VAE_model,optimizer,vae_loss)

Itenración: 1 / 10 -----------------------------
Nº de lote:	 0
Loss:			 40.38923645019531
Reconstruction_loss:	 35.760196685791016
KL_loss:		 4.62903356552124

Nº de lote:	 10
Loss:			 42.82951736450195
Reconstruction_loss:	 38.284149169921875
KL_loss:		 4.5453715324401855

Nº de lote:	 20
Loss:			 40.888427734375
Reconstruction_loss:	 36.21786117553711
KL_loss:		 4.670566558837891

Nº de lote:	 30
Loss:			 40.95425796508789
Reconstruction_loss:	 36.30746841430664
KL_loss:		 4.646793365478516

Nº de lote:	 40
Loss:			 41.938846588134766
Reconstruction_loss:	 37.24623489379883
KL_loss:		 4.6926093101501465

Nº de lote:	 50
Loss:			 41.160621643066406
Reconstruction_loss:	 36.57461166381836
KL_loss:		 4.5860114097595215

Nº de lote:	 60
Loss:			 42.45769500732422
Reconstruction_loss:	 37.78338623046875
KL_loss:		 4.6743059158325195

Nº de lote:	 70
Loss:			 42.91216278076172
Reconstruction_loss:	 38.275794982910156
KL_loss:		 4.636368274688721

Nº de lote:	 80
Loss:			 41.30881500244140

KeyboardInterrupt: 

In [105]:
train_losses["loss"] += output_losses["loss"]
len(train_losses["loss"])

960