# Imports

In [None]:
%pip install -i https://test.pypi.org/simple/ coupledvae==0.0.14 -q

In [None]:
from coupledvae.VAEMNIST import VAE
from coupledvae.experiment_utils import *
from coupledvae.setup_funcs import *
from datetime import datetime

# Mount GDRIVE

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')

# Load Data

In [None]:
# Random seed is the microsecond for current time.
random_seed = datetime.now().microsecond

# Set the training and testing batch sizes.
BATCH_SIZE_TRAIN = 128
BATCH_SIZE_TEST = 5000
# The number at which to split training sets and validation set, with training 
# set size = mnist_split, and validation = 60000 - mnist_split.
mnist_split = '55000' 

# List the different mnist data sets to use. The first should be the training
# dataset.
#corrupted_names = ['identity', 'motion_blur', 'shot_noise', 'spatter', 'fog']
corrupted_names = ['identity', 'shot_noise']

datasets_names = ['mnist'] + [
  f'mnist_corrupted/{corrupted_name}' for corrupted_name in corrupted_names
  ]

# Download the data sets.
datasets = get_datasets_(
    datasets_names, 
    BATCH_SIZE_TRAIN, 
    BATCH_SIZE_TEST, 
    mnist_split, 
    random_seed
    )

training_datasets = ['train/']
testing_datasets = ['test/' + name.split('/')[1] for name in datasets_names[1:]]
#training_datasets = [datasets_names[0]]

# Get the list of keys from the datasets dict.
testing_datasets = list(datasets.keys())
# Drop 'mnist', so only the corrupted dataset names remain.
testing_datasets.remove('mnist')
# Create an empty dictionary to hold only the testing datasets.
testing_datasets_dict = dict()
# Loop through the corrupted dataset names.
for dataset in testing_datasets:
  # Add the corrupted data set to the new dictionary.
  testing_datasets_dict[dataset] = datasets[dataset]


# Set Hyperparameters

In [None]:
check_gpu_availibility()

In [None]:
###
# VAE Initializing Parameters
###

# Latent dim, set the dimensionality of the latent space.
z_dim_vals = [2]#[2, 4, 8, 16, 32]
# Whether to use the analytical coupled divergence, or approximate.
analytic_kl = True
# Set the weight to place on the coupled dsivergence.
beta = 1. # 1., 2., ..., 10.
# Set the standard deviation of the prior distribution.
p_std = 1.
# Set the loss coupling.
loss_coupling_vals = [0.5]#[1e-6, 0.025, 0.05, 0.075, 0.1, 0.2, 0.3, 0.4, 0.5]
# Set the number of base filters in the CNN.
n_filter_base = 64
# Set the learning rate for the Adam optimizer.
learning_rate = 0.0005


###
# VAE Training Parameters
###

# Set the number of epochs to display.
n_epoch = 150
# Set the number of epochs before plots are displayed.
n_epoch_display = 10
# Whether or not to display plots while training.
show_display = False
display_sample = True


###
# Setting Paths
###
dataset_type = 'mnist'
# Set the version of the code being run.
version = 'v9_January_07_2023_MNIST'
# Create the root path where the data will be stored.
#save_path = Path(
#    f'gdrive/My Drive/Colab Notebooks/coupled_vae/vae/output/{version}/'
#    )
save_path = Path(
    f'gdrive/My Drive/Colab Notebooks/Coupled VAE Public/{version}/'
    )
# If the path does not exist, make it.
save_path.mkdir(parents=True, exist_ok=True)

# Set the dirctory where run results will be saved.
model_path = save_path / str(random_seed)
model_path.mkdir(parents=True, exist_ok=True)

# Create the folders for this run in the google drive. It will not override 
# existing version and seed folders
create_gdrive_output_folders(model_path,
                             img_folders=corrupted_names)

