# Advanced Usage Tutorial
This notebooks aims at showing a few advabced features of the translation model.

## 1. Importing the needed parts

In [None]:
# we will need to import a few things first
import numpy as np
import os

# import the encoders and decoders we want to use
from multimodal_autoencoders.model.encoders import CpdEncoder, PQSAREncoder
from multimodal_autoencoders.model.decoders import CpdDecoder, PQSARDecoder

# import the Autoencoder class
from multimodal_autoencoders.base.autoencoder import VariationalAutoencoder

# import a discriminator and a classifier
from multimodal_autoencoders.model.classifiers import Discriminator, SimpleClassifier

# import the JointTrainer 
from multimodal_autoencoders.joint_trainer import JointTrainer

## 2. Setting up the data
We will use the same synthetic data set up as in the basic usage example.

In [2]:
# initialize the random generator
rng = np.random.default_rng(seed = 1234)

# create some latent information common to all modalities
train_latent_information = rng.random(size = (100, 25))

# small helper function for our synthetic data
def generate_modality(latent_information: np.array, n_random_dims: int, samples: int = 100):
    ar = np.concatenate((latent_information, rng.random(size = (100, n_random_dims))), axis=1)
    rng.shuffle(ar, axis=1)

    return ar

# define the data dictionary
# this will be the first part you'll need to hold your actual data
train_data_dict = {
    "modality_1": generate_modality(train_latent_information, 25),
    "modality_2": generate_modality(train_latent_information, 50),
    "modality_3": generate_modality(train_latent_information, 75)}


# we will also create a separate validation data set sharing some similarity to the training data
val_latent_information = train_latent_information * 0.8 + rng.random(size = (100, 25)) * 0.2
val_data_dict = {
    "modality_1": generate_modality(val_latent_information, 25),
    "modality_2": generate_modality(val_latent_information, 50),
    "modality_3": generate_modality(val_latent_information, 75)}

## 3. Setting up the models
### 3.1 Autoencoders with individual pretraining and frozen joint training
Depending on your use case, a pretraining of one or multuple models might be helpful for the overall performance. Additionally, you might want to keep this pre-trained model in its trained state for the joint training and only let the other models adapt to it. For such a scenario, each autoencoder accepts two more argumens:
- pretrain_epochs: integer numver of epochs the model should be pretrained
- train_joint: boolean flag whether the model should also be trained in joint mode

In [3]:
model_dict = {
    "modality_1": VariationalAutoencoder(CpdEncoder(50, 42), CpdDecoder(50, 42, 36), "adam", 0.001),
    "modality_2": VariationalAutoencoder(PQSAREncoder(75, 50), PQSARDecoder(75, 50, 36), "adam", 0.001),
    "modality_3": VariationalAutoencoder(PQSAREncoder(100, 75), PQSARDecoder(100, 75, 36), "adam", 0.001, pretrain_epochs = 10, train_joint = False)}

Initializing variational autoencoder model
Initializing optimizer
Initializing variational autoencoder model
Initializing optimizer
Initializing variational autoencoder model
Initializing optimizer


#### 3.2.1 Discriminator

In [4]:
discriminator = Discriminator("adam", 0.001, 36, len(model_dict), 50)

Initializing classifier model
Initializing optimizer


#### 3.2.2 Classifier and cluster labels

In [5]:
cluster_data = np.concatenate((np.repeat(0, 50), np.repeat(1, 50))).flatten()

In [6]:
classifier = SimpleClassifier("adam", 0.001, 36, 2)

Initializing classifier model
Initializing optimizer


## 4. Seting up the JointTraner

In [7]:
model = JointTrainer(
        model_dict = model_dict,
        discriminator = discriminator,
        max_epochs = 10,
        recon_weight = 3,
        beta = 0.001,
        disc_weight = 3,
        anchor_weight = 1,
        cl_weight = 3,
        classifier = classifier)

01/13/2023 05:30:21 PM initializing trainer


