# Project CVAE on MNIST datas

In [None]:
# Imports
import random
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# Loading Fashion-MNIST datasets
# Note: to improve speed we could save images as flatten images (using tranform for example)
train_dataset = datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.FashionMNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

# Explanation between VAE and Conditional VAE (CVAE)

Resources :
- https://proceedings.neurips.cc/paper/2015/file/8d55a249e6baa5c06772297520da2051-Paper.pdf
- https://arxiv.org/pdf/1312.6114
- https://lilianweng.github.io/posts/2018-08-12-vae/

In addition of the input data, the CVAE takes a conditional variable (a class label in our case) in input of both encoder and decoder. The goal of a CVAE is to shape the latent shape to correspond to the condition variable.

# Hyperparameters

In [None]:
# Global values
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
MAX_EPOCHS = 30
LEARNING_RATE = 1e-3
CLASSES_TO_IDX = train_dataset.class_to_idx
FEATURE_SIZE = train_dataset.data[0].shape[0]
CLASS_SIZE = 10
LATENT_SIZE = 10

Choice of hyperparameters :
- `BATCH_SIZE` :
- `MAX_EPOCHS` :
- `LEARNING_RATE` :

# CVAE implementation

In [None]:
# Print the first image of each dataset
# Note: Do not run it after splitting datas into validation and test sets
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 3))

axs[0].imshow(train_dataset.data[100], cmap='gray')
axs[0].set_title(f"{list(CLASSES_TO_IDX.keys())[train_dataset.targets[100]]}")

axs[1].imshow(test_dataset.data[0], cmap='gray')
axs[1].set_title(f"{list(CLASSES_TO_IDX.keys())[test_dataset.targets[0]]}")

plt.tight_layout()
plt.show()

In [None]:
# Split test between validation and test
val_size = int(test_dataset.data.shape[0] / 2)
test_size = test_dataset.data.shape[0] - val_size
test_dataset, val_dataset = random_split(test_dataset, [test_size, val_size])