# Save the parameters in a dict.
param_dict = {
  'random_seed': random_seed,
  'z_dim_vals': z_dim_vals,
  'analytic_kl': analytic_kl,
  'beta': beta,
  'p_std': p_std,
  'loss_coupling_vals': loss_coupling_vals,
  'n_filter_base': n_filter_base,
  'learning_rate': learning_rate,
  'n_epoch': n_epoch,
  'n_epoch_display': n_epoch_display,
  'train_batch_size': BATCH_SIZE_TRAIN,
  'test_batch_size': BATCH_SIZE_TEST,
  'val_split': mnist_split,
  'datasets': datasets_names,
  'show_display': show_display,
  'display_sample': show_display,
  'model_path': model_path
}

# Set the path for the experiment tracking CSV file.
experiment_tracker_path = save_path / 'experiment_tracker.csv'
# Update the file.
update_experiments(param_dict, experiment_tracker_path)

# Set the training and testing paths.
training_path = model_path / 'train'
testing_path = model_path / 'test'

# Train VAE

In [None]:
early_stop = 20

vae_dict = train_VAEs(
    loss_coupling_vals=loss_coupling_vals, 
    z_dim_vals=z_dim_vals,
    n_filter_base=n_filter_base,
    beta=beta,
    p_std=p_std, 
    analytic_kl=analytic_kl, 
    n_epoch=1,# TODO n_epoch,
    n_epoch_display=n_epoch_display, 
    datasets=datasets,
    dataset_type=dataset_type,
    datasets_names=training_datasets,
    random_seed=random_seed, 
    model_path=training_path,
    show_display=show_display,
    early_stop=early_stop,
    cvae_type='MNIST'
    )

# Plot Training Performance

In [None]:
# Plot the latent space.
if z_dim_vals == [2]:
  for vae_key in vae_dict.keys():
    print(f'Latent Space for {vae_key}')
    plot_latent_images(vae_dict[vae_key].model, n=15, digit_size=28)

In [None]:
plot_training(vae_dict, metric='neg_elbo')

In [None]:
plot_training(vae_dict, metric='recon_loss')

In [None]:
plot_training(vae_dict, metric='coupled_div')

In [None]:
best_model_epoch = vae_dict[''].val_metrics_df.loc[
  vae_dict[''].val_metrics_df['val_neg_elbo'] == vae_dict[''].val_metrics_df['val_neg_elbo'].min()
].index.values + 1

print(f'The best model was saved at epoch {best_model_epoch[0]}.')

# Test

In [None]:
vae = vae_dict['']
vae = VAE(z_dim=vae.__dict__['z_dim'], 
          beta=vae.__dict__['beta'], 
          p_std=vae.__dict__['p_std'], 
          loss_coupling=vae.__dict__['loss_coupling'],
          analytic_kl=vae.__dict__['analytic_kl'], 
          dtype=vae.__dict__['dtype'], 
          display_path=vae.__dict__['display_path']
)

# Load the best model by validation set performance from the checkpoints.
vae.model.load_weights(str(model_path) + '/train/cp.ckpt')

In [None]:
# Get the list of keys from the datasets dict.
testing_datasets = list(datasets.keys())
# Drop 'mnist', so only the corrupted dataset names remain.
testing_datasets.remove('mnist')
# Create an empty dictionary to hold only the testing datasets.
testing_datasets_dict = dict()
# Loop through the corrupted dataset names.
for dataset in testing_datasets:
  # Add the corrupted data set to the new dictionary.
  testing_datasets_dict[dataset] = datasets[dataset]

In [None]:
test_VAE_loop(
    my_vae=vae,
    datasets=testing_datasets_dict, 
    test_path=testing_path, 
    show_display=True,
    random_seed=random_seed,
    test_coupling=1e-6
    )

In [None]:
#January 07, 2023
#Assume run has finished successfully.
#Grab a image at random. 
#Produce 3 realizations of z  (1) Get rid of seed=0 in Sampler (2) pass in 
#vae is the instance of the object

#Decode and exhibit them.