# Notebook containing code used for manuscript supplementary figure 5

### Note that most paths will need to be changed based on where the files were saved to your local folder. Note also that R will need to be installed for one portion of the digitization benchmarking

In [None]:
import anndata
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.io as pio
pio.renderers.default='iframe'
import seaborn as sns
import numpy as np
import pandas as pd
import statsmodels.api as sm
from scipy import stats
from scipy.sparse import csr_matrix
from sklearn import linear_model

import cv2
from skimage.segmentation import find_boundaries

import commot as ct
import spateo as st
import dynamo as dyn
import ncem
from ncem.data import get_data_custom, customLoader

import torch
from torch import nn

import pickle
import os
import sys
from tqdm import tqdm

from scipy.spatial import KDTree

In [None]:
np.random.seed(0)

# Digitization

## Resources used for the digitization benchmark (simulated data, macaque cortex data, Visium mouse brain, MERFISH U2-OS, and MERFISH mouse cortex) can all be found here: https://www.dropbox.com/scl/fo/4tyq5nbex7e2yo3anuemy/AA2p0y9ZAnN6oLWgAtoWA-0?rlkey=jdnxyg8jwx17iexxyn8dvqi7y&st=rsf779qu&dl=0

## The fitEC R package and slideseq_helpers file can also be found in folders within this Dropbox folder 

In [None]:
# Add slideseq_helpers folder to path
sys.path.insert(0, '/mnt/d/SCAnalysis/raphael-group-belayer-bb7b493/src')

### Figure S5a- simulated case 1

In [None]:
one_slice_adata = st.read_h5ad("/mnt/d/SCData/digitization/simulated/half_circle_only_coor.h5ad")
one_slice_adata.obs['label'] = "Simulated"
one_slice_adata

#### Perform Spateo digitization

In [None]:
cluster_label_image_lowres = st.dd.gen_cluster_image(one_slice_adata, bin_size=1, spatial_key="spatial", cluster_key='label', show=False)
cluster_label_list = np.unique(one_slice_adata.obs["cluster_img_label"])
contours, cluster_image_close, cluster_image_contour = st.dd.extract_cluster_contours(cluster_label_image_lowres, cluster_label_list, bin_size=1, k_size=1, show=False)

In [None]:
px.imshow(cluster_image_contour)

In [None]:
# User input to specify a gridding direction
pnt_xY = (116,273)
pnt_xy = (47,273)
pnt_Xy = (308,273)
pnt_XY = (243,273)

# Digitize the area of interest
st.dd.digitize(
    adata=one_slice_adata,
    ctrs=contours,
    ctr_idx=0,
    pnt_xy=pnt_xy,
    pnt_xY=pnt_xY,
    pnt_Xy=pnt_Xy,
    pnt_XY=pnt_XY,
    spatial_key="spatial"
)

#### Run Belayer (for comparison in benchmarking)

In [None]:
import networkx as nx
from slideseq_helpers import alpha_shape, harmonic_slideseq

In [None]:
coords = one_slice_adata.obsm['spatial'] * 10
edges = list(alpha_shape(coords, alpha=10, only_outer=True))
G = nx.DiGraph(edges)

edge_pt=list(nx.simple_cycles(G))
edge_pts=[(coords[e,0],coords[e,1]) for e in edge_pt[0]]
len(edge_pts)

In [None]:
len(edge_pts)
#%%
fig, ax = plt.subplots(figsize=(3,5))
for i, j in edges:
    plt.plot(coords[[i, j], 0], coords[[i, j], 1], color='grey', ls='-',alpha=0.5)

plt.scatter(coords[:,0],coords[:,1],s=2,color='white')


x, y = edge_pts[162]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[231]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[396]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[461]
plt.scatter(x, y, s=10, color='red')

In [None]:
plt.scatter(x, y, s=10, color='red')
#%%
boundary_array = []

bound_2 = edge_pts[461:646]
bound_2 = bound_2+edge_pts[0:162+1]
boundary_array.append(np.array([[i[0], i[1]] for i in bound_2]))

boundary_array.append(np.array([[i[0], i[1]] for i in edge_pts[231:396+1]]))

one_slice_adata.obs['belayer'] = harmonic_slideseq(coords, boundary_array, grid_spacing=25, radius=1)

In [None]:
boundary_array = []

boundary_array.append(np.array([[i[0], i[1]] for i in edge_pts[162:231+1]]))

bound_2 = edge_pts[396:461+1]
boundary_array.append(np.array([[i[0], i[1]] for i in bound_2]))

one_slice_adata.obs['belayer_column'] = harmonic_slideseq(coords, boundary_array, grid_spacing=25, radius=10)

In [None]:
one_slice_adata.obs['digital_layer'] = one_slice_adata.obs['digital_layer']/one_slice_adata.obs['digital_layer'].max() * 100 // 1
one_slice_adata.obs['belayer'] = one_slice_adata.obs['belayer']/one_slice_adata.obs['belayer'].max() * 100 // 1

In [None]:
one_slice_adata.obs['digital_layer_5'] = (one_slice_adata.obs['digital_layer']-1)//20
one_slice_adata.obs['belayer_5'] = (one_slice_adata.obs['belayer'])//20 
one_slice_adata.obs['digital_layer_5'] = one_slice_adata.obs['digital_layer_5'].astype(str)
one_slice_adata.obs['belayer_5'] = one_slice_adata.obs['belayer_5'].astype(str)
st.pl.space(
    one_slice_adata,
    color=['digital_layer_5', 'belayer_5'],
    ncols=2,
    pointsize=0.1,
    show_legend="upper left",
    figsize=(4, 5),
    color_key_cmap = "RdBu_r",
)

In [None]:
one_slice_adata.obs['digital_column'] = one_slice_adata.obs['digital_column']/one_slice_adata.obs['digital_column'].max() * 100 // 1
one_slice_adata.obs['belayer_column'] = one_slice_adata.obs['belayer_column']/one_slice_adata.obs['belayer_column'].max() * 100 // 1

In [None]:
one_slice_adata.obs['digital_column_10'] = (one_slice_adata.obs['digital_column']-1)//10
one_slice_adata.obs['belayer_column_10'] = (one_slice_adata.obs['belayer_column'])//10 
one_slice_adata.obs['digital_column_10'] = one_slice_adata.obs['digital_column_10'].astype(str)
one_slice_adata.obs['belayer_column_10'] = one_slice_adata.obs['belayer_column_10'].astype(str)
st.pl.space(
    one_slice_adata,
    color=['digital_column_10', 'belayer_column_10'],
    ncols=2,
    pointsize=0.1,
    show_legend="upper left",
    figsize=(4, 5),
    color_key_cmap = "RdBu_r",
)

### Figure S5b- simulated case 2

In [None]:
one_slice_adata = st.read_h5ad("/mnt/d/SCData/digitization/simulated/fanshape_only_coor.h5ad")
one_slice_adata.obs['label'] = "Simulated"
one_slice_adata

#### Perform Spateo digitization

In [None]:
cluster_label_image_lowres = st.dd.gen_cluster_image(one_slice_adata, bin_size=1, spatial_key="spatial", cluster_key='label', show=False)
cluster_label_list = np.unique(one_slice_adata.obs["cluster_img_label"])
contours, cluster_image_close, cluster_image_contour = st.dd.extract_cluster_contours(cluster_label_image_lowres, cluster_label_list, bin_size=1, k_size=1, show=False)

In [None]:
px.imshow(cluster_image_contour)

In [None]:
# User input to specify a gridding direction
pnt_xY = (172,213)
pnt_xy = (141,164)
pnt_Xy = (214,164)
pnt_XY = (184,213)

# Digitize the area of interest
st.dd.digitize(
    adata=one_slice_adata,
    ctrs=contours,
    ctr_idx=0,
    pnt_xy=pnt_xy,
    pnt_xY=pnt_xY,
    pnt_Xy=pnt_Xy,
    pnt_XY=pnt_XY,
    spatial_key="spatial"
)

#### Run Belayer (for comparison in benchmarking)

In [None]:
import networkx as nx
from slideseq_helpers import alpha_shape, harmonic_slideseq

In [None]:
coords = one_slice_adata.obsm['spatial'] * 10
edges = list(alpha_shape(coords, alpha=10, only_outer=True))
G = nx.DiGraph(edges)

edge_pt=list(nx.simple_cycles(G))
edge_pts=[(coords[e,0],coords[e,1]) for e in edge_pt[0]]
len(edge_pts)

In [None]:
fig, ax = plt.subplots(figsize=(2,2.5))
for i, j in edges:
    plt.plot(coords[[i, j], 0], coords[[i, j], 1], color='grey', ls='-',alpha=0.5)

plt.scatter(coords[:,0],coords[:,1],s=2,color='white')


x, y = edge_pts[49]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[61]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[110]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[0]
plt.scatter(x, y, s=10, color='red')

In [None]:
boundary_array = []

bound_2 = edge_pts[110:183]
bound_2 = bound_2+edge_pts[0:0+1]
boundary_array.append(np.array([[i[0], i[1]] for i in bound_2]))

boundary_array.append(np.array([[i[0], i[1]] for i in edge_pts[49:61+1]]))

one_slice_adata.obs['belayer'] = harmonic_slideseq(coords, boundary_array, grid_spacing=25, radius=1)

In [None]:
boundary_array = []

boundary_array.append(np.array([[i[0], i[1]] for i in edge_pts[0:49+1]]))

bound_2 = edge_pts[61:110+1]
boundary_array.append(np.array([[i[0], i[1]] for i in bound_2]))

one_slice_adata.obs['belayer_column'] = harmonic_slideseq(coords, boundary_array, grid_spacing=25, radius=1)

In [None]:
one_slice_adata.obs['digital_layer'] = one_slice_adata.obs['digital_layer']/one_slice_adata.obs['digital_layer'].max() * 100 // 1
one_slice_adata.obs['belayer'] = one_slice_adata.obs['belayer']/one_slice_adata.obs['belayer'].max() * 100 // 1

In [None]:
one_slice_adata.obs['digital_layer_5'] = (one_slice_adata.obs['digital_layer']-1)//20
one_slice_adata.obs['belayer_5'] = (one_slice_adata.obs['belayer'])//20 
one_slice_adata.obs['digital_layer_5'] = one_slice_adata.obs['digital_layer_5'].astype(str)
one_slice_adata.obs['belayer_5'] = one_slice_adata.obs['belayer_5'].astype(str)
st.pl.space(
    one_slice_adata,
    color=['digital_layer_5', 'belayer_5'],
    ncols=2,
    pointsize=0.2,
    show_legend="upper left",
    figsize=(4, 5),
    color_key_cmap = "RdBu_r",
)

In [None]:
one_slice_adata.obs['digital_column'] = one_slice_adata.obs['digital_column']/one_slice_adata.obs['digital_column'].max() * 100 // 1
one_slice_adata.obs['belayer_column'] = one_slice_adata.obs['belayer_column']/one_slice_adata.obs['belayer_column'].max() * 100 // 1

