# Colab Init

In [1]:
if 'google.colab' in str(get_ipython()):
    print('Running on Colab')
    from google.colab import drive
    drive.mount('/content/drive/', force_remount=True)
    %cd /content/drive/MyDrive/tcr-embedding/example/

In [2]:
if 'google.colab' in str(get_ipython()):
    !pip install comet-ml scanpy scirpy

# Config

In [3]:
from comet_ml import Experiment, ExistingExperiment
import scanpy as sc
import scirpy as ir
import pandas as pd
import torch
import yaml
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

In [4]:
import sys
sys.path.append('../')
import tcr_embedding as tcr

In [5]:
CONFIG_NAME = 'bigru_paper_oriented'
CHOSEN_DATE = ''
EXPERIMENT_KEY = 
experiment_name = '10x_' + CONFIG_NAME + '_' + CHOSEN_DATE

SyntaxError: invalid syntax (<ipython-input-5-fcfb4ca845ab>, line 3)

In [None]:
experiment_name

In [None]:
%load_ext autoreload
%autoreload 2

# Load dataset

### 10x Dataset

In [None]:
adata = sc.read_h5ad('../data/10x_CD8TC/v5_train_val_test.h5ad')
adata

### Split data into train and val, filter out test set to keep it untouched

In [None]:
adata.obs['set'].value_counts() / len(adata)

In [None]:
adata = adata[adata.obs['set'] != 'test']
adata.obs['set'].value_counts() / len(adata)

In [None]:
train_adata = adata[adata.obs['set'] == 'train']
val_adata = adata[adata.obs['set'] == 'val']

# Initialize and train model

In [None]:
with open(f'../config/{CONFIG_NAME}.yaml') as file:
     params = yaml.load(file)
params

#### If Comet ML is not wanted, set experiment=None

In [None]:
with open('../comet_ml_key/API_key.txt') as f:
    COMET_ML_KEY = f.read()

experiment = ExistingExperiment(api_key=COMET_ML_KEY, previous_experiment=EXPERIMENT_KEY)

In [None]:
model = tcr.models.joint_model.JointModel(
    adatas=[adata],  # adatas containing gene expression and TCR-seq
    names=['10x'],
    aa_to_id = adata.uns['aa_to_id'],  # dict {aa_char: id}
    seq_model_arch=params['seq_model_arch'],  # seq model architecture
    seq_model_hyperparams=params['seq_model_hyperparams'],  # dict of seq model hyperparameters
    scRNA_model_arch=params['scRNA_model_arch'],
    scRNA_model_hyperparams=params['scRNA_model_hyperparams'],
    zdim=params['zdim'],  # zdim
    hdim=params['hdim'],  # hidden dimension of scRNA and seq encoders
    activation=params['activation'],  # activation function of autoencoder hidden layers
    dropout=params['dropout'],
    batch_norm=params['batch_norm'],
    shared_hidden=params['shared_hidden'],  # hidden layers of shared encoder / decoder
    gene_layers=[],  # [] or list of str for layer keys of each dataset
    seq_keys=[]  # [] or list of str for seq keys of each dataset
)

In [None]:
# print model architecture
model.model

# UMAP Plot of latent space

In [None]:
# List of antigens from David Fischer's paper, basically the 8 most common antigens
high_antigen_count = ['A0201_ELAGIGILTV_MART-1_Cancer_binder', 
                      'A0201_GILGFVFTL_Flu-MP_Influenza_binder', 
                      'A0201_GLCTLVAML_BMLF1_EBV_binder', 
                      'A0301_KLGGALQAK_IE-1_CMV_binder', 
                      'A0301_RLRAEAQVK_EMNA-3A_EBV_binder', 
                      'A1101_IVTDFSVIK_EBNA-3B_EBV_binder', 
                      'A1101_AVFDRKSDAK_EBNA-3B_EBV_binder', 
                      'B0801_RAKFKQLL_BZLF1_EBV_binder']

### On Val Data

Filter cells with no binding data and only UMAP on high count antigen bindings

In [None]:
val_adata = val_adata[val_adata.obs['has_binding']]

In [None]:
val_adata.obs['binding_name'].value_counts()

In [None]:
val_adata

Use last saved model

In [None]:
model.load(f'../saved_models/{experiment_name}_last_model.pt')

