In [1]:
import sys
import os
import os.path as osp
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as clr
import matplotlib as mpl
from matplotlib import rcParams
import operator
import pandas as pd
import scanpy as sc
import numpy as np
from typing import Optional, Tuple, Sequence, Union
from scipy.spatial.distance import jensenshannon
from scipy.stats import pearsonr,ttest_ind,mannwhitneyu
from sklearn.metrics import mean_squared_error

In [2]:
project_root = os.getcwd()
data_root = osp.abspath(osp.join(project_root, 'application/spot_deconvolution/data')) #Data is accessed by requirement and is stored in data directory
output_root = osp.abspath(osp.join(project_root, 'application/spot_deconvolution/output'))
dstg_root = osp.abspath(osp.join(project_root, 'application/spot_deconvolution/code/DSTG/DSTG_Result'))
evaluate_root = osp.abspath(osp.join(project_root, 'application/spot_deconvolution/evaluate'))

In [3]:
def adjust_order(real_file, pred_file):
    assert real_file.shape[0] == pred_file.shape[0], "the number of spots is not equal!"
    assert real_file.shape[1] == pred_file.shape[1], "the number of cell types is not equal!"
    
    # index
    pred_file.reindex(real_file.index)
    
    # colnumn
    colnames = []
    for i in range(len(pred_file.columns)):
        k = pred_file.columns[i].replace(" ", ".")
        colnames.append(k)
    
    pred_file.columns = colnames
    pred_file = pred_file[real_file.columns.values]
    
    return pred_file

In [4]:
n_cell = ['100', '50', '30', '20', '10', '5']
gd_res_list = []
coord_list = []
cell2location_res_list = []
destvi_res_list = []
dstg_res_list = []
rctd_res_list = []
seurat_res_list = []
spatialdwls_res_list = []
spotlight_res_list = []
stereoscope_res_list = []
tangram_res_list = []