In [None]:
one_slice_adata.obs['digital_column_10'] = (one_slice_adata.obs['digital_column']-1)//10
one_slice_adata.obs['belayer_column_10'] = (one_slice_adata.obs['belayer_column'])//10 
one_slice_adata.obs['digital_column_10'] = one_slice_adata.obs['digital_column_10'].astype(str)
one_slice_adata.obs['belayer_column_10'] = one_slice_adata.obs['belayer_column_10'].astype(str)
st.pl.space(
    one_slice_adata,
    color=['digital_column_10', 'belayer_column_10'],
    ncols=2,
    pointsize=0.2,
    show_legend="upper left",
    figsize=(4, 5),
    color_key_cmap = "RdBu_r",
)

### Figure 5c- macaque cortex

In [None]:
one_slice_adata = st.read_h5ad("/mnt/d/SCData/digitization/macaque_cortex/T40_adata_only_coor.h5ad")
one_slice_adata.obs['label'] = "T40"
one_slice_adata

#### Perform Spateo digitization

In [None]:
one_slice_adata.obsm['spatial_bin50'] = one_slice_adata.obsm['spatial']//50
subset = one_slice_adata[one_slice_adata.obs['cell_type'].isin(['L2','L2/3','L2/3/4','L3/4','L3/4/5','L4','L4/5','L4/5/6','L5/6','L6']), :].copy()

In [None]:
cluster_label_image_lowres = st.dd.gen_cluster_image(subset, bin_size=1, spatial_key="spatial_bin50", cluster_key='cell_type', show=False)
cluster_label_list = np.unique(subset.obs["cluster_img_label"])

In [None]:
from skimage import morphology

close_kernel=cv2.MORPH_ELLIPSE
bin_size=1
k_size = 9
min_area = 9000
cluster_label_image = cluster_label_image_lowres
cluster_labels = cluster_label_list

cluster_image_close = cluster_label_image.copy()
if type(cluster_labels) == int:
    cluster_image_close = np.where(cluster_image_close == cluster_labels, cluster_image_close, 0)
else:
    cluster_image_close = np.where(np.isin(cluster_image_close, cluster_labels), cluster_image_close, 0)

kernal = cv2.getStructuringElement(close_kernel, (k_size, k_size))
cluster_image_close = cv2.morphologyEx(cluster_image_close, cv2.MORPH_CLOSE, kernal)

cluster_image_close = morphology.remove_small_objects(
    cluster_image_close.astype(bool),
    min_area,
    connectivity=2,
).astype(np.uint8)


ksize=3
kernal = np.ones((k_size, k_size),np.uint8)
cluster_image_close = cv2.erode(cluster_image_close, kernal)
ksize=5
kernal = np.ones((k_size, k_size),np.uint8)
cluster_image_close = cv2.dilate(cluster_image_close, kernal)
px.imshow(cluster_image_close)

In [None]:
cluster_image_close[1139:1147,248:312] = 0
px.imshow(cluster_image_close)

In [None]:
contours, _ = cv2.findContours(cluster_image_close, cv2.RETR_LIST , cv2.CHAIN_APPROX_NONE)

cluster_image_contour = np.zeros((cluster_label_image.shape[0], cluster_label_image.shape[1]))
for i in range(len(contours)):
    cv2.drawContours(cluster_image_contour, contours, i, i + 1, bin_size)
px.imshow(cluster_image_contour)

In [None]:
# User input to specify a gridding direction
pnt_xY = (1324,895)
pnt_xy = (1287,909)
pnt_Xy = (832,1110)
pnt_XY = (798,1103)

# Digitize the area of interest
st.dd.digitize(
    adata=one_slice_adata,
    ctrs=contours,
    ctr_idx=0,
    pnt_xy=pnt_xy,
    pnt_xY=pnt_xY,
    pnt_Xy=pnt_Xy,
    pnt_XY=pnt_XY,
    spatial_key="spatial_bin50"
)

#### Run Belayer (for comparison in benchmarking)

In [None]:
import networkx as nx
from slideseq_helpers import alpha_shape, harmonic_slideseq

In [None]:
one_slice_adata.obsm['spatial_bin50'] = one_slice_adata.obsm['spatial']//50
subset = one_slice_adata[one_slice_adata.obs['cell_type'].isin(['L2','L2/3','L2/3/4','L3/4','L3/4/5','L4','L4/5','L4/5/6','L5/6','L6']), :].copy()
coords = subset.obsm['spatial_bin50']

edges = list(alpha_shape(coords, alpha=4, only_outer=True))
G = nx.DiGraph(edges)

edge_pt=list(nx.simple_cycles(G))
edge_pts=[(coords[e,0],coords[e,1]) for e in edge_pt[50]]

In [None]:
fig, ax = plt.subplots(figsize=(8,10))
for i, j in edges:
    plt.plot(coords[[i, j], 0], coords[[i, j], 1], color='grey', ls='-',alpha=0.5)

plt.scatter(coords[:,0],coords[:,1],s=2,color='white')

x, y = edge_pts[315]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[1945]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[1957]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[303]
plt.scatter(x, y, s=10, color='red')

In [None]:
boundary_array = []
boundary_array.append(np.array([[i[0], i[1]] for i in edge_pts[315:1945+1]]))

bound_2 = edge_pts[1957:3753]
bound_2 = bound_2+edge_pts[0:303]
boundary_array.append(np.array([[i[0], i[1]] for i in bound_2]))

subset.obs['belayer'] = harmonic_slideseq(coords, boundary_array, grid_spacing=10, radius=10)

In [None]:
one_slice_adata.obs['digital_layer_5'] = (one_slice_adata.obs['digital_layer']-1)//10
one_slice_adata.obs['belayer_5'] = (one_slice_adata.obs['belayer'])//10 
one_slice_adata.obs['digital_layer_5'] = one_slice_adata.obs['digital_layer_5'].astype(str)
one_slice_adata.obs['belayer_5'] = one_slice_adata.obs['belayer_5'].astype(str)
st.pl.space(
    one_slice_adata,
    color=['digital_layer_5', 'belayer_5'],
    ncols=2,
    pointsize=0.1,
    show_legend="upper left",
    figsize=(4, 5),
    color_key_cmap = "RdBu_r",
)

### Figure S5d- Visium mouse brain

In [None]:
adata = st.read_h5ad("/mnt/d/SCData/digitization/visium_mouse/visium_adata.h5ad")
adata.uns['__type'] = "UMI"
adata

In [None]:
st.pl.space(adata=adata, color='cluster', pointsize=0.5, show_legend="upper left")

In [None]:
adata.obsm['spatial_bin30'] = adata.obsm['spatial']//100
cluster_label_image_lowres = st.dd.gen_cluster_image(adata, bin_size=1, spatial_key="spatial_bin30", cluster_key='cluster', show=False)
cluster_label_list = np.unique(adata[adata.obs['cluster'].isin(['1', '3']), :].obs["cluster_img_label"])
contours, cluster_image_close, cluster_image_contour = st.dd.extract_cluster_contours(cluster_label_image_lowres, cluster_label_list, bin_size=1, k_size=3, show=False, min_area=1000)
px.imshow(cluster_image_contour, width=500, height=500)

In [None]:
# User input to specify a gridding direction
pnt_xY = (10,63)
pnt_xy = (14,59)
pnt_Xy = (76,98)
pnt_XY = (64,98)

# Digitize the area of interest
st.dd.digitize(
    adata=adata,
    ctrs=contours,
    ctr_idx=0,
    pnt_xy=pnt_xy,
    pnt_xY=pnt_xY,
    pnt_Xy=pnt_Xy,
    pnt_XY=pnt_XY,
    spatial_key="spatial_bin30"
)

In [None]:
st.pl.space(adata, color="digital_column", cmap = "RdBu_r")

### Figure S5e- Gradients along the mouse cortex A-P axis

In [None]:
adata.uns['pp'] = {}
dyn.pp.normalize_cell_expr_by_size_factors(adata)

In [None]:
def polarity(
    adata,
    gene_dict: dict,
    region_key: str,
    palette: list,
    mode: str = "density",
    itv_rpt: int = 1,
    width: int = 5,
    height: int = 3.5,
):
    """Simple function to visualize expression level varies along regions.

    Args:
        adata (AnnData): _description_
        gene_dict (dict): _description_
        region_key (str): _description_
        mode (str, optional): _description_. Defaults to "density".

    Returns:
        _type_: _description_
    """

    import scipy.stats as stat
    digi_region = np.array([])
    gene_list = np.array([])
    gene_mean = np.array([])
    gene_mean_low = np.array([])
    gene_mean_high = np.array([])


    if mode == "exp":
        for i in np.unique(adata.obs[region_key]):
            adata_tmp = adata[adata.obs[region_key] == i, :]
            for anno in list(gene_dict.keys()):
                for gene in gene_dict[anno]:
                    gene_mean_tmp = adata_tmp[:, gene].X.toarray().T[0]
                    digi_region = np.append(digi_region, np.repeat(i,len(adata_tmp)))
                    gene_list = np.append(gene_list, np.repeat(gene + " " + anno,len(adata_tmp)))
                    gene_mean = np.append(gene_mean, gene_mean_tmp)
        df_plt = pd.DataFrame({region_key: digi_region, "Gene": gene_list, "Mean expression": gene_mean})
        ax = sns.relplot(data=df_plt, x=region_key, y="Mean expression", hue="Gene")
    elif mode == "density":
        for i in np.unique(adata.obs[region_key]):
            adata_tmp = adata[adata.obs[region_key] == i, :]
            for anno in list(gene_dict.keys()):
                for gene in gene_dict[anno]:
                        digi_region = np.append(digi_region, i)
                        gene_list = np.append(gene_list, gene + " " + anno)
                        gene_mean = np.append(gene_mean, np.mean(adata_tmp[:, gene].X))
                        data = adata_tmp[:, gene].X.toarray().T[0]
                        data = np.repeat(data,itv_rpt)
                        l, h = stat.t.interval(alpha=0.90, df=len(data)-1, loc=np.mean(data), scale=stat.sem(data))
                        if np.isnan(l):
                            l = np.mean(data)
                        if np.isnan(h):
                            h = np.mean(data)
                        gene_mean_low = np.append(gene_mean_low, max(0,l))
                        gene_mean_high = np.append(gene_mean_high,h)
        gene_mean_low = gene_mean_low + 1e-10
        gene_mean_high = gene_mean_high + 1e-10
        df_plt = pd.DataFrame({region_key: digi_region, "Gene": gene_list, "Mean expression": gene_mean})
        plt.figure()
        p1 = sns.kdeplot(data=df_plt, x=region_key,common_norm=False, weights="Mean expression", hue="Gene")
        #p1.set_xlim(0, max(adata.obs[region_key]))
        plt.close()
        df_plt = pd.DataFrame({region_key: digi_region, "Gene": gene_list, "Mean expression": gene_mean_low})
        plt.figure()
        p2 = sns.kdeplot(data=df_plt, x=region_key,common_norm=False, weights="Mean expression", hue="Gene")
        #p2.set_xlim(0, max(adata.obs[region_key]))
        plt.close()
        df_plt = pd.DataFrame({region_key: digi_region, "Gene": gene_list, "Mean expression": gene_mean_high})
        plt.figure()
        p3 = sns.kdeplot(data=df_plt, x=region_key,common_norm=False, weights="Mean expression", hue="Gene")
        #p3.set_xlim(0, max(adata.obs[region_key]))
        plt.close()
        fig, ax = plt.subplots()
        fig.set_size_inches((width,height))
        uq_gl = np.unique(gene_list)
        for k in range(len(uq_gl)):
            ax.plot(p1.get_children()[k].get_data()[0], p1.get_children()[k].get_data()[1], '-', color=palette[len(palette)-k-1])
            ax.fill_between(
                p1.get_children()[k].get_data()[0],
                p2.get_children()[k].get_data()[1]*np.sum(gene_mean_low[gene_list==uq_gl[-(k+1)]]) / np.sum(gene_mean[gene_list==uq_gl[-(k+1)]]),
                p3.get_children()[k].get_data()[1]*np.sum(gene_mean_high[gene_list==uq_gl[-(k+1)]])/ np.sum(gene_mean[gene_list==uq_gl[-(k+1)]]), color=palette[len(palette)-k-1], alpha=0.1)
    return ax

