# Cluster Scores
The following scores:
- ASW (Silhouette Score)
- NMI (Normalized Mutual Information)
- ARI (Adjusted Random Index)

In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import torch
import os

In [None]:
import sys
sys.path.append('../../..')
import tcr_embedding as tcr

from tcr_embedding.evaluation.Imputation import run_imputation_evaluation
from tcr_embedding.evaluation.WrapperFunctions import get_model_prediction_function
from tcr_embedding.evaluation.Clustering import run_clustering_evaluation

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
random_seed = 42
import torch
import numpy as np
import random
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

In [None]:
def init_and_load_model(donor_nr, model_arch):
    model_files = os.listdir('trained_models/')
    model_fp = [file for file in model_files if 'd'+donor_nr in file and model_arch in file][0]

    print('Loading: ', model_fp)
    file_path = os.path.join('trained_models/', model_fp)
    model_file = torch.load(file_path)
    params = model_file['params']
    
    if 'single' in model_fp and 'separate' not in model_fp:
        init_model = tcr.models.single_model.SingleModel
    elif 'moe' in model_fp:
        init_model = tcr.models.moe.MoEModel
    elif 'poe' in model_fp:
        init_model = tcr.models.poe.PoEModel
    elif 'separate' in model_fp:
        init_model = tcr.models.separate_model.SeparateModel
    else:
        init_model = tcr.models.joint_model.JointModel

    model = init_model(
        adatas=[adata],  # adatas containing gene expression and TCR-seq
        aa_to_id=adata.uns['aa_to_id'],  # dict {aa_char: id}
        seq_model_arch=params['seq_model_arch'],  # seq model architecture
        seq_model_hyperparams=params['seq_model_hyperparams'],  # dict of seq model hyperparameters
        scRNA_model_arch=params['scRNA_model_arch'],
        scRNA_model_hyperparams=params['scRNA_model_hyperparams'],
        zdim=params['zdim'],  # zdim
        hdim=params['hdim'],  # hidden dimension of scRNA and seq encoders
        activation=params['activation'],  # activation function of autoencoder hidden layers
        dropout=params['dropout'],
        batch_norm=params['batch_norm'],
        shared_hidden=params['shared_hidden'],  # hidden layers of shared encoder / decoder
        names=['10x'],
        gene_layers=[],  # [] or list of str for layer keys of each dataset
        seq_keys=[]  # [] or list of str for seq keys of each dataset
    )

    model.load(file_path)
    return model

In [None]:
donor_1_high_count_antigens = ['A1101_IVTDFSVIK_EBNA-3B_EBV_binder',
                               'A0301_KLGGALQAK_IE-1_CMV_binder',
                               'A0201_GILGFVFTL_Flu-MP_Influenza_binder',
                               'A1101_AVFDRKSDAK_EBNA-3B_EBV_binder',
                               'A0201_ELAGIGILTV_MART-1_Cancer_binder']
donor_2_high_count_antigens = ['B0801_RAKFKQLL_BZLF1_EBV_binder',
                               'A0201_GILGFVFTL_Flu-MP_Influenza_binder',
                               'A0301_KLGGALQAK_IE-1_CMV_binder',
                               'A0201_GLCTLVAML_BMLF1_EBV_binder',
                               'A1101_AVFDRKSDAK_EBNA-3B_EBV_binder']

donor_specific_high_count_antigens = {'1': donor_1_high_count_antigens, 
                                      '2': donor_2_high_count_antigens}

In [None]:
donor_1_umap_params = {'scRNA': {'min_dist': 0.5,
                                 'spread': 1.0},
                       'single_transformer': {'min_dist': 1.2,
                                              'spread': 1.0},
                       'joint': {'min_dist': 0.5,
                                 'spread': 0.2},
                       'poe': {'min_dist': 1.0,
                                 'spread': 0.5}}
donor_2_umap_params = {'scRNA': {'min_dist': 0.5,
                                 'spread': 0.8},
                       'single_transformer': {'min_dist': 1.2,
                                              'spread': 0.5},
                       'joint': {'min_dist': 0.8,
                                 'spread': 0.5},
                       'poe': {'min_dist': 0.8,
                                 'spread': 0.5}}
names = {'scRNA': 'RNA-only', 'single_transformer': 'TCR-only', 'joint': 'Concat', 'poe':'PoE'}

In [None]:
selected_genes = ['CD8A', 'CD8B', 'GZMA', 'CCL5', 'GZMH', 'KLRD1', 'IFNG', 'ISG15', 'IFI6', 'IFITM1', 'MX1', 'IFITM3', 'OAS1']
selected_genes_obs = [gene + '_' for gene in selected_genes]

In [None]:
adata_total = sc.read('../../../data/10x_CD8TC/v6_supervised.h5ad')