# Some prints
print(f"Train dataset size: {len(train_dataset.data)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Makes dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# See https://github.com/unnir/cVAE/blob/master/cvae.py
class CVAE(nn.Module):
    def __init__(self, feature_size, latent_size, class_size):
        super(CVAE, self).__init__()
        self.feature_size = feature_size
        self.class_size = class_size

        # encode
        self.fc1  = nn.Linear(feature_size * feature_size + class_size, 400)
        self.fc21 = nn.Linear(400, latent_size)
        self.fc22 = nn.Linear(400, latent_size)

        # decode
        self.fc3 = nn.Linear(latent_size + class_size, 400)
        self.fc4 = nn.Linear(400, feature_size * feature_size)

        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()

        # Center for each class in latent space
        self.class_centers = nn.Parameter(torch.randn(class_size, latent_size))

    def encode(self, x, c): # Q(z|x, c)
        inputs = torch.cat([x, c], 1)
        h1 = self.elu(self.fc1(inputs))
        z_mu = self.fc21(h1)
        z_var = self.fc22(h1)
        return z_mu, z_var

    def sample(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z, c): # P(x|z, c)
        inputs = torch.cat([z, c], 1)
        h3 = self.elu(self.fc3(inputs))
        recon = self.sigmoid(self.fc4(h3))
        return recon.view(inputs.size(0), 1, self.feature_size, self.feature_size)

    def forward(self, x, c):
        x = x.view(x.size(0), -1)
        mu, logvar = self.encode(x, c)
        z = self.sample(mu, logvar)
        return self.decode(z, c), mu, logvar

In [None]:
def loss_function(recon_x, x, mu, logvar, beta=1):
    recon_x = recon_x.view(recon_x.size(0), -1)
    x = x.view(x.size(0), -1)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD

In [None]:
def one_hot(labels, class_size):
    targets = torch.zeros(labels.size(0), class_size)
    for i, label in enumerate(labels):
        targets[i, label] = 1
    return targets.to(DEVICE)

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [None]:
cvae = CVAE(FEATURE_SIZE, LATENT_SIZE, CLASS_SIZE)
cvae.to(DEVICE)
optimizer = optim.Adam(cvae.parameters(), lr=LEARNING_RATE)
# StepLR Scheduler (réduit le learning rate de gamma, donc ici 0.1, toutes les 10 epochs)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

val_loss_l = torch.zeros(MAX_EPOCHS)
train_loss_l = torch.zeros(MAX_EPOCHS)

for epoch in range(1, MAX_EPOCHS + 1):
    # On affiche le learning rate actuel
    current_lr = get_lr(optimizer)
    print(f"Epoch {epoch}/{MAX_EPOCHS} - Learning Rate: {current_lr:.6f}")

    # Training part
    cvae.train()
    running_train_loss = 0.0
    t_train = tqdm(train_loader, desc=f"Epoch {epoch}/{MAX_EPOCHS} Training")
    for data, labels in t_train:
        data, labels = data.to(DEVICE), labels.to(DEVICE)
        labels = one_hot(labels, CLASS_SIZE)
        recon_batch, mu, logvar = cvae(data, labels)
        optimizer.zero_grad()
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()
    train_loss = running_train_loss / len(train_loader.dataset)

    # Validation part
    cvae.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for validation
        running_val_loss = 0.0
        t_val = tqdm(val_loader, desc=f"Epoch {epoch}/{MAX_EPOCHS} Validation")
        for data, labels in t_val:
            data, labels = data.to(DEVICE), labels.to(DEVICE)
            labels = one_hot(labels, CLASS_SIZE)
            recon_batch, mu, logvar = cvae(data, labels)            
            loss = loss_function(recon_batch, data, mu, logvar)

            running_val_loss += loss.item()

    val_loss = running_val_loss / len(val_loader.dataset)

    train_loss_l[epoch - 1] = train_loss
    val_loss_l[epoch - 1] = val_loss
    
    print(f'Epoch {epoch}, Training loss: {train_loss:.4f}, Validation loss: {val_loss:.4f}')
    # Step the scheduler
    scheduler.step()
plt.plot(train_loss_l)
plt.title("Train loss")
plt.show()
plt.plot(val_loss_l)
plt.title("Validation loss")
plt.show()
# Saving part (if needed)
torch.save(cvae.state_dict(), 'cvae_weights.pth')

In [None]:
# Loading model
cvae = CVAE(FEATURE_SIZE, LATENT_SIZE, CLASS_SIZE)
cvae.to(DEVICE)
cvae.load_state_dict(torch.load('cvae_weights.pth', weights_only=True))
cvae.eval()

# Plot
fig, ax = plt.subplots()
scatter = ax.scatter([], [], c=[], cmap='tab10', vmin=0, vmax=9)

with torch.no_grad():
    running_test_loss = 0.0
    total_mse = 0.0
    number_images = 0
    residuals_sum = 0
    t_test = tqdm(test_loader, desc="Testing...")
    for data, labels in t_test:
        data, labels = data.to(DEVICE), labels.to(DEVICE)
        labels = one_hot(labels, CLASS_SIZE)
        recon_batch, mu, logvar = cvae(data, labels) 
        loss = loss_function(recon_batch, data, mu, logvar)
        mse_value = F.mse_loss(recon_batch, data, reduction='sum') 
        total_mse += mse_value.item()
        residuals_sum += torch.abs(data - recon_batch).sum(dim=0) # (1, 28, 28) where 1 is the channel number
        number_images += labels.size(0)

        running_test_loss += loss.item()

        # Plot part
        mu, labels = mu.cpu(), labels.cpu()
        scatter = ax.scatter(mu[:,0], mu[:,1], c=[labels.argmax(dim=1)], cmap='tab10', vmin=0, vmax=9)        
        
test_loss = running_test_loss / len(test_loader.dataset)
MSE = total_mse / number_images
print(f"Test loss: {test_loss} -- MSE: {MSE}")

plt.colorbar(scatter)
plt.title("Latent Space")
plt.show()

In [None]:
# Plot residuals
residuals_mean = residuals_sum / number_images
residuals_mean_np = residuals_mean.cpu().numpy().squeeze()  # (28, 28) The squeeze is used to delete channel dim

plt.imshow(residuals_mean_np, cmap='hot')
plt.title('Residuals')
plt.colorbar()
plt.show()

In [None]:
# Makes some plots
def plot_test(cvae):
    fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(6,6))

    loader_iter = iter(test_loader)

    data1, labels1 = next(loader_iter)
    data1, labels1 = data1.to(DEVICE), labels1.to(DEVICE)
    labels1 = one_hot(labels1, CLASS_SIZE)
    recon1, _, _ = cvae(data1, labels1)

    data2, labels2 = next(loader_iter)
    data2, labels2 = data2.to(DEVICE), labels2.to(DEVICE)
    labels2 = one_hot(labels2, CLASS_SIZE)
    recon2, _, _ = cvae(data2, labels2)

    # Moving back to cpu
    data1, data2 = data1.cpu(), data2.cpu()
    labels1, labels2 = labels1.cpu(), labels2.cpu()
    recon1, recon2 = recon1.detach().cpu(), recon2.detach().cpu()

    # Plots
    i = random.randint(0, BATCH_SIZE - 1)
    axs[0,0].imshow(data1[i].squeeze(), cmap='gray')
    axs[0,0].set_title(f"{list(CLASSES_TO_IDX.keys())[labels1[i].argmax()]}")

    axs[0,1].imshow(data2[i].squeeze(), cmap='gray')
    axs[0,1].set_title(f"{list(CLASSES_TO_IDX.keys())[labels2[i].argmax()]}")

    axs[1,0].imshow(recon1[i].squeeze(), cmap='gray')
    axs[1,1].imshow(recon2[i].squeeze(), cmap='gray')

    plt.tight_layout()
    plt.show()
    pass
