In [None]:
import os.path
import numpy as np
import torch
import torch.nn as nn
from torch import nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import datasets

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#datasets.mkdir_p('img_vae')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
num_latent = 50
num_bins = 51
num_neurons = [1000, 1000]
batch_size = 50
num_epochs = 100
learning_rate = 1e-4
load_models = False

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [7]:
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

100%|██████████| 9.91M/9.91M [01:44<00:00, 95.2kB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 146kB/s]
100%|██████████| 1.65M/1.65M [00:20<00:00, 79.2kB/s]
100%|██████████| 4.54k/4.54k [00:00<?, ?B/s]


In [9]:
print(len(train_dataset))
print(len(test_dataset))
print(len(train_loader))
print(len(test_loader))

60000
10000
469
79


In [None]:
valid_x = train_x[-10000:, :]
train_x = train_x[:-10000, :]
valid_labels = train_labels[-10000:]
train_labels = train_labels[:-10000]

In [None]:
train_x = torch.tensor(train_x).to(device)
test_x = torch.tensor(test_x).to(device)
val_x = torch.tensor(valid_x).to(device)


train_N, train_D = train_x.shape

In [14]:
train_N = len(train_dataset)

first_img, _ = train_dataset[0]
train_D = first_img.numel() 

print(f"number of example: {train_N}")
print(f"dimension of flat image: {train_D}")

number of example: 60000
dimension of flat image: 784


In [16]:
class Encoder(nn.Module):
    def __init__(self, n_latent: int, in_channels: int = 1, n_conv_blocks: int = 2, base_filters: int = 32):
        """
        Encoder che usa convoluzioni e restituisce media e logvar per la distribuzione latente q(z|x)
        
        Args:
            n_latent (int): dimensione del vettore latente Z
            in_channels (int): numero di canali dell'immagine (1 per MNIST)
            n_conv_blocks (int): numero di blocchi convoluzionali
            base_filters (int): numero iniziale di filtri (verrà raddoppiato a ogni blocco)
        """
        super().__init__()

        layers = []
        filters = base_filters

        # Costruzione dei blocchi convoluzionali
        for _ in range(n_conv_blocks):
            layers.append(nn.Conv2d(in_channels, filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(filters))
            layers.append(nn.ELU(inplace=True))
            in_channels = filters
            filters *= 2

        self.conv = nn.Sequential(*layers)

        # Calcola dimensione dell'output convoluzionale
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 28, 28)  # immagine finta MNIST
            conv_out = self.conv(dummy_input)
            self.flattened_size = conv_out.view(1, -1).shape[1]

        # Strato finale che mappa su media e logvar del vettore latente
        self.fc_mu = nn.Linear(self.flattened_size, n_latent)
        self.fc_logvar = nn.Linear(self.flattened_size, n_latent)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)  # flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

In [17]:
class Decoder(nn.Module):
    def __init__(self, n_latent: int, output_channels: int = 1, n_conv_blocks: int = 2, base_filters: int = 32):
        """
        Decoder che prende un vettore latente z e ricostruisce un'immagine MNIST (28x28).
        
        Args:
            n_latent (int): dimensione del vettore latente z
            output_channels (int): numero di canali dell'immagine in output (1 per MNIST)
            n_conv_blocks (int): numero di blocchi deconvoluzionali
            base_filters (int): numero iniziale di filtri (decrescente durante il decoding)
        """
        super().__init__()

        self.n_conv_blocks = n_conv_blocks
        self.base_filters = base_filters
        self.output_channels = output_channels

        # Calcolo della dimensione della feature map dopo l'encoder (dipende dalla struttura simmetrica)
        self.init_spatial_dim = 28 // (2 ** n_conv_blocks)  # per MNIST 28x28
        self.init_filters = base_filters * (2 ** (n_conv_blocks - 1))
        self.projected_dim = self.init_filters * self.init_spatial_dim * self.init_spatial_dim

        # Proiezione lineare dal vettore latente alla mappa iniziale
        self.fc = nn.Linear(n_latent, self.projected_dim)

        # Costruzione dei blocchi trasposti
        layers = []
        filters = self.init_filters
        for i in range(n_conv_blocks - 1):
            layers.append(nn.ConvTranspose2d(filters, filters // 2, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(nn.BatchNorm2d(filters // 2))
            layers.append(nn.ELU(inplace=True))
            filters //= 2

        # Ultimo blocco: output a 1 canale (immagine ricostruita)
        layers.append(nn.ConvTranspose2d(filters, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
        layers.append(nn.Sigmoid())  # output in [0, 1]

        self.deconv = nn.Sequential(*layers)

    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), self.init_filters, self.init_spatial_dim, self.init_spatial_dim)
        x = self.deconv(x)
        return x

In [1]:
class VAE(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = encoder.to(self.device)
        self.decoder = decoder.to(self.device)

    def forward(self, x):
        """
        VAE: encode → reparametrize → decode.
        """
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

    def reparameterize(self, mu, logvar):
        """
         z = mu + sigma * epsilon
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encode(self, x):
        """
        return (mu, logvar)
        """
        return self.encoder(x)

    def decode(self, z):
        """
        Reconstruct the image starting from latent vector z
        
        """
        return self.decoder(z)

    def reconstruct(self, x):
        """
        encode → reparametrize → decode
        """
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def sample(self, n_samples: int):
        """
        sample z ~ N(0,I) .
        """
        with torch.no_grad():
            z = torch.randn(n_samples, self.encoder.n_latent, device=self.device)
            samples = self.decode(z)
        return samples

    def elbo(self, x, beta=1.0):
        """
        Compute ELBO
        """
        x_recon, mu, logvar = self.forward(x)

        # Ricostruzione con MSE o BCE
        recon_loss = nn.functional.binary_cross_entropy(
            x_recon, x, reduction='sum'
        )

        # KL divergente tra N(mu, sigma^2) e N(0, 1)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        elbo = recon_loss + beta * kl_div
        return elbo, recon_loss, kl_div

NameError: name 'nn' is not defined