In [1]:
import warnings
warnings.filterwarnings('ignore')  # "error", "ignore", "always", "default", "module" or "once"

import scanpy as sc
import yaml
import sys

sys.path.append('../')
import tcr_embedding as tcr  # tune needs to reload this module
from tcr_embedding.evaluation.Imputation import run_imputation_evaluation
from tcr_embedding.evaluation.WrapperFunctions import get_model_prediction_function

In [2]:
adata = sc.read_h5ad('../data/10x_CD8TC/v6_supervised.h5ad')

In [4]:
params = yaml.load(open('../config/transformer.yaml'), Loader=yaml.FullLoader)

#### Change between PoE and MoE model, uncomment the corresponding line

In [5]:
init_model = tcr.models.moe.MoEModel
# init_model = tcr.models.poe.PoEModel

In [6]:
model = init_model(
    adatas=[adata],  # adatas containing gene expression and TCR-seq
    names=['10x'],
    aa_to_id=adata.uns['aa_to_id'],  # dict {aa_char: id}
    seq_model_arch=params['seq_model_arch'],  # seq model architecture
    seq_model_hyperparams=params['seq_model_hyperparams'],  # dict of seq model hyperparameters
    scRNA_model_arch=params['scRNA_model_arch'],
    scRNA_model_hyperparams=params['scRNA_model_hyperparams'],
    zdim=params['zdim'],  # zdim
    hdim=params['hdim'],  # hidden dimension of scRNA and seq encoders
    activation=params['activation'],  # activation function of autoencoder hidden layers
    dropout=params['dropout'],
    batch_norm=params['batch_norm'],
    shared_hidden=params['shared_hidden'],  # hidden layers of shared encoder / decoder
    gene_layers=[],  # [] or list of str for layer keys of each dataset
    seq_keys=[], # [] or list of str for seq keys of each dataset,
)

In [7]:
model.train(
    experiment_name='test',
    n_iters=None,
    n_epochs=10,
    batch_size=params['batch_size'],
    lr=params['lr'],
    losses=params['losses'],  # list of losses for each modality: losses[0] := scRNA, losses[1] := TCR
    loss_weights=params['loss_weights'],  # [] or list of floats storing weighting of loss in order [scRNA, TCR, KLD]
    kl_annealing_epochs=None,
    val_split='set',  # float or str, if float: split is determined automatically, if str: used as key for train-val column
    metadata=['clonotype'],
    early_stop=100,
    balanced_sampling=None,  #opt=[None, 'clonotype'],
    validate_every=5,
    save_every=1000,
    save_path='saved_models_delete',
    num_workers=0,
    verbose=0,  # 0: only tdqm progress bar, 1: val loss, 2: train and val loss
    continue_training=False,
    device=None,
    comet=None
)

Create Dataloader
Dataloader created


Epoch:  18%|█████████████▊                                                              | 2/11 [01:48<08:07, 54.18s/it]


KeyboardInterrupt: 

Evaluate the latent space with kNN

In [8]:
test_embedding_func = get_model_prediction_function(model, batch_size=params['batch_size'])
summary = run_imputation_evaluation(adata, test_embedding_func, query_source='val', use_non_binder=True,
                                    use_reduced_binders=True)

metrics = summary['knn']
print(f"{metrics['weighted avg']['f1-score']}")

0.6473222063886667


In [10]:
metrics

{'A0201_ELAGIGILTV_MART-1_Cancer_binder': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 48},
 'A0201_GILGFVFTL_Flu-MP_Influenza_binder': {'precision': 0.9715789473684211,
  'recall': 0.923,
  'f1-score': 0.9466666666666668,
  'support': 1000},
 'A0201_GLCTLVAML_BMLF1_EBV_binder': {'precision': 0.23333333333333334,
  'recall': 0.08333333333333333,
  'f1-score': 0.12280701754385966,
  'support': 84},
 'A0301_KLGGALQAK_IE-1_CMV_binder': {'precision': 0.38657105606258146,
  'recall': 0.1402554399243141,
  'f1-score': 0.205831308573412,
  'support': 4228},
 'A0301_RLRAEAQVK_EMNA-3A_EBV_binder': {'precision': 0.0,
  'recall': 0.0,
  'f1-score': 0.0,
  'support': 59},
 'A1101_AVFDRKSDAK_EBNA-3B_EBV_binder': {'precision': 0.9875141884222475,
  'recall': 0.7184145334434352,
  'f1-score': 0.8317399617590822,
  'support': 1211},
 'A1101_IVTDFSVIK_EBNA-3B_EBV_binder': {'precision': 0.7631578947368421,
  'recall': 0.19333333333333333,
  'f1-score': 0.30851063829787234,
  'supp