plot_test(cvae)

### Une autre Loss pour ne pas avoir des classes superposées

In [None]:
#Cette perte fonctionne en choisissant trois exemples :
#Ancre : un point latent de la classe cible.
#Positif : un autre point latent de la même classe.
#Négatif : un point latent d'une classe différente.
#La triplet loss pousse l’ancre et le positif à se rapprocher, et l’ancre et le négatif à s'éloigner. Cela peut être plus efficace pour séparer les classes.

def triplet_loss(mu, labels, margin=1.0):
    batch_size = mu.size(0)
    loss = 0
    for i in range(batch_size):
        anchor = mu[i]
        pos_indices = (labels == labels[i]).nonzero().view(-1)
        neg_indices = (labels != labels[i]).nonzero().view(-1)
        
        if len(pos_indices) > 1 and len(neg_indices) > 0:
            pos_index = pos_indices[torch.randint(1, len(pos_indices), (1,))].item()
            neg_index = neg_indices[torch.randint(0, len(neg_indices), (1,))].item()
            positive = mu[pos_index]
            negative = mu[neg_index]
            
            # Triplet loss calculation
            pos_dist = torch.sum((anchor - positive) ** 2)
            neg_dist = torch.sum((anchor - negative) ** 2)
            loss += torch.relu(pos_dist - neg_dist + margin)
    loss /= batch_size
    return loss

In [None]:
def new_loss(recon_x, x, mu, logvar, labels, class_centers, beta=1, lambda_center=0.1, lambda_triplet=5):
    # Reconstruction loss
    recon_x = recon_x.view(recon_x.size(0), -1)
    x = x.view(x.size(0), -1)
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL Divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Center loss
    batch_size = mu.size(0)
    center_loss = 0
    for i in range(batch_size):
        target_center = class_centers[labels[i].argmax()]  # get the center for the correct class
        center_loss += torch.sum((mu[i] - target_center) ** 2)  # distance to the target center
    center_loss /= batch_size
    
    # Triplet loss for better separation
    triplet_loss_value = triplet_loss(mu, labels.argmax(dim=1))
    
    # Total loss with center loss and triplet loss
    return BCE + beta * KLD + lambda_center * center_loss + lambda_triplet * triplet_loss_value

In [None]:
cvae_new_loss = CVAE(FEATURE_SIZE, LATENT_SIZE, CLASS_SIZE)
cvae_new_loss.to(DEVICE)
optimizer = optim.Adam(cvae_new_loss.parameters(), lr=LEARNING_RATE)
# StepLR Scheduler (réduit le learning rate de gamma, donc ici 0.1, toutes les 10 epochs)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
MAX_EPOCHS = 10
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


