In [None]:
import os

import torch
import torchvision
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

In [None]:
# Impostiamo gli hyperparametri
num_epochs = 100
batch_size = 128
learning_rate = 1e-3

In [None]:
# Trasformazioni delle immagini
def to_img(x):
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

img_transform = transforms.Compose([
    transforms.ToTensor()
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
# Scarichiamo e prepariamo il dataset
dataset = MNIST('./data', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 82808265.04it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 22757034.35it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 18859807.46it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3696978.22it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [None]:
if not os.path.exists('./vae_img'):
    os.mkdir('./vae_img')

# Definiamo il modello

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    # Reparameterization trick del VAE: aggiustiamo un campione di una
    # distribuzione Normale (eps) con media (mu) e variance (std) imparate
    # dai dati tramite l'encoder
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))

    def forward(self, x):
        # Impariamo media e varianza tramite l'encoder
        mu, logvar = self.encode(x)
        # Reparameterization trick
        z = self.reparametrize(mu, logvar)
        # Facciamo il decoder per ricostruire l'immagine
        return self.decode(z), mu, logvar

In [None]:
# Definiamo un'istanza del modello, spostiamo su GPU
model = VAE()
if torch.cuda.is_available():
    model.cuda()

# Definiamo la loss di ricostruzione
reconstruction_function = nn.MSELoss(size_average=False)

# Combiniamo la loss di ricostruzione con la divergenza KL per costruire
# la loss finale del nostro VAE
def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: immagine ricostruita
    x: immagine originale
    mu: media
    logvar: varianza (log)
    """
    recon_loss = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return recon_loss + KLD

# Definiamo l'ottimizzatore
optimizer = optim.Adam(model.parameters(), lr=1e-3)



In [None]:
# Alleniamo il nostro VAE
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(dataloader):
        img, _ = data
        img = img.view(img.size(0), -1)
        if torch.cuda.is_available():
            img = img.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(img)
        loss = loss_function(recon_batch, img, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch_idx * len(img),
                len(dataloader.dataset), 100. * batch_idx / len(dataloader),
                loss.item() / len(img)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(dataloader.dataset)))
    if epoch % 10 == 0:
        save = to_img(recon_batch.cpu().data)
        save_image(save, './vae_img/image_{}.png'.format(epoch))

torch.save(model.state_dict(), './vae.pth')

====> Epoch: 0 Average loss: 45.2655
====> Epoch: 1 Average loss: 35.0659
====> Epoch: 2 Average loss: 33.1507
====> Epoch: 3 Average loss: 32.2515
====> Epoch: 4 Average loss: 31.7557
====> Epoch: 5 Average loss: 31.3470
====> Epoch: 6 Average loss: 31.0969
====> Epoch: 7 Average loss: 30.8976
====> Epoch: 8 Average loss: 30.7506
====> Epoch: 9 Average loss: 30.6050
====> Epoch: 10 Average loss: 30.4532
====> Epoch: 11 Average loss: 30.3851
====> Epoch: 12 Average loss: 30.2688
====> Epoch: 13 Average loss: 30.1726
====> Epoch: 14 Average loss: 30.0947
====> Epoch: 15 Average loss: 30.0325
====> Epoch: 16 Average loss: 29.9425
====> Epoch: 17 Average loss: 29.8967
====> Epoch: 18 Average loss: 29.8212
====> Epoch: 19 Average loss: 29.7776
====> Epoch: 20 Average loss: 29.7143
====> Epoch: 21 Average loss: 29.6553
====> Epoch: 22 Average loss: 29.5846
====> Epoch: 23 Average loss: 29.5683
====> Epoch: 24 Average loss: 29.5171
====> Epoch: 25 Average loss: 29.4433
====> Epoch: 26 Averag

Il VAE che abbiamo utilizzato in questo notebook è molto semplice. E' costruito con layer fully connected, mentre per lavorare con le immagini, come abbiamo visto, è molto più pratico utilizzare reti convoluzionali. Per esercizi più complessi, ti consiglio quindi di costruire VAE con reti convoluzionali, come ad esempio quello qui sotto.

In [None]:
class ConvVAE(BaseVAE):

    def __init__(self,
                 in_channels: int,
                 latent_dim: int,
                 hidden_dims: List = None,
                 **kwargs) -> None:
        super(ConvVAE, self).__init__()

        self.latent_dim = latent_dim
        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Costruiamo l'encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)

        # Costruiamo il decoder
        modules = []
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )
        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input):
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z):
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input, **kwargs):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]