In [15]:
import CeLEry as cel

import os,csv,re
import pandas as pd
import numpy as np
import scanpy as sc
import math
from skimage import io, color

from scipy.sparse import issparse
import random, torch
import warnings
warnings.filterwarnings("ignore")
import pickle
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from anndata import AnnData, read_h5ad
import seaborn as sns
from tqdm import tqdm
import matplotlib.ticker as mtick
import scipy

import json

In [None]:
data_merfish = read_h5ad("data/liver_merfish.h5ad")

Rdata_ind, Qdata_ind, _, _ =train_test_split(range(data_merfish.shape[0]), data_merfish.obs['louvain'], test_size=0.5,random_state=1,stratify=data_merfish.obs['louvain'])

Rdata = data_merfish[np.sort(Rdata_ind), :]
Qdata = data_merfish[np.sort(Qdata_ind), :]
Qdata_all = Qdata.copy()


## 25% data for comparison
Rdata_ind, Qdata_ind, _, _ =train_test_split(range(Rdata.shape[0]), Rdata.obs['louvain'], test_size=0.25,random_state=1,stratify=Rdata.obs['louvain'])
Rdata = Rdata[np.sort(Qdata_ind), :]

Rdata_ind, Qdata_ind, _, _ =train_test_split(range(Qdata.shape[0]), Qdata.obs['louvain'], test_size=0.25,random_state=1,stratify=Qdata.obs['louvain'])
Qdata = Qdata[np.sort(Qdata_ind), :]

In [9]:
celery_pred = np.load("output/liver/celery_liver.npy", allow_pickle=False)
tangram_pred = np.load("output/liver/tangram_liver.npy", allow_pickle=False)
spaotsc_pred = np.load("output/liver/spaotsc_liver.npy", allow_pickle=False)
novosparc_pred = np.load("output/liver/novosparc_liver.npy", allow_pickle=False)

In [10]:
Qdata_all.obs['x_celery'] = celery_pred[:, 0]
Qdata_all.obs['y_celery'] = celery_pred[:, 1]

In [11]:
celery_pred = celery_pred[np.sort(Qdata_ind),:]
Qdata.obs['x_celery'] = celery_pred[:, 0]
Qdata.obs['y_celery'] = celery_pred[:, 1]

Qdata.obs['x_tangram'] = tangram_pred[:, 0]
Qdata.obs['y_tangram'] = tangram_pred[:, 1]

Qdata.obs['x_spaotsc'] = spaotsc_pred[:, 0]
Qdata.obs['y_spaotsc'] = spaotsc_pred[:, 1]

Qdata.obs['x_novosparc'] = novosparc_pred[:, 0]
Qdata.obs['y_novosparc'] = novosparc_pred[:, 1]

Spatial prediction plot colored by cell type

In [14]:
x_min = np.min([np.min(Qdata.obs['x_cord']), np.min(Qdata_all.obs['x_celery']), np.min(Qdata.obs['x_spaotsc']), np.min(Qdata.obs['x_novosparc'])]) - 150
y_min = np.min([np.min(Qdata.obs['y_cord']), np.min(Qdata_all.obs['y_celery']), np.min(Qdata.obs['y_spaotsc']), np.min(Qdata.obs['y_novosparc'])]) - 150
x_max = np.max([np.max(Qdata.obs['x_cord']), np.max(Qdata.obs['x_celery']), np.max(Qdata.obs['x_spaotsc']), np.max(Qdata.obs['x_novosparc'])]) + 150
y_max = np.max([np.max(Qdata.obs['y_cord']), np.max(Qdata.obs['y_celery']), np.max(Qdata.obs['y_spaotsc']), np.max(Qdata.obs['y_novosparc'])]) + 150

fig, axes = plt.subplots(1, 5, figsize=(30, 6))
sns.scatterplot(data=Qdata_all.obs, x="x_cord", y="y_cord",s=1, hue="louvain",hue_order=list(pd.Series(data_merfish.obs['louvain']).value_counts().keys()), 
                ax=axes[0], legend=False).set(title="Truth", xlabel=None, ylabel=None, xlim=(x_min, x_max), ylim=(y_min, y_max))