for n in n_cell:
    gd_file_name = 'spot_n' + n + '_prop.csv'
    gd_file_dir = osp.abspath(osp.join(data_root, gd_file_name))
    gd_res = pd.read_csv(gd_file_dir, index_col=0)
    
    # adjust index and colnumn
    spot_name = sorted(gd_res.index, key=lambda x: int("".join([i for i in x if i.isdigit()])))
    gd_res = gd_res.reindex(spot_name)

    # coord and ground truth
    coord = gd_res[['spot_x', 'spot_y']]
    gd_res = gd_res.drop(columns=['spot_x', 'spot_y'])

    ct_name = sorted(gd_res.columns.values)
    gd_res = gd_res[ct_name]

    colnames = []
    for i in range(len(gd_res.columns)):
        k = gd_res.columns[i].replace(" ", ".")
        colnames.append(k)
    
    gd_res.columns = colnames
    
    gd_res_list.append(gd_res)
    coord_list.append(coord)
    
    # Cell2location
    output_file_name = 'Cell2location_' + n + '.csv'
    output_file_dir = osp.abspath(osp.join(output_root, output_file_name))
    cell2location_res = pd.read_csv(output_file_dir, index_col=0)
    cell2location_res = adjust_order(real_file=gd_res, pred_file=cell2location_res)
    cell2location_res_list.append(cell2location_res)

    # DestVI
    output_file_name = 'DestVI_' + n + '.csv'
    output_file_dir = osp.abspath(osp.join(output_root, output_file_name))
    destvi_res = pd.read_csv(output_file_dir, index_col=0)
    destvi_res = adjust_order(real_file=gd_res, pred_file=destvi_res)
    destvi_res_list.append(destvi_res)

    # RCTD
    output_file_name = 'RCTD_' + n + '.csv'
    output_file_dir = osp.abspath(osp.join(output_root, output_file_name))
    rctd_res = pd.read_csv(output_file_dir, index_col=0)
    rctd_res = adjust_order(real_file=gd_res, pred_file=rctd_res)
    rctd_res_list.append(rctd_res)

    # Seurat
    output_file_name = 'Seurat_' + n + '.csv'
    output_file_dir = osp.abspath(osp.join(output_root, output_file_name))
    seurat_res = pd.read_csv(output_file_dir, index_col=0)
    seurat_res = adjust_order(real_file=gd_res, pred_file=seurat_res)
    seurat_res_list.append(seurat_res)

    # spatialDWLS
    output_file_name = 'spatialDWLS_' + n + '.csv'
    output_file_dir = osp.abspath(osp.join(output_root, output_file_name))
    spatialdwls_res = pd.read_csv(output_file_dir, index_col=0)
    spatialdwls_res = adjust_order(real_file=gd_res, pred_file=spatialdwls_res)
    spatialdwls_res_list.append(spatialdwls_res)

    # SPOTlight
    output_file_name = 'SPOTlight_' + n + '.csv'
    output_file_dir = osp.abspath(osp.join(output_root, output_file_name))
    spotlight_res = pd.read_csv(output_file_dir, index_col=0)
    spotlight_res = adjust_order(real_file=gd_res, pred_file=spotlight_res)
    spotlight_res_list.append(spotlight_res)

    # Stereoscope
    output_file_name = 'Stereoscope_' + n + '.csv'
    output_file_dir = osp.abspath(osp.join(output_root, output_file_name))
    stereoscope_res = pd.read_csv(output_file_dir, index_col=0)
    stereoscope_res = adjust_order(real_file=gd_res, pred_file=stereoscope_res)
    stereoscope_res_list.append(stereoscope_res)

    # Tangram
    output_file_name = 'Tangram_' + n + '.csv'
    output_file_dir = osp.abspath(osp.join(output_root, output_file_name))
    tangram_res = pd.read_csv(output_file_dir, index_col=0)
    tangram_res = adjust_order(real_file=gd_res, pred_file=tangram_res)
    tangram_res_list.append(tangram_res)
    
    # DSTG
    output_file_name = 'DSTG_' + n + '_predict_output.csv'
    output_file_dir = osp.abspath(osp.join(dstg_root, output_file_name))
    dstg_res = pd.read_csv(output_file_dir, header=None)
    dstg_res.index = gd_res.index

    sc_ref_meta = 'sc_ref_meta.csv'
    sc_ref_meta_dir = osp.abspath(osp.join(data_root, sc_ref_meta))
    sc_ref_meta = pd.read_csv(sc_ref_meta_dir, index_col=0)
    ct_name = sorted(set(list(sc_ref_meta['Cell_type'])), key=list(sc_ref_meta['Cell_type']).index)
    colnames = []
    for i in range(len(ct_name)):
        k = ct_name[i].replace(" ", ".")
        colnames.append(k)
    
    dstg_res.columns = colnames
    dstg_res = adjust_order(real_file=gd_res, pred_file=dstg_res)
    dstg_res_list.append(dstg_res)
    

In [8]:
def ssim(im1,im2,M=1):
    im1, im2 = im1/im1.max(), im2/im2.max()
    mu1 = im1.mean()
    mu2 = im2.mean()
    sigma1 = np.sqrt(((im1 - mu1) ** 2).mean())
    sigma2 = np.sqrt(((im2 - mu2) ** 2).mean())
    sigma12 = ((im1 - mu1) * (im2 - mu2)).mean()
    k1, k2, L = 0.01, 0.03, M
    C1 = (k1*L) ** 2
    C2 = (k2*L) ** 2
    C3 = C2/2
    l12 = (2*mu1*mu2 + C1)/(mu1 ** 2 + mu2 ** 2 + C1)
    c12 = (2*sigma1*sigma2 + C2)/(sigma1 ** 2 + sigma2 ** 2 + C2)
    s12 = (sigma12 + C3)/(sigma1*sigma2 + C3)
    ssim = l12 * c12 * s12
    return ssim

def rmse(x1,x2):
    return mean_squared_error(x1,x2,squared=False)
def mae(x1,x2):
    return np.mean(np.abs(x1-x2))