In [None]:
gene_dict = {
    '1': ['Epha7'],
    '2': ['Epha5'],
    '3': ['Cntnap2'],
    '4': ['Nr2f1'],
    '5': ['Lhx2'],
}

import matplotlib as mpl
palette = [mpl.colors.to_hex(i) for i in sns.color_palette("tab10",n_colors=len(gene_dict))]

In [None]:
adata.obs['digital_column_25'] = adata.obs['digital_column'] // 2
ax = polarity(adata, gene_dict, region_key="digital_column_25", palette=palette,)
ax.tick_params(direction="out")
plt.xlim([10,40])

### Figure S5f- Nuclear boundary-enriched genes in the MERFISH U2-OS dataset

#### Digitization

In [None]:
spot_adata = st.read_h5ad("/mnt/d/SCData/digitization/subcellular/subcellular_by_scc/spot.h5ad")
spot_adata

In [None]:
def digital_loop(a, cell_bound, cell_shape, nucleus_labels, eps=1e4):
    a_pre = a.copy()
    while True:
        a_pri = conv_and_deal_with_non_cell(a_pre, 51, cell_shape, cell_bound, nucleus_labels)
        # print(abs(np.sum(a_pri - a_pre)))
        if abs(np.sum(a_pri - a_pre)) > eps:
            a_pre = a_pri.copy()
        else:
            break
    
    while True:     
        a_pri = conv_and_deal_with_non_cell(a_pre, 7, cell_shape, cell_bound, nucleus_labels)
        # print(abs(np.sum(a_pri - a_pre)))
        if abs(np.sum(a_pri - a_pre)) > eps:
            a_pre = a_pri.copy()
        else:
            return a_pri
        
def conv2d_by_torch(a, k=5, avg=False):
    assert k % 2 == 1
    a = a.astype(np.float32)
    a = torch.from_numpy(a[None,:])
    cov = nn.Conv2d(1, 1, kernel_size=k, stride=1, padding=int((k-1)/2),  bias=False)
    cov.weight = torch.nn.Parameter(torch.ones_like(cov.weight))
    c = cov(a)
    number= k*k if avg else 1
    return c.detach().numpy()[0]/number
    
def conv_and_deal_with_non_cell(a_pre, k, cell_shape, cell_bound, nucleus_labels):
    a_pri = conv2d_by_torch(a_pre, k=k, avg=True)
    a_pri = np.where(cell_shape==0, 100, a_pri) # non-cell: 100
    a_pri[np.where(cell_bound==1)] = 100 # cell boundary: 100
    a_pri[np.where(nucleus_labels>0)] = 0   # nuclear: 0
    return a_pri
    
def digital_adata(adata, cell_layer, nucleus_layer, out_layer, eps=1e3):
    id_of_cells_have_nucleus = np.unique(adata.layers[cell_layer][adata.layers[nucleus_layer]>0])
    cell_shape = np.where(np.isin(adata.layers[cell_layer], id_of_cells_have_nucleus), adata.layers[cell_layer], 0)
    cell_bound = find_boundaries(cell_shape, mode="inner").astype(np.uint8)
    nucleus_shape = adata.layers[nucleus_layer]
    dig_circle_init = np.where(cell_shape>0, 50, 100)
    dig_circle_init = np.where(nucleus_shape>0, 0, dig_circle_init)
    digital = digital_loop(dig_circle_init, cell_bound, cell_shape, nucleus_shape, eps)
    digital = np.where(cell_shape>0, digital, -1)
    adata.layers[out_layer] = digital

In [None]:
digital_adata(spot_adata, 'batch0_cell_shape', 'batch0_nucleus_shape', out_layer='batch0_nucleus_to_mem_digital', eps=10)

In [None]:
plt.imshow(spot_adata.layers['batch0_nucleus_to_mem_digital'])
cax = plt.axes([0.85, 0.2, 0.075, 0.6])
plt.colorbar(cax=cax)

In [None]:
adata = st.read_h5ad("/mnt/d/SCData/digitization/subcellular/subcellular_by_scc/u2os_merfish.h5ad")
adata

In [None]:
def generate_long_matrix_df(adata, batch):
    df = adata.uns['points'].copy()
    df = df[df['batch']==batch]
    df['x_int'] = df['x'].astype(np.int32)
    df['y_int'] = df['y'].astype(np.int32)
    df = df[df['gene'].str.count('notarget')==0]
    df['count'] = 1
    return df

def generate_X_for_one_gene(df, gene):
    df = df.copy()
    df['cell_ids'] = df['cell'].cat.codes + 1
    df = df[df['gene'] == gene]
    X = csr_matrix((df['count'], (df['y_int'], df['x_int']))).toarray()
    return X

def reg_one_gene(df, spot_adata, digital_layer, gene):
    digital_one_gene_df = get_digital_info_for_one_gene(df, spot_adata, digital_layer, gene)
    clf = linear_model.LinearRegression() # LinearRegression PoissonRegressor
    clf.fit(y=digital_one_gene_df['expression'], X=digital_one_gene_df['digital_layer'][:,None])
    var = clf.score(y=digital_one_gene_df['expression'], X=digital_one_gene_df['digital_layer'][:,None])
    return var

def get_digital_info_for_one_gene(df, spot_adata, digital_layer, gene):
    X = generate_X_for_one_gene(df, gene)
    cell_ids = spot_adata.layers['batch0_cell_shape'][0:X.shape[0], 0:X.shape[1]]
    digital_one_gene = digital_layer[0:X.shape[0], 0:X.shape[1]]//10
    digital_one_gene_df = pd.DataFrame({'expression': X.flatten(), 'digital_layer': digital_one_gene.flatten(), 'cell_ids': cell_ids.flatten()})
    digital_one_gene_df = digital_one_gene_df[digital_one_gene_df['digital_layer']>0]
    digital_one_gene_df = digital_one_gene_df.groupby(['digital_layer', 'cell_ids']).agg('mean').reset_index()
    digital_one_gene_df['gene'] = gene
    return digital_one_gene_df

In [None]:
df = generate_long_matrix_df(adata, 0)
df

In [None]:
digital_layer = spot_adata.layers['batch0_nucleus_to_mem_digital'].astype(np.int32)

In [None]:
vars = []
for gene in np.unique(df['gene']):
    vars.append(reg_one_gene(df, spot_adata, digital_layer, gene))

var_df = pd.DataFrame({'var': vars, 'gene': np.unique(df['gene'])})

In [None]:
digital_6_gene_dfs = []
for gene in var_df.sort_values(by='var', ascending=False).head(3)['gene']:
    digital_6_gene_dfs.append(get_digital_info_for_one_gene(df, spot_adata, digital_layer, gene))
for gene in var_df.sort_values(by='var', ascending=False).tail(3)['gene']:
    digital_6_gene_dfs.append(get_digital_info_for_one_gene(df, spot_adata, digital_layer, gene))
digital_6_gene_df = pd.concat(digital_6_gene_dfs, ignore_index=True)

digital_6_gene_df

In [None]:
sns.relplot(data=digital_6_gene_df, x='digital_layer', y='expression', kind="line", hue='gene', palette=['red', 'darkorange', 'darkred', 'dodgerblue', 'royalblue', 'blue'])

### Figure S5g- Nuclear centroid-enriched genes in the MERFISH U2-OS dataset

In [None]:
df = generate_long_matrix_df(adata, 0)
df.x = df.x.astype(int)
df.y = df.y.astype(int)

df_tttmp = df.iloc[:1,:]
df_tttmp.x = 1
df_tttmp.y = 1

df = pd.concat([df,df_tttmp], axis=0)
df['MIDCount'] = df['count'].copy()
df['geneID'] = df['gene'].copy()
df.to_csv("tmp.csv",index=False,sep="\t")
df

In [None]:
spot_adata.layers['img_digital'] = np.zeros_like(spot_adata.layers['batch0_nucleus_shape'])

for cell_n in np.unique(spot_adata.layers['batch0_nucleus_shape']):
    if cell_n==0:
        continue
    spot_adata.layers['tmp_img'] = spot_adata.layers['batch0_nucleus_shape'].copy()
    spot_adata.layers['tmp_img'][spot_adata.layers['tmp_img']!=cell_n] = 0
    st.cs.utils.get_cell_shape(spot_adata, layer="tmp_img", thickness=1)
    
    aa = np.array(np.where(spot_adata.layers['tmp_img_boundary'] == 255)).T

    import os
    os.environ['R_HOME'] = '/home/jingzh/.conda/envs/spaco_dev/lib/R'

    from rpy2.robjects.packages import importr
    from rpy2 import robjects
    fitEC = importr("fitEC")

    import rpy2.robjects.numpy2ri
    rpy2.robjects.numpy2ri.activate()

    X = aa
    efit = fitEC.fit_ellipse(X)

    a = efit[2][0]
    b = efit[3][0]
    c = np.sqrt(a**2 - b**2)

    f1_dis = np.zeros_like(spot_adata.layers['tmp_img_boundary'])
    f1_dis[:] = 1

    x1= int(efit[1][0]-a*np.sin(efit[4][0]))
    y1= int(efit[1][1]+a*np.cos(efit[4][0]))

    x2= int(efit[1][0]+a*np.sin(efit[4][0]))
    y2= int(efit[1][1]-a*np.cos(efit[4][0]))
    f1_dis = cv2.line(f1_dis,(y1,x1), (y2,x2), (0),1)
    f1_dis = cv2.distanceTransform(f1_dis,cv2.DIST_L2,cv2.DIST_MASK_PRECISE)

    f1_dis = f1_dis / b * a

    f2_dis = np.zeros_like(spot_adata.layers['tmp_img_boundary'])
    f2_dis[:] = 1

    x1= int(efit[1][0]-a*np.cos(efit[4][0]))
    y1= int(efit[1][1]-a*np.sin(efit[4][0]))

    x2= int(efit[1][0]+a*np.cos(efit[4][0]))
    y2= int(efit[1][1]+a*np.sin(efit[4][0]))
    f2_dis = cv2.line(f2_dis,(y1,x1), (y2,x2), (0),1)
    f2_dis = cv2.distanceTransform(f2_dis,cv2.DIST_L2,cv2.DIST_MASK_PRECISE)

    digi_field = np.sqrt(f1_dis**2+f2_dis**2) - np.min(np.sqrt(f1_dis**2+f2_dis**2))

    spot_adata.layers['tmp_img_digital'] = np.where(spot_adata.layers['tmp_img']!=0, digi_field, 0)

    spot_adata.layers['tmp_img_digital'] = spot_adata.layers['tmp_img_digital'] / np.max(spot_adata.layers['tmp_img_digital']) * 100


    #plt.figure(figsize=(30,30))
    #plt.imshow(spot_adata.layers['tmp_img_digital'], cmap="Reds")
    
    spot_adata.layers['img_digital'] = spot_adata.layers['img_digital'] + spot_adata.layers['tmp_img_digital']

    