sns.scatterplot(data=Qdata_all.obs, x="x_celery", y="y_celery",s=1, hue="louvain",hue_order=list(pd.Series(data_merfish.obs['louvain']).value_counts().keys()), 
                ax=axes[1], legend=False).set(title="CeLEry", xlabel=None, ylabel=None, xlim=(x_min, x_max), ylim=(y_min, y_max))
sns.scatterplot(data=Qdata.obs, x="x_tangram", y="y_tangram",s=1, hue="louvain", hue_order=list(pd.Series(data_merfish.obs['louvain']).value_counts().keys()),
                ax=axes[2], legend=False).set(title="Tangram", xlabel=None, ylabel=None, xlim=(x_min, x_max), ylim=(y_min, y_max))
sns.scatterplot(data=Qdata.obs, x="x_spaotsc", y="y_spaotsc",s=1, hue="louvain", hue_order=list(pd.Series(data_merfish.obs['louvain']).value_counts().keys()), 
                ax=axes[3], legend=False).set(title="SpaOTsc", xlabel=None, ylabel=None, xlim=(x_min, x_max), ylim=(y_min, y_max))
sns.scatterplot(data=Qdata.obs, x="x_novosparc", y="y_novosparc",s=1, hue="louvain", hue_order=list(pd.Series(data_merfish.obs['louvain']).value_counts().keys()), 
                ax=axes[4]).set(title="novoSpaRc", xlabel=None, ylabel=None, xlim=(x_min, x_max), ylim=(y_min, y_max))

plt.legend(loc="upper right", title="Cell type", bbox_to_anchor=(1.3, 0.75))

figname = "output/liver/plot/pred_type.png"
fig.savefig(figname)
plt.close(fig)

Correlation between true X/Y value and predicted X/Y value

In [16]:
x_corr = [scipy.stats.pearsonr(Qdata.obs['x_cord'], Qdata.obs['x_celery']).statistic, scipy.stats.pearsonr(Qdata.obs['x_cord'], Qdata.obs['x_tangram']).statistic,
            scipy.stats.pearsonr(Qdata.obs['x_cord'], Qdata.obs['x_spaotsc']).statistic, scipy.stats.pearsonr(Qdata.obs['x_cord'], Qdata.obs['x_novosparc']).statistic]

y_corr = [scipy.stats.pearsonr(Qdata.obs['y_cord'], Qdata.obs['y_celery']).statistic, scipy.stats.pearsonr(Qdata.obs['y_cord'], Qdata.obs['y_tangram']).statistic,
            scipy.stats.pearsonr(Qdata.obs['y_cord'], Qdata.obs['y_spaotsc']).statistic, scipy.stats.pearsonr(Qdata.obs['y_cord'], Qdata.obs['y_novosparc']).statistic]

value = [*x_corr, *y_corr]
method = ['CeLEry', 'Tangram', 'SpaOTsc', 'novoSpaRc']*2
all = np.repeat(['X axis', 'Y axis'], 4)
rmse_df = pd.DataFrame(np.array([value, method, all])).T
rmse_df.columns = ['value', 'method', 'all']
rmse_df.value = rmse_df.value.astype('float')
cols = ['#CAE7B9', '#F3DE8A', '#EB9486', '#7E7F9A']

fig, axes = plt.subplots(1, 1, figsize=(6, 5))
sns.barplot(data=rmse_df, x='all',  y='value', hue="method",  palette=cols)
plt.legend(bbox_to_anchor=(1.05, 0.3), loc=3, borderaxespad=0,fontsize=10)
plt.ylabel('Correlation', {'size' :10})
plt.xlabel('')
plt.ylim((0,1))
fig.subplots_adjust(right=0.75)

plt.grid(axis='both', color='silver', alpha=0.3)

ax = plt.gca()
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2f'))
ax.grid(True)
ax.set_xmargin(0.05)

figname = "output/liver/plot/pred_corr.pdf"
fig.savefig(figname)
plt.close(fig)

Spatial plot colored by Euclidian distance between true location and predicted location

In [17]:
sq = lambda x, y: (x - y)**2