adata_total.obs['binding_label'][~adata_total.obs['binding_name'].isin(tcr.constants.HIGH_COUNT_ANTIGENS)] = -1
adata_total.obs['binding_name'][~adata_total.obs['binding_name'].isin(tcr.constants.HIGH_COUNT_ANTIGENS)] = 'no_data'
# For visualization purpose, else the scanpy plot script thinks the rare specificities are still there and the colors get skewed
adata_total.obs['binding_name'] = adata_total.obs['binding_name'].astype(str)

Various different settings:
- Use X or use latent representation - Use latent representation, as it separates the antigen-binding better
- I am doing differential expression on the preprocessed data (log-transform, normalization, etc.), does that make sense?
- Use non-binder or not? Don't use, since we are not interested in non-binders in general and it clutters the figures
- How high to choose the resolution? Whatever looks best in clustering
- Does it make sense to contain TRVA/Bs and HLAs?
- Use train, val or test? This is not supervised anymore, so we can also just use the train set.
- Maybe we look into the individual bindings? It declutters the plots a lot, but may be that's cheating?
- General question: Does it make sense to sample a latent space? Instead we can always use mu, which would give more stable performance measures, and probably better (truncation trick)


# Donor 1

In [None]:
donor_nr = '1'

adata = adata_total[adata_total.obs['donor'] == 'donor_'+donor_nr]
adata = adata[adata.obs['binding_name'].isin(donor_specific_high_count_antigens[donor_nr])]
adata.obs['set'] = 'train'

In [None]:
adata.obs['binding_name'].value_counts()

In [None]:
models = ['scRNA', 'single_transformer', 'joint', 'poe']
cluster_results = []

for i in range(len(models)):
    current_model = models[i]
    print(current_model)
    model = init_and_load_model(donor_nr, current_model)
    test_embedding_func = get_model_prediction_function(model, batch_size=1024)
    
    resolutions = [0.01, 0.1, 1.0]
    for resolution in resolutions:
        cluster_result = run_clustering_evaluation(adata, test_embedding_func, 'train', name_label='binding_name', 
                                                   cluster_params={'resolution': resolution, 'num_neighbors': 5}, visualize=True)
        cluster_result['resolution'] = resolution
        cluster_result['model'] = current_model
        cluster_results.append(cluster_result)

In [None]:
cluster_results = pd.DataFrame(cluster_results)
cluster_results

In [None]:
for metric in ['silhouette_score', 'AMI', 'NMI', 'ARI']:
    
    for current_model in models:
        scores = cluster_results[cluster_results['model'] == current_model][metric]
        plt.plot(cluster_results[cluster_results['model'] == current_model]['resolution'], scores, label=current_model)
    
    plt.ylabel(metric)
    plt.xlabel('resolution')
    plt.legend()
    plt.show()

In [None]:
cluster_results1 = cluster_results

In [None]:
for current_model in models:
    print(current_model)
    print(cluster_results1[cluster_results1['model'] == current_model].max())
    print('\n')

In [None]:
cluster_results1.to_csv('donor1_clustering_results_resolution.csv')

# Donor 2

In [None]:
donor_nr = '2'

adata = adata_total[adata_total.obs['donor'] == 'donor_'+donor_nr]
adata = adata[adata.obs['binding_name'].isin(donor_specific_high_count_antigens[donor_nr])]
adata.obs['set'] = 'train'

In [None]:
adata.obs['binding_name'].value_counts()

In [None]:
models = ['scRNA', 'single_transformer', 'joint', 'poe']
cluster_results = []

for i in range(len(models)):
    current_model = models[i]
    print(current_model)
    model = init_and_load_model(donor_nr, current_model)
    test_embedding_func = get_model_prediction_function(model, batch_size=1024)
    
    resolutions = [0.01, 0.1, 1.0]
    for resolution in resolutions:
        cluster_result = run_clustering_evaluation(adata, test_embedding_func, 'train', name_label='binding_name', 
                                                   cluster_params={'resolution': resolution, 'num_neighbors': 5}, visualize=True)
        cluster_result['resolution'] = resolution
        cluster_result['model'] = current_model
        cluster_results.append(cluster_result)

In [None]:
cluster_results = pd.DataFrame(cluster_results)
cluster_results

In [None]:
for metric in ['silhouette_score', 'AMI', 'NMI', 'ARI']:
    
    for current_model in models:
        scores = cluster_results[cluster_results['model'] == current_model][metric]
        plt.plot(cluster_results[cluster_results['model'] == current_model]['resolution'], scores, label=current_model)
    
    plt.ylabel(metric)
    plt.xlabel('resolution')
    plt.legend()
    plt.show()

In [None]:
cluster_results2 = cluster_results

In [None]:
for current_model in models:
    print(current_model)
    print(cluster_results2[cluster_results2['model'] == current_model].max())
    print('\n')

In [None]:
cluster_results2.to_csv('donor2_clustering_results_resolution.csv')