from collections.abc import Iterable
def compare_results(gd,result_list,metric='pcc',columns=None,axis=1):
    if metric=='pcc':
        func = pearsonr
        r_ind = 0
    if metric=='mae':
        func = mae
        r_ind = None
    if metric=='jsd':
        func = jensenshannon
        r_ind = None
    if metric=='rmse':
        func = rmse
        r_ind = None
    if metric=='ssim':
        func = ssim
        r_ind = None
    if isinstance(result_list, pd.DataFrame):
        c_list = []
        if axis == 1:
            print('axis: ',1)
            for i,c in enumerate(gd.columns):
                r = func(gd.iloc[:,i].values, np.clip(result_list.iloc[:,i],0,1))
                if isinstance(result_list, Iterable):
                    if r_ind is not None:
                        r = r[r_ind]
                c_list.append(r)
        else:
            print('axis: ',0)
            for i,c in enumerate(gd.index):
                r = func(gd.iloc[i,:].values, np.clip(result_list.iloc[i,:],0,1))
                if isinstance(result_list, Iterable):
                    if r_ind is not None:
                        r = r[r_ind]
                c_list.append(r)
        df = pd.DataFrame(c_list,index=gd.columns,columns=columns)
    else:
        df_list = []
        for res in result_list:
            c_list = []
            if axis == 1:
                for i,c in enumerate(gd.columns):
                    r = func(gd.iloc[:,i].values, np.clip(res.iloc[:,i],0,1))
                    if isinstance(res, Iterable):
                        if r_ind is not None:
                            r = r[r_ind]
                    c_list.append(r)
                df_tmp = pd.DataFrame(c_list,index=gd.columns)
            else:
                for i,c in enumerate(gd.index):
                    r = func(gd.iloc[i,:].values, np.clip(res.iloc[i,:],0,1))
                    if isinstance(res, Iterable):
                        if r_ind is not None:
                            r = r[r_ind]
                    c_list.append(r)
                df_tmp = pd.DataFrame(c_list,index=gd.index)
            df_list.append(df_tmp)
        df = pd.concat(df_list,axis=1)
        df.columns = columns
    return df

In [None]:
dataset = {100:0, 50:1, 30:2, 20:3, 10:4, 5:5} # n_cell = 100, 50, 30, 20, 10, 5
# pcc
df_spot_list = []
df_cluster_list = []
for n,d in dataset.items(): 
    # spot
    df_spot = compare_results(
        gd_res_list[d],
        [cell2location_res_list[d],destvi_res_list[d],dstg_res_list[d],rctd_res_list[d],seurat_res_list[d],
         spatialdwls_res_list[d],spotlight_res_list[d],stereoscope_res_list[d],tangram_res_list[d]],
        columns = ['Cell2location','DestVI','DSTG','RCTD','Seurat','spatialDWLS','SPOTlight','Stereoscope','Tangram'],
        axis=0,
        metric='pcc'
    )
    df_spot['n_cell'] = n
    df_spot_list.append(df_spot)
    
    # clusters
    df_cluster = compare_results(
        gd_res_list[d],
        [cell2location_res_list[d],destvi_res_list[d],dstg_res_list[d],rctd_res_list[d],seurat_res_list[d],
         spatialdwls_res_list[d],spotlight_res_list[d],stereoscope_res_list[d],tangram_res_list[d]],
        columns = ['Cell2location','DestVI','DSTG','RCTD','Seurat','spatialDWLS','SPOTlight','Stereoscope','Tangram'],
        axis=1,
        metric='pcc'
    )
    df_cluster['n_cell'] = n
    df_cluster_list.append(df_cluster)

pcc_all_spot = pd.concat(df_spot_list,axis=0)
pcc_all_cluster = pd.concat(df_cluster_list,axis=0)

# ssim
df_spot_list = []
df_cluster_list = []
for n,d in dataset.items(): 
    # spot
    df_spot = compare_results(
        gd_res_list[d],
        [cell2location_res_list[d],destvi_res_list[d],dstg_res_list[d],rctd_res_list[d],seurat_res_list[d],
         spatialdwls_res_list[d],spotlight_res_list[d],stereoscope_res_list[d],tangram_res_list[d]],
        columns = ['Cell2location','DestVI','DSTG','RCTD','Seurat','spatialDWLS','SPOTlight','Stereoscope','Tangram'],
        axis=0,
        metric='ssim'
    )
    df_spot['n_cell'] = n
    df_spot_list.append(df_spot)
    
    # clusters
    df_cluster = compare_results(
        gd_res_list[d],
        [cell2location_res_list[d],destvi_res_list[d],dstg_res_list[d],rctd_res_list[d],seurat_res_list[d],
         spatialdwls_res_list[d],spotlight_res_list[d],stereoscope_res_list[d],tangram_res_list[d]],
        columns = ['Cell2location','DestVI','DSTG','RCTD','Seurat','spatialDWLS','SPOTlight','Stereoscope','Tangram'],
        axis=1,
        metric='ssim'
    )
    df_cluster['n_cell'] = n
    df_cluster_list.append(df_cluster)