plt.figure(figsize=(10,10))
plt.imshow(spot_adata.layers['img_digital'], cmap="RdBu_r")

In [None]:
dilate_kernel = np.ones((3, 3), np.uint8)
spot_adata.layers['batch0_cell_shape_boundary_dilate'] = cv2.dilate(spot_adata.layers['batch0_cell_shape_boundary'], dilate_kernel, iterations=1)

plt.imshow(np.where(spot_adata.layers['batch0_cell_shape_boundary']!=0,50,spot_adata.layers['img_digital']), cmap="RdBu_r")

In [None]:
df = generate_long_matrix_df(adata, 0)
df

In [None]:
digital_layer = spot_adata.layers['img_digital'].astype(np.int32)

In [None]:
vars = []
for gene in np.unique(df['gene']):
    vars.append(reg_one_gene(df, spot_adata, digital_layer, gene))

var_df = pd.DataFrame({'var': vars, 'gene': np.unique(df['gene'])})

In [None]:
digital_6_gene_dfs = []
for gene in var_df.sort_values(by='var', ascending=False).head(3)['gene']:
    digital_6_gene_dfs.append(get_digital_info_for_one_gene(df, spot_adata, digital_layer, gene))
for gene in var_df.sort_values(by='var', ascending=False).tail(3)['gene']:
    digital_6_gene_dfs.append(get_digital_info_for_one_gene(df, spot_adata, digital_layer, gene))
digital_6_gene_df = pd.concat(digital_6_gene_dfs, ignore_index=True)

digital_6_gene_df

In [None]:
sns.relplot(data=digital_6_gene_df, x='digital_layer', y='expression', kind="line", hue='gene', palette=['red', 'darkorange', 'darkred', 'dodgerblue', 'royalblue', 'blue'])

### Figure S5h-k- Digitization benchmark on the MERFISH mouse cortex sample

In [None]:
one_slice_adata = st.read_h5ad("/mnt/d/SCData/digitization/merfish/one_slice_digitization.h5ad")
one_slice_adata

In [None]:
st.pl.space(adata=one_slice_adata, color='subclass', pointsize=0.5, show_legend="upper left")

#### Spateo digitization

In [None]:
one_slice_adata.obsm['spatial_bin30'] = one_slice_adata.obsm['spatial']//30
cluster_label_image_lowres = st.dd.gen_cluster_image(one_slice_adata, bin_size=1, spatial_key="spatial_bin30", cluster_key='class_label', show=False)
cluster_label_list = np.unique(one_slice_adata[one_slice_adata.obs['class_label'].isin(['Glutamatergic','GABAergic', 'Other']), :].obs["cluster_img_label"])
contours, cluster_image_close, cluster_image_contour = st.dd.extract_cluster_contours(cluster_label_image_lowres, cluster_label_list, bin_size=1, k_size=6, show=False)
px.imshow(cluster_image_contour, width=500, height=500)

In [None]:
# User input to specify a gridding direction
pnt_xY = (41,133)
pnt_xy = (0,107)
pnt_Xy = (70,91)
pnt_XY = (53,133)

# Digitize the area of interest
st.dd.digitize(
    adata=one_slice_adata,
    ctrs=contours,
    ctr_idx=0,
    pnt_xy=pnt_xy,
    pnt_xY=pnt_xY,
    pnt_Xy=pnt_Xy,
    pnt_XY=pnt_XY,
    spatial_key="spatial_bin30"
)

In [None]:
# Visualize digitized layers and columns
one_slice_adata.obs['digital_layer_5'] = (one_slice_adata.obs['digital_layer']-1)//20
one_slice_adata.obs['digital_layer_5'] = one_slice_adata.obs['digital_layer_5'].astype(str)

st.pl.space(
    one_slice_adata[one_slice_adata.obs['digital_layer']!=0],
    color=['digital_layer','digital_layer_5'],
    ncols=2,
    pointsize=0.1,
    show_legend="upper left",
    color_key_cmap="tab20",
    figsize=(4, 5),
)

#### Belayer digitization

In [None]:
import networkx as nx
from slideseq_helpers import alpha_shape, harmonic_slideseq

In [None]:
coords = one_slice_adata.obsm['spatial']

edges = list(alpha_shape(coords, alpha=100, only_outer=True))
G = nx.DiGraph(edges)

edge_pts=list(nx.simple_cycles(G))[0]
edge_pts=[(coords[e,0],coords[e,1]) for e in edge_pts]

In [None]:
fig, ax = plt.subplots(figsize=(2,3))
for i, j in edges:
    plt.plot(coords[[i, j], 0], coords[[i, j], 1], color='grey', ls='-',alpha=0.5)

plt.scatter(coords[:,0],coords[:,1],s=2,color='grey')

x, y = edge_pts[48]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[94]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[20]
plt.scatter(x, y, s=10, color='red')

x, y = edge_pts[13]
plt.scatter(x, y, s=10, color='red')

In [None]:
boundary_array = []
boundary_array.append(np.array([[i[0], i[1]] for i in edge_pts[48:94+1]]))
boundary_array.append(np.array([[i[0], i[1]] for i in edge_pts[13:20+1]]))
one_slice_adata.obs['belayer'] = harmonic_slideseq(coords, boundary_array, grid_spacing=20, radius=10)

In [None]:
# Visualize digitized layers and columns
one_slice_adata.obs['belayer_5'] = (one_slice_adata.obs['belayer'])//20
one_slice_adata.obs['belayer_5'] = one_slice_adata.obs['belayer_5'].astype(str)

st.pl.space(
    one_slice_adata[one_slice_adata.obs['digital_layer']!=0],
    color=['belayer','belayer_5'],
    ncols=2,
    pointsize=0.1,
    show_legend="upper left",
    color_key_cmap="tab20",
    figsize=(4, 5),
)

#### Distribution of layer-enriched cell types

In [None]:
df = one_slice_adata.obs[['subclass', 'digital_layer', 'belayer']]
df = df[df['subclass'].isin(['L2/3 IT', 'L4/5 IT', 'L5 IT', 'L6 IT'])]
df['subclass'] = df['subclass'].astype('str').astype('category')

In [None]:
sns.kdeplot(data=df, x='digital_layer', hue='subclass', common_norm=False)
sns.kdeplot(data=df, x='belayer', hue='subclass', common_norm=False, linestyle="--")

In [None]:
sns.scatterplot(data=df, x='digital_layer', y='belayer', size=5, marker='x')

In [None]:
if os.path.exists("tmp.csv"):
    os.remove("tmp.csv")

# CCI

In [None]:
np.random.seed(888)

In [None]:
%config InlineBackend.print_figure_kwargs={'dpi': 300.0}

In [None]:
# For viewability purposes, process all "COL" elements to "Collagens": 
def replace_col_with_collagens(string):
    parts = string.split(':')
    elements = parts[0].split('/')
    
    # Flag to check if we've encountered a "COL" element or a "Collagens" element
    encountered_col = False
    
    for i, element in enumerate(elements):
        # If the element starts with "COL" or "b_COL", or if it is "Collagens" or "b_Collagens"
        if element.startswith("COL") or element.startswith("b_COL") or element in ["Collagens", "b_Collagens"]:
            # If we've already encountered a "COL" or "Collagens" element, remove this one
            if encountered_col:
                elements[i] = None
            # Otherwise, replace it with "Collagens" or "b_Collagens" as appropriate
            else:
                if element.startswith("b_COL") or element == "b_Collagens":
                    elements[i] = "b_Collagens"
                else:
                    elements[i] = "Collagens"
                encountered_col = True
                
    # Remove None elements and join the rest with slashes
    replaced_part = '/'.join([element for element in elements if element is not None])
    
    # If there's a second part, add it back
    if len(parts) > 1:
        replaced_string = replaced_part + ':' + parts[1]
    else:
        replaced_string = replaced_part
        
    return replaced_string

## Resources used for the CosMx sample can be found: https://www.dropbox.com/scl/fo/z3bvppoq96vg442lma0rs/ACIXLqp-FXjuYQ2ZeAeHFEA?rlkey=84h21aoigdxrpfz9yrbsyepwg&st=55ozincu&dl=0
## Resources used for the MERFISH sample can be found: https://www.dropbox.com/scl/fo/s7mjpdgbk4f2mj1rndooo/AAkL3b4W3JazjGDn3pSKKrk?rlkey=3acd5da9bcl743x0byrrm8jxp&st=9g9b9o47&dl=0 
## Database files used here can be found: https://www.dropbox.com/scl/fo/dcd95so9zhkb8lnjkkxep/ANwmkFeb-sgtS89leHQezlU?rlkey=saiul4j5rr1vt6lwjl4hirmwh&st=brpjqw2c&dl=0

### Make sure to change each file path to the relevant local folder

In [None]:
# Set the Spateo database directory here:
database_dir = "/mnt/d/SCData/CCI_database"

## Load MERFISH brain sample FOVs

In [None]:
# Two FOVs are included- use this to change which is selected (options are 153 or 162)
fov_number = 153

In [None]:
# Replace with wherever this file is stored locally
path_to_merfish = f"/mnt/d/SCData/Spateo_data/MERFISH_mouse_cortex/MERFISH_mouse_brain_mouse1_fov{fov_number}.h5ad"
# Replace with wherever the L:R database is stored locally
lr_db = pd.read_csv("/mnt/c/Users/danie/Desktop/Github/Github/spateo-release-main/spateo/tools/database/lr_db_mouse.csv", index_col=0)

In [None]:
merfish_fov = anndata.read_h5ad(path_to_merfish)
merfish_fov.uns["__type"] = "UMI"

### Figure S5l- spatially-resolved cell types plot

In [None]:
spatial_coords = merfish_fov.obsm['spatial']
x_coords = spatial_coords[:, 0]
y_coords = spatial_coords[:, 1]
cell_types = merfish_fov.obs['general_cell_type']

In [None]:
unique_cell_types = np.unique(cell_types)
color_map = {
    'Astro': '#d70000',
    'Endo': '#00fdcf',
    'L23_IT': '#eeb9b9',
    'L45_IT': '#00af8a',
    'L56_NP': '#d38c8f',
    'L5_ET': '#c59f72',
    'L5_IT': '#00d6d5',
    'L6_CT': '#a9001f',
    'L6_IT': '#bfd57c',
    'L45_IT_SSp': '#f46200',
    'L6b': '#d2b75b',
    'Lamp5': '#ad94ec',
    'Micro': '#213400',
    'OPC': '#fb7cff',
    'Oligo': '#91a2ea',
    'PVM': '#ad3b30',
    'Peri': '#734abc',
    'Pvalb': '#602541',
    'SMC': '#e2b392',
    'Sncg': '#bc94d2',
    'Sst': '#1726ff',
    'VLMC': '#8a1323',
    'Vip': '#2f3ea8',
    'striatum': '#ffa500'
}

In [None]:
# Scatter plot
plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16

fig, axes = plt.subplots(1, 1, figsize=(6, 5))
fig.suptitle(f'MERFISH mouse brain- FOV {fov_number}', fontsize=20)
    