In [None]:
z = model.get_latent(
    adatas=[val_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=high_antigen_count + ['binding_name']
)

In [None]:
z

In [None]:
sc.pp.neighbors(z, use_rep='X')
sc.tl.umap(z)

In [None]:
for antigen in high_antigen_count:
    ax = sc.pl.umap(z, color=antigen, return_fig=True, alpha=0.3)
    ax.tight_layout()
    experiment.log_figure(figure_name=f'val_{antigen}', figure=ax, step=model.epoch, overwrite=False)
    ax.clf()
ax = sc.pl.umap(z[z.obs['binding_name'].isin(high_antigen_count)], color='binding_name', return_fig=True, alpha=0.4)
ax.set_size_inches(8, 4.8)
ax.tight_layout()
experiment.log_figure(figure_name=f'val_binding_name', figure=ax, step=model.epoch, overwrite=False)

Use "best" saved model (currently based on val_loss)

In [None]:
model.load(f'../saved_models/{experiment_name}_best_model.pt')

In [None]:
z = model.get_latent(
    adatas=[val_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=high_antigen_count + ['binding_name']
)

In [None]:
sc.pp.neighbors(z, use_rep='X')
sc.tl.umap(z)

In [None]:
for antigen in high_antigen_count:
    ax = sc.pl.umap(z, color=antigen, return_fig=True, alpha=0.3)
    ax.tight_layout()
    experiment.log_figure(figure_name=f'val_{antigen}', figure=ax, step=model.epoch, overwrite=False)
    ax.clf()
ax = sc.pl.umap(z[z.obs['binding_name'].isin(high_antigen_count)], color='binding_name', return_fig=True, alpha=0.4)
ax.set_size_inches(8, 4.8)
ax.tight_layout()
experiment.log_figure(figure_name=f'val_binding_name', figure=ax, step=model.epoch, overwrite=False)

### On Train Data

Filter cells with no binding data and only UMAP on high count antigen bindings

In [None]:
train_adata = train_adata[train_adata.obs['has_binding']]

In [None]:
train_adata.obs['binding_name'].value_counts()

In [None]:
train_adata

Use last saved model

In [None]:
model.load(f'../saved_models/{experiment_name}_last_model.pt')

In [None]:
z = model.get_latent(
    adatas=[train_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=high_antigen_count + ['binding_name']
)

In [None]:
z

In [None]:
sc.pp.neighbors(z, use_rep='X')
sc.tl.umap(z)

In [None]:
for antigen in high_antigen_count:
    ax = sc.pl.umap(z, color=antigen, return_fig=True, alpha=0.3)
    ax.tight_layout()
    experiment.log_figure(figure_name=f'train_{antigen}', figure=ax, step=model.epoch, overwrite=False)
    ax.clf()
ax = sc.pl.umap(z[z.obs['binding_name'].isin(high_antigen_count)], color='binding_name', return_fig=True, alpha=0.4)
ax.set_size_inches(8, 4.8)
ax.tight_layout()
experiment.log_figure(figure_name=f'train_binding_name', figure=ax, step=model.epoch, overwrite=False)

Use "best" saved model, (currently based on val loss)

In [None]:
model.load(f'../saved_models/{experiment_name}_best_model.pt')

In [None]:
z = model.get_latent(
    adatas=[train_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=high_antigen_count + ['binding_name']
)

In [None]:
sc.pp.neighbors(z, use_rep='X')
sc.tl.umap(z)

In [None]:
for antigen in high_antigen_count:
    ax = sc.pl.umap(z, color=antigen, return_fig=True, alpha=0.3)
    ax.tight_layout()
    experiment.log_figure(figure_name=f'train_{antigen}', figure=ax, step=model.epoch, overwrite=False)
    ax.clf()
ax = sc.pl.umap(z[z.obs['binding_name'].isin(high_antigen_count)], color='binding_name', return_fig=True, alpha=0.4)
ax.set_size_inches(8, 4.8)
ax.tight_layout()
experiment.log_figure(figure_name=f'train_binding_name', figure=ax, step=model.epoch, overwrite=False)

# kNN prediction

In [None]:
model.load(f'../saved_models/{experiment_name}_best_model.pt')

In [None]:
train_adata.obs['binding_name'].value_counts()

In [None]:
val_adata.obs['binding_name'].value_counts()

In [None]:
z_train = model.get_latent(
    adatas=[train_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=['binding_name', 'binding_label']
)

In [None]:
z_val = model.get_latent(
    adatas=[val_adata],
    names=['10x'],
    batch_size=256,
    num_workers=0,
    gene_layers=[],
    seq_keys=[],
    metadata=['binding_name', 'binding_label']
)

Filter out cells that have rare antigen specificity

In [None]:
classes = 'binding_name'
z_train = z_train[z_train.obs[classes].isin(high_antigen_count)]
z_train.obs[classes].value_counts()

In [None]:
z_val = z_val[z_val.obs[classes].isin(high_antigen_count)]
z_val.obs[classes].value_counts()

In [None]:
model.kNN(z_train, z_val, classes, 5, 'distance')

In [None]:
z_val = z_val[z_val.obs[classes].isin(high_antigen_count) | z_val.obs['pred_'+classes].isin(high_antigen_count)]
z_val.shape

In [None]:
from sklearn.metrics import classification_report
print(classification_report(z_val.obs[classes], z_val.obs['pred_'+classes]))
experiment.log_text(text=classification_report(z_val.obs[classes], z_val.obs['pred_'+classes]), step=model.epoch)

In [None]:
metrics = classification_report(z_val.obs[classes], z_val.obs['pred_'+classes], output_dict=True)
for antigen, metric in metrics.items():
    if antigen != 'accuracy':
        experiment.log_metrics(metric, prefix=antigen, step=model.epoch, epoch=model.epoch)
    else:
        experiment.log_metric('accuracy', metric, step=model.epoch, epoch=model.epoch)

In [None]:
experiment.log_confusion_matrix(matrix=pd.crosstab(z_val.obs[classes], z_val.obs['pred_'+classes]).values,
                                labels=pd.crosstab(z_val.obs[classes], z_val.obs['pred_'+classes]).index.to_list(),
                                step=model.epoch, epoch=model.epoch)

In [None]:
pd.crosstab(z_val.obs[classes], z_val.obs['pred_'+classes])

In [None]:
experiment.end()