# Colab Init

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on Colab')
    from google.colab import drive
    drive.mount('/content/drive/', force_remount=True)
    %cd /content/drive/MyDrive/tcr-embedding/example/

In [None]:
if 'google.colab' in str(get_ipython()):
    !pip install comet-ml scanpy scirpy

# Config

In [None]:
from comet_ml import Experiment
import scanpy as sc
import scirpy as ir
import pandas as pd
import torch
import yaml
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

In [None]:
import sys
sys.path.append('../')
import tcr_embedding as tcr

In [None]:
CONFIG_NAME = 'bigru_paper_oriented'
current_datetime = datetime.now().strftime("%Y%m%d-%H.%M")
experiment_name = '10x_' + CONFIG_NAME + '_' + current_datetime

In [None]:
experiment_name

In [None]:
%load_ext autoreload
%autoreload 2

# Load dataset

### 10x Dataset

In [None]:
adata = sc.read_h5ad('../data/10x_CD8TC/v5_train_val_test.h5ad')
adata

### Split data into train and val, filter out test set to keep it untouched

In [None]:
adata.obs['set'].value_counts() / len(adata)

In [None]:
adata = adata[adata.obs['set'] != 'test']
adata.obs['set'].value_counts() / len(adata)

In [None]:
train_adata = adata[adata.obs['set'] == 'train']
val_adata = adata[adata.obs['set'] == 'val']

# Initialize and train model

In [None]:
with open(f'../config/{CONFIG_NAME}.yaml') as file:
     params = yaml.load(file)
params

#### If Comet ML is not wanted, set experiment=None

In [None]:
with open('../comet_ml_key/API_key.txt') as f:
    COMET_ML_KEY = f.read()

experiment = Experiment(api_key=COMET_ML_KEY, workspace='tcr', project_name='10x_GRU')
experiment.log_parameters(params)
experiment.log_parameter('experiment_name', experiment_name)

In [None]:
model = tcr.models.joint_model.JointModel(
    adatas=[adata],  # adatas containing gene expression and TCR-seq
    names=['10x'],
    aa_to_id = adata.uns['aa_to_id'],  # dict {aa_char: id}
    seq_model_arch=params['seq_model_arch'],  # seq model architecture
    seq_model_hyperparams=params['seq_model_hyperparams'],  # dict of seq model hyperparameters
    scRNA_model_arch=params['scRNA_model_arch'],
    scRNA_model_hyperparams=params['scRNA_model_hyperparams'],
    zdim=params['zdim'],  # zdim
    hdim=params['hdim'],  # hidden dimension of scRNA and seq encoders
    activation=params['activation'],  # activation function of autoencoder hidden layers
    dropout=params['dropout'],
    batch_norm=params['batch_norm'],
    shared_hidden=params['shared_hidden'],  # hidden layers of shared encoder / decoder
    gene_layers=[],  # [] or list of str for layer keys of each dataset
    seq_keys=[]  # [] or list of str for seq keys of each dataset
)

In [None]:
# print model architecture
model.model

In [None]:
model.train(
    experiment_name=experiment_name,
    n_iters=None,
    n_epochs=300,
    batch_size=params['batch_size'],
    lr=params['lr'],
    losses=params['losses'],  # list of losses for each modality: losses[0] := scRNA, losses[1] := TCR
    loss_weights=params['loss_weights'],  # [] or list of floats storing weighting of loss in order [scRNA, TCR, KLD]
    val_split='set',  # float or str, if float: split is determined automatically, if str: used as key for train-val column
    metadata=['tcr_seq'],
    validate_every=5,
    print_every=5,
    save_every=25,
    num_workers=0,
    verbose=1,  # 0: only tdqm progress bar, 1: val loss, 2: train and val loss
#     continue_training=True,
    device=None,
    comet=experiment
)

In [None]:
model.history

In [None]:
model.train_history

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 10));
plt.subplot(221);
plt.plot(model.train_history['epoch'], model.train_history['loss'], '.-', label='Train loss');
plt.plot(model.history['epoch'], model.history['loss'], '.-', label='Val loss');
plt.legend();

plt.subplot(222);
plt.plot(model.train_history['epoch'], model.train_history['scRNA_loss'], '.-', label='Train scRNA loss');
plt.plot(model.history['epoch'], model.history['scRNA_loss'], '.-', label='Val scRNA loss');
plt.legend();

plt.subplot(223);
plt.plot(model.train_history['epoch'], model.train_history['TCR_loss'], '.-', label='Train TCR loss');
plt.plot(model.history['epoch'], model.history['TCR_loss'], '.-', label='Val TCR loss');
plt.xlabel('#Epochs');
plt.legend();

plt.subplot(224);
plt.plot(model.train_history['epoch'], model.train_history['KLD_loss'], '.-', label='Train KLD loss');
plt.plot(model.history['epoch'], model.history['KLD_loss'], '.-', label='Val KLD loss');
plt.xlabel('#');
plt.legend();


# UMAP Plot of latent space

In [None]:
# List of antigens from David Fischer's paper, basically the 8 most common antigens
high_antigen_count = ['A0201_ELAGIGILTV_MART-1_Cancer_binder', 
                      'A0201_GILGFVFTL_Flu-MP_Influenza_binder', 
                      'A0201_GLCTLVAML_BMLF1_EBV_binder', 
                      'A0301_KLGGALQAK_IE-1_CMV_binder', 
                      'A0301_RLRAEAQVK_EMNA-3A_EBV_binder', 
                      'A1101_IVTDFSVIK_EBNA-3B_EBV_binder', 
                      'A1101_AVFDRKSDAK_EBNA-3B_EBV_binder', 
                      'B0801_RAKFKQLL_BZLF1_EBV_binder']

