In [1]:
import pandas as pd
from tqdm import tqdm

In [2]:
import sys
sys.path.append('../mvTCR/')
import tcr_embedding.utils_training as utils
import config.constants_10x as const

from tcr_embedding.utils_preprocessing import stratified_group_shuffle_split, group_shuffle_split
from tcr_embedding.evaluation.Imputation import run_imputation_evaluation
from tcr_embedding.evaluation.Clustering import run_clustering_evaluation
from tcr_embedding.evaluation.kNN import run_knn_within_set_evaluation
from tcr_embedding.evaluation.WrapperFunctions import get_model_prediction_function

In [3]:
def load_model(adata, dataset, split, model, donor=''):
    if donor != '':
        path_model = f'saved_models/journal_2/10x/splits/{model}/10x_donor_{donor}_split_{split}_{model}'
    else:
        path_model = f'saved_models/journal/Fischer/{model}/{model}_{dataset}_'
    path_model += '.pt'
    model = utils.load_model(adata, path_model)
    return model
    

In [4]:
def load_data(split, dataset):
    adata = utils.load_data(dataset)
    random_seed = split
    sub, non_sub = group_shuffle_split(adata, group_col='clonotype', val_split=0.2, random_seed=random_seed)
    train, val = group_shuffle_split(sub, group_col='clonotype', val_split=0.20, random_seed=random_seed)
    
    adata.obs['set'] = 'train'
    adata.obs.loc[non_sub.obs.index, 'set'] = '-'
    adata.obs.loc[val.obs.index, 'set'] = 'val'
    adata = adata[adata.obs['set'].isin(['train', 'val'])]
    return adata

In [6]:
dataset_2_celltype = {
    'Haniffa': 'full_clustering',
    'covid': 'cell_type',
    'Borcherding': 'functional.cluster',
}

In [8]:
adata = load_data(0, 'covid')

## Preservance of Cell type and Clonotype

In [9]:
metadata = ['T_cells', 'clonotype', 'responsive']

model_names = []
splits = []
metrics = []
scores = []
datasets = []
for dataset in ['covid', 'Haniffa', 'Borcherding']
    for split in range(0, 5):
        data = load_covid_data(split, dataset)
        for model_name in ['moe', 'tcr', 'rna']:
            print(f'split: {split},  model: {model_name}')
            model = load_model(data, 'covid', split, model_name)
            test_embedding_func = get_model_prediction_function(model)

            best_nmi = -99
            for resolution in [0.01, 0.1, 1.0]:
                cluster_result = run_clustering_evaluation(data, test_embedding_func, 'train', 
                                                           name_label=dataset_2_celltype[dataset], 
                                                   cluster_params={'resolution': resolution, 'num_neighbors': 5})
                best_nmi = max(cluster_result['NMI'], best_nmi)
            model_names.append(model_name)
            splits.append(split)
            metrics.append('NMI_cell_type')
            scores.append(best_nmi)
            datasets.append(dataset)
                
            best_nmi = -99
            for resolution in [0.01, 0.1, 1.0]:
                cluster_result = run_clustering_evaluation(data, test_embedding_func, 'train', name_label='clonotype', 
                                                   cluster_params={'resolution': resolution, 'num_neighbors': 5})
                best_nmi = max(cluster_result['NMI'], best_nmi)
            model_names.append(model_name)
            splits.append(split)
            metrics.append('NMI_clonotype')
            scores.append(best_nmi)
            datasets.append(dataset)


results_perservance = {
    'model': model_names,
    'split': splits,
    'metric': metrics,
    'score': scores,
    'dataset': datasets
}
results_perservance = pd.DataFrame(results_perservance)
results_perservance.to_csv('../results/performance_perservance.csv')
results_perservance

split: 0,  model: concat
split: 0,  model: moe
split: 0,  model: poe
split: 0,  model: tcr
split: 0,  model: rna
split: 1,  model: concat
split: 1,  model: moe
split: 1,  model: poe
split: 1,  model: tcr
split: 1,  model: rna
split: 2,  model: concat
split: 2,  model: moe
split: 2,  model: poe
split: 2,  model: tcr
split: 2,  model: rna
split: 3,  model: concat
split: 3,  model: moe
split: 3,  model: poe
split: 3,  model: tcr
split: 3,  model: rna
split: 4,  model: concat
split: 4,  model: moe
split: 4,  model: poe
split: 4,  model: tcr
split: 4,  model: rna


Unnamed: 0,model,split,metrics,scores
0,concat,0,NMI_cell_type,0.408943
1,concat,0,NMI_reactivity,0.156583
2,moe,0,NMI_cell_type,0.437105
3,moe,0,NMI_reactivity,0.198819
4,poe,0,NMI_cell_type,0.39922
5,poe,0,NMI_reactivity,0.218664
6,tcr,0,NMI_cell_type,0.428652
7,tcr,0,NMI_reactivity,0.263353
8,rna,0,NMI_cell_type,0.511926
9,rna,0,NMI_reactivity,0.140069


## Write Supplemantary Material S1

In [10]:
path_out = '../results/supplement/S1_benchmarking.xlsx'
with pd.ExcelWriter(path_out, mode='a') as writer:  
    results_perservance.to_excel(writer, sheet_name='Perservance')