In [1]:
import time
from pathlib import Path
from operator import itemgetter

import scanpy as sc
import numpy as np
import pandas as pd
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from scipy.sparse import csr_matrix
from joblib import Parallel, delayed
import os
import torch

import scSLAT
from scSLAT.model import Cal_Spatial_Net, load_anndatas, run_SLAT_mlp_AGF, spatial_match, run_SLAT, compute_lisi_for_adata, run_SLAT_AGF_contrast
from scSLAT.viz import match_3D_multi, hist, Sankey, build_3D
from scSLAT.metrics import region_statistics

In [2]:
adata_1=sc.read_h5ad('D:\ppppaper\data\Result\seqFISH.h5ad')
adata_1.var_names_make_unique(join="++")
adata_1

AnnData object with n_obs × n_vars = 11529 × 351
    obs: 'z', 'uniqueID', 'x_global', 'y_global', 'embryo', 'Estage', 'x_global_affine', 'y_global_affine', 'UMAP1', 'UMAP2', 'cluster', 'celltype_mapped_refined', 'celltype_mapped', 'annotation'
    var: 'gene_names'
    uns: 'Spatial_Net', 'annotation_colors'
    obsm: 'SLAT', 'banksy', 'spatial', 'stLVG'

In [3]:
adata_2=sc.read_h5ad('D:\ppppaper\data\Result\Stereo_seq.h5ad')
adata_2.var_names_make_unique(join="++")
adata_2

AnnData object with n_obs × n_vars = 5031 × 25568
    obs: 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'annotation', 'Regulon - 2310011J03Rik', 'Regulon - 5730507C01Rik', 'Regulon - Alx1', 'Regulon - Alx3', 'Regulon - Alx4', 'Regulon - Ar', 'Regulon - Arid3a', 'Regulon - Arid3c', 'Regulon - Arnt2', 'Regulon - Arx', 'Regulon - Ascl1', 'Regulon - Atf1', 'Regulon - Atf4', 'Regulon - Atf5', 'Regulon - Atf6', 'Regulon - Atf7', 'Regulon - Bach1', 'Regulon - Bach2', 'Regulon - Barhl1', 'Regulon - Barx1', 'Regulon - Batf', 'Regulon - Bcl11a', 'Regulon - Bcl3', 'Regulon - Bcl6', 'Regulon - Bcl6b', 'Regulon - Bclaf1', 'Regulon - Bdp1', 'Regulon - Bhlha15', 'Regulon - Bhlhe22', 'Regulon - Bhlhe23', 'Regulon - Bhlhe41', 'Regulon - Bmyc', 'Regulon - Boll', 'Regulon - Bptf', 'Regulon - Brca1', 'Regulon - Brf1', 'Regulon - Brf2', 'Regulon - Bsx', 'Regulon - Cdx1', 'Regulon - Cdx2', 'Regulon - Cebpa', 'Regulon - Cebpz', 'Regulon - Chd1', 'Regulon - Clock', 'Re

In [17]:
embd0 = adata_1.obsm['stLVG']
embd1 = adata_2.obsm['stLVG']

In [18]:
best, index, distance = spatial_match([embd0, embd1], adatas=[adata_1, adata_2])
adata1_df = pd.DataFrame({'index': range(embd0.shape[0]),
                        'x': adata_1.obsm['spatial'][:,0],
                        'y': adata_1.obsm['spatial'][:,1],
                        'celltype': adata_1.obs['annotation']})
adata2_df = pd.DataFrame({'index': range(embd1.shape[0]),
                        'x': adata_2.obsm['spatial'][:,0],
                        'y': adata_2.obsm['spatial'][:,1],
                        'celltype': adata_2.obs['annotation']})

matching = np.array([range(index.shape[0]), best])
best_match = distance[:,0]

In [19]:
adata2_df['target_celltype'] = adata1_df.iloc[matching[1,:],:]['celltype'].to_list()
matching_table = adata2_df.groupby(['celltype','target_celltype']).size().unstack(fill_value=0)

In [20]:
# import pandas as pd
# import plotly.graph_objects as go
# from typing import Optional, List

# def Sankey(matching_table: pd.DataFrame,
#            filter_num: Optional[int] = 50,
#            color: Optional[List[str]] = 'red',
#            title: Optional[str] = '',
#            prefix: Optional[List[str]] = ['E11.5', 'E12.5'],
#            layout: Optional[List[int]] = [1300, 900],
#            font_size: Optional[float] = 15,
#            font_color: Optional[str] = 'Black',
#            save_name: Optional[str] = None,
#            format: Optional[str] = 'png',
#            width: Optional[int] = 1200,
#            height: Optional[int] = 1000,
#            return_fig: Optional[bool] = False
#            ) -> None:
#     source, target, value = [], [], []
#     label_ref = [a + f'_{prefix[0]}' for a in matching_table.columns.to_list()]
#     label_query = [a + f'_{prefix[1]}' for a in matching_table.index.to_list()]
#     label_all = label_query + label_ref
#     label2index = dict(zip(label_all, list(range(len(label_all)))))

#     for i, query in enumerate(label_query):
#         for j, ref in enumerate(label_ref):
#             if int(matching_table.iloc[i, j]) > filter_num:
#                 target.append(label2index[query])
#                 source.append(label2index[ref])
#                 value.append(int(matching_table.iloc[i, j]))

#     fig = go.Figure(
#         data=[go.Sankey(
#             node=dict(
#                 pad=50,
#                 thickness=50,
#                 line=dict(color="green", width=0.5),
#                 label=[''] * len(label_all),  # 将所有标签设置为空字符串
#                 color=color
#             ),
#             link=dict(
#                 source=source,  # indices correspond to labels, eg A1, A2, A1, B1, ...
#                 target=target,
#                 value=value
#             )
#         )],
#         layout=go.Layout(autosize=False, width=layout[0], height=layout[1])
#     )

#     fig.update_layout(title_text=title, font_size=font_size, font_color=font_color)
#     fig.show()
#     if save_name is not None:
#         fig.write_image(save_name + f'.{format}', width=width, height=height)
#     if return_fig:
#         return fig

In [21]:
num_nodes = len(matching_table.columns) + len(matching_table.index)
# fig = Sankey(matching_table, prefix=['seqFISH', 'Stereo-seq'], filter_num=60, layout=[1600,900], format='pdf')

fig = Sankey(
    matching_table,
    filter_num=60,
    color='red',
    prefix=['seqFISH', 'Stereo-seq'],
    layout=[1600, 900],
    format='pdf',
    # return_fig=True  # 确保返回fig对象
)

In [22]:
# fig.write_image('./sankey_stLVG.pdf', engine="kaleido")