# Training the model for Analysis
In this notebook, we will explain how to train your model with the objective of dataset analysis. Here, we optimize to conserve multiple annotations at the same time (e.g. clonotype and celltype). You can determine the influence of both modalities (TCR via clonotype, GEX via celltypes) by specifying a weight for annotation. This might require retraining on a couple of weight values for finding a mixture suitable for your analysis.

In [1]:
# comet-ml must be imported before torch and sklearn
import comet_ml
import scanpy as sc

import os
import sys
sys.path.append('..')

## Data Preperation
First we load the data via the Scanpy API.

In [2]:
path_data = '../data/Haniffa/haniffa_test.h5ad'
adata = sc.read(path_data)

Then we divide the data into training and validation data. Typically, a split of 20% validation data was used. The data is splitted to uniquely group 'clonotype' to either training or validation set.

In [3]:
import tcr_embedding.utils_training as utils
from tcr_embedding.utils_preprocessing import group_shuffle_split
utils.fix_seeds(42)

train, val = group_shuffle_split(adata, group_col='clonotype', val_split=0.20, random_seed=42)
adata.obs['set'] = 'train'
adata.obs.loc[val.obs.index, 'set'] = 'val'

## Defining the model parameters
We need to proivde the model a couple of parameters:
- study_name: Name for logging
- comet_workspace: we logged some of the experiments via Comet-ML. This gives you more information on the training process, but is not needed. We will therefore use None here. Otherwise, specifiy the workspace name of your Comet-ML project.
- model_name: assigns which model is used from (rna, tcr, moe, poe, concat). We will use the best performing moe.
- balanced_sample: oversample rare elements of this column. Recommended to use a column storing the clonotype to avoid overfitting.
- metadata: annotation to color the umaps when storing immediate results on Comet-ML. If no Comet-ML is used, pass an empty list
- save_path: path to store the trained models over multiple training runs
- conditional: name of a conditional variable (see preprocessing). The model partially removes batch effects over this column.
- n_epoch: amounts of epochs to train the model. For the paper we used 200 epochs. For showcasing however, we will reduce it to 5.

In [4]:
params_experiment = {
    'study_name': f'haniffa_tutorial',
    'comet_workspace': None, 
    'model_name': 'moe',
    'balanced_sampling': 'clonotype',
    'metadata': [],
    'save_path': '../saved_models/haniffa_tutorial',
    'conditional': 'patient_id',
    'n_epochs': 5,
}


## Defining Optimization parameters
For this analysis, we will optimize to perserve clonotype and celltype (full_clustering). This optimization mode is called 'pseudo_metric'. By specifying the weight, we can choose the weighting between both modalities.

In [5]:
params_optimization = {
    'name': 'pseudo_metric',
    'prediction_labels':
        {'clonotype': 1,
         'full_clustering': 5}
}

## Calling the training functions
Finally, we need to specificy a couple of parameters for running the training. Training will be aborted either after \<timeout\> seconds or after having trained 3 models with 1 available GPU.

In [6]:
from tcr_embedding.models.model_selection import run_model_selection

timeout = (20*60)
n_samples = 3
n_gpus = 1
run_model_selection(adata, params_experiment, params_optimization, n_samples, timeout, n_gpus)

[32m[I 2022-06-01 14:47:57,135][0m A new study created in RDB with name: haniffa_tutorial[0m
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:58<00:00, 11.63s/it]
[32m[I 2022-06-01 14:48:58,469][0m Trial 0 finished with value: 1.4392686179739715 and parameters: {'dropout': 0.1, 'activation': 'linear', 'rna_hidden': 1500, 'hdim': 200, 'shared_hidden': 100, 'rna_num_layers': 1, 'tfmr_encoding_layers': 4, 'loss_weights_kl': 4.0428727350273357e-07, 'loss_weights_tcr': 0.034702669886504146, 'lr': 1.0994335574766187e-05, 'zdim': 50, 'tfmr_embedding_size': 16, 'tfmr_num_heads': 8, 'tfmr_dropout': 0.15000000000000002}. Best is trial 0 with value: 1.4392686179739715.[0m
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:24<00:00,  4.94s/it]
[32m[I 2022-06-01 14:49:25,234][0m Trial 1 finished with value: 1.4204253402130869 and parameters: {'dropout': 0.1, 'activation': 'linear', 'rna_hidden'

RuntimeError: CUDA out of memory. Tried to allocate 120.00 MiB (GPU 0; 4.00 GiB total capacity; 2.38 GiB already allocated; 0 bytes free; 2.53 GiB reserved in total by PyTorch)

## Output
The console output indicates the best model after Hyperparameter Optimization. We will now load this model and embedd our data with it. Following, we can continue with standard analysis.

In [None]:
path_model = '../saved_models/haniffa_tutorial/trial_0/best_model_by_metric.pt'
model = utils.load_model(adata, path_model)

In [None]:
latent_moe = moe_model.get_latent(data, metadata=[], return_mean=True)
latent_moe.obs = adata.obs.copy()

In [None]:
sc.pp.neighbors(latent_moe, use_rep='X')
sc.tl.umap(latent_moe)
sc.pl.umap(latent_moe, color='full_clustering')