In [1]:
import pandas as pd
import numpy as np
import torch
import seaborn as sb
from tqdm import tqdm

In [2]:
import sys
sys.path.append('../mvTCR/')
import tcr_embedding.utils_training as utils_train
import config.constants_10x as const

from tcr_embedding.utils_preprocessing import stratified_group_shuffle_split, group_shuffle_split
from tcr_embedding.evaluation.Imputation import run_imputation_evaluation
from tcr_embedding.evaluation.Clustering import run_clustering_evaluation
from tcr_embedding.evaluation.kNN import run_knn_within_set_evaluation
from tcr_embedding.evaluation.WrapperFunctions import get_model_prediction_function

from tcr_embedding.models.model_selection_count_prediction import DecisionHead
from sklearn.preprocessing import OneHotEncoder

In [3]:
def embed_data(adata, path_model, key_counts):
    model = utils_train.load_model(adata, path_model)
    embedding = model.get_latent(adata, metadata=['set', 'binding_name'], return_mean=True)
    embedding.obsm[key_counts] = adata.obsm[key_counts]
    return embedding


def get_training_data(embedding, key_counts):
    mask_train = embedding.obs['set'] == 'train'
    x_train = embedding.X[mask_train]
    y_train = embedding.obsm[key_counts][mask_train].toarray()

    mask_val = embedding.obs['set'] == 'val'
    x_val = embedding.X[mask_val]
    y_val = embedding.obsm[key_counts][mask_val].toarray()
    return CustomDataset(x_train, y_train), CustomDataset(x_val, y_val) 

In [4]:
def load_10x_data(donor, split):
    adata = utils_train.load_data('10x')
    
    if str(donor) != 'None':
        adata = adata[adata.obs['donor'] == f'donor_{donor}']
    else:
        enc = OneHotEncoder(sparse=False)
        enc.fit(adata.obs['donor'].to_numpy().reshape(-1, 1))
        adata.obsm['donor'] = enc.transform(adata.obs['donor'].to_numpy().reshape(-1, 1))

    adata = adata[adata.obs['binding_name'].isin(const.HIGH_COUNT_ANTIGENS)]
    if split != 'full':
        random_seed = split

        train_val, test = group_shuffle_split(adata, group_col='clonotype', val_split=0.20, random_seed=random_seed)
        train, val = group_shuffle_split(train_val, group_col='clonotype', val_split=0.25, random_seed=random_seed)

        adata.obs['set'] = 'train'
        adata.obs.loc[val.obs.index, 'set'] = 'val'
        adata.obs.loc[test.obs.index, 'set'] = 'test'
        adata = adata[adata.obs['set'].isin(['train', 'val', 'test'])]
    return adata

In [5]:
def load_avidity_model(donor, split, model_name):
    path_checkpoint = f'../mvTCR/saved_models/journal_2/10x/avidity/{model_name}/'
    path_checkpoint += f'10x_avidity_{donor}_split_{split}_{model_name}.ckpt'
    checkpoint = torch.load(path_checkpoint)
    model = DecisionHead.load_from_checkpoint(checkpoint_path=path_checkpoint)
    return model

In [6]:
def msle(x_pred, x_true):
    log_pred = torch.log(x_pred+1)
    log_true = torch.log(x_true+1)
    error = torch.nn.MSELoss()(log_pred, log_true)
    return error.detach().numpy().item()

In [7]:
import sklearn
def r2(x_pred, x_true):
    log_pred = x_pred.detach().numpy()
    log_true = x_true.detach().numpy()
    r2 = sklearn.metrics.r2_score(log_true, log_pred)
    return r2

In [8]:
from scipy.stats.stats import spearmanr, pearsonr  
def pearson_corr(x, y):      
    corr = pearsonr(x, y)[0]
    return corr

In [9]:
metadata = ['binding_name', 'clonotype', 'donor']

model_names = []
splits = []
metrics = []
scores = []
donors = []

for donor in list(range(1, 5)) + ['None']:
    for split in tqdm(range(0, 5)):
        adata = load_10x_data(donor, split)
        for model_name in ['moe',  'poe', 'concat', 'tcr', 'rna']:
            
            path_model_emb = f'saved_models/journal_2/10x/splits/{model_name}/'
            path_model_emb += f'10x_donor_{donor}_split_{split}_{model_name}.pt'

            embedded_data = embed_data(adata, path_model_emb, 'binding_counts')
            embedded_data = embedded_data[embedded_data.obs['set']=='test']
            model = load_avidity_model(donor, split, model_name)
            embedded_data_torch = torch.from_numpy(embedded_data.X)

            count_prediction = model(embedded_data_torch)
            count_truth = embedded_data.obsm['binding_counts']
            count_truth = torch.from_numpy(count_truth)


            model_names += [model_name] * (adata.obsm['binding_counts'].shape[1])*2
            splits += [split] * (adata.obsm['binding_counts'].shape[1])*2
            donors += [donor] * (adata.obsm['binding_counts'].shape[1])*2

            for i, binding in enumerate(const.HIGH_COUNT_ANTIGENS):
                msle_val = msle(count_prediction[:, i], count_truth[:, i])
                metrics.append(f'MSLE_{binding}')
                scores.append(msle_val)
            
            for i, binding in enumerate(const.HIGH_COUNT_ANTIGENS):
                r2_val = pearson_corr(count_prediction[:, i].detach().numpy(), count_truth[:, i].detach().numpy())
                if str(r2_val) in ['nan', 'NaN', None] or r2_val == np.nan:
                    r2_val = 0
                metrics.append(f'Pearson_{binding}')
                scores.append(r2_val)
            

results_10x_avd = {
    'model': model_names,
    'split': splits,
    'metrics': metrics,
    'scores': scores,
    'donors': donors
}
results_10x_avd = pd.DataFrame(results_10x_avd)
results_10x_avd.to_csv('../results/performance_avidity_10x.csv')
results_10x_avd

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [02:04<00:00, 24.98s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:52<00:00, 46.59s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [03:10<00:00, 38.15s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:42<00:00, 20.54s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [06:28<00:00, 77.76s/it]


Unnamed: 0,model,split,metrics,scores,donors
0,moe,0,MSLE_A0201_ELAGIGILTV_MART-1_Cancer_binder,0.325847,1
1,moe,0,MSLE_A0201_GILGFVFTL_Flu-MP_Influenza_binder,1.679810,1
2,moe,0,MSLE_A0201_GLCTLVAML_BMLF1_EBV_binder,0.066037,1
3,moe,0,MSLE_A0301_KLGGALQAK_IE-1_CMV_binder,1.073541,1
4,moe,0,MSLE_A0301_RLRAEAQVK_EMNA-3A_EBV_binder,0.848029,1
...,...,...,...,...,...
1995,rna,4,Pearson_A0301_KLGGALQAK_IE-1_CMV_binder,0.665275,
1996,rna,4,Pearson_A0301_RLRAEAQVK_EMNA-3A_EBV_binder,0.631071,
1997,rna,4,Pearson_A1101_IVTDFSVIK_EBNA-3B_EBV_binder,0.398274,
1998,rna,4,Pearson_A1101_AVFDRKSDAK_EBNA-3B_EBV_binder,0.529994,


## Write to Supplementary

In [10]:
path_out = '../results/supplement/S1_benchmarking.xlsx'

with pd.ExcelWriter(path_out, mode='a') as writer:  
    results_10x_avd.to_excel(writer, sheet_name='avidity_10x')