In [1]:
import time
from pathlib import Path
from operator import itemgetter

import scanpy as sc
import numpy as np
import pandas as pd
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from scipy.sparse import csr_matrix
from joblib import Parallel, delayed
import os
import torch

import scSLAT
from scSLAT.model import Cal_Spatial_Net, load_anndatas, run_SLAT_mlp_AGF, spatial_match, run_SLAT, compute_lisi_for_adata, run_SLAT_AGF_contrast
from scSLAT.viz import match_3D_multi, hist, Sankey, build_3D
from scSLAT.metrics import region_statistics

In [None]:
class match_3D_multi():
    r"""
    Plot the mapping result between 2 datasets
    
    Parameters
    ---------
    dataset_A
        pandas dataframe which contain ['index','x','y'], reference dataset
    dataset_B
        pandas dataframe which contain ['index','x','y'], target dataset
    matching
        matching results
    meta
        dataframe colname of meta, such as celltype
    expr
        dataframe colname of gene expr
    subsample_size
        subsample size of matches
    reliability
        match score (cosine similarity score)
    scale_coordinate
        if scale coordinate via (:math:`data - np.min(data)) / (np.max(data) - np.min(data))`)
    rotate
        how to rotate the slides (force scale_coordinate), such as ['x','y'], means dataset0 rotate on x axes
        and dataset1 rotate on y axes
    change_xy
        exchange x and y on dataset_B
    subset
        index of query cells to be plotted

    Note
    ----------
    dataset_A and dataset_B can in different length
        
    """
    def __init__(self,dataset_A:pd.DataFrame,
                dataset_B:pd.DataFrame,
                matching:np.ndarray,
                meta:Optional[str]=None,
                expr:Optional[str]=None,
                subsample_size:Optional[int]=300,
                reliability:Optional[np.ndarray]=None,
                scale_coordinate:Optional[bool]=True,
                rotate:Optional[List[str]]=None,
                exchange_xy:Optional[bool]=False,
                subset: Optional[List[int]]=None
        ) -> None:
        self.dataset_A = dataset_A.copy()
        self.dataset_B = dataset_B.copy()
        self.meta = meta
        self.matching= matching
        self.conf = reliability
        self.subset = subset # index of query cells to be plotted
        scale_coordinate = True if rotate != None else scale_coordinate
        
        assert all(item in dataset_A.columns.values for item in ['index','x','y'])
        assert all(item in dataset_B.columns.values for item in ['index','x','y'])
        
        if meta:
            set1 = list(set(self.dataset_A[meta]))
            set2 = list(set(self.dataset_B[meta]))
            self.celltypes = set1 + [x for x in set2 if x not in set1]
            self.celltypes.sort() # make sure celltypes are in the same order
            overlap = [x for x in set2 if x in set1]
            print(f"dataset1: {len(set1)} cell types; dataset2: {len(set2)} cell types; \n\
                    Total :{len(self.celltypes)} celltypes; Overlap: {len(overlap)} cell types \n\
                    Not overlap :[{[y for y in (set1+set2) if y not in overlap]}]"
                    )
        self.expr = expr if expr else False
            
        if scale_coordinate:
            for i, dataset in enumerate([self.dataset_A, self.dataset_B]):
                for axis in ['x','y']:
                    dataset[axis] = (dataset[axis] - np.min(dataset[axis])) / (np.max(dataset[axis])- np.min(dataset[axis]))
                    if rotate == None:
                        pass
                    elif axis in rotate[i]:
                        dataset[axis] = 1 - dataset[axis]
        if exchange_xy:
            self.dataset_B[['x','y']] = self.dataset_B[['y','x']]

        if not subset is None:
            matching = matching[:,subset]
        if matching.shape[1] > subsample_size and subsample_size > 0:
            self.matching = matching[:,np.random.choice(matching.shape[1],subsample_size, replace=False)]
        else:
            subsample_size = matching.shape[1]
            self.matching = matching
        print(f'Subsampled {subsample_size} pairs from {matching.shape[1]}')
        
            
        self.datasets = [self.dataset_A, self.dataset_B]
    
    def draw_3D(self,
                size: Optional[List[int]]=[10,10],
                conf_cutoff: Optional[float]=0,
                point_size: Optional[List[int]]=[0.1,0.1],
                line_width: Optional[float]=0.3,
                line_color:Optional[str]='grey',
                line_alpha: Optional[float]=0.7,
                hide_axis: Optional[bool]=False,
                show_error: Optional[bool]=True,
                show_celltype: Optional[bool]=False,
                cmap: Optional[bool]='Reds',
                save:Optional[str]=None
        ) -> None:
        r"""
        Draw 3D picture of two datasets
        
        Parameters:
        ----------
        size
            plt figure size
        conf_cutoff
            confidence cutoff of mapping to be plotted
        point_size
            point size of every dataset
        line_width
            pair line width
        line_color
            pair line color
        line_alpha
            pair line alpha
        hide_axis
            if hide axis
        show_error
            if show error celltype mapping with different color
        cmap
            color map when vis expr
        save
            save file path
        """
        self.conf_cutoff = conf_cutoff
        show_error = show_error if self.meta else False
        fig = plt.figure(figsize=(size[0],size[1]))
        ax = fig.add_subplot(111, projection='3d')
        # color by meta
        if self.meta:
            color = get_color(len(self.celltypes))
            c_map = {}
            for i, celltype in enumerate(self.celltypes):
                c_map[celltype] = color[i]
            if self.expr:
                c_map = cmap
                # expr_concat = pd.concat(self.datasets)[self.expr].to_numpy()
                # norm = plt.Normalize(expr_concat.min(), expr_concat.max())
            for i, dataset in enumerate(self.datasets):
                if self.expr:
                    norm = plt.Normalize(dataset[self.expr].to_numpy().min(), dataset[self.expr].to_numpy().max())
                for cell_type in self.celltypes:
                    slice = dataset[dataset[self.meta] == cell_type]
                    xs = slice['x']
                    ys = slice['y']
                    zs = i
                    if self.expr:
                        ax.scatter(xs, ys, zs, s=point_size[i], c=slice[self.expr], cmap=c_map, norm=norm)
                    else:
                        ax.scatter(xs, ys, zs, s=point_size[i], c=c_map[cell_type])
        # plot points without meta
        else:
            for i, dataset in enumerate(self.datasets):
                xs = dataset['x']
                ys = dataset['y']
                zs = i
                ax.scatter(xs,ys,zs,s=point_size[i])
        # plot line
        self.c_map = c_map
        self.draw_lines(ax, show_error, show_celltype, line_color, line_width, line_alpha)
        if hide_axis:
            plt.axis('off')
        if save != None:
            plt.savefig(save)
        plt.show()
        
    def draw_lines(self, ax, show_error, show_celltype, line_color, line_width=0.3, line_alpha=0.7) ->  None:
        r"""
        Draw lines between paired cells in two datasets
        """

        equal_count = 0  # Counter for matches with equal cell types
        not_equal_count = 0  # Counter for matches with not equal cell types
        for i in range(self.matching.shape[1]):
            if not self.conf is None and self.conf[i] < self.conf_cutoff:
                continue
            pair = self.matching[:,i]
            default_color = line_color
            if self.meta != None:
                celltype1 = self.dataset_A.loc[self.dataset_A['index']==pair[1], self.meta].astype(str).values[0]
                celltype2 = self.dataset_B.loc[self.dataset_B['index']==pair[0], self.meta].astype(str).values[0]
                if show_error:
                    if celltype1 == celltype2:
                        color = '#ade8f4' # blue
                        equal_count += 1
                    else:
                        color = '#ffafcc'  # red
                        not_equal_count += 1
                if show_celltype:
                    if celltype1 == celltype2:
                        color = self.c_map[celltype1]
                    else:
                        color = '#696969' # celltype1 error match color
            point0 = np.append(self.dataset_A[self.dataset_A['index']==pair[1]][['x','y']], 0)
            point1 = np.append(self.dataset_B[self.dataset_B['index']==pair[0]][['x','y']], 1)

            coord = np.row_stack((point0, point1))
            color = color if show_error or show_celltype else default_color
            ax.plot(coord[:,0], coord[:,1], coord[:,2], color=color, linestyle="dashed", linewidth=line_width, alpha=line_alpha)
            # print(f'Count of matches with equal cell types: {equal_count}')
            # print(f'Count of matches with not equal cell types: {not_equal_count}')