pred_dist_celery = np.sqrt(np.sum(sq(np.array(Qdata_all.obs[['x_cord', 'y_cord']]), np.array(Qdata_all.obs[['x_celery', 'y_celery']])), axis=1))
pred_dist_tangram = np.sqrt(np.sum(sq(np.array(Qdata.obs[['x_cord', 'y_cord']]), np.array(Qdata.obs[['x_tangram', 'y_tangram']])), axis=1))
pred_dist_spaotsc = np.sqrt(np.sum(sq(np.array(Qdata.obs[['x_cord', 'y_cord']]), np.array(Qdata.obs[['x_spaotsc', 'y_spaotsc']])), axis=1))
pred_dist_novosparc = np.sqrt(np.sum(sq(np.array(Qdata.obs[['x_cord', 'y_cord']]), np.array(Qdata.obs[['x_novosparc', 'y_novosparc']])), axis=1))
vmax = np.max([np.max(pred_dist_celery), np.max(pred_dist_tangram), np.max(pred_dist_spaotsc), np.max(pred_dist_novosparc)])

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4,figsize=(25,6))
ax1.scatter(Qdata_all.obs['x_cord'], Qdata_all.obs['y_cord'], s=1, c=pred_dist_celery, vmin=0, vmax=vmax)
ax1.set_title('CeLEry')
ax2.scatter(Qdata.obs['x_cord'], Qdata.obs['y_cord'], s=2, c=pred_dist_tangram, vmin=0, vmax=vmax)
ax2.set_title('Tangram')
ax3.scatter(Qdata.obs['x_cord'], Qdata.obs['y_cord'], s=2, c=pred_dist_spaotsc, vmin=0, vmax=vmax)
ax3.set_title('spaOTsc')
a1 = ax4.scatter(Qdata.obs['x_cord'], Qdata.obs['y_cord'], s=2, c=pred_dist_novosparc, vmin=0, vmax=vmax)
ax4.set_title('novoSpaRc')

fig.colorbar(a1, fraction=0.046, pad=0.04)

figname = "output/liver/plot/pred_error.pdf"
fig.savefig(figname)
plt.close(fig)

Pairwise distance calculation and pairwise distance density plot

In [18]:
def distCompute(data_merfish):
    celery_dist = []
    tangram_dist = []
    true_dist = []
    spaotsc_dist = []
    novosparc_dist = []
    Qdata_loc = np.array(data_merfish.obs[['x_cord', 'y_cord']])
    celery_pred = np.array(data_merfish.obs[['x_celery', 'y_celery']])
    tangram_pred = np.array(data_merfish.obs[['x_tangram', 'y_tangram']])
    spaotsc_pred = np.array(data_merfish.obs[['x_spaotsc', 'y_spaotsc']])
    novosparc_pred = np.array(data_merfish.obs[['x_novosparc', 'y_novosparc']])

    for i in tqdm(range(Qdata_loc.shape[0])):
        celery_i = celery_pred[i, :]
        celery_points = celery_pred[i+1:, :]
        celery_dist.extend(np.sqrt(np.sum((celery_points - celery_i)**2, axis=1)))

        tangram_i = tangram_pred[i, :]
        tangram_points = tangram_pred[i+1:, :]
        tangram_dist.extend(np.sqrt(np.sum((tangram_points - tangram_i)**2, axis=1)))

        spaotsc_i = spaotsc_pred[i, :]
        spaotsc_points = spaotsc_pred[i+1:, :]
        spaotsc_dist.extend(np.sqrt(np.sum((spaotsc_points - spaotsc_i)**2, axis=1)))

        novosparc_i = novosparc_pred[i, :]
        novosparc_points = novosparc_pred[i+1:, :]
        novosparc_dist.extend(np.sqrt(np.sum((novosparc_points - novosparc_i)**2, axis=1)))

        true_i = Qdata_loc[i, :]
        true_points = Qdata_loc[i+1:, :]
        true_dist.extend(np.sqrt(np.sum((true_points - true_i)**2, axis=1)))
    return celery_dist, tangram_dist, spaotsc_dist, novosparc_dist, true_dist


celery_dist, tangram_dist, spaotsc_dist, novosparc_dist, true_dist = distCompute(Qdata)   

100%|██████████| 19556/19556 [01:01<00:00, 317.21it/s] 


In [34]:
import scipy
import matplotlib.ticker as mtick
value = [scipy.stats.pearsonr(true_dist, celery_dist).statistic, scipy.stats.pearsonr(true_dist, tangram_dist).statistic,
            scipy.stats.pearsonr(true_dist, spaotsc_dist).statistic, scipy.stats.pearsonr(true_dist, novosparc_dist).statistic]
