In [58]:
import os.path
import numpy as np
import torch
import torch.nn as nn
from torch import nn
from torch.nn import Linear
from torch.nn import Softplus
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import datasets

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#datasets.mkdir_p('img_vae')

In [59]:
num_latent = 50
num_bins = 51
num_neurons = [1000, 1000]
batch_size = 50
num_epochs = 100
learning_rate = 1e-4
load_models = False

In [60]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [61]:
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [62]:
print(len(train_dataset))
print(len(test_dataset))
print(len(train_loader))
print(len(test_loader))

60000
10000
469
79


In [63]:
train_N = len(train_dataset)

first_img, _ = train_dataset[0]
train_D = first_img.numel() 

print(f"number of example: {train_N}")
print(f"dimension of flat image: {train_D}")

number of example: 60000
dimension of flat image: 784


In [64]:
class Encoder(nn.Module):
    def __init__(self, n_latent: int, in_channels: int = 1, n_conv_blocks: int = 2, base_filters: int = 32):
        """
        Encoder che usa convoluzioni e restituisce media e logvar per la distribuzione latente q(z|x)
        
        Args:
            n_latent (int): dimensione del vettore latente Z
            in_channels (int): numero di canali dell'immagine (1 per MNIST)
            n_conv_blocks (int): numero di blocchi convoluzionali
            base_filters (int): numero iniziale di filtri (verrà raddoppiato a ogni blocco)
        """
        super().__init__()

        layers = []
        filters = base_filters

        # Costruzione dei blocchi convoluzionali
        for _ in range(n_conv_blocks):
            layers.append(nn.Conv2d(in_channels, filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(filters))
            layers.append(nn.ELU(inplace=True))
            in_channels = filters
            filters *= 2

        self.conv = nn.Sequential(*layers)

        # Calcola dimensione dell'output convoluzionale
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 28, 28)  # immagine finta MNIST
            conv_out = self.conv(dummy_input)
            self.flattened_size = conv_out.view(1, -1).shape[1]

        # Strato finale che mappa su media e logvar del vettore latente
        self.fc_mu = nn.Linear(self.flattened_size, n_latent)
        self.fc_logvar = nn.Linear(self.flattened_size, n_latent)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)  # flatten
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

