# Train Variational autoencoder
This notebook is made to train an autoencoder or a variational autoencoder on the mnist data set. It can be run locally or on golab. Checks have been implemented for colab use. 

## Check if in Colab

In [1]:
import os
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

print("Is in Colab: ", IN_COLAB)
if IN_COLAB:
    os.system('git clone https://github.com/AllaVinner/JL-ML.git')
    os.system('pip install -e JL-ML')
    import site
    site.main()
    
saved_path   = os.path.join('..','saved-models') if not IN_COLAB else os.path.join('JL-ML','saved-models')

Is in Colab:  False


## Setup

In [2]:
#Test to load
import yaml

import tensorflow as tf
from tensorflow import keras
import numpy as np

from jlauto.models.load_premade import load_premade_model
from jlauto.models.continuous_bernoulli_loss import continuous_bernoulli_loss

In [3]:
#Preprocess mnist data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
#num_samples = 100
train_digits = np.expand_dims(x_train, -1).astype("float32") / 255
input_shape = train_digits.shape[1:]

In [5]:
# Set default configure 
config = {
    'name': 'CHANGE',
    'model_type': 'variational_autoencoder',
    'model_name': 'mnist_cnn_shallow',
    'latent_dim': 'CHANGE',
    'optimizer': 'adam',
    'reconstruction_loss': 'binary_crossentropy', 
    'latent_loss': 'kl_divergence',
    'reconstruction_factor': 1000,
    'latent_factor': 1,
    'batch_size': 512,
    'epochs': 1,
}

In [6]:
# Changing parameters
changing_config = {}
changing_config['latent_dim'] = [3,4]
changing_config['name']   = ['ae_latent_dim_'+str(lat_dim) for lat_dim in changing_config['latent_dim']]


## Train

In [10]:
for i in range(len(changing_config['name'])):
  # update config
  for key, values in changing_config.items():
    config[key] = values[i]
  
  # Save config
  model_path = os.path.join(saved_path,config['name'])
  os.system(f'mkdir {model_path}')
  with open(os.path.join(saved_path,config['name'],'config.yaml'), 'w') as yaml_file:
    yaml.dump(config, yaml_file)

  # Create and train model
  model = load_premade_model(model_type = config['model_type'],
                            model_name = config['model_name'],
                            latent_dim = config['latent_dim'],
                            input_shape = input_shape)

  model.compile(optimizer = config['optimizer'],
                reconstruction_loss = config['reconstruction_loss'],
                latent_loss = config['latent_loss'],
                reconstruction_factor = config['reconstruction_factor'],
                latent_factor = config['latent_factor'],)

  model.fit(train_digits,
            epochs = config['epochs'],
            batch_size = config['batch_size'])

  model.save(os.path.join(saved_path, config['name']))




KeyboardInterrupt: 

## Zip if in colab

In [None]:
if IN_COLAB:
    for i,name in enumerate(changing_config['name']):
        os.system(f'zip -r ./model_{i}.zip {os.path.join(saved_path,name)}')


# Investigate model