## 5. Train with early stopping
Now that all parts of the model are set up again, we can begin training. This time around we don't only want to train the model, but we mant to make sure it stops training as soon as it stops to generalize to unseen data. The method of early stopping allows to do this automatically. The train call provides two further parameters to customize the early stopping procedure:
- patience: integer number of epochs that the model is allowed to not improve on unseen data 
- min_value: float value of minimal difference between validation loss of the previous and the current epoch needed to count as an improvement

These parameters should be chosen carefully. Too little patience will lead to the training stopping to early, even though the model would have recovered a few epochs later. Too much patience and the model might train longer than needed. The architecture will store a model checkpoint for you at the beginning of a consecutive series of overfitting epochs. This allows to return the optimal point at which the model was best performing and best generalizing.<br>
The min_value needed to count an epoch result as overfitted can be very domain specific. Depending on the scale of your data a larger difference between training and validation loss might be loss of an issue. Always keep in mind to not too small of a value as otherwise you might stop the training prematurely. A good practice is to do a first short training run and evalutate a good min_value based on the log.

In [8]:
meter_dict = model.train(
    train_data_dict = train_data_dict,
    val_data_dict = val_data_dict,
    batch_size = 10,
    cluster_labels = cluster_data,
    use_gpu = False,
    patience = 2,
    min_value = 10)

print(meter_dict["loss"].avg)

01/13/2023 05:30:22 PM starting joint training
01/13/2023 05:30:22 PM Pretrain epoch 10
01/13/2023 05:30:22 PM Pretrain epoch 9
01/13/2023 05:30:22 PM Pretrain epoch 8
01/13/2023 05:30:23 PM Pretrain epoch 7
01/13/2023 05:30:23 PM Pretrain epoch 6
01/13/2023 05:30:23 PM Pretrain epoch 5
01/13/2023 05:30:23 PM Pretrain epoch 4
01/13/2023 05:30:23 PM Pretrain epoch 3
01/13/2023 05:30:23 PM Pretrain epoch 2
01/13/2023 05:30:23 PM Pretrain epoch 1


