In [1]:
import pandas as pd
from tqdm import tqdm

In [2]:
import sys
sys.path.append('../mvTCR/')
import tcr_embedding.utils_training as utils
import config.constants_10x as const

from tcr_embedding.utils_preprocessing import stratified_group_shuffle_split, group_shuffle_split
from tcr_embedding.evaluation.WrapperFunctions import get_model_prediction_function

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def load_model(adata, dataset, split, model, donor=''):
    if donor != '':
        path_model = f'saved_models/journal_2/10x/splits/{model}/10x_donor_{donor}_split_{split}_{model}'
    elif dataset=='covid':
        path_model = f'saved_models/journal_2/Fischer/splits/{model}/{dataset}_split_{split}_{model}'
    else:
        path_model = f'saved_models/journal_2/{dataset}/splits/{model}/{dataset}_split_{split}_{model}'
    path_model += '.pt'
    model = utils.load_model(adata, path_model)
    return model
    

In [4]:
def load_data(split, dataset):
    adata = utils.load_data(dataset)
    random_seed = split
    sub, non_sub = group_shuffle_split(adata, group_col='clonotype', val_split=0.2, random_seed=random_seed)
    train, val = group_shuffle_split(sub, group_col='clonotype', val_split=0.20, random_seed=random_seed)
    
    adata.obs['set'] = 'train'
    adata.obs.loc[non_sub.obs.index, 'set'] = '-'
    adata.obs.loc[val.obs.index, 'set'] = 'val'
    adata = adata[adata.obs['set'].isin(['train', 'val'])]
    return adata

In [5]:
dataset_2_celltype = {
    'haniffa': 'full_clustering',
    'covid': 'cell_type',
    'borcherding': 'functional.cluster',
    '10x': 'celltype'
}

## Improved Implementation Cluster Evaluation
The clustering evaluation as provided in the mvTCR code does not scale well for this setting (large dataset, on many resolution + several metrices). To avoid duplicate computation, we speed things up a little bit here. Eventhough this could be performed in the main code as well, I decided to do this only here, to keep the codebase stable between evaluation runs.
In specifics, this is:
- compute the embedding only once
- compute the neighborhood on this embedding
- only then calculate the clustering for various resolutions
- for each resolutions directly evaluate the different labels

In [6]:
from anndata import AnnData
import scanpy as sc
import tcr_embedding.evaluation.Metrics as Metrics
def run_clustering_evaluation(data_full, embedding_function, source_data='val', labels=None,
                              resolutions=None):
    data_eval = data_full[data_full.obs['set'] == source_data]

    embeddings = embedding_function(data_eval)
    embeddings = AnnData(embeddings)
    for label_col in labels:
        embeddings.obs[label_col] = data_eval.obs[label_col].to_numpy()
    
    sc.pp.neighbors(embeddings, n_neighbors=5, use_rep='X', random_state=0)
    
    cluster_scores = {col: [] for col in labels}
    for res in resolutions:
        sc.tl.leiden(embeddings, resolution=res, random_state=0) 
        
        for col_label in labels:
            labels_true = embeddings.obs[col_label].to_numpy()
            labels_pred = embeddings.obs['leiden'].to_numpy()
            
            score = Metrics.get_normalized_mutual_information(labels_true, labels_pred)
            cluster_scores[col_label].append(score)
    cluster_scores = {key: max(values) for key, values in cluster_scores.items()}
    return cluster_scores

## Preservance of Cell type and Clonotype

In [7]:
model_names = []
splits = []
metrics = []
scores = []
datasets = []
for dataset in ['haniffa', 'borcherding', 'covid']:
    print(f'- {dataset}')
    for split in tqdm(range(0, 5)):
        data = load_data(split, dataset)
        for model_name in ['moe', 'tcr', 'rna', 'concat', 'poe']:
            model = load_model(data, dataset, split, model_name)
            test_embedding_func = get_model_prediction_function(model)
            
            cluster_scores = run_clustering_evaluation(data, test_embedding_func, 'train', 
                                                       labels=[dataset_2_celltype[dataset], 'clonotype'],
                                                       resolutions=[0.01, 0.1, 1.0, 10, 25])
            model_names += [model_name] * 2
            splits += [split] * 2
            datasets += [dataset] * 2
            metrics += ['NMI_cell_type', 'NMI_clonotype']
            scores += [cluster_scores[dataset_2_celltype[dataset]], cluster_scores['clonotype']]

results_perservance = {
    'model': model_names,
    'split': splits,
    'metric': metrics,
    'score': scores,
    'dataset': datasets
}
results_perservance = pd.DataFrame(results_perservance)
results_perservance.to_csv('../results/performance_perservance.csv')
results_perservance

- haniffa


100%|███████████████████████████████████████████████████████████████████████████████████| 5/5 [22:50<00:00, 274.06s/it]


- borcherding


100%|████████████████████████████████████████████████████████████████████████████████| 5/5 [6:48:31<00:00, 4902.22s/it]


- covid


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:09<00:00, 13.83s/it]


Unnamed: 0,model,split,metric,score,dataset
0,moe,0,NMI_cell_type,0.302889,haniffa
1,moe,0,NMI_clonotype,0.728880,haniffa
2,tcr,0,NMI_cell_type,0.166110,haniffa
3,tcr,0,NMI_clonotype,0.783031,haniffa
4,rna,0,NMI_cell_type,0.427773,haniffa
...,...,...,...,...,...
145,rna,4,NMI_clonotype,0.746569,covid
146,concat,4,NMI_cell_type,0.420390,covid
147,concat,4,NMI_clonotype,0.816002,covid
148,poe,4,NMI_cell_type,0.364541,covid


In [8]:
results_perservance.groupby(['dataset', 'model', 'metric'])['score'].mean()

dataset      model   metric       
borcherding  concat  NMI_cell_type    0.139368
                     NMI_clonotype    0.657497
             moe     NMI_cell_type    0.119940
                     NMI_clonotype    0.710212
             poe     NMI_cell_type    0.131371
                     NMI_clonotype    0.689875
             rna     NMI_cell_type    0.146637
                     NMI_clonotype    0.629866
             tcr     NMI_cell_type    0.077469
                     NMI_clonotype    0.801840
covid        concat  NMI_cell_type    0.393454
                     NMI_clonotype    0.815538
             moe     NMI_cell_type    0.410863
                     NMI_clonotype    0.812570
             poe     NMI_cell_type    0.366861
                     NMI_clonotype    0.807563
             rna     NMI_cell_type    0.461384
                     NMI_clonotype    0.749611
             tcr     NMI_cell_type    0.334614
                     NMI_clonotype    0.782649
haniffa      concat  NMI_

## Write Supplemantary Material S1

In [9]:
path_out = '../results/supplement/S1_benchmarking.xlsx'
with pd.ExcelWriter(path_out, mode='a') as writer:  
    results_perservance.to_excel(writer, sheet_name='CellCharacteristics')