# Omniglot experiments

In [None]:
import sys
import os
sys.path.append('../')

import forget_me_not 

### Set the parameters here

In [None]:
# Dataset parameters
TRAIN_FRACTION = 0.9
EVAL_FRACTION = 0.1

# Model hyperparameters
HIDDEN_DIM = 1024
LATENT_DIM = 128
BETA = 10.0
LAMBDA = 10.0

# Training settings
LEARNING_RATE = 0.0002
BATCH_SIZE = 1024
MAX_NUM_EPOCHS = 30
ACCELERATOR = 'cpu'

# Misc
REPORT_ROOT_DIR = None
PBAR = True

## Omniglot dataset

In [None]:
from forget_me_not.datasets.omniglot import OmniglotDataModule


dm = OmniglotDataModule(
    data_dir='./dataset_store/omniglot/',
    train_fraction=TRAIN_FRACTION, 
    eval_fraction=EVAL_FRACTION,
)

IMG_DIM = (64, 64)
dm.setup('fit')
dm.setup('test')

## Metrics

In [None]:
from forget_me_not import metrics 
from functools import partial
import torch

metric_and_its_params = {
    "negative_log_likelihood" : { 
        'dim' : (IMG_DIM[0] * IMG_DIM[1]),
        'num_importance_sampling' : 500
    },
    "active_units" : {},
    "mutual_information" : {
        'num_samples' : 1000
    },
    "density_and_coverage" : {
        'nearest_k' : 5
    }
}

def add_monitoring_metrics(model):
    model.add_additional_monitoring_metric('validation', 'NLL', partial(metrics.compute_negative_log_likelihood_for_batch, **metric_and_its_params['negative_log_likelihood']), timeit=True)
    model.add_additional_monitoring_metric('validation', 'AU', partial(metrics.active_units_for_batch, **metric_and_its_params['active_units']), timeit=True, agg_func=partial(torch.mean, dtype=torch.float32))

## Training $\beta$-VAE 

In [None]:
from forget_me_not.models.cnn_vae import CNNVAE, CNNEncoder, CNNDecoder
from forget_me_not.training.train_beta_vae import BetaVAEModule, train


vae_model = CNNVAE(
    img_encoder = CNNEncoder(num_channels=1),
    img_decoder = CNNDecoder(num_channels=1, dim=HIDDEN_DIM),
    hidden_dim = HIDDEN_DIM,
    latent_dim = LATENT_DIM,
) 

model = BetaVAEModule(vae_model, loss='vanilla-beta-vae', beta=BETA, learning_rate=LEARNING_RATE)
add_monitoring_metrics(model)

val_data_loader = dm.val_dataloader(batch_size=None)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator=ACCELERATOR, enable_progress_bar=PBAR, early_stop=True)

### Metrics

In [None]:
test_data_loader = dm.test_dataloader(batch_size=BATCH_SIZE)
results = metrics.compute_metrics(vae_model, test_data_loader, metric_and_its_params)
for metric, res in results.items():
    print(f"{metric}: {res}")

### PCA on encodings of the test set

In [None]:
from forget_me_not.plots import plot_latent_representation_2d
test_data_loader = dm.test_dataloader(batch_size=None)
report_dir = os.path.join(REPORT_ROOT_DIR, 'beta') if REPORT_ROOT_DIR is not None else None
data, labels = next(iter(test_data_loader))
plot_latent_representation_2d(vae_model, data, labels, report_dir)

In [None]:
del vae_model

# Self critic VAE

In [None]:
from forget_me_not.models.cnn_vae import CNNVAE, CNNEncoder, CNNDecoder
from forget_me_not.training.train_beta_vae import BetaVAEModule, train


vae_model_sc = CNNVAE(
    img_encoder = CNNEncoder(num_channels=1),
    img_decoder = CNNDecoder(num_channels=1, dim=HIDDEN_DIM),
    hidden_dim = HIDDEN_DIM,
    latent_dim = LATENT_DIM,
) 
model_sc = BetaVAEModule(vae_model_sc, loss='self-critic', beta=LAMBDA, learning_rate=LEARNING_RATE)
add_monitoring_metrics(model_sc)

val_data_loader = dm.val_dataloader(batch_size=BATCH_SIZE)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model_sc, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator=ACCELERATOR, enable_progress_bar=PBAR, early_stop=True)

In [None]:
test_data_loader = dm.test_dataloader(batch_size=BATCH_SIZE)
results = metrics.compute_metrics(vae_model_sc, test_data_loader, metric_and_its_params)
for metric, res in results.items():
    print(f"{metric}: {res}")

In [None]:
from forget_me_not.plots import plot_latent_representation_2d
test_data_loader = dm.test_dataloader(batch_size=None)
report_dir = os.path.join(REPORT_ROOT_DIR, 'self_critic') if REPORT_ROOT_DIR is not None else None

data, labels = next(iter(test_data_loader))
plot_latent_representation_2d(vae_model_sc, data, labels, report_dir)

In [None]:
del vae_model_sc

# NN Critic

In [None]:
CONTRAST_DIM = 256
HIDDEN_DIM_X = 512
HIDDEN_DIM_Z = 384

In [None]:
from forget_me_not.models.cnn_vae import CNNEncoder, CNNDecoder, CNNVAE, CriticNetworkForCNNVAE, CNNVAEWithCriticNetwork
from forget_me_not.training.train_beta_vae import BetaVAEModule, train

critic_network = CriticNetworkForCNNVAE(
    img_encoder = CNNEncoder(num_channels=1),
    img_enc_dim = HIDDEN_DIM, 
    latent_dim = LATENT_DIM, 
    contrast_dim = CONTRAST_DIM, 
    hidden_dim_x = HIDDEN_DIM_X,
    hidden_dim_z = HIDDEN_DIM_Z,
    dtype=torch.float32
)



vae_model_nnc = CNNVAEWithCriticNetwork(
    critic_network, 
    img_encoder = CNNEncoder(num_channels=1),
    img_decoder = CNNDecoder(num_channels=1, dim=HIDDEN_DIM),
    hidden_dim = HIDDEN_DIM,
    latent_dim = LATENT_DIM,
) 

    
model_nnc = BetaVAEModule(vae_model_nnc, loss='nn-critic', beta=LAMBDA, learning_rate=LEARNING_RATE)
add_monitoring_metrics(model_nnc)

val_data_loader = dm.val_dataloader(batch_size=BATCH_SIZE)
train_data_loader = dm.train_dataloader(batch_size=BATCH_SIZE)


train(model_nnc, train_data_loader, val_data_loader, num_epochs=MAX_NUM_EPOCHS, accelerator=ACCELERATOR, enable_progress_bar=PBAR, early_stop=True)

In [None]:
test_data_loader = dm.test_dataloader(batch_size=BATCH_SIZE)
results = metrics.compute_metrics(vae_model_nnc, test_data_loader, metric_and_its_params)
for metric, res in results.items():
    print(f"{metric}: {res}")

In [None]:
from forget_me_not.plots import plot_latent_representation_2d
test_data_loader = dm.test_dataloader(batch_size=None)
report_dir = os.path.join(REPORT_ROOT_DIR, 'nn_critic') if REPORT_ROOT_DIR is not None else None

data, labels = next(iter(test_data_loader))
plot_latent_representation_2d(vae_model_nnc, data, labels, report_dir)

In [None]:
del vae_model_nnc