In [None]:
import os
import umap
import scanpy
import pickle
import warnings
import operator
import itertools
import umap.plot
import numpy as np
import pandas as pd
from utils import *
import seaborn as sns
from tqdm import tqdm as tqdm
import matplotlib.pyplot as plt
from collections import Counter
from matplotlib.gridspec import GridSpec
from sklearn.metrics.pairwise import cosine_similarity

# 1. Load data

In [None]:
dataset="pancreas" # "pancreas" or "pancreas"
#map data objects into batch/lebel keys
data_keys={"lung":{"batch_key":"batch","label_key":"cell_type", "file_name": "lung_unintegrated"},
           "pancreas":{"batch_key":"tech","label_key":"celltype", "file_name": "pancreas_unintegrated"}}
methods = [ "SciTuna",'Scanorama', 'fastMNN', 'Seurat', 'SAUCIE']

### 1.2  Original Datasets (unintegrated)

In [None]:
print("Loading dataset")
unintegrated_data=scanpy.read_h5ad("data/{}.h5ad".format(data_keys[dataset]["file_name"]))

### 1.1 Batch Pairs 

In [None]:
# retreive batch pairs as tuples
datasets = ["{}_{}".format(a,b) for idx, a in enumerate(np.unique(unintegrated_data.obs[data_keys[dataset]["batch_key"]])) for b in np.unique(unintegrated_data.obs[data_keys[dataset]["batch_key"]])[idx + 1:]]
print("There are :",len(datasets)," batch pairs.")
datasets[:4]

# 2. Methods vs. Metrics

### 2.1 Metric Scores

In [None]:
outputs_folder = "output/{}/".format(dataset)
'''
The function 'load_and_combine_scores' fetches results for each method and dataset 
individually, subsequently merging them into respective aggregated datasets for each 
method. This method is useful if the evaluation notebook terminated before combining 
scores for each method into a single file. Otherwise, use the 'load_combined_scores' 
function, which loads the combined results for each method.
'''

methods_vs_metrics = load_and_combine_scores(outputs_folder, methods, datasets)
#methods_vs_metrics = load_combined_scores(outputs_folder, methods, datasets)

# 3. Similarity scores

In [None]:
similarity_scores = calculate_sim_scores(unintegrated_data, data_keys[dataset]["batch_key"], data_keys[dataset]["label_key"], datasets)
sorted_similarity_scores = sorted(similarity_scores.items(), key=operator.itemgetter(1))
sorted_similarity_scores

# 4. Aggregated scores

### 4.1 Mapping dictionaries for renaming methods and scores

In [None]:
bio_metrics = ['NMI cluster/label', 
               'ARI cluster/label', 
               'Cell type ASW',
               'Isolated label silhouette',
               'Isolated label F1',
               'CC conservation', 
               'HVG conservation',
               'cLISI',
               '1 - Over correction']

batch_metrics = ['Batch ASW', 
                 'PCR batch', 
                 'Graph connectivity', 
                 'iLISI']

len(bio_metrics) + len(batch_metrics)

In [None]:
alpha = 0.4
agg_scores = None
for method in methods_vs_metrics:
    method_scores = methods_vs_metrics[method].dropna(axis= 0)
    if agg_scores is None:
        agg_scores = pd.DataFrame(method_scores.mean(axis=1), columns = [method]).copy()
    else:
        agg_scores = pd.concat([agg_scores, pd.DataFrame(method_scores.mean(axis = 1), columns = [method])], axis = 1)
agg_scores=agg_scores.transpose()
agg_scores["Batch correction"] = agg_scores[batch_metrics].mean(axis = 1)
agg_scores["Biological conservation"] = agg_scores[bio_metrics].mean(axis = 1)
agg_scores["Overall score"] = (0.4 * agg_scores["Batch correction"]) + ((0.6) * agg_scores["Biological conservation"])
agg_scores = np.round(agg_scores, 3)

In [None]:
selected_metrics = ["Overall score", "Biological conservation", "Batch correction"]

### 4.3 Overall scores

In [None]:
agg_scores[[selected_metrics[0]]].sort_values(selected_metrics[0], ascending = False)

### 4.4 Biological scores

In [None]:
agg_scores[[selected_metrics[1]]].sort_values(selected_metrics[1], ascending = False)

In [None]:
agg_scores[bio_metrics]

### 4.5 Batch scores

In [None]:
agg_scores[[selected_metrics[2]]].sort_values(selected_metrics[2], ascending = False)

In [None]:
agg_scores[batch_metrics]

