Load DKD mouse slide-seq v2 slice dataset from pysodb (https://pysodb.readthedocs.io/en/latest/installation/installation.html)

In [None]:
import pysodb 
import scanpy as sc
import anndata as ad 

import pysodb

sodb = pysodb.SODB()

experiment_list = sodb.list_experiment_by_dataset('Marshall2022High_mouse')

adata_list = []


for i, e in enumerate(experiment_list):
    adata = sodb.load_experiment('Marshall2022High_mouse',e)
    adata.obs['x'] = adata.obsm['spatial'][...,0]
    adata.obs['y'] = adata.obsm['spatial'][...,1]
    adata_list.append(adata)
    
adata = ad.concat(adata_list, label='slice_id')

adata = adata[adata.obs['disease'].isin(['diabetic kidney disease', 'normal'])].copy()
adata.obs['condition'] = (adata.obs['disease'] == 'diabetic kidney disease').astype(int)

ct_obs = 'cell_type'
batch_obs = 'slice_id'

Apply Taichi by DKD condition

In [None]:
import time
from Taichi.model import Taichi


import scanpy as sc 
import time
import numpy as np
import anndata as ad

start_time = time.time()

model = Taichi(adata, ct_obs=ct_obs, slice_id=batch_obs)

model.mender_init(scale=4, radius=50, nn_mode='radius')

model.run_mender(n_process=200)

model.label_refinement()

adata = model.graph_diffusion()

end_time = time.time()

print(f'Total Running Time {end_time - start_time}')

In [None]:
wt = adata[adata.obs['disease'] == 'normal'].copy()
obs = adata[adata.obs['disease'] == 'diabetic kidney disease'].copy()

DKD-relevant niches visualization

In [None]:
import squidpy as sq 

import scanpy as sc 

import matplotlib.pyplot as plt

from matplotlib.colors import ListedColormap

#TNBC=16 4 5 (4 15)
#CRC= 1 4 5  (4 50)
#Diabetes=3 4 5 (4 50)
from sklearn.neighbors import LocalOutlierFactor

plt.rcParams['font.family'] = 'Arial'
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 1.5
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.titlesize'] = 20  # Adjust title size
plt.rcParams['axes.labelsize'] = 15

res = adata

batch_obs = 'slice_id'
ct_obs = 'cell_type'


palette = ['#CA9C91', '#8D5FA3', '#7E594D', '#CE8DAC', '#C24A7A', '#D2AD50', '#83B756',
           '#95D1D7', '#748EBB', '#CC625F', '#FFD377', '#FD9BA0', '#BE9E33', '#C0E56F']



res_palette = ['#DAE4EE', '#E59EA1']


train_adata = res.copy()

d = {c: palette[i] for i, c  in enumerate(train_adata.obs[ct_obs].cat.categories)}

train_adata.obs['new_labels'] = train_adata.obs['new_labels'].astype('category')

_, control_index = train_adata[train_adata.obs['condition'] == 0].obs[batch_obs].factorize()
_, condition_index = train_adata[train_adata.obs['condition'] == 1].obs[batch_obs].factorize()


adata_1 = train_adata[train_adata.obs[batch_obs].isin([control_index[0]])].copy()
adata_2 = train_adata[train_adata.obs[batch_obs].isin([condition_index[3]])].copy()
adata_3 = train_adata[train_adata.obs[batch_obs].isin([condition_index[4]])].copy()
adata_4 = train_adata[train_adata.obs[batch_obs].isin([condition_index[5]])].copy()

adata_list = [adata_2, adata_3, adata_4]

d = {c: palette[i] for i, c  in enumerate(train_adata.obs[ct_obs].cat.categories)}

for k, adata in enumerate(adata_list):

    
    data = adata.obsm['spatial']
        
    clf = LocalOutlierFactor(n_neighbors=2, contamination=0.2)
    is_inlier = clf.fit_predict(data) == 1

    g = sq.pl.spatial_scatter(adata[is_inlier], shape=None, color=['new_labels'],  return_ax=True, ncols=1, size=10, wspace=0.1, hspace=0.1, palette=ListedColormap(res_palette))

    all_handles_labels = []

    g = [g]

    for i, ax in enumerate(g):
        ax.get_legend().remove()
        handles, labels = ax.get_legend_handles_labels()
        for handle, label in zip(handles, labels):
            if label not in [l for _, l in all_handles_labels]:
                all_handles_labels.append((handle, label))
        '''for spine in ax.spines.values():
            spine.set_visible(False)'''
        ax.legend().remove()
        ax.set_title('Taichi Outcome') 
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        
    plt.tight_layout()
    plt.show()

    unique_handles_labels =  dict(all_handles_labels).items()

    # Create a new figure for the comprehensive legend
    fig_legend, ax_legend = plt.subplots(figsize=(2, 3))  # Adjust size as needed
    fig_legend.legend(*zip(*unique_handles_labels), loc='center')
    ax_legend.axis('off')
    plt.show()

    cat = adata[is_inlier].obs[ct_obs].cat.categories
    
    t_p = [d[c] for c in cat]

    g = sq.pl.spatial_scatter(adata[is_inlier], shape=None, color=[ct_obs],  return_ax=True, ncols=1, size=10, wspace=0.1, hspace=0.1, palette=ListedColormap(t_p))

    all_handles_labels = []

    g = [g]

    for i, ax in enumerate(g):
        ax.get_legend().remove()
        handles, labels = ax.get_legend_handles_labels()
        for handle, label in zip(handles, labels):
            if label not in [l for _, l in all_handles_labels]:
                all_handles_labels.append((handle, label))
        '''for spine in ax.spines.values():
            spine.set_visible(False)'''
        ax.legend().remove()
        ax.set_title('Cell Type') 
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        
    plt.tight_layout()
    plt.show()

    unique_handles_labels =  dict(all_handles_labels).items()

    # Create a new figure for the comprehensive legend
    fig_legend, ax_legend = plt.subplots(figsize=(2, 3))  # Adjust size as needed
    fig_legend.legend(*zip(*unique_handles_labels), loc='center')
    ax_legend.axis('off')
    plt.show()

Identification of the DKD-relevant niches specifical spatially varible genes by Spagcn(https://github.com/jianhuupenn/SpaGCN)

In [None]:
import numpy as np
import scanpy as sc 
import SpaGCN as spg
from scipy.sparse import issparse


import scanpy as sc 
import SpaGCN as spg
from scipy.sparse import issparse


raw_1 = adata_4.copy()

sc.pp.normalize_total(raw_1, 1e5)
sc.pp.log1p(raw_1)

raw_1.X = raw_1.X.toarray()

target = 1

#Use domain 0 as an example
#target='PVT'
#Set filtering criterials
min_in_group_fraction=0.8
min_in_out_group_ratio=1
min_fold_change=1.5
#Search radius such that each spot in the target domain has approximately 10 neighbors on average
adj_2d=spg.calculate_adj_matrix(x=raw_1.obs["x"].tolist(), y=raw_1.obs["y"].tolist(), histology=False)
start, end= np.quantile(adj_2d[adj_2d!=0],q=0.000001), np.quantile(adj_2d[adj_2d!=0],q=0.1)
r=spg.search_radius(target_cluster=target, cell_id=raw_1.obs.index.tolist(), x=raw_1.obs["x"].tolist(), y=raw_1.obs["y"].tolist(), pred=raw_1.obs["new_labels"].tolist(), start=start, end=end, num_min=10, num_max=14,  max_run=100)
#Detect neighboring domains
nbr_domians=spg.find_neighbor_clusters(target_cluster=target,
                            cell_id=raw_1.obs.index.tolist(), 
                            x=raw_1.obs["x"].tolist(), 
                            y=raw_1.obs["y"].tolist(), 
                            pred=raw_1.obs["new_labels"].tolist(),
                            radius=r,
                            ratio=1/2)

nbr_domians=nbr_domians[0:3]
de_genes_info=spg.rank_genes_groups(input_adata=raw_1,
                            target_cluster=target,
                            nbr_list=nbr_domians, 
                            label_col="new_labels", 
                            adj_nbr=True, 
                            log=True)
#Filter genes
de_genes_info=de_genes_info[(de_genes_info["pvals_adj"]<0.05)]
filtered_info=de_genes_info
filtered_info=filtered_info[(filtered_info["pvals_adj"]<0.05) &
                        (filtered_info["in_out_group_ratio"]>min_in_out_group_ratio) &
                        (filtered_info["in_group_fraction"]>min_in_group_fraction) &
                        (filtered_info["fold_change"]>min_fold_change)]
filtered_info_1=filtered_info.sort_values(by="in_group_fraction", ascending=False)
filtered_info_1["target_dmain"]=target
filtered_info_1["neighbors"]=str(nbr_domians)
print("SVGs for domain ", str(target),":", filtered_info_1["genes"].tolist())


In [None]:
import matplotlib.colors as clr

target=1
meta_name, meta_exp=spg.find_meta_gene(input_adata=raw_1,
                    pred=raw_1.obs["new_labels"].tolist(),
                    target_domain=target,
                    start_gene='ENSMUSG00000028307',
                    mean_diff=0,
                    early_stop=True,
                    max_iter=10,
                    use_raw=False)


raw_1.obs['meta_1'] =  meta_exp

g = sq.pl.spatial_scatter(raw_1, shape=None, color=['meta_1'], return_ax=True)