# 02 - Training

## A - Libraries

In [None]:
import pickle
import pandas as pd
import numpy as np

from metaboDGD.util import data, train
from metaboDGD.src import model
import matplotlib.pyplot as plt
import torch
import torch.distributions as D

## B - Retrieve Dataframe and Cohorts

In [None]:
dir = 'outputs/'
df_fname = 'CombinedDataset_CAMP.csv'
df_exp_fname = 'Exponent_CombinedDataset_CAMP.csv'
cohorts_fname = 'cohorts.pkl'

In [None]:
df = pd.read_csv(dir + df_fname)
df.set_index('Unnamed: 0', inplace=True)
df.index.name = None

df_exp = pd.read_csv(dir + df_exp_fname)
df_exp.set_index('Unnamed: 0', inplace=True)
df_exp.index.name = None

f = open(dir + cohorts_fname, 'rb')
cohorts = pickle.load(f)

## C - Preparing TrainLoader and DGD Model

In [None]:
train_dict = {}
test_dict  = {}

train_lbls = []
test_lbls  = []

plot_counts = {}

In [None]:
train_loader, test_loader = data.create_dataloaders(
    cohorts=cohorts,
    df=df_exp,
    batch_size=64
)

In [None]:
dgd_model = model.MetaboDGD(
    latent_dim=10,
    output_dim=1915,
    dec_hidden_layers_dim=[500, 1000, 1500],
    dec_output_prediction_type='mean',
    dec_output_activation_type='softplus',
    n_comp=8,
    cm_type='diagonal'
)

In [None]:
dgd_model, train_rep, test_rep, history = train.train_dgd(
    dgd_model=dgd_model,
    train_loader=train_loader,
    validation_loader=test_loader,
    n_epochs=100,
    lr_schedule_epochs=None,
    lr_schedule=[1e-4, 1e-3, 1e-2],
    optim_betas=[0.5, 0.7],
    wd=1e-4,
)

In [None]:
history

In [None]:
plt.plot(history['epoch'], history['train_loss'], label='train')
plt.plot(history['epoch'], history['val_loss']  , label='validation')
plt.ylabel("Total Loss")
plt.xlabel("Epoch")
plt.legend()
plt.title("Training Loss Curve - Total Loss")

In [None]:
plt.plot(history['epoch'], history['train_recon_loss'], label='train')
plt.plot(history['epoch'], history['val_recon_loss']  , label='validation')
plt.ylabel("Total Loss")
plt.xlabel("Epoch")
plt.legend()
plt.title("Training Loss Curve - Reconstruction Loss")

In [None]:
plt.plot(history['epoch'], history['train_dist_loss'], label='train')
plt.plot(history['epoch'], history['val_dist_loss']  , label='validation')
plt.ylabel("Total Loss")
plt.xlabel("Epoch")
plt.legend()
plt.title("Training Loss Curve - GMM Distribution Loss")