In [65]:
class Decoder(nn.Module):
    def __init__(self, n_latent: int, output_channels: int = 1, n_conv_blocks: int = 2, base_filters: int = 32):
        """
        Decoder che prende un vettore latente z e ricostruisce un'immagine MNIST (28x28).
        
        Args:
            n_latent (int): dimensione del vettore latente z
            output_channels (int): numero di canali dell'immagine in output (1 per MNIST)
            n_conv_blocks (int): numero di blocchi deconvoluzionali
            base_filters (int): numero iniziale di filtri (decrescente durante il decoding)
        """
        super().__init__()

        self.n_conv_blocks = n_conv_blocks
        self.base_filters = base_filters
        self.output_channels = output_channels

        # Calcolo della dimensione della feature map dopo l'encoder (dipende dalla struttura simmetrica)
        self.init_spatial_dim = 28 // (2 ** n_conv_blocks)  # per MNIST 28x28
        self.init_filters = base_filters * (2 ** (n_conv_blocks - 1))
        self.projected_dim = self.init_filters * self.init_spatial_dim * self.init_spatial_dim

        # Proiezione lineare dal vettore latente alla mappa iniziale
        self.fc = nn.Linear(n_latent, self.projected_dim)

        # Costruzione dei blocchi trasposti
        layers = []
        filters = self.init_filters
        for i in range(n_conv_blocks - 1):
            layers.append(nn.ConvTranspose2d(filters, filters // 2, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(nn.BatchNorm2d(filters // 2))
            layers.append(nn.ELU(inplace=True))
            filters //= 2

        # Ultimo blocco: output a 1 canale (immagine ricostruita)
        layers.append(nn.ConvTranspose2d(filters, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
        layers.append(nn.Sigmoid())  # output in [0, 1]

        self.deconv = nn.Sequential(*layers)

    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), self.init_filters, self.init_spatial_dim, self.init_spatial_dim)
        x = self.deconv(x)
        return x

In [66]:
class VAE(nn.Module):
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.encoder = encoder.to(self.device)
        self.decoder = decoder.to(self.device)

    def forward(self, x):
        """
        VAE: encode → reparametrize → decode.
        """
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

    def reparameterize(self, mu, logvar):
        """
         z = mu + sigma * epsilon
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encode(self, x):
        """
        return (mu, logvar)
        """
        return self.encoder(x)

    def decode(self, z):
        """
        Reconstruct the image starting from latent vector z
        
        """
        return self.decoder(z)

    def reconstruct(self, x):
        """
        encode → reparametrize → decode
        """
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def sample(self, n_samples: int):
        """
        sample z ~ N(0,I) .
        """
        with torch.no_grad():
            z = torch.randn(n_samples, self.encoder.n_latent, device=self.device)
            samples = self.decode(z)
        return samples

    def elbo(self, x, beta=1.0):
        """
        Compute ELBO
        """
        x_recon, mu, logvar = self.forward(x)

        # Ricostruzione con MSE o BCE
        recon_loss = nn.functional.binary_cross_entropy(
            x_recon, x, reduction='sum'
        )

        # KL divergente tra N(mu, sigma^2) e N(0, 1)
        kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        elbo = recon_loss + beta * kl_div
        return elbo, recon_loss, kl_div

In [67]:
def kl_loss(mu_z, var_z):
    return torch.mean(0.5 * torch.sum(mu_z**2 + var_z**2 - torch.log(var_z**2) - 1, dim=1))

In [68]:
import torch
from torch import nn
from torch.nn import Linear, ELU, BatchNorm2d, Conv2d, ModuleList, Softplus
import torch.nn.functional as F

class GaussianEncoder(nn.Module):
    def __init__(self, n_latent: int, in_channels: int = 1, n_conv_blocks: int = 3, n_filters: int = 32, n_fc: int = 2048):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.n_latent = n_latent

        # Build convolutional encoder
        layers = []
        for _ in range(n_conv_blocks):
            layers.append(Conv2d(in_channels, n_filters, kernel_size=3, stride=2, padding=1))
            layers.append(BatchNorm2d(n_filters))
            layers.append(ELU(inplace=True))
            in_channels = n_filters
            n_filters *= 2
        self.conv_head = ModuleList(layers)

        # Linear layers to latent space
        self.flattened_size = n_fc
        self.mu = Linear(n_fc, n_latent)
        self.var = Linear(n_fc, n_latent)
        self.var_act = Softplus()

    def forward(self, x):
        out = x
        for layer in self.conv_head:
            out = layer(out)
        out = out.flatten(start_dim=1)  # shape: [batch, flattened]
        mu = self.mu(out)
        var = self.var_act(self.var(out))
        z = self.reparameterize(mu, var)
        return z, mu, var

    def reparameterize(self, mu, var):
        eps = torch.randn_like(var)
        return mu + eps * torch.sqrt(var)


In [69]:
import torch
from torch import nn
from torch.nn import Linear, ELU, BatchNorm2d, ConvTranspose2d, Sequential
from torch.distributions import Normal
import math

class GaussianDecoder(nn.Module):
    def __init__(self, n_pixels: int, n_latent: int, n_deconv_blocks: int = 3, init_filters: int = 128, n_fc: int = 2048, var: float = 0.05):
        super().__init__()
        self.n_latent = n_latent
        self.var = var

        self.n_fc = n_fc
        self.init_filters = init_filters
        self.deconv_input_size = int(math.sqrt(n_fc // init_filters))

        # MLP to project latent space to feature map
        self.linear = Linear(n_latent, n_fc)

        # Decoder: series of ConvTranspose2d to reconstruct 28x28
        filters = init_filters
        layers = []

        for _ in range(n_deconv_blocks - 1):
            layers.append(ConvTranspose2d(filters, filters // 2, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(BatchNorm2d(filters // 2))
            layers.append(ELU(inplace=True))
            filters //= 2

        # Final layer: output one channel for MNIST
        layers.append(ConvTranspose2d(filters, 1, kernel_size=3, stride=2, padding=1, output_padding=1))
        self.decoder = Sequential(*layers)

    def forward(self, z):
        out = self.linear(z)
        out = out.view(z.size(0), self.init_filters, self.deconv_input_size, self.deconv_input_size)
        out = self.decoder(out)
        return out  # mu

    def output(self, deconvoluted):
        return deconvoluted  # Already mu

    def reconstruct(self, z):
        mu = self.forward(z)
        x_hat = Normal(mu, self.var).rsample()  # Use rsample for autograd compatibility
        return x_hat


In [70]:
class GaussianVAE(nn.Module):
    def __init__(self, n_latent: int, input_shape: tuple, var: float = 0.05):
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.n_latent = n_latent
        self.var = var  # varianza del decoder

        self.encoder = GaussianEncoder(n_latent).to(self.device)
        self.decoder = GaussianDecoder(input_shape, n_latent, var=var).to(self.device)

    def forward(self, x):
        z, _, _ = self.encode(x)
        return self.decode(z)

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)

    def reconstruct(self, x):
        return self.forward(x)

    def elbo(self, x):
        z, mu_z, var_z = self.encode(x)
        x_recon = self.decode(z)

        # Ricostruzione: calcolo MSE tra immagine originale e decodificata
        mse = (x_recon - x) ** 2
        recon_loss = 0.5 * torch.sum(torch.log(torch.tensor(2 * np.pi * self.var)) + mse / self.var)

        # KL Divergence
        kl = kl_loss(mu_z, var_z)

        return recon_loss + kl

In [71]:
def print_vae_params(vae: VAE):
    print(vae_gaussian.encoder)
    print(vae_gaussian.decoder)
    print(f'Number of encoder parameters: {sum(p.numel() for p in vae_gaussian.encoder.parameters())}')
    print(f'Number of decoder parameters: {sum(p.numel() for p in vae_gaussian.decoder.parameters())}')

In [72]:
vae_gaussian = GaussianVAE(num_latent, train_D)
print_vae_params(vae_gaussian)

GaussianEncoder(
  (conv_head): ModuleList(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ELU(alpha=1.0, inplace=True)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ELU(alpha=1.0, inplace=True)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ELU(alpha=1.0, inplace=True)
  )
  (mu): Linear(in_features=2048, out_features=50, bias=True)
  (var): Linear(in_features=2048, out_features=50, bias=True)
  (var_act): Softplus(beta=1.0, threshold=20.0)
)
GaussianDecoder(
  (linear): Linear(in_features=50, out_features=2048, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 

In [73]:
class EarlyStopping:
    def __init__(self, threshold: float, patience=3):
        """
        :param threshold: The loss decrease threshold to consider for early stopping
        :param patience: The number of epochs to wait before stopping
        """
        self.threshold = threshold
        self.patience = patience
        self.last_loss = float("inf")
        self.counter = 0

    def __call__(self, loss: float) -> bool:
        """
        :param loss: The current loss
        :return: True if the training should stop, False otherwise
        """
        if loss >= self.last_loss or self.last_loss - loss < self.threshold:
            self.counter += 1
        else:
            self.counter = 0
        self.last_loss = loss
        return self.counter >= self.patience

In [75]:
from torch.nn import Module

def load_weights(model: Module, dist_type: str, model_type: str):
    model.load_state_dict(pickle.load(open(f"models/{dist_type}/{model_type}.pkl", "rb")))

In [None]:
from typing import List, Tuple
import torch
from tqdm import tqdm
import pickle

def train_vae(
    vae: GaussianVAE, 
    train_x: torch.Tensor, 
    #val_x: torch.Tensor, 
    stop_early, 
    learning_rate: float,
    batch_size: int,
    num_epochs: int,
    img_size: int,
    dist_type: str = "gaussian"
) -> Tuple[List, List]:

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    vae.to(device)

    optimizer = torch.optim.Adam(
        list(vae.encoder.parameters()) + list(vae.decoder.parameters()), 
        lr=learning_rate
    )

    train_loss_history = []
    val_loss_history = []
    pbar = tqdm(range(num_epochs))

    for epoch in pbar:
        vae.train()
        train_loss = 0.0

        # Shuffle indices for training data
        shuffled_idx = torch.randperm(train_x.shape[0])
        idx_batches = torch.split(shuffled_idx, batch_size)

        for idx in idx_batches:
            optimizer.zero_grad()
            batch_x = train_x[idx].to(device)

            # Reshape batch_x to (batch_size, channels=1, height, width)
            input_x = batch_x.view(-1, 1, img_size, img_size)

            loss = vae.elbo(input_x)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * batch_x.size(0)

        train_loss /= train_x.shape[0]
        train_loss_history.append(train_loss)

        # Validation
        vae.eval()
        val_loss = 0.0
        with torch.no_grad():
            shuffled_idx = torch.randperm(val_x.shape[0])
            idx_batches = torch.split(shuffled_idx, batch_size)

            for idx in idx_batches:
                batch_x = val_x[idx].to(device)
                input_x = batch_x.view(-1, 1, img_size, img_size)

                loss = vae.elbo(input_x)
                val_loss += loss.item() * batch_x.size(0)

            val_loss /= val_x.shape[0]
            val_loss_history.append(val_loss)

        pbar.set_postfix({'Train Loss': train_loss, 'Validation Loss': val_loss})

        if stop_early(train_loss_history[-1]):
            print(f'Training criterion reached at epoch {epoch}. Stopping training...')
            # Save model weights
            torch.save(vae.encoder.state_dict(), f"models/{dist_type}/encoder.pth")
            torch.save(vae.decoder.state_dict(), f"models/{dist_type}/decoder.pth")
            break

    return train_loss_history, val_loss_history


In [84]:
if load_models:
    load_weights(vae_gaussian.encoder, dist_type, 'encoder')
    load_weights(vae_gaussian.decoder, dist_type, 'decoder')
else:
    print('Training a new Gaussian VAE...')
    stop_early = EarlyStopping(stop_threshold)

    gaussian_train_loss, gaussian_val_loss = train_vae(
        vae=vae_gaussian,
        train_x=train_dataset,      # o train_x, a seconda di come li chiami tu
        #val_x=val_dataset,          # o val_x
        stop_early=stop_early,
        learning_rate=learning_rate,
        batch_size=batch_size,
        num_epochs=num_epochs,
        img_size=28,
        dist_type=dist_type
    )


Training a new Gaussian VAE...


TypeError: train_vae() missing 1 required positional argument: 'val_x'