for cell_type in unique_cell_types:
    idx = cell_types == cell_type
    axes.scatter(x_coords[idx], y_coords[idx], color=color_map[cell_type], label=cell_type, s=5)

axes.set_ylim(axes.get_ylim())  # Sync y-limits with the scatter plot
# Remove plot borders
for spine in axes.spines.values():
    spine.set_visible(False)

# Remove tick marks and labels
axes.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
axes.legend(title='Cell Type', bbox_to_anchor=(1.15, 1), loc='upper left', title_fontsize='12', fontsize='10')

plt.tight_layout(rect=[0, 0, 0.95, 1])  # Adjust the layout to fit everything nicely
plt.show()

In [None]:
# Set to the folders that the inputs (only the targets list for this data) are contained in and that the outputs (model results) will save to:
cci_input_directory = "/mnt/d/SCAnalysis/Spateo_MERFISH_benchmark/CCI_inputs"
cci_output_directory = "/mnt/d/SCAnalysis/Spateo_MERFISH_benchmark/CCI_outputs"
cci_output_id = os.path.join(cci_output_directory, f"fov_{fov_number}_target_genes.csv")
cci_targets_file = os.path.join(cci_input_directory, f"target_genes_slice{fov_number}.txt")

#### Initialize and run CCI model (can skip if predictions .csv file was created locally or downloaded from the folder)

In [None]:
if fov_number == 153:
    lb = 50.0
    ub = 139.6
else:
    lb = 48.9
    ub = 136.4

In [None]:
# For clarity, this is how the distance bounds are determined
lb = st.tl.find_neighbors.find_bw_for_n_neighbors(
    merfish_fov,
    coords_key="spatial",
    n_anchors=2000,
    target_n_neighbors=9,
    initial_bw=200,
    exclude_self=True
)

In [None]:
ub = st.tl.find_neighbors.find_bw_for_n_neighbors(
    merfish_fov,
    coords_key="spatial",
    n_anchors=2000,
    target_n_neighbors=70,
    initial_bw=200,
    exclude_self=True
)

In [None]:
adata_path = path_to_cosmx
output_path = cci_output_id
target_path = cci_targets_file
cci_dir_path = database_dir
mod_type = "niche"
distr = "poisson"
species = "mouse"
group_key = "general_cell_type"
coords_key = "spatial"
distance_membrane_bound = lb
distance_secreted = ub
minbw = lb
maxbw = ub * 1.5

if not os.path.exists(os.path.dirname(output_path)):
    os.makedirs(os.path.dirname(output_path))

In [None]:
parser, args_list = st.tl.define_spateo_argparse(
    adata_path=adata_path,
    targets_path=target_path,
    cci_dir=cci_dir_path,
    mod_type=mod_type,
    distr=distr,
    species=species,
    group_key=group_key,
    coords_key=coords_key,
    distance_membrane_bound=distance_membrane_bound,
    distance_secreted=distance_secreted,
    minbw=minbw,
    maxbw=maxbw,
    output_path=output_path,
)

In [None]:
import time

t1 = time.time()

swr_model = st.tl.MuSIC(parser, args_list)
swr_model._set_up_model()
swr_model.fit()
swr_model.predict_and_save()

t_last = time.time()

print("Total Time Elapsed:", np.round(t_last - t1, 2), "seconds")
print("-" * 60)

In [None]:
# Note that the predictions.csv file is also provided in the Dropbox

#### Run NCEM model for the comparison (can also skip to the next section- the predictions file is included in the Dropbox folder)

In [None]:
# These are the contents of the "target_genes" txt files, spelled out in the form of a list
if sample_id == 153:
    target_genes = ["Flt1", "Aqp4", "Parm1", "Rorb", "Syt6", "Calb1", "Prdm8", "Rspo1", "Lypd1", "Adamts4", "Vtn", "Lamp5"]
else:
    target_genes = ["Flt1", "Aqp4", "Parm1", "Rorb", "Syt6", "Calb1", "Prdm8", "Rspo1", "Ptpru", "Adamts4", "Vtn", "Lamp5"]

In [None]:
# 6779 is prime, need to randomly drop 1 cell for NCEM- this is saved as a separate file in the MERFISH dropbox
# Replace the file path with wherever this file is stored locally
mouse_brain_ncem = anndata.read_h5ad(f"/mnt/d/SCData/Spateo_data/MERFISH_mouse_cortex/MERFISH_mouse_brain_mouse1_fov{fov_number}_NCEM_processed.h5ad")

In [None]:
# Requirement for initializing interpreter
mouse_brain_ncem.uns["spatial"] = "Hello, world"

In [None]:
# Use the upper distance bound as the distance parameter for NCEM
dist = ub

In [None]:
interpreter = ncem.interpretation.interpreter.InterpreterInteraction()

In [None]:
interpreter.data = customLoader(
    adata=mouse_brain_ncem, cluster='general_cell_type', patient='Batch', library_id='Batch', radius=dist,
)
get_data_custom(interpreter=interpreter)

In [None]:
interpreter.n_eval_nodes_per_graph = 2

In [None]:
interpreter.get_sender_receiver_effects()

In [None]:
# Backsolve to get the design matrix:
img_keys = interpreter.img_keys_all
nodes_idx = interpreter.nodes_idx_all

In [None]:
(target, interactions, _, _, _), y = interpreter._get_np_data(image_keys=img_keys, nodes_idx=nodes_idx)
x_design = np.concatenate([target, interactions], axis=1)
x_design

In [None]:
def ols_fit(x_, y_):
    """beta = (XT * X)^-1 XT y"""
    X = np.matmul(
        np.linalg.pinv(np.matmul(x_.T, x_)),
        x_.T
    )
    return np.array([
        np.matmul(
            X, y_[:, [i]]
        )
        for i in range(y_.shape[1])
    ])

ols = ols_fit(x_=x_design, y_=y)
params = ols.squeeze()

params.shape

In [None]:
reconst = np.matmul(x_design, params.T)
reconst

In [None]:
reconst_df = pd.DataFrame(reconst, index=mouse_brain_ncem.obs_names, columns=target_genes)

In [None]:
# Change to an appropriate location on the local system
save_path = f"/mnt/d/SCAnalysis/Spateo_MERFISH_benchmark/NCEM_predictions_fov{fov_number}.csv"
reconst_df.to_csv(save_path)

### Figure S5m- barplots comparing performance of Spateo vs. NCEM

In [None]:
# Change to the location on the local system where the NCEM predictions were saved
ncem_save_path = f"/mnt/d/SCAnalysis/Spateo_MERFISH_benchmark/NCEM_predictions_fov{fov_number}.csv"
ncem_reconst_df = pd.read_csv(ncem_save_path, index_col=0)

In [None]:
spateo_save_path = os.path.join(cci_output_directory, "predictions.csv")
spateo_reconst_df = pd.read_csv(spateo_save_path, index_col=0)

In [None]:
# In case the previous section was skipped over
if sample_id == 153:
    target_genes = ["Flt1", "Aqp4", "Parm1", "Rorb", "Syt6", "Calb1", "Prdm8", "Rspo1", "Lypd1", "Adamts4", "Vtn", "Lamp5"]
else:
    target_genes = ["Flt1", "Aqp4", "Parm1", "Rorb", "Syt6", "Calb1", "Prdm8", "Rspo1", "Ptpru", "Adamts4", "Vtn", "Lamp5"]

# 6779 is prime, need to randomly drop 1 cell for NCEM- this is saved as a separate file in the MERFISH dropbox
# Replace the file path with wherever this file is stored locally
mouse_brain_ncem = anndata.read_h5ad(f"/mnt/d/SCData/Spateo_data/MERFISH_mouse_cortex/MERFISH_mouse_brain_mouse1_fov{fov_number}_NCEM_processed.h5ad")

#### Bootstrap resampling- can skip over this section as well if these files have already been generated or the result files were downloaded from the resource folder