['epoch:0',
 'modality_3_pretrain_recon:0.0908463180065155',
 'modality_3_pretrain_kl:8.795190334320068',
 'modality_3_pretrain_loss:0.28133414387702943',
 'val_loss_diff:-8.552241325378418',
 'modality_1_recon:0.40513763427734373',
 'modality_1_kl:43.265543365478514',
 'discriminator_adverserial_loss:1.100578200817108',
 'modality_1_translation_loss:0.3637415438890457',
 'anchor_loss:1.166156788667043',
 'modality_2_recon:0.3078315854072571',
 'modality_2_kl:32.14865665435791',
 'modality_2_translation_loss:0.3867570996284485',
 'modality_3_recon:0.09010617956519126',
 'modality_3_kl:2.775472855567932',
 'modality_3_translation_loss:0.4915455669164658',
 'loss:7.536335897445679',
 'classifier_loss:0.7464352488517761',
 'discriminator_training_loss:1.0991130113601684',
 'val_modality_1_recon:0.34855607748031614',
 'val_modality_1_kl:37.28128275871277',
 'val_discriminator_adverserial_loss:1.0997918407122294',
 'val_modality_1_translation_loss:0.34117076396942136',
 'val_anchor_loss:1.1

01/13/2023 05:30:24 PM Stopping training early. Patience of 2 epochs was reached.


['epoch:2',
 'modality_3_pretrain_recon:0.0908463180065155',
 'modality_3_pretrain_kl:8.795190334320068',
 'modality_3_pretrain_loss:0.28133414387702943',
 'val_loss_diff:-0.13251852989196777',
 'modality_1_recon:0.2704917773604393',
 'modality_1_kl:35.03425521850586',
 'discriminator_adverserial_loss:1.102511970202128',
 'modality_1_translation_loss:0.29095015823841097',
 'anchor_loss:1.1465026100476583',
 'modality_2_recon:0.15984044820070267',
 'modality_2_kl:28.42070827484131',
 'modality_2_translation_loss:0.3294759154319763',
 'modality_3_recon:0.09030837416648865',
 'modality_3_kl:2.7754729270935057',
 'modality_3_translation_loss:0.36657718420028684',
 'loss:7.032061529159546',
 'classifier_loss:0.6784351527690887',
 'discriminator_training_loss:1.093507719039917',
 'val_modality_1_recon:0.23625757098197936',
 'val_modality_1_kl:37.144650840759276',
 'val_discriminator_adverserial_loss:1.1030825575192769',
 'val_modality_1_translation_loss:0.26793771982192993',
 'val_anchor_los

## 6. Pre-training a classifier
In some cases it might be beneficial to not only pre-train an autoencoder, but to have it influenced by a cluster classifier. This allows to create the original use case published by Dai Yang et al. of pre-training an autoencoder and a classifer, to which the other models get aligned in the joint training phase. The train method of the JointTrainer class provides a "cluster_modality" parameter towards this aim. Simply provide the model key to which the classifier should be added during pre-training. 

In [9]:
model_dict = {
    "modality_1": VariationalAutoencoder(CpdEncoder(50, 42), CpdDecoder(50, 42, 36), "adam", 0.001),
    "modality_2": VariationalAutoencoder(PQSAREncoder(75, 50), PQSARDecoder(75, 50, 36), "adam", 0.001),
    "modality_3": VariationalAutoencoder(PQSAREncoder(100, 75), PQSARDecoder(100, 75, 36), "adam", 0.001, pretrain_epochs = 10, train_joint = False)}

# re-intialize the trainer
model = JointTrainer(
        model_dict = model_dict,
        discriminator = discriminator,
        max_epochs = 10,
        recon_weight = 3,
        beta = 0.001,
        disc_weight = 3,
        anchor_weight = 1,
        cl_weight = 3,
        classifier = classifier)

# launch training with classifier in pre-training
# by providing an existing model key through the
# cluster_modality parameter
meter_dict = model.train(
    train_data_dict = train_data_dict,
    val_data_dict = val_data_dict,
    batch_size = 10,
    cluster_labels = cluster_data,
    cluster_modality = "modality_3",
    use_gpu = False)

print(meter_dict["loss"].avg)

01/13/2023 05:30:25 PM initializing trainer
01/13/2023 05:30:25 PM starting joint training
01/13/2023 05:30:25 PM Pretrain epoch 10
01/13/2023 05:30:25 PM Pretrain epoch 9
01/13/2023 05:30:25 PM Pretrain epoch 8
01/13/2023 05:30:25 PM Pretrain epoch 7
01/13/2023 05:30:25 PM Pretrain epoch 6


Initializing variational autoencoder model
Initializing optimizer
Initializing variational autoencoder model
Initializing optimizer
Initializing variational autoencoder model
Initializing optimizer


01/13/2023 05:30:26 PM Pretrain epoch 5
01/13/2023 05:30:26 PM Pretrain epoch 4
01/13/2023 05:30:26 PM Pretrain epoch 3
01/13/2023 05:30:26 PM Pretrain epoch 2
01/13/2023 05:30:26 PM Pretrain epoch 1


['epoch:0',
 'modality_3_pretrain_recon:0.09025356099009514',
 'modality_3_pretrain_kl:73.07494506835937',
 'modality_3_pretrain_classifier_loss:0.5069375306367874',
 'modality_3_pretrain_loss:1.8646482229232788',
 'val_loss_diff:-8.739702796936035',
 'modality_1_recon:0.44191673696041106',
 'modality_1_kl:44.84805335998535',
 'discriminator_adverserial_loss:1.1014357924461364',
 'modality_1_translation_loss:0.3671684920787811',
 'anchor_loss:1.2267947713534038',
 'modality_2_recon:0.3136404246091843',
 'modality_2_kl:30.411287689208983',
 'modality_2_translation_loss:0.39869125485420226',
 'modality_3_recon:0.08974274769425392',
 'modality_3_kl:36.80703964233398',
 'modality_3_translation_loss:0.5109993278980255',
 'loss:7.23685851097107',
 'classifier_loss:0.6077002704143524',
 'discriminator_training_loss:1.0964747428894044',
 'val_modality_1_recon:0.3684763342142105',
 'val_modality_1_kl:38.37733345031738',
 'val_discriminator_adverserial_loss:1.0985676447550456',
 'val_modality_1_