# 5. Plots - Aggregated Scores

In [None]:
try:
    os.mkdir("/Users/a.h./Desktop/SciTuna/output/{}/metric_plots/".format(dataset))
except:
    print("Folder exists..")

#initialize the plot parameters
plt.rcParams["font.family"] = "Times New Roman"
params = {
    "width" : 0.6, 
    "color" : ['#7F2400', '#FFD23F', '#28A097', '#13426C', '#F26A26'], 
    "ylim" : (0.0,1.0), 
    "legend" : False, 
    "fontsize": 30
}

In [None]:
overall_score = agg_scores[["Overall score"]].loc[methods]
bio_scores = agg_scores[["Biological conservation"] + bio_metrics].loc[methods]
batch_scores = agg_scores[["Batch correction"] + batch_metrics].loc[methods]
file_name = "agg_scores.pdf"
title = ""
plot_metrics(overall_score, bio_scores, batch_scores, title,file_name, params, outputs_folder)

# 6. Plots - Batch pairs

In [None]:
ids_map = {
    '_': ' ',
    '-': ' ',
    'celseq':'CEL-Seq',
    'celseq2':"CEL-Seq2" ,
    'smarter':'SMARTer',
    'smartseq2':"SMART-Seq2",
    'fluidigmc1':"Fluidigm C1" ,
    'inDrop1':"inDrop 1",
    'inDrop2':"inDrop 2",
    'inDrop3':"inDrop 3",
    'inDrop4':"inDrop 4",
}

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    for pair in datasets:
            
        print(pair)
        pair_x_scores = None
        for method in methods_vs_metrics:
            pair_x_method_y_scores = methods_vs_metrics[method][[pair]].transpose().copy().rename(index = {pair: method})
            if pair_x_scores is None:
                pair_x_scores = pair_x_method_y_scores.copy()
            else:
                pair_x_scores = pd.concat([pair_x_scores, pair_x_method_y_scores], axis = 0)

        
        pair_x_scores["Batch correction"] = pair_x_scores[batch_metrics].mean(axis = 1)
        pair_x_scores["Biological conservation"] = pair_x_scores[bio_metrics].mean(axis = 1)
        pair_x_scores["Overall score"] = (alpha * pair_x_scores["Batch correction"]) + ((1. - alpha) * pair_x_scores["Biological conservation"])
        pair_x_scores = np.round(pair_x_scores, 3)
        print(pair_x_scores[["Overall score", "Biological conservation", "Batch correction"]])
        
        
        #plots
        overall_score = pair_x_scores[["Overall score"]].loc[methods]
        bio_scores = pair_x_scores[["Biological conservation"] + bio_metrics].loc[methods]
        batch_scores = pair_x_scores[["Batch correction"] + batch_metrics].loc[methods]
        file_name = "{}_scores.pdf".format(pair)

        title = "{} | {}".format(ids_map[pair.split("_")[0]], ids_map[pair.split("_")[1]])
        plot_metrics(overall_score, bio_scores, batch_scores, title,file_name, params, outputs_folder)
       
       

# 7. Plots - Umap plots

### 7.1 Prepare UMAPs

In [None]:
try:
    os.mkdir('{}/umap_plots/'.format(outputs_folder))
except:
    pass

### 7.2 Analyzing UMAPs

In [None]:
colors = ['#6F0747', '#073B31', '#052A94', '#8D8CFF', '#1EBF2B', '#1DC7B3', '#128B8D', '#6CBC6A', '#BD8993', '#C07A48', '#EBAE83', '#000000', '#536E8B', '#773276', '#17BECF', '#AEC7E8', '#FFBB78']
sns_markers = ['o', 's', 'D', '^', 'v', '<', '>', 'P', 'X', '*', 'H', 'd', 'p', '8', 's', 'd', 'D']
color_palette = {}
markers = {}
i = 0
for ct in (np.unique(unintegrated_data.obs[data_keys[dataset]["label_key"]])):
    
    color_palette[ct] = colors[i]
    markers [ct] = sns_markers[i]
    i+=1
    
for pair in datasets:
    pair = pair.split("_")
    if pair[0] not in color_palette:
        color_palette[pair[0]] = "#052A94"
    if pair[1] not in color_palette:
        color_palette[pair[1]] = "#ACA106"
    markers[pair[0]] = sns_markers[0]
    markers[pair[1]] = sns_markers[0]

In [None]:
# Extract UMAP coordinates and cell types
font_size = 13
types = {
    "ct" : data_keys[dataset]["label_key"],
    "bt" : data_keys[dataset]["batch_key"]
}

