# Project CVAE on MNIST datas

In [None]:
# Imports
import random
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# Loading Fashion-MNIST datasets
# Note: to improve speed we could save images as flatten images (using tranform for example)
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

### Affichage aléatoire d'un exemple de chaque classe

In [None]:
# Dictionnaire pour mapper les labels aux noms des classes
class_names = [  
    'class 0: T-shirt/top', 'class 1: Trouser', 'class 2: Pullover', 'class 3: Dress', 'class 4: Coat',
    'class 5: Sandal', 'class 6: Shirt', 'class 7: Sneaker', 'class 8: Bag', 'class 9: Ankle boot'
]

def plot_random_samples_per_class(dataset):
    # Initialiser un dictionnaire pour stocker un échantillon par classe
    samples = {i: None for i in range(10)}

    # Parcourir le dataset et sélectionner un échantillon aléatoire par classe
    for img, label in torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True):
        label = label.item()
        if samples[label] is None:  # Sélectionner le premier échantillon trouvé pour chaque classe
            samples[label] = img
        if all(v is not None for v in samples.values()):  # Arrêter une fois qu'on a un échantillon de chaque classe
            break

    # Créer la figure avec des sous-graphiques 2x5
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    fig.suptitle("Random Sample from Each FashionMNIST Class")

    # Afficher chaque image dans le subplot correspondant
    for i, (label, img) in enumerate(samples.items()):
        ax = axes[i // 5, i % 5]  # Position dans la grille
        ax.imshow(img.squeeze(), cmap="gray")
        ax.set_title(class_names[label])
        ax.axis('off')

    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Ajuster l'espacement
    plt.show()

# Appeler la fonction pour afficher les échantillons
plot_random_samples_per_class(train_dataset)


La classe 1 est différente des autres puisqu'il s'agit d'un pantalon (bas), aprés les classes.
En s'intéressant au contenu des classes, nous observons qu'une classe contient des vetements/chaussures/accessoires assez différents (forme assez variable), la variance intra classe est assez grande, cela explique le fait que les individus d'une même classe peuvent être assez éloigné 5 ( cluster pas très compact). La variance inter classe n'est pas non plus très grande puisque certains articles peuvent être similaires à d'autres. Cela peut impliquer une superposition d'individus de différentes classes ( ou overlapping).

# Explanation between VAE and Conditional VAE (CVAE)

Resources :
- https://proceedings.neurips.cc/paper/2015/file/8d55a249e6baa5c06772297520da2051-Paper.pdf
- https://arxiv.org/pdf/1312.6114
- https://lilianweng.github.io/posts/2018-08-12-vae/

In addition of the input data, the CVAE takes a conditional variable (a class label in our case) in input of both encoder and decoder. The goal of a CVAE is to shape the latent shape to correspond to the condition variable.

# Hyperparameters

In [None]:
# Global values
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)
BATCH_SIZE = 64
MAX_EPOCHS = 1
LEARNING_RATE = 1e-3
CLASSES_TO_IDX = train_dataset.class_to_idx
IDX_TO_CLASSES = {idx: cls for cls, idx in CLASSES_TO_IDX.items()}
FEATURE_SIZE = train_dataset.data[0].shape[0]
CLASS_SIZE = 10
LATENT_SIZE = 10

# CVAE implementation

In [None]:
# Print the first image of each dataset
# Note: Do not run it after splitting datas into validation and test sets
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 3))

axs[0].imshow(train_dataset.data[100], cmap='gray')
axs[0].set_title(f"{list(CLASSES_TO_IDX.keys())[train_dataset.targets[100]]}")

axs[1].imshow(test_dataset.data[0], cmap='gray')
axs[1].set_title(f"{list(CLASSES_TO_IDX.keys())[test_dataset.targets[0]]}")

plt.tight_layout()
plt.show()

In [None]:
# Split test between validation and test
val_size = int(test_dataset.data.shape[0] / 2)
test_size = test_dataset.data.shape[0] - val_size
test_dataset, val_dataset = random_split(test_dataset, [test_size, val_size])