train_loss_triplet = torch.zeros(MAX_EPOCHS)
val_loss_triplet = torch.zeros(MAX_EPOCHS)
for epoch in range(1, MAX_EPOCHS + 1):
    # On affiche le learning rate actuel
    current_lr = get_lr(optimizer)
    print(f"Epoch {epoch}/{MAX_EPOCHS} - Learning Rate: {current_lr:.6f}")

    # Training part
    cvae_new_loss.train()
    running_train_loss = 0.0
    t_train = tqdm(train_loader, desc=f"Epoch {epoch}/{MAX_EPOCHS} Training")
    for data, labels in t_train:
        data, labels = data.to(DEVICE), labels.to(DEVICE)
        labels = one_hot(labels, CLASS_SIZE)
        recon_batch, mu, logvar = cvae_new_loss(data, labels)
        optimizer.zero_grad()
        beta = min(0.01 , 0.2 )  # augmenter progressivement jusqu'à target_beta
        lambda_triplet = min(2+ epoch * 0.5, 8)  # target_lambda_triplet peut être fixé à 15 ou 20
        loss = new_loss(recon_batch, data, mu, logvar, labels, cvae_new_loss.class_centers, beta=beta, lambda_triplet=lambda_triplet)
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()
    train_loss = running_train_loss / len(train_loader.dataset)

    # Validation part
    cvae_new_loss.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation for validation
        running_val_loss = 0.0
        t_val = tqdm(val_loader, desc=f"Epoch {epoch}/{MAX_EPOCHS} Validation")
        for data, labels in t_val:
            data, labels = data.to(DEVICE), labels.to(DEVICE)
            labels = one_hot(labels, CLASS_SIZE)
            recon_batch, mu, logvar = cvae_new_loss(data, labels)            
            loss = new_loss(recon_batch, data, mu, logvar, labels, cvae_new_loss.class_centers)

            running_val_loss += loss.item()

    val_loss = running_val_loss / len(val_loader.dataset)

    val_loss_triplet[epoch-1] = val_loss
    train_loss_triplet[epoch-1] = train_loss
    
    print(f'Epoch {epoch}, Training loss: {train_loss:.4f}, Validation loss: {val_loss:.4f}')
    # Step the scheduler
    scheduler.step()

plt.plot(train_loss_triplet)
plt.title("Train loss with the Triplet Loss")
plt.show()
plt.plot(val_loss_triplet)
plt.title("Validation loss with the Triplet Loss")
plt.show()

# Saving part (if needed)
torch.save(cvae_new_loss.state_dict(), 'cvae_weights.pth')

In [None]:
# Plot
fig, ax = plt.subplots()
scatter = ax.scatter([], [], c=[], cmap='tab10', vmin=0, vmax=9)

with torch.no_grad():
    running_test_loss = 0.0
    total_mse = 0.0
    number_images = 0
    residuals_sum = 0
    t_test = tqdm(test_loader, desc="Testing...")
    for data, labels in t_test:
        data, labels = data.to(DEVICE), labels.to(DEVICE)
        labels = one_hot(labels, CLASS_SIZE)
        recon_batch, mu, logvar = cvae_new_loss(data, labels) 
        loss = new_loss(recon_batch, data, mu, logvar, labels, cvae_new_loss.class_centers)
        mse_value = F.mse_loss(recon_batch, data, reduction='sum') 
        total_mse += mse_value.item()
        residuals_sum += torch.abs(data - recon_batch).sum(dim=0) # (1, 28, 28) where 1 is the channel number
        number_images += labels.size(0)

        running_test_loss += loss.item()

        # Plot part
        mu, labels = mu.cpu(), labels.cpu()
        scatter = ax.scatter(mu[:,0], mu[:,1], c=[labels.argmax(dim=1)], cmap='tab10', vmin=0, vmax=9)        
        
test_loss = running_test_loss / len(test_loader.dataset)
MSE = total_mse / number_images
print(f"Test loss: {test_loss} -- MSE: {MSE}")

plt.colorbar(scatter)
plt.title("Latent Space")
plt.show()

In [None]:
plot_test(cvae_new_loss)