ssim_all_spot = pd.concat(df_spot_list,axis=0)
ssim_all_cluster = pd.concat(df_cluster_list,axis=0)

# rmse
df_spot_list = []
df_cluster_list = []
for n,d in dataset.items(): 
    # spot
    df_spot = compare_results(
        gd_res_list[d],
        [cell2location_res_list[d],destvi_res_list[d],dstg_res_list[d],rctd_res_list[d],seurat_res_list[d],
         spatialdwls_res_list[d],spotlight_res_list[d],stereoscope_res_list[d],tangram_res_list[d]],
        columns = ['Cell2location','DestVI','DSTG','RCTD','Seurat','spatialDWLS','SPOTlight','Stereoscope','Tangram'],
        axis=0,
        metric='rmse'
    )
    df_spot['n_cell'] = n
    df_spot_list.append(df_spot)
    
    # clusters
    df_cluster = compare_results(
        gd_res_list[d],
        [cell2location_res_list[d],destvi_res_list[d],dstg_res_list[d],rctd_res_list[d],seurat_res_list[d],
         spatialdwls_res_list[d],spotlight_res_list[d],stereoscope_res_list[d],tangram_res_list[d]],
        columns = ['Cell2location','DestVI','DSTG','RCTD','Seurat','spatialDWLS','SPOTlight','Stereoscope','Tangram'],
        axis=1,
        metric='rmse'
    )
    df_cluster['n_cell'] = n
    df_cluster_list.append(df_cluster)

rmse_all_spot = pd.concat(df_spot_list,axis=0)
rmse_all_cluster = pd.concat(df_cluster_list,axis=0)

# jsd
df_spot_list = []
df_cluster_list = []
for n,d in dataset.items(): 
    # spot
    df_spot = compare_results(
        gd_res_list[d],
        [cell2location_res_list[d],destvi_res_list[d],dstg_res_list[d],rctd_res_list[d],seurat_res_list[d],
         spatialdwls_res_list[d],spotlight_res_list[d],stereoscope_res_list[d],tangram_res_list[d]],
        columns = ['Cell2location','DestVI','DSTG','RCTD','Seurat','spatialDWLS','SPOTlight','Stereoscope','Tangram'],
        axis=0,
        metric='jsd'
    )
    df_spot['n_cell'] = n
    df_spot_list.append(df_spot)
    
    # clusters
    df_cluster = compare_results(
        gd_res_list[d],
        [cell2location_res_list[d],destvi_res_list[d],dstg_res_list[d],rctd_res_list[d],seurat_res_list[d],
         spatialdwls_res_list[d],spotlight_res_list[d],stereoscope_res_list[d],tangram_res_list[d]],
        columns = ['Cell2location','DestVI','DSTG','RCTD','Seurat','spatialDWLS','SPOTlight','Stereoscope','Tangram'],
        axis=1,
        metric='jsd'
    )
    df_cluster['n_cell'] = n
    df_cluster_list.append(df_cluster)

jsd_all_spot = pd.concat(df_spot_list,axis=0)
jsd_all_cluster = pd.concat(df_cluster_list,axis=0)

In [10]:
pcc_all_spot.to_csv(osp.abspath(osp.join(evaluate_root, 'pcc_all_spot.csv')))
pcc_all_cluster.to_csv(osp.abspath(osp.join(evaluate_root, 'pcc_all_cluster.csv')))
ssim_all_spot.to_csv(osp.abspath(osp.join(evaluate_root, 'ssim_all_spot.csv')))
ssim_all_cluster.to_csv(osp.abspath(osp.join(evaluate_root, 'ssim_all_cluster.csv')))
rmse_all_spot.to_csv(osp.abspath(osp.join(evaluate_root, 'rmse_all_spot.csv')))
rmse_all_cluster.to_csv(osp.abspath(osp.join(evaluate_root, 'rmse_all_cluster.csv')))
jsd_all_spot.to_csv(osp.abspath(osp.join(evaluate_root, 'jsd_all_spot.csv')))
jsd_all_cluster.to_csv(osp.abspath(osp.join(evaluate_root, 'jsd_all_cluster.csv')))