# Toy experiment

In [None]:
import sys
import os
sys.path.append('../')

import forget_me_not 

### Set the parameters here

In [None]:
# Dataset parameters
TRAIN_FRACTION = 0.9
EVAL_FRACTION = 0.1
NUM_SAMPLES_PER_CLASS = 20000
GAUSSIAN_MIXTURE_DIM = 16
GAUSSIAN_MIXTURE_CLASSES = 10

# Model hyperparameters
HIDDEN_DIM = 32
LATENT_DIM = 4
BETA = 20.0

# Training settings
LEARNING_RATE = 0.0002
BATCH_SIZE = 1024
MAX_NUM_EPOCHS = 9

## Gaussian mixture dataset

In [None]:
from forget_me_not.datasets.gaussian_mixture import GaussianMixtureDataModule
dm = GaussianMixtureDataModule(
    n_samples=NUM_SAMPLES_PER_CLASS, 
    n_features=GAUSSIAN_MIXTURE_DIM, 
    n_classes=GAUSSIAN_MIXTURE_CLASSES, 
    variance_scale=(0, 8), 
    mean_scale=(0, 40),
    seed=1,
    train_fraction=TRAIN_FRACTION, 
    eval_fraction=EVAL_FRACTION,
)
dm.plot_train()
dm.plot_test()
dm.plot_eval()


## Training $\beta$-VAE 

In [None]:
from forget_me_not.models.vae import VAE
from forget_me_not.training.train_beta_vae import BetaVAEModule, train
vae_model = VAE(dim=GAUSSIAN_MIXTURE_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
model = BetaVAEModule(vae_model, loss='vanilla-beta-vae', beta=BETA, learning_rate=LEARNING_RATE)


val_data_loader = dm.val_dataloader(batch_size=BATCH_SIZE)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator='cpu', enable_progress_bar=True, early_stop=False)

## Metrics

In [None]:
from forget_me_not import metrics 

def print_metrics(vae_model, dm):
    test_data_loader = dm.test_dataloader(batch_size=BATCH_SIZE)
    nll = metrics.compute_negative_log_likelihood(vae_model, test_data_loader, dim=GAUSSIAN_MIXTURE_DIM, num_importance_sampling=500)
    print(f"Negative log likelihood: {nll}")

    au = metrics.active_units(vae_model, test_data_loader)
    print(f"Active units: {au} out of {LATENT_DIM}")

    mi = metrics.mutual_information(vae_model, test_data_loader, num_samples=1000)
    print(f"Mutual information: {mi}")

    dc = metrics.compute_density_and_coverage(vae_model, test_data_loader, num_samples=len(test_data_loader.dataset), nearest_k = 5)
    print(f"Density: {dc['density']}, Coverage: {dc['coverage']}")

In [None]:
print_metrics(vae_model, dm)

## PCA on encodings of the test set

In [None]:
import matplotlib.pyplot as plt
import torch
from sklearn.decomposition import PCA

def plot_latent_representation_2d(vae_model, samples, labels):
    vae_model.eval()
    with torch.no_grad():
        latent_rep = vae_model.get_latent_representation(samples, deterministic=False)
        
    pca = PCA(n_components=2)
    reduced_rep = pca.fit_transform(latent_rep)
    
    data = reduced_rep
    plt.scatter(data[:, 0], data[:, 1], c=labels.unsqueeze(1), marker='.')
    plt.show()

def plot_reconstruction_2d(vae_model, samples, labels):
    vae_model.eval()
    with torch.no_grad():
        _, recon, *_ = vae_model.forward(samples, deterministic=False)
        
    pca = PCA(n_components=2)
    reduced_rep = pca.fit_transform(recon)
    
    data = reduced_rep
    plt.scatter(data[:, 0], data[:, 1], c=labels.unsqueeze(1), marker='.')
    plt.show()
    
def plot_latent_and_reconstruction(vae_model, dm):
    test_data_loader = dm.test_dataloader(batch_size=None)
    vae_model.eval()
    with torch.no_grad():
        data, labels = next(iter(test_data_loader))
        plot_latent_representation_2d(vae_model, data, labels)
        plot_reconstruction_2d(vae_model, data, labels)


In [None]:
plot_latent_and_reconstruction(vae_model, dm)

In [None]:
del vae_model

# Self critic VAE

In [None]:
from forget_me_not.models.vae import VAE
from forget_me_not.training.train_beta_vae import BetaVAEModule, train
vae_model_sc = VAE(dim=GAUSSIAN_MIXTURE_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
model_sc = BetaVAEModule(vae_model_sc, loss='self-critic', beta=10.0, learning_rate=LEARNING_RATE)


val_data_loader = dm.val_dataloader(batch_size=BATCH_SIZE)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model_sc, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator='cpu', enable_progress_bar=True, early_stop=True)

In [None]:
print_metrics(vae_model_sc, dm)

In [None]:
plot_latent_and_reconstruction(vae_model_sc, dm)

In [None]:
del vae_model_sc

# NN Critic

In [None]:
CONTRAST_DIM = 8
HIDDEN_DIM_X = 12
HIDDEN_DIM_Z = 16

In [None]:
from forget_me_not.models.vae import VAEWithCriticNetwork, CriticNetwork
from forget_me_not.training.train_beta_vae import BetaVAEModule, train

critic_network = CriticNetwork(
    dim=GAUSSIAN_MIXTURE_DIM, 
    latent_dim=LATENT_DIM, 
    contrast_dim=CONTRAST_DIM,
    hidden_dim_x=HIDDEN_DIM_X,
    hidden_dim_z=HIDDEN_DIM_Z
)



vae_model_nnc = VAEWithCriticNetwork(critic_network, dim=GAUSSIAN_MIXTURE_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
model_nnc = BetaVAEModule(vae_model_nnc, loss='nn-critic', beta=10.0, learning_rate=LEARNING_RATE)


val_data_loader = dm.val_dataloader(batch_size=BATCH_SIZE)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model_nnc, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator='cpu', enable_progress_bar=True, early_stop=True)

In [None]:
print_metrics(vae_model_nnc, dm)

In [None]:
plot_latent_and_reconstruction(vae_model_nnc, dm)

In [None]:
del vae_model_nnc