print(value)

method = ['CeLEry', 'Tangram', 'SpaOTsc', 'novoSpaRc']
all = np.repeat('Method', 4)
rmse_df = pd.DataFrame(np.array([value, method, all])).T
rmse_df.columns = ['value', 'method', 'all']
rmse_df.value = rmse_df.value.astype('float')
cols = ['#CAE7B9', '#F3DE8A', '#EB9486', '#7E7F9A']

fig, axes = plt.subplots(1, 1, figsize=(6, 5))
corr_plot = sns.barplot(data=rmse_df,  y='value', x="method",  palette=cols)
corr_plot.set(ylim=(0, 0.8))
plt.ylabel('Pairwise Distance Correlation', {'size' :10})
plt.xlabel('Methods')
plt.grid(axis='both', color='silver', alpha=0.3)

ax = plt.gca()
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2f'))
ax.grid(True)
ax.set_xmargin(0.05)

figname = "output/liver/plot/pairwise_corr.pdf"
fig.savefig(figname)
plt.close(fig)

[0.6855667106274315, 0.33287112319580836, 0.36091639352901633, 0.3178124807092779]


In [12]:
import mpl_scatter_density
from matplotlib.colors import LinearSegmentedColormap


white_viridis = LinearSegmentedColormap.from_list('white_viridis', [
    (0, '#ffffff'),
    (1e-20, '#440053'),
    (0.2, '#404388'),
    (0.4, '#2a788e'),
    (0.6, '#21a784'),
    (0.8, '#78d151'),
    (1, '#fde624'),
], N=256)

max_lim = np.max([np.max(celery_dist), np.max(tangram_dist), np.max(true_dist), np.max(spaotsc_dist), np.max(novosparc_dist)])

def using_mpl_scatter_density(ax, x, y, title, label=True):
    density = ax.scatter_density(x, y, cmap=white_viridis)
    plt.title(title)
    if label:
        plt.ylabel("Pairwise distance between predicted coordinates", fontsize=15)
    lims = [0, max_lim]
    ax.plot(lims, lims, 'k-', alpha=0.75, zorder=0)
    plt.ylim((0, max_lim))
    plt.xlim((0, max_lim))
    fig.colorbar(density)
    
fig = plt.figure(figsize=(30,6))
ax1 = fig.add_subplot(1,4,1,projection='scatter_density')
using_mpl_scatter_density(ax1, true_dist, celery_dist, title="CeLEry")
ax2 = fig.add_subplot(1,4,2,projection='scatter_density')
using_mpl_scatter_density(ax2, true_dist, tangram_dist, title="Tangram", label=False)
ax3 = fig.add_subplot(1,4,3,projection='scatter_density')
using_mpl_scatter_density(ax3, true_dist, spaotsc_dist, title="SpaOTsc", label=False)
ax4 = fig.add_subplot(1,4,4,projection='scatter_density')
using_mpl_scatter_density(ax4, true_dist, novosparc_dist, title="novoSpaRc", label=False)

fig.text(0.5, 0.02, 'Pairwise distance between true coordinates', va='center', ha='center', fontsize=15)
figname = "output/liver/plot/pairwise.png"
fig.savefig(figname)
plt.close(fig)

Relative gene expression map

In [10]:
data_merfish = Qdata.copy()
Qdata_df = pd.DataFrame(data_merfish.X.copy())
Qdata_df.columns = data_merfish.var.index

data_merfish_all = Qdata_all.copy()
Qdata_df_all = pd.DataFrame(data_merfish_all.X.copy())
Qdata_df_all.columns = data_merfish_all.var.index

# gene_lst = Qdata_df.columns
gene_lst = ["CEACAM1", "TGFB1"]