# Some prints
print(f"Train dataset size: {len(train_dataset.data)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Makes dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

**Model used : Conditional Variational Auto-Encoders (CVAE)**

In [None]:
# See https://github.com/unnir/cVAE/blob/master/cvae.py
class CVAE(nn.Module):
    def __init__(self, feature_size, latent_size, class_size):
        super(CVAE, self).__init__()
        self.feature_size = feature_size
        self.class_size = class_size

        # encode
        self.fc1  = nn.Linear(feature_size * feature_size + class_size, 400)
        self.fc21 = nn.Linear(400, latent_size)
        self.fc22 = nn.Linear(400, latent_size)

        # decode
        self.fc3 = nn.Linear(latent_size + class_size, 400)
        self.fc4 = nn.Linear(400, feature_size * feature_size)

        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()

        # Center for each class in latent space
        self.class_centers = nn.Parameter(torch.randn(class_size, latent_size))

    def encode(self, x, c): # Q(z|x, c)
        inputs = torch.cat([x, c], 1)
        h1 = self.elu(self.fc1(inputs))
        z_mu = self.fc21(h1)
        z_var = self.fc22(h1)
        return z_mu, z_var

    def sample(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z, c): # P(x|z, c)
        inputs = torch.cat([z, c], 1)
        h3 = self.elu(self.fc3(inputs))
        recon = self.sigmoid(self.fc4(h3))
        return recon.view(inputs.size(0), 1, self.feature_size, self.feature_size)

    def forward(self, x, c):
        x = x.view(x.size(0), -1)
        mu, logvar = self.encode(x, c)
        z = self.sample(mu, logvar)
        return self.decode(z, c), mu, logvar

**Losses**

In [None]:
def loss_function(recon_x, x, mu, logvar, beta=1):
    recon_x = recon_x.view(recon_x.size(0), -1)
    x = x.view(x.size(0), -1)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD

In [None]:
# Cette perte fonctionne en choisissant trois exemples :
# Ancre : un point latent de la classe cible.
# Positif : un autre point latent de la même classe.
# Négatif : un point latent d'une classe différente.
# La triplet loss pousse l’ancre et le positif à se rapprocher, et l’ancre et le négatif à s'éloigner. Cela peut être plus efficace pour séparer les classes.

def triplet_loss(mu, labels, margin=1.0):
    batch_size = mu.size(0)
    loss = 0
    for i in range(batch_size):
        anchor = mu[i]
        pos_indices = (labels == labels[i]).nonzero().view(-1)
        neg_indices = (labels != labels[i]).nonzero().view(-1)
        
        if len(pos_indices) > 1 and len(neg_indices) > 0:
            pos_index = pos_indices[torch.randint(1, len(pos_indices), (1,))].item()
            neg_index = neg_indices[torch.randint(0, len(neg_indices), (1,))].item()
            positive = mu[pos_index]
            negative = mu[neg_index]
            
            # Triplet loss calculation
            pos_dist = torch.sum((anchor - positive) ** 2)
            neg_dist = torch.sum((anchor - negative) ** 2)
            loss += torch.relu(pos_dist - neg_dist + margin)
    loss /= batch_size
    return loss

In [None]:
def new_loss(recon_x, x, mu, logvar, labels, class_centers, beta=1, lambda_center=0.1, lambda_triplet=5):
    # Reconstruction loss
    recon_x = recon_x.view(recon_x.size(0), -1)
    x = x.view(x.size(0), -1)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL Divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Center loss
    batch_size = mu.size(0)
    center_loss = 0
    for i in range(batch_size):
        target_center = class_centers[labels[i].argmax()]  # get the center for the correct class
        center_loss += torch.sum((mu[i] - target_center) ** 2)  # distance to the target center
    center_loss /= batch_size
    
    # Triplet loss for better separation
    triplet_loss_value = triplet_loss(mu, labels.argmax(dim=1))
    
    # Total loss with center loss and triplet loss
    return BCE + beta * KLD + lambda_center * center_loss + lambda_triplet * triplet_loss_value

**Useful functions**

In [None]:
def one_hot(labels, class_size):
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return targets.to(DEVICE)

In [None]:
def get_lr(optimizer):
    return optimizer.param_groups[0]['lr']

**Training part**

In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    loss_name,
    device=DEVICE,
    scheduler=None,
    max_epochs=MAX_EPOCHS,
    class_size=CLASS_SIZE,
    save_model=False,
    save_path="cvae_weights.pth"
):
    model = model.to(device)
    train_loss_l = torch.zeros(max_epochs)
    val_loss_l = torch.zeros(max_epochs)

    for epoch in range(1, max_epochs + 1):
        # ** Training **
        model.train()
        running_train_loss = 0.0
        for data, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{max_epochs} Training"):
            data, labels = data.to(device), labels.to(device)
            if class_size is not None:
                labels = one_hot(labels, class_size)

            # Forward pass
            recon_batch, mu, logvar = model(data, labels)
            if loss_name == "recons":
                loss = loss_function(recon_batch, data, mu, logvar)
            else:
                beta = min(0.01 , 0.2 )  # increase until target_beta
                lambda_triplet = min(2+ epoch * 0.5, 8)  # target_lambda_triplet could be 15 or 20
                loss = new_loss(recon_batch, data, mu, logvar, labels, model.class_centers, beta=beta, lambda_triplet=lambda_triplet)
                
            # Backward pass and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()
        train_loss = running_train_loss / len(train_loader.dataset)

        # ** Validation **
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for data, labels in tqdm(val_loader, desc=f"Epoch {epoch}/{max_epochs} Validation"):
                data, labels = data.to(device), labels.to(device)
                if class_size is not None:
                    labels = one_hot(labels, class_size)

                # Forward pass
                recon_batch, mu, logvar = model(data, labels)
                loss = loss_function(recon_batch, data, mu, logvar)
                running_val_loss += loss.item()
        val_loss = running_val_loss / len(val_loader.dataset)

        # Saving Losses
        train_loss_l[epoch - 1] = train_loss
        val_loss_l[epoch - 1] = val_loss

        # Print losses
        print(f"Epoch {epoch}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

        # Update LR
        if scheduler:
            print(f"Learning Rate: {get_lr(optimizer):.6f}")
            scheduler.step()

    # Save trained model
    if save_model:
        torch.save(model.state_dict(), save_path)
        print(f"Model saved to {save_path}")

    return train_loss_l, val_loss_l

In [None]:
cvae = CVAE(FEATURE_SIZE, LATENT_SIZE, CLASS_SIZE)
optimizer = torch.optim.Adam(cvae.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Call the train function
train_loss, val_loss = train_model(
    model=cvae,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    loss_name="recons",
    scheduler=scheduler,
    save_model=True,
)

# Plot losses
plt.figure(figsize=(10, 5))
plt.plot(train_loss, label="Train Loss")
plt.plot(val_loss, label="Validation Loss")
plt.title("Loss Curves")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()

**Testing part**

In [None]:
def test_model(
    model,
    test_loader,
    loss_name,
    device=DEVICE,
    class_size=CLASS_SIZE,
    weights_path="cvae_weights.pth",
    plot_latent_space=True,
    cmap="tab10"
):
    # Load model
    model = model.to(device)
    model.eval()

    running_test_loss = 0.0
    total_mse = 0.0
    residuals_sum = 0

    # Init plot
    if plot_latent_space:
        fig, ax = plt.subplots()
        scatter = ax.scatter([], [], c=[], cmap=cmap, vmin=0, vmax=class_size - 1)

    with torch.no_grad():
        for data, labels in tqdm(test_loader, desc="Testing..."):
            data, labels = data.to(device), labels.to(device)
            if class_size is not None:
                labels = one_hot(labels, class_size)

            # Forward pass
            recon_batch, mu, logvar = model(data, labels)
            if loss_name == "recons":
                loss = loss_function(recon_batch, data, mu, logvar)
            else:
                loss = new_loss(recon_batch, data, mu, logvar, labels, model.class_centers)

            # Mean Squared Error
            mse_value = F.mse_loss(recon_batch, data, reduction='sum')
            total_mse += mse_value.item()
            residuals_sum += torch.abs(data - recon_batch).sum(dim=0)
            
            running_test_loss += loss.item()

            # Update plot
            if plot_latent_space:
                mu, labels = mu.cpu(), labels.cpu()
                scatter = ax.scatter(
                    mu[:, 0], mu[:, 1], c=labels.argmax(dim=1), cmap=cmap, vmin=0, vmax=class_size - 1
                )

    # Compute means
    test_loss = running_test_loss / len(test_loader.dataset)
    mse = total_mse / len(test_loader.dataset)

    print(f"Test loss: {test_loss:.4f} -- MSE: {mse:.4f}")

    if plot_latent_space:
        plt.colorbar(scatter)
        plt.title("Latent Space")
        plt.show()

    return test_loss, mse, residuals_sum

In [None]:
cvae = CVAE(FEATURE_SIZE, LATENT_SIZE, CLASS_SIZE)

# Call the test function
test_loss, mse, residuals_sum = test_model(
    model=cvae,
    test_loader=test_loader,
    loss_name="recons",
)

In [None]:
# Makes some plots
def plot_test(
    model,
    test_loader,
    device=DEVICE,
    class_size=CLASS_SIZE,
    classes=IDX_TO_CLASSES,
    batch_size=1,
    cmap="gray",
    figsize=(8, 8),
    num_samples=2
):
    model.eval()
    fig, axs = plt.subplots(nrows=num_samples, ncols=2, figsize=figsize)
    loader_iter = iter(test_loader)

    if num_samples > len(test_loader.dataset):
        print("Reduce the batch-size or the num_samples value!")
        return

    with torch.no_grad():
        for row in range(num_samples):
            data, labels = next(loader_iter)
            data, labels = data.to(device), labels.to(device)
            if class_size is not None:
                labels = one_hot(labels, class_size)

            recon, _, _ = model(data, labels)

            # Choose one sample in the batch
            if batch_size > 1:
                idx = random.randint(0, batch_size - 1)
            else:
                idx = 0

            # Move to CPU for plot
            original = data[idx].cpu().squeeze()
            reconstructed = recon[idx].cpu().detach().squeeze()
            label_idx = labels[idx].argmax() if class_size else labels[idx]
            label_name = classes[label_idx.item()] if classes else f"Class {label_idx}"

            # Plot
            axs[row, 0].imshow(original, cmap=cmap)
            axs[row, 0].set_title(f"Original: {label_name}")
            axs[row, 0].axis("off")

            axs[row, 1].imshow(reconstructed, cmap=cmap)
            axs[row, 1].set_title("Reconstructed")
            axs[row, 1].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
plot_test(cvae, test_loader)

**Sampling**

In [None]:
num_classes = 10
num_samples = 5  # Nombre d'échantillons par classe
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cvae.to(DEVICE)

# Générer des images pour chaque classe
def generate_and_plot_samples(cvae_model, num_classes, num_samples, latent_size, feature_size):
    fig, axes = plt.subplots(num_classes, num_samples, figsize=(15, 10))
    fig.suptitle("Échantillons générés par classe")

    for class_idx in range(num_classes):
        # Créer un vecteur one-hot pour la classe actuelle
        class_condition = torch.zeros(num_classes, device=DEVICE)
        class_condition[class_idx] = 1
        class_condition = class_condition.unsqueeze(0).repeat(num_samples, 1)  # Répéter pour chaque échantillon

        # Générer des échantillons aléatoires dans l'espace latent
        z = torch.randn(num_samples, latent_size, device=DEVICE)

        # Générer des images à partir de chaque échantillon latent
        with torch.no_grad():
            generated_images = cvae_model.decode(z, class_condition)

        # Afficher chaque image générée dans un sous-graphe
        for i in range(num_samples):
            ax = axes[class_idx, i]
            ax.imshow(generated_images[i].squeeze().cpu(), cmap="gray")
            ax.axis("off")
            if i == 0:
                ax.set_ylabel(f"Classe {class_idx}")

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

generate_and_plot_samples(cvae, num_classes, num_samples, latent_size=LATENT_SIZE, feature_size=FEATURE_SIZE)