In [None]:
# 假设有一个 celltype 列表或数组，包含每个点的 celltype 信息
celltypes = multi_align.get_celltypes()  # 假设有这个方法获取 celltype 列表
unique_celltypes = sorted(set(celltypes))

# 创建 celltype 到颜色的映射
color_map = {celltype: pathology_color[i % len(pathology_color)] for i, celltype in enumerate(unique_celltypes)}

# 修改 draw_3D 方法来使用颜色映射
def draw_3D_with_colors(self, size=[7, 8], line_width=0.7, line_color='grey', point_size=[1.25, 2.5], 
                        hide_axis=True, show_error=False, save='./Alignment.png'):
    fig = plt.figure(figsize=size)
    ax = fig.add_subplot(111, projection='3d')

    for i, (xs, ys, zs) in enumerate(self.get_coordinates()):  # 假设有 get_coordinates 方法
        celltype = celltypes[i]
        color = color_map[celltype]  # 获取对应的颜色
        ax.scatter(xs, ys, zs, s=point_size[i], c=color, linewidth=line_width)

    if hide_axis:
        ax.axis('off')

    if save:
        plt.savefig(save)
    plt.show()

# 使用新的方法绘图
multi_align.draw_3D_with_colors(size=[7, 8], line_width=0.7, line_color='grey', point_size=[1.25, 2.5], 
                                hide_axis=True, show_error=False, save='./Alignment.png')