In [None]:
# Note that for the figure, only the R-squared comparison is included, but this computes additional metrics that can also be compared w/ modifications to code below:
def compute_metrics(y_true, y_pred):
    metrics = {}
    
    rp, _ = pearsonr(y_true, y_pred)
    r, _ = spearmanr(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = np.sqrt(mse)

    binary_y_true = (y_true != 0).astype(int)
    binary_y_pred = (y_pred != 0).astype(int)
    intersection = np.sum(binary_y_true * binary_y_pred)
    union = np.sum(np.maximum(binary_y_true, binary_y_pred))
    ji = intersection / union if union != 0 else 1.0

    metrics['Pearson r'] = rp
    metrics['Spearman r'] = r
    metrics['R-squared'] = r2
    metrics['RMSE'] = rmse
    metrics['Jaccard index'] = ji

    return metrics

In [None]:
np.random.seed(888)

In [None]:
# Define the number of bootstrap samples
n_bootstrap_samples = 1000
n_samples_to_pick = 200
confidence_level = 0.95

# Initialize empty DataFrames
total_rs_df = pd.DataFrame()
bootstrap_df = pd.DataFrame()  # To store all bootstrap results

In [None]:
# Bootstrap for the Spateo results
for i, gene in tqdm(enumerate(all_genes), desc="Computing metrics for all genes..."):
    y = merfish_fov[:, gene].X.toarray().reshape(-1)
    spateo_results_target = spateo_reconst_df[gene].values.reshape(-1)

    non_zero_indices = np.nonzero(y)[0]
    y_non_zero = y[non_zero_indices]
    spateo_results_target_non_zero = spateo_results_target[non_zero_indices]
    
    if np.isnan(y).any() or np.isnan(spateo_results_target).any():
        print(f"Array contains NaN values for gene {gene}")
        continue
    
    # Bootstrap resampling
    bootstrap_metrics = []
    for _ in range(n_bootstrap_samples):
        indices = np.random.choice(len(y_non_zero), size=n_samples_to_pick, replace=False)
        y_resampled = y_non_zero[indices]
        spateo_results_resampled = spateo_results_target_non_zero[indices]
        metrics = compute_metrics(y_resampled, spateo_results_resampled)
        bootstrap_metrics.append(metrics)
    
    # Convert bootstrap metrics to DataFrame
    spateo_bootstrap_df = pd.DataFrame(bootstrap_metrics)
    spateo_bootstrap_df.columns = [f"Spateo {col} {gene}" for col in spateo_bootstrap_df.columns]
    bootstrap_df = pd.concat([bootstrap_df, spateo_bootstrap_df], axis=1)
    
    # Compute confidence intervals
    ci_lower = spateo_bootstrap_df.quantile((1 - confidence_level) / 2)
    ci_upper = spateo_bootstrap_df.quantile(1 - (1 - confidence_level) / 2)

    # Original metrics for the gene
    original_metrics = compute_metrics(y_non_zero, spateo_results_target_non_zero)
    
    # Append metric and confidence intervals to DataFrame
    new_rs = pd.DataFrame([{'Gene names': gene, 'Model type': "Spateo", 'R-squared': original_metrics['R-squared'],
                            'CI Lower': ci_lower[f'Spateo R-squared {gene}'], 'CI Upper': ci_upper[f'Spateo R-squared {gene}']}])
    total_rs_df = pd.concat([total_rs_df, new_rs])

In [None]:
# Bootstrap for the NCEM results
for i, gene in tqdm(enumerate(all_genes), desc="Computing NCEM metrics for all genes..."):
    y = mouse_brain_ncem[:, gene].X.toarray().reshape(-1)
    NCEM_results_target = ncem_reconst_df[gene].values.reshape(-1)

    non_zero_indices = np.nonzero(y)[0]
    y_non_zero = y[non_zero_indices]
    NCEM_results_target_non_zero = NCEM_results_target[non_zero_indices]
    
    if np.isnan(y).any() or np.isnan(NCEM_results_target).any():
        print(f"Array contains NaN values for gene {gene}")
        continue
    
    # Bootstrap resampling
    bootstrap_metrics = []
    for _ in range(n_bootstrap_samples):
        indices = np.random.choice(len(y_non_zero), size=n_samples_to_pick, replace=False)
        y_resampled = y_non_zero[indices]
        NCEM_resampled = NCEM_results_target_non_zero[indices]
        metrics = compute_metrics(y_resampled, NCEM_resampled)
        bootstrap_metrics.append(metrics)
    
    # Convert bootstrap metrics to DataFrame
    ncem_bootstrap_df = pd.DataFrame(bootstrap_metrics)
    ncem_bootstrap_df.columns = [f"NCEM {col} {gene}" for col in ncem_bootstrap_df.columns]
    bootstrap_df = pd.concat([bootstrap_df, ncem_bootstrap_df], axis=1)
    
    # Compute confidence intervals
    ci_lower = ncem_bootstrap_df.quantile((1 - confidence_level) / 2)
    ci_upper = ncem_bootstrap_df.quantile(1 - (1 - confidence_level) / 2)

    # Original metrics for the gene
    original_metrics = compute_metrics(y_non_zero, NCEM_results_target_non_zero)
    
    # Append metrics and confidence intervals to DataFrames
    new_rs = pd.DataFrame([{'Gene names': gene, 'Model type': "NCEM", 'R-squared': original_metrics['R-squared'],
                            'CI Lower': ci_lower[f'NCEM R-squared {gene}'], 'CI Upper': ci_upper[f'NCEM R-squared {gene}']}])
    total_rs_df = pd.concat([total_rs_df, new_rs])

In [None]:
# Change the path to an appropriate local directory
save_folder = f"/mnt/d/SCAnalysis/Spateo_MERFISH_benchmark"
total_rs_df.index = np.arange(len(total_rs_df))
total_rs_df.to_csv(os.path.join(save_folder, f"fov{fov_number}_benchmark_R-sq_results.csv"))

#### Barplot

In [None]:
def fisher_z_test_with_ci(r1, ci_lower1, ci_upper1, r2, ci_lower2, ci_upper2):
    z1 = 0.5 * np.log((1 + r1) / (1 - r1))
    z2 = 0.5 * np.log((1 + r2) / (1 - r2))
    var_z1 = ((ci_upper1 - ci_lower1) / (2 * 1.96)) ** 2
    var_z2 = ((ci_upper2 - ci_lower2) / (2 * 1.96)) ** 2
    z_diff = (z1 - z2) / np.sqrt(var_z1 + var_z2)
    p_value = 2 * (1 - norm.cdf(abs(z_diff)))
    return p_value

In [None]:
# Change the path to the directory these were saved to
save_folder = f"/mnt/d/SCAnalysis/Spateo_MERFISH_benchmark"
metric_results = pd.read_csv(os.path.join(save_folder, f"fov{fov_number}_benchmark_R-sq_results.csv"), index_col=0)

In [None]:
from scipy.stats import ttest_ind
from statsmodels.stats.multitest import multipletests

In [None]:
results = []

# Perform the test for each gene
for gene in metric_results["Gene names"].unique():
    spateo_row = metric_results[(metric_results["Gene names"] == gene) & (metric_results["Model type"] == "Spateo")].iloc[0]
    spateo_r = spateo_row["R-squared"]
    spateo_ci_lower = spateo_row["CI Lower"]
    spateo_ci_upper = spateo_row["CI Upper"]
    
    for index, row in metric_results[(metric_results["Gene names"] == gene) & (metric_results["Model type"] != "Spateo")].iterrows():
        model = row["Model type"]
        other_r = row["R-squared"]
        other_ci_lower = row["CI Lower"]
        other_ci_upper = row["CI Upper"]
        p_value = fisher_z_test_with_ci(spateo_r, spateo_ci_lower, spateo_ci_upper, other_r, other_ci_lower, other_ci_upper)
        results.append([gene, model, p_value])

# Convert results to a dataframe
results_df = pd.DataFrame(results, columns=["Gene", "Model", "p-value"])

# Adjust p-values for multiple comparisons using Benjamini-Hochberg correction
results_df["q-value"] = results_df["p-value"] * len(results_df) / (np.arange(1, len(results_df) + 1))
results_df

In [None]:
pastel_colors = sns.color_palette("pastel")
# Convert the colors to hex codes
colors_hex = [sns.color_palette("pastel").as_hex() for color in pastel_colors][0]
colors_hex[0], colors_hex[1] = colors_hex[1], colors_hex[0]
colors_hex

In [None]:
plt.figure(figsize=(12, 5))

# Use seaborn's barplot function with hue parameter for condition
ax = sns.barplot(data=metric_results, x="Gene names", y=col, hue="Model type", palette=colors_hex, edgecolor='black', dodge=True, ci=None)

# Add error bars
bar_width = 0.8 / len(metric_results["Model type"].unique())  # Adjusting for the number of hue categories
for i, (gene, model) in enumerate(zip(metric_results["Gene names"], metric_results["Model type"])):
    y = metric_results[(metric_results["Gene names"] == gene) & (metric_results["Model type"] == model)][col].values[0]
    ci_lower = metric_results[(metric_results["Gene names"] == gene) & (metric_results["Model type"] == model)]["CI Lower"].values[0]
    ci_upper = metric_results[(metric_results["Gene names"] == gene) & (metric_results["Model type"] == model)]["CI Upper"].values[0]
    
    # Calculate x position for each bar
    x = np.where(metric_results["Gene names"].unique() == gene)[0][0]
    x = x - bar_width / 2 * (len(metric_results["Model type"].unique()) - 1) + bar_width * list(metric_results["Model type"].unique()).index(model)
    
    # Plot the error bar
    ax.errorbar(x, y, yerr=[[y - ci_lower], [ci_upper - y]], fmt='none', c='black', elinewidth=3, capsize=4.0, capthick=2.5)

# Add significance annotations
# Initialize a dictionary to track the number of annotations for each gene
annotation_count = {gene: 0 for gene in metric_results["Gene names"].unique()}

for _, row in results_df.iterrows():
    gene = row["Gene"]
    model = row["Model"]
    p_value = row["p-value"]
    q_value = row["q-value"]

    # Determine the asterisk symbol based on q-value
    if q_value < 0.00005:
        symbol = "****"
    elif q_value < 0.0005:
        symbol = "***"
    elif q_value < 0.005:
        symbol = "**"
    elif q_value < 0.05:
        symbol = "*"
    else:
        continue  # Skip if not significant
    
    # Find the y-value for the annotation
    y = metric_results[(metric_results["Gene names"] == gene) & (metric_results["Model type"] == model)][col].values[0]
    
    # Calculate x position for each bar
    x = np.where(metric_results["Gene names"].unique() == gene)[0][0]
    x = x - bar_width / 2 * (len(metric_results["Model type"].unique()) - 1) + bar_width * list(metric_results["Model type"].unique()).index(model)
    
    # Find the y-value for the Spateo model
    y_spateo = metric_results[(metric_results["Gene names"] == gene) & (metric_results["Model type"] == "Spateo")][col].values[0]
    x_spateo = np.where(metric_results["Gene names"].unique() == gene)[0][0]
    x_spateo = x_spateo - bar_width / 2 * (len(metric_results["Model type"].unique()) - 1) + bar_width * list(metric_results["Model type"].unique()).index("Spateo")
    
    # Calculate the vertical position for the annotation
    annotation_offset = annotation_count[gene] * 0.15 + 0.1
    y_max = max(y, y_spateo) + annotation_offset
    
    # Plot the line between the Spateo bar and the other model bar
    ax.plot([x, x, x_spateo, x_spateo], [y_max, y_max + 0.02, y_max + 0.02, y_max], lw=1.5, c='black')
    
    # Add the asterisk annotation above the line
    ax.text((x + x_spateo) / 2, y_max + 0.02, symbol, ha='center', va='bottom', color='black', fontsize=18)
    
    # Update the annotation count for the gene
    annotation_count[gene] += 1

# For better readability, place the legend outside of the plot
plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', fontsize=28)

plt.ylabel(r'Variance explained ($R^2$)', fontsize=24)
plt.xlabel('Target gene', fontsize=36)
plt.xticks(fontsize=30, rotation=90)
plt.yticks(fontsize=28)
plt.ylim(0.1, 1.2)

plt.show()

## Load FOV 4 of the CosMx lung cancer sample

In [None]:
# Replace with wherever this file is stored locally
path_to_cosmx = "/mnt/d/SCData/Spateo_data/CosMx/fov_4.h5ad"
# Replace with wherever the L:R database is stored locally
lr_db = pd.read_csv("/mnt/c/Users/danie/Desktop/Github/Github/spateo-release-main/spateo/tools/database/lr_db_mouse.csv", index_col=0)

In [None]:
lung_fov4 = anndata.read_h5ad(path_to_cosmx)
lung_fov4.uns["__type"] = "UMI"

### Figure S5o- spatially-resolved cell types plot

In [None]:
st.pl.geo(
    lung_fov4, 
    color=["predicted_celltypes"], 
    show_legend='upper left', 
    save_show_or_return='show', 
    figsize=(5, 3), 
    color_key=lung_fov4.uns["celltype_colors"]
)

### Benchmark w/ the COMMOT CCI array

In [None]:
lb = 120.0
ub = 336.1

#### Run Spateo CCI model (can skip if predictions .csv file was created locally or downloaded from the folder)

In [None]:
# Set to the folders that the inputs (ligands list, receptors list, targets list) are contained in and that the outputs (model results) will save to:
cci_input_directory = "/mnt/d/SCAnalysis/Spateo_CosMx_benchmark/CCI_inputs"
cci_output_directory = "/mnt/d/SCAnalysis/Spateo_CosMx_benchmark/CCI_outputs"
cci_output_id = os.path.join(cci_output_directory, "lung_fov4_target_genes.csv")
cci_ligands_file = os.path.join(cci_input_directory, "ligands.txt")
cci_receptors_file = os.path.join(cci_input_directory, "receptors.txt")
cci_targets_file = os.path.join(cci_input_directory, "targets.txt")

In [None]:
# For clarity, this is how the distance bounds are determined
lb = st.tl.find_neighbors.find_bw_for_n_neighbors(
    lung_fov4,
    coords_key="spatial",
    target_n_neighbors=9,
    initial_bw=100,
    exclude_self=True
)

In [None]:
ub = st.tl.find_neighbors.find_bw_for_n_neighbors(
    lung_fov4,
    coords_key="spatial",
    target_n_neighbors=70,
    initial_bw=100,
    exclude_self=True
)

In [None]:
# Define inputs:
adata_path = path_to_cosmx
output_path = cci_output_directory
# Use the ligand/receptor paths from the model fitting:
ligand_path = cci_ligands_file
receptor_path = cci_receptors_file
target_path = cci_targets_file
cci_dir_path = database_dir
mod_type = "lr"
species = "human"
distr = "poisson"

# Key storing cell type information
group_key = "predicted_celltypes"

# Key storing your spatial coordinates
coords_key = "spatial"
distance_membrane_bound = lb
distance_secreted = ub
minbw = lb * 1.5
maxbw = ub

if not os.path.exists(os.path.dirname(output_path)):
    os.makedirs(os.path.dirname(output_path))

In [None]:
parser, args_list = st.tl.define_spateo_argparse(
    adata_path=adata_path,
    custom_lig_path=ligand_path,
    custom_rec_path=receptor_path,
    targets_path=target_path,
    cci_dir=cci_dir_path,
    mod_type=mod_type,
    distr=distr,
    species=species,
    group_key=group_key,
    coords_key=coords_key,
    distance_membrane_bound=distance_membrane_bound,
    distance_secreted=distance_secreted,
    minbw=minbw,
    maxbw=maxbw,
    output_path=output_path,
)

In [None]:
import time

t1 = time.time()

swr_model = st.tl.MuSIC(parser, args_list)
swr_model._set_up_model()
swr_model.fit()
swr_model.predict_and_save()

t_last = time.time()

print("Total Time Elapsed:", np.round(t_last - t1, 2), "seconds")
print("-" * 60)

In [None]:
# Note that the predictions.csv file is also provided in the Dropbox

#### Run COMMOT for the comparison (can skip if the AnnData object was already created locally or downloaded from the folder)

In [None]:
np.random.seed(42)

In [None]:
# Processing for secreted signaling (will add fields to the AnnData object)
df_cellchat = ct.pp.ligand_receptor_database(species='human', signaling_type='Secreted Signaling', database='CellChat')
df_cellchat_filtered = ct.pp.filter_lr_database(df_cellchat, lung_fov4, min_cell_pct=0.05)

ct.tl.spatial_communication(
    lung_fov4,
    database_name='cellchat', 
    df_ligrec=df_cellchat_filtered, 
    dis_thr=ub, 
    heteromeric=True, 
    pathway_sum=True
)

In [None]:
# Processing for ECM signaling
df_cellchat = ct.pp.ligand_receptor_database(species='human', signaling_type='ECM-Receptor', database='CellChat')
df_cellchat_filtered = ct.pp.filter_lr_database(df_cellchat, lung_fov4, min_cell_pct=0.05)

# Both models operate w/ the assumption that ECM components diffuse about as far as other extracellular factors
ct.tl.spatial_communication(
    lung_fov4,
    database_name='cellchat', 
    df_ligrec=df_cellchat_filtered, 
    dis_thr=ub, 
    heteromeric=True, 
    pathway_sum=True
)

In [None]:
# Processing for membrane-bound signaling
df_cellchat = ct.pp.ligand_receptor_database(species='human', signaling_type='Cell-Cell Contact', database='CellChat')
df_cellchat_filtered = ct.pp.filter_lr_database(df_cellchat, lung_fov4, min_cell_pct=0.05)

ct.tl.spatial_communication(
    lung_fov4,
    database_name='cellchat', 
    df_ligrec=df_cellchat_filtered, 
    dis_thr=lb, 
    heteromeric=True, 
    pathway_sum=True
)

In [None]:
# Save AnnData object with COMMOT info- this will also be uploaded to the Dropbox
path_to_cosmx_commot = "/mnt/d/SCData/Spateo_data/CosMx/fov_4_COMMOT.h5ad"
lung_fov4.write_h5ad(path_to_cosmx_commot)

In [None]:
# Compute signal received for each cell predicted by COMMOT:
commot_signal_received = pd.DataFrame(0, columns=lung_fov4.obsp.keys(), index=lung_fov4.obs_names)
for key in lung_fov4.obsp.keys():
    sig_array = lung_fov4.obsp[key]
    test = np.array(sig_array.sum(axis=0)).reshape(-1)
    commot_signal_received.loc[:, key] = test

In [None]:
save_path = "/mnt/d/SCData/Spateo_data/CosMx/fov_4_COMMOT_signal_received.csv"
commot_signal_received.to_csv(save_path)

### Figure S5p- comparison of COMMOT signal to Spateo signal

In [None]:
path_to_commot_signal_received = "/mnt/d/SCData/Spateo_data/CosMx/fov_4_COMMOT_signal_received.csv"
commot_signal_received = pd.read_csv(path_to_commot_signal_received, index_col=0)
lung_fov4_spateo = anndata.read_h5ad(path_to_cosmx)

In [None]:
# Path to Spateo model design matrix
spateo_dm_path = "/mnt/d/SCData/Spateo_data/CosMx/design_matrix_full.csv"
spateo_dm = pd.read_csv(spateo_dm_path, index_col=0)

In [None]:
# We define the presence or absence of a received signal in each cell using the Jaccard index, applied to both the Spateo signal array and the COMMOT signal array
def jaccard(x, y):
    """Compute the Jaccard index between two binary vectors."""
    from scipy.spatial.distance import cdist
    
    # Ensure the input vectors are boolean
    x = x.astype(bool)
    y = y.astype(bool)
    
    # Compute the Jaccard distance using cdist function
    jaccard_distance = cdist(x.values.reshape(1, -1), y.values.reshape(1, -1), metric='jaccard')
    
    # Convert Jaccard distance to Jaccard index
    jaccard_index = 1 - jaccard_distance[0][0]
    
    return jaccard_index

In [None]:
# This is to match the interactions between the Spateo array and the COMMOT array
def match_columns(commot_columns, design_columns):
    matched_columns = {}
    for col in commot_columns:
        # Only consider columns with three dashes
        if col.count("-") != 3:
            continue

        # Extracting the left and right parts of the commot column
        _, _, prefix, suffix = col.split("-")
        expected_col_name = f"{prefix}:{suffix}"
        
        # Checking if the expected column name exists in the design columns
        if expected_col_name in design_columns:
            matched_columns[col] = expected_col_name
    return matched_columns

def compute_jaccard(matched_columns, commot_df, design_df):
    output = []
    for commot_col, design_col in matched_columns.items():
        commot_data = commot_df[commot_col]
        design_data = design_df[design_col]
        
        # Computing the Jaccard index
        jaccard_index = jaccard(commot_data > 0, design_data > 0)
        output.append((commot_col, design_col, jaccard_index))
    
    return output

In [None]:
matched_columns = match_columns(
    commot_signal_received.columns,
    spateo_dm.columns
)

In [None]:
jaccard_indices = compute_jaccard(
    matched_columns,
    commot_signal_received,
    spateo_dm
)

In [None]:
jaccard_df = pd.DataFrame(
    jaccard_indices,
    columns=['Commot Column', 'Design Column', 'Jaccard Index']
)
jaccard_df

In [None]:
# Number of "active signals" for each cell:
commot_features = jaccard_df["Commot Column"]
spateo_features = jaccard_df["Design Column"]

commot_sub = commot_signal_received[commot_features]
spateo_dm_sub = spateo_dm[spateo_features]

In [None]:
# Number of "active signals" for each cell:
commot_sub_nz = commot_sub.applymap(lambda x: 1 if x != 0 else 0)
spateo_dm_sub_nz = spateo_dm_sub.applymap(lambda x: 1 if x != 0 else 0)

active_signals_commot = commot_sub_nz.sum(axis=1)
active_signals_spateo = spateo_dm_sub_nz.sum(axis=1)

In [None]:
from scipy.stats import pearsonr, spearmanr
rp_nonzero, _ = pearsonr(active_signals_commot, active_signals_spateo)
r_nonzero, _ = spearmanr(active_signals_commot, active_signals_spateo)

# Plot the scatter plot for nonzero y values
plt.scatter(active_signals_commot, active_signals_spateo, s=50, facecolors='darkorange', edgecolors='black', linewidths=0.75)

# Set the title and axis labels
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.title(f"Active signals for COMMOT and Spateo\nSpearman r = {r_nonzero:.3f}, Pearson r = {rp_nonzero:.3f}", fontsize=16)
plt.xlabel("Predicted signals per cell- COMMOT", fontsize=16)
plt.ylabel("Predicted signals per cell- Spateo", fontsize=16)

# Show the plot
plt.show()

### Figure S5q- Spearman comparison

In [None]:
# Find the local path where target genes were expressed
cci_input_directory = "/mnt/d/SCAnalysis/Spateo_CosMx_benchmark/CCI_inputs"
cci_targets_file = os.path.join(cci_input_directory, "targets.txt")

In [None]:
with open(cci_targets_file, "r") as file:
    lines = file.readlines()

targets = [line.strip() for line in lines]
targets

#### Use COMMOT signal to predict gene expression (skip to the next section if the Spearman correlations file was already saved)

In [None]:
lung_fov4_spateo = anndata.read_h5ad(path_to_cosmx)

In [None]:
path_to_commot_signal_received = "/mnt/d/SCData/Spateo_data/CosMx/fov_4_COMMOT_signal_received.csv"
commot_signal_received = pd.read_csv(path_to_commot_signal_received, index_col=0)

In [None]:
adata_targets = lung_fov4_spateo[:, targets].copy()

In [None]:
targets_df = pd.DataFrame(adata_targets.X.toarray(), columns=targets, index=adata_targets.obs_names)
targets_df

In [None]:
# Iteratively perform Poisson regression on each column of AnnData object
models = {}
pearson_correlations_COMMOT = {}
spearman_correlations_COMMOT = {}
pearson_correlations_nz_subset_COMMOT = {}
spearman_correlations_nz_subset_COMMOT = {}
predictions = pd.DataFrame(0, columns=targets, index=adata_targets.obs_names)
not_modeled = []

for col in targets_df.columns:
    print(f"Performing Poisson regression on {col}")
    y = targets_df[col].values
    nonzero_names = targets_df[col][targets_df[col] != 0].index.tolist()
    y_nz = targets_df.loc[nonzero_names, col].values

    X = commot_signal_received
    try:
        model = sm.GLM(y, X, family=sm.families.Poisson()).fit()
        models[col] = model
        y_pred = model.predict(X).values
        predictions[col] = y_pred
        y_pred_nz = predictions.loc[nonzero_names, col].values

        rp, _ = stats.pearsonr(y, y_pred)
        rs, _ = stats.spearmanr(y, y_pred)
        pearson_correlations_COMMOT[col] = rp
        spearman_correlations_COMMOT[col] = rs
    
        print(f"Pearson correlation coefficient for {col}: {rp}")
        print(f"Spearman correlation coefficient for {col}: {rs}")

        rp, _ = stats.pearsonr(y_nz, y_pred_nz)
        rs, _ = stats.spearmanr(y_nz, y_pred_nz)
        pearson_correlations_nz_subset_COMMOT[col] = rp
        spearman_correlations_nz_subset_COMMOT[col] = rs

        print(f"Pearson correlation coefficient for nonzero {col}: {rp}")
        print(f"Spearman correlation coefficient for nonzero {col}: {rs}")
    except:
        not_modeled.append(col)

In [None]:
# Save results- for the figure panel, only the Spearman correlation is shown, so save this file
save_dir = "/mnt/d/SCData/Spateo_data/CosMx"
pd.DataFrame.from_dict(spearman_correlations_COMMOT, orient='index', columns=['spearman']).to_csv(os.path.join(save_dir, "spearman_correlations.csv"))
# Save list of not-modeled genes:
with open(os.path.join(save_dir, "COMMOT_not_modeled.txt"), "w") as file:
    for g in not_modeled:
        file.write(f"{g}\n")

In [None]:
# Also save the models
model_dir = "/mnt/d/SCData/Spateo_data/CosMx/COMMOT_models"
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

for model_name, model in models.items():
    with open(os.path.join(model_dir, f"{model_name}.pkl"), "wb") as file:
        pickle.dump(model, file)

#### Plot

In [None]:
lung_fov4_spateo = anndata.read_h5ad(path_to_cosmx)
adata_targets = lung_fov4_spateo[:, targets].copy()

In [None]:
targets_df = pd.DataFrame(adata_targets.X.toarray(), columns=targets, index=adata_targets.obs_names)
targets_df

In [None]:
# Load Spateo predictions to compute metrics for those as well:
cci_output_directory = "/mnt/d/SCAnalysis/Spateo_CosMx_benchmark/CCI_outputs"
spateo_save_path = os.path.join(cci_output_directory, "predictions.csv")
spateo_reconst_df = pd.read_csv(spateo_save_path, index_col=0)

In [None]:
save_dir = "/mnt/d/SCData/Spateo_data/CosMx"
commot_correlations_path = os.path.join(save_dir, "spearman_correlations.csv")
spearman_correlations_COMMOT = pd.read_csv(commot_correlations_path, index_col=0)

In [None]:
# Model fitting may have errored for some of these genes for the COMMOT model:
with open(os.path.join(save_dir, "COMMOT_not_modeled.txt")) as file:
    content = file.read()
    # Assuming each item is on a separate line
    not_modeled = content.splitlines()

not_modeled

In [None]:
pearson_correlations_spateo = {}
spearman_correlations_spateo = {}
pearson_correlations_spateo_nz = {}
spearman_correlations_spateo_nz = {}

for col in spateo_reconst_df.columns:
    # Load this prediction separately:
    if col not in not_modeled:
        y = targets_df[col].values.reshape(-1)
        y_pred = spateo_reconst_df[col].values.reshape(-1)

        nonzero_names = targets_df[col][targets_df[col] != 0].index.tolist()
        y_nz = targets_df.loc[nonzero_names, col].values
        y_pred_nz = spateo_reconst_df.loc[nonzero_names, col].values

        rp, _ = stats.pearsonr(y, y_pred)
        rs, _ = stats.spearmanr(y, y_pred)
        pearson_correlations_spateo[col] = rp
        spearman_correlations_spateo[col] = rs
    
        print(f"Pearson correlation coefficient for {col}: {rp}")
        print(f"Spearman correlation coefficient for {col}: {rs}")

        rp, _ = stats.pearsonr(y_nz, y_pred_nz)
        rs, _ = stats.spearmanr(y_nz, y_pred_nz)
        pearson_correlations_spateo_nz[col] = rp
        spearman_correlations_spateo_nz[col] = rs
    
        print(f"Pearson correlation coefficient for {col}, nonzero subset: {rp}")
        print(f"Spearman correlation coefficient for {col}, nonzero subset: {rs}")

In [None]:
# Comparative barplot:
# Create dataframe
spearman_df = pd.DataFrame({'Labels': list(spearman_correlations_COMMOT.index), 
                            'COMMOT-derived': spearman_correlations_COMMOT.values.reshape(-1), 
                            'Spateo': list(spearman_correlations_spateo.values())}).melt('Labels', var_name='Model', value_name='Correlation')

In [None]:
pastel_colors = sns.color_palette("pastel")
pastel_colors[0]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(30, 8))
sns.barplot(x='Labels', y='Correlation', hue='Model', data=spearman_df, palette=pastel_colors, edgecolor='black')
ax.set_title('Spearman correlations for modeled genes', fontsize=48)
ax.set_xlabel('Genes', fontsize=36)
ax.set_ylabel(r'Spearman ${r}$', fontsize=36)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=32)  # Rotate x-axis labels
ax.set_ylim(0, 1)
# Get current y-tick labels and convert them to string with desired format
y_tick_labels = [f'{label:.2f}' for label in ax.get_yticks()]
# Set new y-tick labels
ax.set_yticklabels(y_tick_labels, fontsize=32)

plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize=30)

plt.tight_layout()
plt.show()

### Figure S5r, s- examples of predicted effects for specific "signaling families"

In [None]:
lung_fov4_spateo = anndata.read_h5ad(path_to_cosmx)

In [None]:
model_dir = "/mnt/d/SCData/Spateo_data/CosMx/COMMOT_models"

In [None]:
target_gene = "KRT19"

#### Plots for the COMMOT models

In [None]:
with open(os.path.join(model_dir, f"{target_gene}.pkl"), "rb") as file:
    model = pickle.load(file)

In [None]:
sm_coeffs_data = model.summary().tables[1].data
sm_coeffs_df = pd.DataFrame(sm_coeffs_data[1:], columns=sm_coeffs_data[0])
sm_coeffs_df = sm_coeffs_df.set_index(sm_coeffs_df.columns[0])
sm_coeffs_df = sm_coeffs_df.astype(float)
sm_coeffs_df

In [None]:
if target_gene == "DDR1":
    interactions = ["COL6A3-ITGA3_ITGB1", "COL4A2-ITGA2_ITGB1", "COL4A2-ITGA3_ITGB1", "COL6A1-ITGA2_ITGB1", "COL6A1-ITGA3_ITGB1", "COL9A2-ITGA3_ITGB1", "COL9A2-ITGA2_ITGB1"]
