## VAE-MS
This is a Minimal Working Example of the VAE-MS. Firstly the nessecary modules are loaded

In [16]:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, TensorDataset

from VAEMS_init import *
from train_loop import train
from utils import initialize_nmf, poisnll
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

For this tutorial we are using randomly generated integer data with 1000 observations with 96 'mutation types'. This data is 100x normalized, used to generate NMF-based priors and divided into train/validation/test sets.

In [17]:
catalogue = torch.randint(low = 0, high = 180, size = (1000, 96))
n_pat, n_mut = catalogue.shape

h_dim = 3

catalogue_norm = np.array(catalogue[:])
totalMutations = catalogue.sum(axis=1)
indices = np.where(totalMutations > (n_mut * 100))
norm_genome = (
        np.array(catalogue)[indices[0],:]
        / np.array(totalMutations)[indices[0]][:,np.newaxis]
        * (n_mut * 100)
    )
catalogue_norm[list(indices),:] = norm_genome
catalogue_norm = pd.DataFrame(catalogue_norm)

train_val_idx, test_idx = train_test_split(range(n_pat), test_size=0.2)
train_idx, validation_idx = train_test_split(train_val_idx, test_size=0.2)

lamdba_prior_train, lamdba_prior_val, lamdba_prior_test, start_sigs, start_error = initialize_nmf(
        catalogue_norm, train_idx, validation_idx, test_idx, h_dim=h_dim, tol=1e-10)


################## Trainloader ###################################################

# Normalized data for training
train_data_norm = TensorDataset(
        torch.tensor(catalogue_norm.values[train_idx, :], dtype=torch.float32),
        torch.tensor(lamdba_prior_train.values, dtype=torch.float32))
validation_data_norm = TensorDataset(
        torch.tensor(catalogue_norm.values[validation_idx, :], dtype=torch.float32),
        torch.tensor(lamdba_prior_val.values, dtype=torch.float32))

# Unnormalized data for evaluation
trainval_prior = pd.concat([lamdba_prior_train,lamdba_prior_val]).values
trainval_data = TensorDataset(
        torch.tensor(catalogue[train_val_idx, :], dtype=torch.float32),
        torch.tensor(trainval_prior, dtype=torch.float32))
test_data = TensorDataset(
        torch.tensor(catalogue[test_idx, :], dtype=torch.float32),
        torch.tensor(lamdba_prior_test.values, dtype=torch.float32))

  torch.tensor(catalogue[train_val_idx, :], dtype=torch.float32),
  torch.tensor(catalogue[test_idx, :], dtype=torch.float32),


Next, we configure the VAE-MS model with hyperparameters stored in the 'config' dictionary.

In [18]:
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = dict(
        l_dim = [60, 40],
        batch_size = 16,
        learning_rate = 1e-3,
        activation = "ReLU",
        beta_kl =0.01,
        optimizer = "Adam",
        h_dim = h_dim)

model = VAEMS(
            input_dim=n_mut,
            l_dim=config['l_dim'],
            h_dim=config['h_dim'],
            activation=config['activation'],
            start_sigs= start_sigs
        )

Now training is performed for 500 epochs with early stopping using the 'Adam' optimizer.

In [19]:
train_dl = DataLoader(train_data_norm, batch_size=config['batch_size'], shuffle=True)
val_dl = DataLoader(validation_data_norm, batch_size=config['batch_size'], shuffle=False)

optimizer = getattr(torch.optim, config['optimizer'])(model.parameters(), lr=config['learning_rate'])

best_model, _, best_val_loss, _, _ = train(
    device=device,
    num_epochs=500,
    model=model,
    input_dim=n_mut,
    optimizer=optimizer,
    loss_fn= poisnll,
    trainloader=train_dl,
    valloader=val_dl,
    beta_kl=config['beta_kl'],
     patience=10
)

Epoch [1/500], Train Loss: 701.4787, Val Loss: 116.8415
best val loss 116.8415 updated at 1 epochs after 0 epochs without improvement
Epoch [2/500], Train Loss: 575.0501, Val Loss: 116.3988
best val loss 116.3988 updated at 2 epochs after 0 epochs without improvement
Epoch [3/500], Train Loss: 522.3417, Val Loss: 115.7296
best val loss 115.7296 updated at 3 epochs after 0 epochs without improvement
Epoch [4/500], Train Loss: 487.0071, Val Loss: 115.5176
best val loss 115.5176 updated at 4 epochs after 0 epochs without improvement
Epoch [5/500], Train Loss: 458.3468, Val Loss: 114.8387
best val loss 114.8387 updated at 5 epochs after 0 epochs without improvement
Epoch [6/500], Train Loss: 432.4667, Val Loss: 114.3674
best val loss 114.3674 updated at 6 epochs after 0 epochs without improvement
Epoch [7/500], Train Loss: 409.9308, Val Loss: 113.0938
best val loss 113.0938 updated at 7 epochs after 0 epochs without improvement
Epoch [8/500], Train Loss: 388.9146, Val Loss: 112.5538
best v

Lastly, post training, variables are extracted and metrics are calculated

In [20]:
############################################ EVALUATION LOGIC ##############################################
best_model.eval()

############################################ EXTRACT PARAMS ###############################################
                
vhat, _, Poisson_dist, signatures_est = best_model(trainval_data.tensors) 
exp_train = pd.DataFrame(Poisson_dist.rate.detach().numpy())
vhat_te, _, Poisson_dist_te,_ = best_model(test_data.tensors)
exp_test = pd.DataFrame(Poisson_dist_te.rate.detach().numpy())

signatures_est = pd.DataFrame(signatures_est.data.detach().numpy().T)


#################################### EVAULATE LOSS #############################################
reconst_loss = poisnll(trainval_data.tensors[0], vhat)
kl_div = Poisson_dist.kl(trainval_data.tensors[1])
loss_train = reconst_loss + config['beta_kl'] * kl_div.mean()

reconst_loss_te = poisnll(test_data.tensors[0], vhat_te)
kl_div_te = Poisson_dist_te.kl(test_data.tensors[1]).mean()
loss_test = reconst_loss_te + config['beta_kl'] * kl_div_te

#################################### COMPUTE METRICS ####################################################
mse_est = np.mean(pd.DataFrame(((catalogue[train_val_idx,:] - vhat) ** 2).detach().numpy()))
mse_test = np.mean(pd.DataFrame(((catalogue[test_idx,:] - vhat_te) ** 2).detach().numpy()))