In [2]:
adata_1=sc.read_h5ad(r"D:\ppppaper\data\seqFISH\filtered_seqFish.h5ad")
adata_1.var_names_make_unique(join="++")
adata_1.obs['annotation'] = adata_1.obs['celltype_mapped_refined']
adata_1

AnnData object with n_obs × n_vars = 11529 × 351
    obs: 'z', 'uniqueID', 'x_global', 'y_global', 'embryo', 'Estage', 'x_global_affine', 'y_global_affine', 'UMAP1', 'UMAP2', 'cluster', 'celltype_mapped_refined', 'celltype_mapped', 'annotation'
    var: 'gene_names'
    obsm: 'spatial'

In [3]:
adata_2=sc.read_h5ad(r"D:\ppppaper\data\Stereo_seq\filtered_Stereoseq.h5ad")
adata_2.var_names_make_unique(join="++")
adata_2

AnnData object with n_obs × n_vars = 5031 × 25568
    obs: 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'annotation', 'Regulon - 2310011J03Rik', 'Regulon - 5730507C01Rik', 'Regulon - Alx1', 'Regulon - Alx3', 'Regulon - Alx4', 'Regulon - Ar', 'Regulon - Arid3a', 'Regulon - Arid3c', 'Regulon - Arnt2', 'Regulon - Arx', 'Regulon - Ascl1', 'Regulon - Atf1', 'Regulon - Atf4', 'Regulon - Atf5', 'Regulon - Atf6', 'Regulon - Atf7', 'Regulon - Bach1', 'Regulon - Bach2', 'Regulon - Barhl1', 'Regulon - Barx1', 'Regulon - Batf', 'Regulon - Bcl11a', 'Regulon - Bcl3', 'Regulon - Bcl6', 'Regulon - Bcl6b', 'Regulon - Bclaf1', 'Regulon - Bdp1', 'Regulon - Bhlha15', 'Regulon - Bhlhe22', 'Regulon - Bhlhe23', 'Regulon - Bhlhe41', 'Regulon - Bmyc', 'Regulon - Boll', 'Regulon - Bptf', 'Regulon - Brca1', 'Regulon - Brf1', 'Regulon - Brf2', 'Regulon - Bsx', 'Regulon - Cdx1', 'Regulon - Cdx2', 'Regulon - Cebpa', 'Regulon - Cebpz', 'Regulon - Chd1', 'Regulon - Clock', 'Re

In [4]:
x1_coords=adata_1.obsm['spatial'][:,0]
y1_coords=adata_1.obsm['spatial'][:,1]
x2_coords=adata_2.obsm['spatial'][:,0]
y2_coords=adata_2.obsm['spatial'][:,1]
locations_1 = np.array([x1_coords, y1_coords])
locations_2 = np.array([x2_coords, y2_coords])