elif target_gene == "KRT19":
    interactions = [col for col in sm_coeffs_df.index if "WNT5A" in col and "FZD" in col]

matching_columns = [col for col in sm_coeffs_df.index if any(interaction in col for interaction in interactions)]
filtered_df = sm_coeffs_df.loc[matching_columns]
filtered_df

In [None]:
filtered_df.index = [label.replace('commot-cellchat-', '') for label in filtered_df.index]
filtered_df = filtered_df.sort_values('coef', ascending=False)
filtered_df['Interaction'] = filtered_df.index
filtered_df

In [None]:
filtered_df['coef'] = pd.to_numeric(filtered_df['coef'], errors='coerce')
filtered_df = filtered_df.sort_values('coef', ascending=False)

In [None]:
import matplotlib.colors as mcolors
# Use the seismic colormap
colormap = plt.get_cmap('seismic')

# Determine the colors based on the coefficients
norm = mcolors.Normalize(vmin=-1, vmax=1)
colors = [colormap(norm(value)) for value in filtered_df['coef']]

fig, ax = plt.subplots(figsize=(3.5, 2))
sns.barplot(x='Interaction', y='coef', data=filtered_df, ax=ax, palette=colors, edgecolor='black')
ax.set_title(f'COMMOT-derived model \n predicted effects on {target_gene}', fontsize=14)
ax.set_xlabel('L:R interactions', fontsize=18)
ax.set_ylabel('Effect size', fontsize=18)
plt.xticks(rotation=90, fontsize=14)
plt.yticks(fontsize=10)
plt.show()

#### Plots for the Spateo models

In [None]:
target_cells = np.where(lung_fov4_spateo[:, target_gene].X.toarray() > 0)[0]
adata_target = lung_fov4_spateo[target_cells].copy()

In [None]:
cci_output_directory = "/mnt/d/SCAnalysis/Spateo_CosMx_benchmark/CCI_outputs"
# If this directory/file do not exist (indicating during model training a different path was specified, change this below):
cci_output_id = os.path.join(cci_output_directory, f"lung_fov4_target_genes_{target_gene}.csv")
target_coeffs = pd.read_csv(cci_output_id, index_col=0)
target_coeffs = target_coeffs[[c for c in target_coeffs.columns if "b_" in c]]
target_coeffs

In [None]:
target_coeffs = target_coeffs.iloc[target_cells]
target_coeffs

In [None]:
if target_gene == "DDR1":
    target_coeffs_sub = target_coeffs.loc[:, [col for col in target_coeffs.columns if ("CD44" in col or "_IT" in col) and "ICAM" not in col and "IGF" not in col and "MIF" not in col and "CD40" not in col]]
    target_coeffs_sub.columns = [replace_col_with_collagens(c) for c in target_coeffs_sub.columns]
    means = target_coeffs_sub.apply(lambda x: x[x > 0].mean())
    means = pd.DataFrame(means, columns=["coeff"])
    means["Interaction"] = [idx.replace("b_", "") for idx in means.index]
    means = means.sort_values('coeff', ascending=False)
    means = means.iloc[1:20]
    print(means)
    # Final set to plot:
    means = means.loc[["b_Collagens/FN1/VTN:ITGAV_ITGB8", "b_SPP1/VTN:ITGAV_ITGB5", "b_Collagens/FN1/THBS1/THBS2:ITGA3_ITGB1", "b_Collagens/SPP1:ITGA9_ITGB1", "b_CDH1/Collagens:ITGA2_ITGB1", "b_ANGPTL1/Collagens:ITGA1_ITGB1", "b_FN1/SPP1/VTN:ITGAV_ITGB1"]]

    import matplotlib.colors as mcolors
    # Use the seismic colormap
    colormap = plt.get_cmap('seismic')
    
    # Determine the colors based on the coefficients
    norm = mcolors.Normalize(vmin=-means['coeff'].max(), vmax=means['coeff'].max())
    colors = [colormap(norm(value)) for value in means['coeff']]
    
    fig, ax = plt.subplots(figsize=(3.5, 2))
    sns.barplot(x='Interaction', y='coeff', data=means, ax=ax, palette=colors, edgecolor='black')
    ax.set_title(f'Spateo model predicted \n effects on {target_gene}- ECM', fontsize=18)
    ax.set_xlabel('L:R interactions', fontsize=14)
    ax.set_ylabel('Normalized \n mean effect size', fontsize=14)
    plt.xticks(rotation=90, fontsize=14)
    plt.yticks(fontsize=10)
    plt.show()
elif target_gene == "KRT19":
    target_coeffs_sub = target_coeffs.loc[:, [col for col in target_coeffs.columns if "WNT5A" in col and "FZD" in col]]
    means = target_coeffs_sub.apply(lambda x: x[x > 0].mean())
    means = pd.DataFrame(means, columns=["coeff"])
    means["Interaction"] = [idx.replace("b_", "") for idx in means.index]
    means = means.sort_values('coeff', ascending=False)

    import matplotlib.colors as mcolors
    # Use the seismic colormap
    colormap = plt.get_cmap('seismic')
    
    # Determine the colors based on the coefficients
    norm = mcolors.Normalize(vmin=-means['coeff'].max(), vmax=means['coeff'].max())
    colors = [colormap(norm(value)) for value in means['coeff']]
    
    fig, ax = plt.subplots(figsize=(3.5, 2))
    sns.barplot(x='Interaction', y='coeff', data=means, ax=ax, palette=colors, edgecolor='black')
    ax.set_title(f'Spateo model predicted \n effects on {target_gene}- WNT', fontsize=18)
    ax.set_xlabel('L:R interactions', fontsize=14)
    ax.set_ylabel('Normalized \n mean effect size', fontsize=14)
    plt.xticks(rotation=90, fontsize=14)
    plt.yticks(fontsize=10)
    plt.show()