### On Val Data

Filter cells with no binding data and only UMAP on high count antigen bindings

In [None]:
val_adata = val_adata[val_adata.obs['has_binding']]

In [None]:
val_adata.obs['binding_name'].value_counts()

In [None]:
val_adata

Use last saved model

In [None]:
model.load(f'../saved_models/{experiment_name}_last_model.pt')

In [None]:
z = model.get_latent(
    adatas=[val_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=high_antigen_count + ['binding_name']
)

In [None]:
z

In [None]:
sc.pp.neighbors(z, use_rep='X')
sc.tl.umap(z)

In [None]:
# plt.figure(figsize=(10,10))
for antigen in high_antigen_count:
    ax = sc.pl.umap(z, color=antigen, return_fig=True, alpha=0.3)
    experiment.log_figure(figure_name=f'val_{antigen}', figure=ax, step=model.epoch, overwrite=False)
    ax.clf()
ax = sc.pl.umap(z, color='binding_name', return_fig=True, alpha=0.2)
experiment.log_figure(figure_name=f'val_binding_name', figure=ax, step=model.epoch, overwrite=False)
ax.clf()

Use "best" saved model (currently based on val_loss)

In [None]:
model.load(f'../saved_models/{experiment_name}_best_model.pt')

In [None]:
z = model.get_latent(
    adatas=[val_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=high_antigen_count + ['binding_name']
)

In [None]:
sc.pp.neighbors(z, use_rep='X')
sc.tl.umap(z)

In [None]:
# plt.figure(figsize=(10,10))
for antigen in high_antigen_count:
    ax = sc.pl.umap(z, color=antigen, return_fig=True, alpha=0.2)
    experiment.log_figure(figure_name=f'val_{antigen}', figure=ax, step=model.epoch, overwrite=False)
    ax.clf()
ax = sc.pl.umap(z, color='binding_name', return_fig=True, alpha=0.2)
experiment.log_figure(figure_name=f'val_binding_name', figure=ax, step=model.epoch, overwrite=False)
ax.clf()

### On Train Data

Filter cells with no binding data and only UMAP on high count antigen bindings

In [None]:
train_adata = train_adata[train_adata.obs['has_binding']]

In [None]:
train_adata.obs['binding_name'].value_counts()

In [None]:
train_adata

Use last saved model

In [None]:
model.load(f'../saved_models/{experiment_name}_last_model.pt')

In [None]:
z = model.get_latent(
    adatas=[train_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=high_antigen_count + ['binding_name']
)

In [None]:
z

In [None]:
sc.pp.neighbors(z, use_rep='X')
sc.tl.umap(z)

In [None]:
# plt.figure(figsize=(10,10))
for antigen in high_antigen_count:
    ax = sc.pl.umap(z, color=antigen, return_fig=True, alpha=0.3)
    experiment.log_figure(figure_name=f'train_{antigen}', figure=ax, step=model.epoch, overwrite=False)
    ax.clf()
ax = sc.pl.umap(z, color='binding_name', return_fig=True, alpha=0.2)
experiment.log_figure(figure_name=f'train_binding_name', figure=ax, step=model.epoch, overwrite=False)
ax.clf()

Use "best" saved model, (currently based on val loss)

In [None]:
model.load(f'../saved_models/{experiment_name}_best_model.pt')

In [None]:
z = model.get_latent(
    adatas=[train_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=high_antigen_count + ['binding_name']
)

In [None]:
sc.pp.neighbors(z, use_rep='X')
sc.tl.umap(z)

In [None]:
# plt.figure(figsize=(10,10))
for antigen in high_antigen_count:
    ax = sc.pl.umap(z, color=antigen, return_fig=True, alpha=0.2)
    experiment.log_figure(figure_name=f'train_{antigen}', figure=ax, step=model.epoch, overwrite=False)
    ax.clf()
ax = sc.pl.umap(z, color='binding_name', return_fig=True, alpha=0.2)
experiment.log_figure(figure_name=f'train_binding_name', figure=ax, step=model.epoch, overwrite=False)
ax.clf()

# kNN prediction

In [None]:
train_adata.obs['binding_name'].value_counts()

In [None]:
val_adata.obs['binding_name'].value_counts()

In [None]:
z_train = model.get_latent(
    adatas=[train_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=['binding_name', 'binding_label']
)

In [None]:
z_val = model.get_latent(
    adatas=[val_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=['binding_name', 'binding_label']
)

In [None]:
classes = 'binding_name'
model.kNN(z_train, z_val, classes, 5, 'distance')

In [None]:
z_val = z_val[z_val.obs[classes].isin(high_antigen_count) | z_val.obs['pred_'+classes].isin(high_antigen_count)]
z_val.shape

In [None]:
(z_val.obs['pred_'+classes] == 'A0201_FLYALALLL_LMP2A_EBV_binder').sum()

In [None]:
from sklearn.metrics import classification_report
print(classification_report(z_val.obs[classes], z_val.obs['pred_'+classes]))
experiment.log_text(text=classification_report(z_val.obs[classes], z_val.obs['pred_'+classes]), step=model.epoch)

In [None]:
experiment.log_table('confusion_matrix.csv', pd.crosstab(z_val.obs[classes], z_val.obs['pred_'+classes]))
pd.crosstab(z_val.obs[classes], z_val.obs['pred_'+classes])

In [None]:
experiment.end()