x_min = np.min([np.min(data_merfish_all.obs['x_cord']), np.min(data_merfish_all.obs['x_celery']), np.min(data_merfish.obs['x_tangram']), np.min(data_merfish.obs['x_spaotsc']), np.min(data_merfish.obs['x_novosparc'])]) - 150
x_max = np.max([np.max(data_merfish_all.obs['x_cord']), np.max(data_merfish_all.obs['x_celery']), np.max(data_merfish.obs['x_tangram']), np.min(data_merfish.obs['x_spaotsc']), np.min(data_merfish.obs['x_novosparc'])]) + 150
y_min = np.min([np.min(data_merfish_all.obs['y_cord']), np.min(data_merfish_all.obs['y_celery']), np.min(data_merfish.obs['y_tangram']), np.min(data_merfish.obs['y_spaotsc']), np.min(data_merfish.obs['y_novosparc'])]) - 150
y_max = np.max([np.max(data_merfish_all.obs['y_cord']), np.max(data_merfish_all.obs['y_celery']), np.max(data_merfish.obs['y_tangram']), np.min(data_merfish.obs['y_spaotsc']), np.min(data_merfish.obs['y_novosparc'])]) + 150

for i in gene_lst:
    map_col = "lightgoldenrodyellow"
    point_col = "GnBu"

    cmap_spa = Qdata_df[i]
    cmap_spa = np.stack(cmap_spa)
    cmap_spa[cmap_spa > np.quantile(cmap_spa, 0.995)] = np.quantile(cmap_spa, 0.995)
    cmap_spa = cmap_spa/np.max(cmap_spa)

    cmap_spa_all = Qdata_df_all[i]
    cmap_spa_all = np.stack(cmap_spa_all)
    cmap_spa_all[cmap_spa_all > np.quantile(cmap_spa_all, 0.995)] = np.quantile(cmap_spa_all, 0.995)
    cmap_spa_all = cmap_spa_all/np.max(cmap_spa_all)

    min_val = np.min([np.min(cmap_spa)])
    max_val = np.max([np.max(cmap_spa)])


    fig = plt.figure(figsize=(25,5))
    ax1 = fig.add_subplot(1,5,1)
    ax1.set_facecolor(map_col)
    ax1.scatter(Qdata_all.obs['x_cord'], Qdata_all.obs['y_cord'],s=1,c=cmap_spa_all, cmap=plt.get_cmap(point_col), vmin=min_val, vmax=max_val)
    ax1.set_xlim(x_min, x_max)
    ax1.set_title("Truth")

    ax2 = fig.add_subplot(1,5,2)
    ax2.set_facecolor(map_col)
    ax2.scatter(Qdata_all.obs['x_celery'], Qdata_all.obs['y_celery'], s=1,c=cmap_spa_all, cmap=plt.get_cmap(point_col), vmin=min_val, vmax=max_val)
    ax2.set_xlim(x_min, x_max)
    ax2.set_title("CeLEry")

    ax3 = fig.add_subplot(1,5,3)
    ax3.set_facecolor(map_col)
    ax3.scatter(data_merfish.obs['x_tangram'], data_merfish.obs['y_tangram'], s=2,c=cmap_spa, cmap=plt.get_cmap(point_col), vmin=min_val, vmax=max_val)
    ax3.set_xlim(x_min, x_max)
    ax3.set_title("Tangram")

    ax4 = fig.add_subplot(1,5,4)
    ax4.set_facecolor(map_col)
    ax4.scatter(data_merfish.obs['x_spaotsc'], data_merfish.obs['y_spaotsc'], s=2,c=cmap_spa, cmap=plt.get_cmap(point_col), vmin=min_val, vmax=max_val)
    ax4.set_xlim(x_min, x_max)
    ax4.set_title("SpaOTsc")

    ax5 = fig.add_subplot(1,5,5)
    a5 = ax5.set_facecolor(map_col)
    ax5.scatter(data_merfish.obs['x_novosparc'], data_merfish.obs['y_novosparc'], s=2,c=cmap_spa, cmap=plt.get_cmap(point_col), vmin=min_val, vmax=max_val)
    ax5.set_xlim(x_min, x_max)
    ax5.set_title("novoSpaRc")

    colormap = plt.cm.get_cmap(point_col)
    sm = plt.cm.ScalarMappable(cmap=colormap)
    fig.colorbar(sm, fraction=0.046, pad=0.04)
    fig.text(0.07, 0.5, i, va='center', ha='center', rotation='vertical', fontsize = 12)

    figname = "output/liver/plot/" + i + ".png"
    fig.savefig(figname)
    plt.close(fig)