locations_1_tensor = torch.tensor(locations_1)
locations_2_tensor = torch.tensor(locations_2)
locations_1_tensor = locations_1_tensor.transpose(0, 1)
locations_1_tensor = locations_1_tensor.to(dtype=torch.float32)
locations_2_tensor = locations_2_tensor.transpose(0, 1)
locations_2_tensor = locations_2_tensor.to(dtype=torch.float32)
location = [locations_1_tensor, locations_2_tensor]

In [5]:
Cal_Spatial_Net(adata_1, k_cutoff=50, model='KNN')
Cal_Spatial_Net(adata_2, k_cutoff=25, model='KNN')

Calculating spatial neighbor graph ...
The graph contains 642647 edges, 11529 cells.
55.741781594240614 neighbors per cell on average.
Calculating spatial neighbor graph ...
The graph contains 132251 edges, 5031 cells.
26.287219240707614 neighbors per cell on average.


In [6]:
edges, features = load_anndatas([adata_1, adata_2], feature='DPCA')

Use DPCA feature to format graph



See the tutorial for concat at: https://anndata.readthedocs.io/en/latest/concatenation.html
  view_to_actual(adata)
  view_to_actual(adata)
  view_to_actual(adata)


In [7]:
embd0_0, embd0_1, embd1_0, embd1_1, embd0, embd1, time = run_SLAT_AGF_contrast(features, edges, location, limit_loss=0.001)

GPU is not available
Running
---------- epochs: 1 ----------
---- Ran row_normalize in 0.04 s ----

---- Ran row_normalize in 0.03 s ----

---------- epochs: 2 ----------
---- Ran row_normalize in 0.07 s ----

---- Ran row_normalize in 0.05 s ----

---------- epochs: 3 ----------
---- Ran row_normalize in 0.06 s ----

---- Ran row_normalize in 0.03 s ----

---------- epochs: 4 ----------
---- Ran row_normalize in 0.05 s ----

---- Ran row_normalize in 0.03 s ----

---------- epochs: 5 ----------
---- Ran row_normalize in 0.06 s ----

---- Ran row_normalize in 0.03 s ----

---------- epochs: 6 ----------
---- Ran row_normalize in 0.06 s ----

---- Ran row_normalize in 0.03 s ----

---- Ran row_normalize in 0.06 s ----

---- Ran row_normalize in 0.03 s ----

---------- epochs: 1 ----------
---- Ran row_normalize in 0.06 s ----



  complex_array_tensor = torch.tensor(complex_array, dtype=torch.float32)


---- Ran row_normalize in 0.03 s ----

---------- epochs: 2 ----------
---- Ran row_normalize in 0.06 s ----

---- Ran row_normalize in 0.03 s ----

---------- epochs: 3 ----------
---- Ran row_normalize in 0.06 s ----

---- Ran row_normalize in 0.03 s ----

---------- epochs: 4 ----------
---- Ran row_normalize in 0.06 s ----

---- Ran row_normalize in 0.05 s ----

---------- epochs: 5 ----------
---- Ran row_normalize in 0.06 s ----

---- Ran row_normalize in 0.04 s ----

---------- epochs: 6 ----------
---- Ran row_normalize in 0.08 s ----

---- Ran row_normalize in 0.03 s ----

---- Ran row_normalize in 0.07 s ----

---- Ran row_normalize in 0.02 s ----

---------- Combined epochs: 0 ----------
---------- Combined epochs: 1 ----------
---------- Combined epochs: 2 ----------
---------- Combined epochs: 3 ----------
---------- Combined epochs: 4 ----------
---------- Combined epochs: 5 ----------
---------- Combined epochs: 6 ----------
---------- Combined epochs: 7 ----------
-----

In [24]:
import psutil
print(u'当前进程的内存使用：%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )

info = psutil.virtual_memory()
print( u'电脑总内存：%.4f GB' % (info.total / 1024 / 1024 / 1024) )
print(u'当前使用的总内存占比：',info.percent)
print(u'cpu个数：',psutil.cpu_count())

当前进程的内存使用：5.0918 GB
电脑总内存：15.7884 GB
当前使用的总内存占比： 65.8
cpu个数： 8
