# Toy experiment

In [None]:
import sys
import os
sys.path.append('../')

import forget_me_not 

### Set the parameters here

In [None]:
# Dataset parameters
TRAIN_FRACTION = 0.9
EVAL_FRACTION = 0.1
NUM_SAMPLES_PER_CLASS = 20000
GAUSSIAN_MIXTURE_DIM = 16
GAUSSIAN_MIXTURE_CLASSES = 10

# Model hyperparameters
HIDDEN_DIM = 8
LATENT_DIM = 4
BETA = 10.0
LAMBDA = 10.0

# Training settings
LEARNING_RATE = 0.0002
BATCH_SIZE = 1024
MAX_NUM_EPOCHS = 30
ACCELERATOR = 'cpu'

## Gaussian mixture dataset

In [None]:
from forget_me_not.datasets.gaussian_mixture import GaussianMixtureDataModule
dm = GaussianMixtureDataModule(
    n_samples=NUM_SAMPLES_PER_CLASS, 
    n_features=GAUSSIAN_MIXTURE_DIM, 
    n_classes=GAUSSIAN_MIXTURE_CLASSES, 
    variance_scale=(0, 8), 
    mean_scale=(0, 40),
    seed=1,
    train_fraction=TRAIN_FRACTION, 
    eval_fraction=EVAL_FRACTION,
)
dm.plot_train()
dm.plot_test()
dm.plot_eval()


## Metrics

In [None]:
from forget_me_not import metrics 
from functools import partial
import torch

metric_and_its_params = {
    "negative_log_likelihood" : { 
        'dim' : GAUSSIAN_MIXTURE_DIM,
        'num_importance_sampling' : 500
    },
    "active_units" : {},
    "mutual_information" : {
        'num_samples' : 1000
    },
    "density_and_coverage" : {
        'nearest_k' : 5
    }
}

def add_monitoring_metrics(model):
    model.add_additional_monitoring_metric('validation', 'NLL', partial(metrics.compute_negative_log_likelihood_for_batch, **metric_and_its_params['negative_log_likelihood']), timeit=True)
    model.add_additional_monitoring_metric('validation', 'AU', partial(metrics.active_units_for_batch, **metric_and_its_params['active_units']), timeit=True, agg_func=partial(torch.mean, dtype=torch.float32))

## Training $\beta$-VAE 

In [None]:
from forget_me_not.models.vae import VAE
from forget_me_not.training.train_beta_vae import BetaVAEModule, train

vae_model = VAE(dim=GAUSSIAN_MIXTURE_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
model = BetaVAEModule(vae_model, loss='vanilla-beta-vae', beta=BETA, learning_rate=LEARNING_RATE)
add_monitoring_metrics(model)

val_data_loader = dm.val_dataloader(batch_size=None)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator=ACCELERATOR, enable_progress_bar=True, early_stop=True)

### Metrics

In [None]:
test_data_loader = dm.test_dataloader(batch_size=BATCH_SIZE)
metrics = metrics.compute_metrics(vae_model, test_data_loader, metric_and_its_params)
for metric, res in metrics.items():
    print(f"{metric}: {res}")

### PCA on encodings of the test set

In [None]:
from forget_me_not.plots import plot_latent_and_reconstruction
test_data_loader = dm.test_dataloader(batch_size=None)
plot_latent_and_reconstruction(vae_model, test_data_loader)

In [None]:
del vae_model

# Self critic VAE

In [None]:
from forget_me_not.models.vae import VAE
from forget_me_not.training.train_beta_vae import BetaVAEModule, train
vae_model_sc = VAE(dim=GAUSSIAN_MIXTURE_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
model_sc = BetaVAEModule(vae_model_sc, loss='self-critic', beta=LAMBDA, learning_rate=LEARNING_RATE)
add_monitoring_metrics(model_sc)

val_data_loader = dm.val_dataloader(batch_size=BATCH_SIZE)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model_sc, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator=ACCELERATOR, enable_progress_bar=True, early_stop=True)

In [None]:
test_data_loader = dm.test_dataloader(batch_size=BATCH_SIZE)
metrics = metrics.compute_metrics(vae_model_sc, test_data_loader, metric_and_its_params)
for metric, res in metrics.items():
    print(f"{metric}: {res}")

In [None]:
from forget_me_not.plots import plot_latent_and_reconstruction
test_data_loader = dm.test_dataloader(batch_size=None)
plot_latent_and_reconstruction(vae_model_sc, test_data_loader)

In [None]:
del vae_model_sc

# NN Critic

In [None]:
CONTRAST_DIM = 8
HIDDEN_DIM_X = 12
HIDDEN_DIM_Z = 6

In [None]:
from forget_me_not.models.vae import VAEWithCriticNetwork, CriticNetwork
from forget_me_not.training.train_beta_vae import BetaVAEModule, train

critic_network = CriticNetwork(
    dim=GAUSSIAN_MIXTURE_DIM, 
    latent_dim=LATENT_DIM, 
    contrast_dim=CONTRAST_DIM,
    hidden_dim_x=HIDDEN_DIM_X,
    hidden_dim_z=HIDDEN_DIM_Z
)



vae_model_nnc = VAEWithCriticNetwork(critic_network, dim=GAUSSIAN_MIXTURE_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM)
model_nnc = BetaVAEModule(vae_model_nnc, loss='nn-critic', beta=LAMBDA, learning_rate=LEARNING_RATE)
add_monitoring_metrics(model_nnc)

val_data_loader = dm.val_dataloader(batch_size=BATCH_SIZE)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model_nnc, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator=ACCELERATOR, enable_progress_bar=True, early_stop=True)

In [None]:
test_data_loader = dm.test_dataloader(batch_size=BATCH_SIZE)
metrics = metrics.compute_metrics(vae_model_nnc, test_data_loader, metric_and_its_params)
for metric, res in metrics.items():
    print(f"{metric}: {res}")

In [None]:
from forget_me_not.plots import plot_latent_and_reconstruction
test_data_loader = dm.test_dataloader(batch_size=None)
plot_latent_and_reconstruction(vae_model_nnc, test_data_loader)

In [None]:
del vae_model_nnc