legend_title_map = {
    "ct": "Cell types",
    "bt": "Batches"
}


def format_label(label, ids_map):
    for old, new in ids_map.items():
        label = label.replace(old, new)
    return label[0].upper()+label[1:]


# Create a figure with subplots

columns_vs_rows={}
umaps = {}
for pair in sorted_similarity_scores:
    umaps = {}

    if os.path.isfile('{}/umap_plots/{}_umap.png'.format(outputs_folder, pair)):
        continue
    print(pair)
    pair = pair[0]
    row = 0
    fig, axes = plt.subplots(2, 6, figsize=(25, 8), sharex=False, sharey=False)
    print(pair)
    umaps[pair] = {}    
    color_palette[pair.split("_")[0]] = "#052A94"
    color_palette[pair.split("_")[1]] = "#ACA106"
    markers[pair[0]] = sns_markers[0]
    markers[pair[1]] = sns_markers[0]
    init = True
    columns_vs_rows[pair]={}
    for method in methods:
        print("\t",method)
        method_output = scanpy.read_h5ad("{}/{}/integrated/{}.h5ad".format(outputs_folder,pair,method))
        if init:
            init=False
            columns_vs_rows[pair]["cells"]=method_output.obs_names
            columns_vs_rows[pair]["genes"]=method_output.var_names
            unintegratedPairX = unintegrated_data[unintegrated_data.obs[data_keys[dataset]["batch_key"]].isin(pair.split("_"))][columns_vs_rows[pair]["cells"],columns_vs_rows[pair]["genes"]]
            scanpy.pp.neighbors(unintegratedPairX)
            scanpy.tl.umap(unintegratedPairX, n_components=2)
            umaps[pair]["unintegrated"] = unintegratedPairX.copy()
            
        method_output = method_output[columns_vs_rows[pair]["cells"],columns_vs_rows[pair]["genes"]]
        scanpy.pp.neighbors(method_output)
        scanpy.tl.umap(method_output, n_components=2)
        umaps[pair][method] = method_output.copy()
    for _type in types:
        col = 0
        for method in umaps[pair]:
            if method not in methods:
                if "uni" not in method:
                    continue

            method_umap = umaps[pair][method][columns_vs_rows[pair]["cells"], columns_vs_rows[pair]["genes"]].copy()
            #method_umap = method_umap[~method_umap.obs[types[_type]].isin(["Macrophage","Type 2", "T/NK cell","Endothelium"])]
            umap_coords = method_umap.obsm['X_umap']
            ids = method_umap.obs[types[_type]]

            # Create a DataFrame for easier manipulation
            umap_df = pd.DataFrame(umap_coords, columns=['UMAP1', 'UMAP2'])
            umap_df[types[_type]] = ids.values

            # Plotting on the corresponding subplot
            ax = axes[row][col]
            sns.scatterplot(
                x='UMAP1', y='UMAP2',
                hue=types[_type],
                style=types[_type],
                palette=color_palette,
                markers=markers,
                data=umap_df,
                s=5, 
                linewidth=0.1,
                ax=ax
            )

            # Customize the legend
            if "uni" in method:
                title = "Before integration"
            elif "SciTuna" in method:
                title = "SciTuna"
            else:
                title = method
            ax.set_title(title, fontsize=font_size)

            ax.set_xlabel("UMAP1", fontsize=font_size)
            ax.set_ylabel("UMAP2", fontsize=font_size)
            ax.tick_params(axis='both', which='major', labelsize=font_size)

            # Add the legend only to the fifth subplot
            if col == 5:
                handles, labels = ax.get_legend_handles_labels()
                formatted_labels = [format_label(label, ids_map) for label in labels]
                ax.legend(
                    handles, 
                    formatted_labels, 
                    title=legend_title_map[_type], 
                    title_fontsize=font_size,  # Adjust legend title font size here
                    bbox_to_anchor=(1.05, 1), 
                    loc='upper left', 
                    markerscale=1,
                    fontsize=font_size
                )

            else:
                ax.get_legend().remove()

            col+=1
        row+=1
    
    # Adjust layout
#     plt.title("{} | {}".format(replace_map[pair.split("_")[0]], replace_map[pair.split("_")[1]]), fontsize=font_size)
    plt.tight_layout()
    plt.savefig('{}/umap_plots/{}_umap.png'.format(outputs_folder, pair),dpi=300, bbox_inches = 'tight')
    plt.show()
    del umaps