## environment

In [2]:
# Loading the Packages
%reload_ext autoreload
%autoreload 2

import os
from pathlib import Path
import pickle
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import scanpy as sc
from scipy.signal import argrelextrema
from scipy.signal import find_peaks

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams.update({
    "pgf.texsystem": "xelatex",      # 使用 XeLaTeX，如果不需要 LaTeX 公式渲染，可以省略
    'font.family': 'serif',          # 字体设置为衬线字体
    'text.usetex': False,            # 禁用 LaTeX，使用 Matplotlib 内置文字渲染
    'pgf.rcfonts': False,            # 禁用 pgf 的默认字体管理
    'pdf.fonttype': 42,              # 确保字体为 TrueType 格式，可被 Illustrator 编辑
    'ps.fonttype': 42,               # EPS 文件也使用 TrueType 格式
    'figure.dpi': 300,               # 设置图形分辨率
    'savefig.dpi': 300,              # 保存的图形文件分辨率
    'axes.unicode_minus': False,     # 避免负号问题
})

In [None]:
# workdir 
BASE_DIR = Path(r'F:\spatial_data\processed')
RUN_ID = '20230523_HCC_PRISM_probe_refined'
src_dir = BASE_DIR / f'{RUN_ID}_processed'

# Load one slide exp
base_path = BASE_DIR / f'{RUN_ID}_processed'
data_path = base_path / "segmented"
typ_path = base_path / "cell_typing"
output_path = base_path / "interaction_expression"
output_path.mkdir(exist_ok=True)

## exp-dis-corr

### load data

In [4]:
# exp and cell type information
combine_adata = sc.read_h5ad(typ_path/'combine_adata_st.h5ad')
adata_direct = sc.read_h5ad(typ_path/'adata_leiden_res_1.h5ad')
adata = adata_direct[adata_direct.obs.index.isin(combine_adata.obs.index)]
adata.obs = combine_adata.obs

# exp batch information
adata.obs['batch'] = 'PRISM31_HCC'
adata.obs['batch'] = adata.obs['batch'].astype('category')

# spatial information
adata.obs = adata.obs.rename(columns={'X_pos':'X', 'Y_pos':'Y'})
adata.obsm['spatial'] = adata.obs.loc[:, ['X', 'Y']].values
print(adata)
adata.obs.head()

AnnData object with n_obs × n_vars = 80396 × 31
    obs: 'dataset', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'n_genes', 'n_counts', 'type', 'leiden', 'sample', 'tissue', 'tmp_leiden', 'leiden_subtype', 'subtype', 'leiden_type', 'Y', 'X', 'batch'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'mean', 'std'
    uns: 'leiden', 'log1p', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap', 'spatial'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'


Unnamed: 0,dataset,n_genes_by_counts,log1p_n_genes_by_counts,total_counts,log1p_total_counts,n_genes,n_counts,type,leiden,sample,tissue,tmp_leiden,leiden_subtype,subtype,leiden_type,Y,X,batch
0,PRISM_HCC,9,2.302585,22.0,3.135494,9,22.0,Macrophage,0,,non_liver,0,18,Macrophage_LYVE1+,11,1109,21214,PRISM31_HCC
2,PRISM_HCC,10,2.397895,16.0,2.833213,10,16.0,Mast,13,,non_liver,13,12,Mast_CPA3+,8,1156,21143,PRISM31_HCC
3,PRISM_HCC,6,1.94591,8.0,2.197225,6,8.0,other,73,,non_liver,-2,-2,other,-2,1157,11785,PRISM31_HCC
4,PRISM_HCC,6,1.94591,7.0,2.079442,6,7.0,CD4+,28,,non_liver,28,23,"T_CD4+, PD1+, CTLA4+",12,1164,21132,PRISM31_HCC
5,PRISM_HCC,13,2.639057,41.0,3.73767,13,41.0,CD8+,45,,non_liver,45,28,"T_CD8+, GZMA+, CXCL13+",13,1177,11828,PRISM31_HCC


In [None]:
genes = adata.var_names
adata = adata[adata.obs.type != 'other']

### functions

In [5]:
def get_expression(adata: ad.AnnData, key: str) -> np.ndarray:
    """
    Retrieves expression values for a given gene or observation annotation from an AnnData object.

    Args:
        adata: An AnnData object containing expression data.
        key: The name of the gene or observation annotation to retrieve.

    Returns:
        A NumPy array containing the expression values.

    Raises:
        ValueError: If the key is not found in either the var_names or obs columns of the AnnData object.
    """

    if key in adata.var_names:
        return np.array(adata[:, key].X.flatten())
    elif key in adata.obs.columns:
        return np.array(adata.obs[key])
    else:
        raise ValueError(f"{key} not found in object")

In [6]:
from scipy.spatial import distance


def get_closest_cell(adata: ad.AnnData, obs_column: str, subtype_1: str, subtype_2: str) -> np.ndarray:
    """
    Finds the closest cell of a specific subtype to each cell of another subtype.

    Args:
        adata: An AnnData object containing spatial coordinates and subtype annotations.
        subtype_1: The first subtype to consider.
        subtype_2: The second subtype to consider.

    Returns:
        A NumPy array containing the minimum distance to the closest cell in the second subtype for each cell in the first subtype.

    Raises:
        ValueError: If either subtype is not found in the adata object.
    """

    if subtype_1 not in adata.obs[obs_column].unique():
        raise ValueError(f"{obs_column} {subtype_1} not found in adata")
    if subtype_2 not in adata.obs[obs_column].unique():
        raise ValueError(f"{obs_column} {subtype_2} not found in adata")

    locations_1 = adata[adata.obs[obs_column] == subtype_1].obsm["spatial"]
    locations_2 = adata[adata.obs[obs_column] == subtype_2].obsm["spatial"]

    distances_subtype = distance.cdist(locations_1, locations_2).min(axis=1)
    return distances_subtype

In [7]:
from scipy import stats


def correlation_between_distance_and_expression(
    adata: ad.AnnData, subtype: str, obs_column: str, key: str, method: str = "spearman"
) -> pd.DataFrame:
    """
    Calculates correlation between expression of a given gene/annotation
    and distance to cells of other subtypes for a specific subtype.

    Args:
        adata: An AnnData object containing spatial coordinates, subtype annotations, and expression data.
        subtype: The subtype to focus on for expression and distance calculations.
        key: The name of the gene or observation annotation to retrieve expression values for.
        method: The correlation method to use, either "pearson" or "spearman" (default).

    Returns:
        A pandas DataFrame with columns 'subtype_1', 'subtype_2', 'pvalue', and 'correlation',
        representing the subtype pairs, p-values, and correlation coefficients.
    Raises:
        ValueError: If either subtype is not found in the adata object or if an invalid method is specified.
    """

    if subtype not in adata.obs[obs_column].unique():
        raise ValueError(f"{obs_column} {subtype} not found in adata")

    allowed_methods = ["pearson", "spearman"]
    if method not in allowed_methods:
        raise ValueError(
            f"Invalid correlation method: {method}. Allowed methods are: {', '.join(allowed_methods)}"
        )

    results = []
    for subtype_2 in adata.obs[obs_column].unique():
        distances = get_closest_cell(adata, obs_column=obs_column, subtype_1=subtype, subtype_2=subtype_2)
        expression = get_expression(
            adata[adata.obs[obs_column] == subtype], key=key
        )

        if method == "pearson":
            corr, pval = stats.pearsonr(distances, expression)
        else:
            corr, pval = stats.spearmanr(distances, expression)

        results.append(
            {
                "subtype_1": subtype,
                "subtype_2": subtype_2,
                "pvalue": pval,
                "correlation": corr,
            }
        )

    return pd.DataFrame(results)

In [8]:
def get_batchwise_correlation_between_distance_and_expression(
    adata: ad.AnnData, subtype: str, obs_column: str, key: str, method: str = "spearman"
) -> pd.DataFrame:
    """
    Calculates correlation between distance and expression for a specific subtype across batches,
    combining results into a single DataFrame.

    Args:
        adata: An AnnData object containing spatial coordinates, subtype annotations, expression data, and batch information.
        subtype: The subtype to focus on for expression and distance calculations.
        key: The name of the gene or observation annotation to retrieve expression values for.

    Returns:
        A pandas DataFrame containing correlation results for all batches,
        with columns 'subtype_1', 'subtype_2', 'pvalue', 'correlation', and 'batch'.
    """

    results = []
    for b in adata.obs["batch"].cat.categories:
        adata_batch = adata[adata.obs["batch"] == b]
        df = correlation_between_distance_and_expression(
            adata_batch, subtype=subtype, obs_column=obs_column, key=key, method=method
        )
        df["batch"] = b
        results.append(df)

    df = pd.concat(results, ignore_index=True)
    df["batch"] = pd.Categorical(
        df["batch"], categories=adata.obs["batch"].cat.categories
    )
    return df

In [9]:
from scipy.cluster import hierarchy

def get_order(x):
    link = hierarchy.linkage(x)
    idx = hierarchy.leaves_list(hierarchy.optimal_leaf_ordering(link, x))
    return idx

### plot all gene types

In [None]:
cur_path = output_path / "type"
cur_path.mkdir(exist_ok=True)

df_path = cur_path / 'df'
df_path.mkdir(exist_ok=True)
fig_path = cur_path / 'heatmap_border_fix'
fig_path.mkdir(exist_ok=True)

for celltype in tqdm(adata.obs.type.unique()):
    # df = []
    # for gene in genes:
    #     df_gene = get_batchwise_correlation_between_distance_and_expression(
    #         adata=adata, subtype=celltype, obs_column='type', key=gene, method='spearman')
    #     df_gene["gene"] = gene
    #     df.append(df_gene)
    # df = pd.concat(df)
    # df.to_csv(df_path / f"{celltype}_correlation.csv", index=False)

    df = pd.read_csv(df_path / f"{celltype}_correlation.csv")
    df = df[~(df["subtype_1"] == df["subtype_2"])]
    mat = df.groupby(by=["gene", "subtype_2"])["correlation"].mean().reset_index()
    mat = mat.pivot(index="subtype_2", columns="gene", values="correlation")
    fig, ax = plt.subplots(figsize=(10,5))
    sns.heatmap(mat.iloc[get_order(mat), get_order(mat.T)], 
                cmap="coolwarm_r", vmax=0.4, vmin=-0.4, annot=False, linewidths=0.5, ax=ax)
    plt.tight_layout()
    plt.title(f"{celltype}")
    plt.savefig(fig_path / f"{celltype}_correlation.png")
    plt.close()

In [None]:
cur_path = output_path / "subtype"
cur_path.mkdir(exist_ok=True)

df_path = cur_path / 'df'
df_path.mkdir(exist_ok=True)
fig_path = cur_path / 'heatmap_border_fix'
fig_path.mkdir(exist_ok=True)

for subtype in tqdm(adata.obs.subtype.unique()):
    # df = []
    # for gene in genes:
    #     df_gene = get_batchwise_correlation_between_distance_and_expression(
    #         adata=adata, subtype=subtype, obs_column='subtype', key=gene, method='spearman')
    #     df_gene["gene"] = gene
    #     df.append(df_gene)
    # df = pd.concat(df)
    # # print(df.head())
    # df.to_csv(df_path / f"{subtype}_correlation.csv", index=False)
    df = pd.read_csv(df_path / f"{subtype}_correlation.csv")

    df = df[~(df["subtype_1"] == df["subtype_2"])]
    mat = df.groupby(by=["gene", "subtype_2"])["correlation"].mean().reset_index()
    mat = mat.pivot(index="subtype_2", columns="gene", values="correlation")
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(mat.iloc[get_order(mat), get_order(mat.T)], 
                cmap="coolwarm_r", vmax=0.4, vmin=-0.4, annot=False, linewidths=0.5, ax=ax)
    plt.tight_layout()
    plt.title(f"{subtype}")
    plt.savefig(fig_path / f"{subtype}_correlation.png")
    plt.close()

### plot marker gene of each cell type

In [None]:
import yaml

annotation_params = yaml.safe_load(open(typ_path/"annotation_params.yaml", "r"))
marker_gene_dict = annotation_params["marker_gene_dict"]
for celltype in marker_gene_dict:
    for subtype in marker_gene_dict[celltype]:
        df_subtype = pd.read_csv(output_path / "subtype" / "df" / f"{subtype}_correlation.csv")
        marker_gene_list = marker_gene_dict[celltype][subtype]
        mat = df_subtype.groupby(by=["gene", "subtype_2"])["correlation"].mean().reset_index()
        mat = mat.pivot(index="subtype_2", columns="gene", values="correlation")
        mat = mat[marker_gene_list]
        
        fig, ax = plt.subplots(figsize=(len(marker_gene_list)+3, 16))
        plt.subplots_adjust(left=0.4, right=0.9, top=0.9, bottom=0.1)
        try: order1 = get_order(mat)
        except: order1 = list(range(mat.shape[0]))
        try: order2 = get_order(mat.T)
        except: order2 = list(range(mat.shape[1]))
        if mat.shape[1] == 1:
        # order1 按照从大到小排序
            order1 = np.argsort(mat.iloc[:, 0].values)[::-1]
            
        sns.heatmap(mat.iloc[order1, order2], 
                    cmap="coolwarm_r", vmax=0.4, vmin=-0.4, annot=False, linewidths=0.5, ax=ax)
        # plt.tight_layout()
        plt.title(f"{subtype}")
        plt.savefig(output_path / "subtype" / f"marker_gene_{subtype}_correlation.png")
        plt.close()

    gene_list = [subtype_list for subtype_list in marker_gene_dict[celltype].values()]
    gene_list = [gene for sublist in gene_list for gene in sublist]
    df_celltype = pd.read_csv(output_path / "type" / "df" / f"{celltype}_correlation.csv")
    mat = df_celltype.groupby(by=["gene", "subtype_2"])["correlation"].mean().reset_index()
    mat = mat.pivot(index="subtype_2", columns="gene", values="correlation")
    mat = mat[gene_list]

    fig, ax = plt.subplots(figsize=(len(gene_list)+3, 16))
    plt.subplots_adjust(left=0.4, right=0.9, top=0.9, bottom=0.1)
    try: order1 = get_order(mat)
    except: order1 = list(range(mat.shape[0]))
    try: order2 = get_order(mat.T)
    except: order2 = list(range(mat.shape[1]))
    if mat.shape[1] == 1:
        # order1 按照从大到小排序
        order1 = np.argsort(mat.iloc[:, 0].values)[::-1]

    sns.heatmap(mat.iloc[order1, order2], 
            cmap="coolwarm_r", vmax=0.4, vmin=-0.4, annot=False, linewidths=0.5, ax=ax)
    # plt.tight_layout()
    plt.title(f"{celltype}")
    plt.savefig(output_path / "type" / f"marker_gene_{celltype}_correlation.png")
    plt.close()

## exp-neighbor-corr

### load adata

In [10]:
adata_STAGATE = sc.read_h5ad(src_dir / 'STAGATE' / 'rad_cutoff_250' / 'adata_STAGATE.h5ad')
adata_STAGATE = adata_STAGATE[adata_STAGATE.obs.type != 'other']
adata_STAGATE

View of AnnData object with n_obs × n_vars = 60329 × 31
    obs: 'dataset', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'n_genes', 'n_counts', 'type', 'leiden', 'sample', 'tissue', 'tmp_leiden', 'leiden_subtype', 'subtype', 'leiden_type', 'Y', 'X', 'mclust_2', 'mclust_3', 'mclust_4', 'mclust_5', 'mclust_6', 'mclust_7', 'mclust_8', 'mclust_9', 'mclust_10', 'mclust_11', 'mclust_12', 'mclust_13', 'mclust_14', 'mclust_15', 'mclust_16', 'mclust_17', 'mclust_18', 'mclust_19'
    var: 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'mean', 'std'
    uns: 'Spatial_Net', 'leiden', 'log1p', 'neighbors', 'pca', 'umap'
    obsm: 'STAGATE', 'X_pca', 'X_umap', 'spatial'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'

In [11]:
Spatial_Net = adata_STAGATE.uns['Spatial_Net']
Spatial_Net = Spatial_Net[Spatial_Net['Cell1'].isin(adata_STAGATE.obs.index) & Spatial_Net['Cell2'].isin(adata_STAGATE.obs.index)]
adata_STAGATE.uns['Spatial_Net'] = Spatial_Net
adata_STAGATE.uns['Spatial_Net']

Unnamed: 0,Cell1,Cell2,Distance
0,0,4,98.737024
1,0,2,85.146932
2,0,29,204.787695
4,0,15,135.003704
5,0,27,225.002222
...,...,...,...
11,88845,88800,209.401528
12,88845,88811,192.338244
13,88845,88829,148.556387
16,88845,88827,86.267027


### function

In [12]:
def get_neighbor_type_counts(adata, target_type, obs_column='type'):
    """
    提取目标细胞的邻居细胞类型统计。

    参数:
        adata: AnnData 对象
        target_type: 目标细胞类型
        obs_column: 存储细胞类型的列名

    返回:
        neighbor_type_counts: DataFrame，每行是一个目标细胞，每列是一种细胞类型的邻居数目
    """
    # 提取目标细胞的索引
    target_cell_indices = adata.obs[adata.obs[obs_column] == target_type].index

    # 提取 Spatial_Net 中与目标细胞相关的邻居关系
    neighbor_relations = adata.uns['Spatial_Net'][adata.uns['Spatial_Net']['Cell1'].isin(target_cell_indices)]
    # print(target_type)
    # print(len(target_cell_indices))
    # print(len(neighbor_relations))
    # 统计每个目标细胞的邻居细胞类型
    neighbor_type_counts = (
        neighbor_relations.groupby('Cell1', group_keys=True)['Cell2']
        .apply(lambda x: adata.obs.loc[x, obs_column].value_counts())
        .unstack(fill_value=0)
    )

    # 确保所有目标细胞都在 neighbor_type_counts 中
    neighbor_type_counts = neighbor_type_counts.reindex(target_cell_indices, fill_value=0)

    return neighbor_type_counts

In [13]:
def get_expression_matrix(adata, target_cell_indices):
    """
    提取目标细胞的基因表达矩阵。

    参数:
        adata: AnnData 对象
        target_cell_indices: 目标细胞的索引

    返回:
        expression_matrix: numpy 数组，每行是一个细胞，每列是一个基因的表达量
    """
    return adata[target_cell_indices, :].X.toarray()

In [14]:
def calculate_correlations(expression_matrix, neighbor_type_counts, adata, method='pearson'):
    """
    计算基因表达量与邻居细胞类型数目的相关性。

    参数:
        expression_matrix: numpy 数组，目标细胞的基因表达矩阵
        neighbor_type_counts: DataFrame，目标细胞的邻居细胞类型统计
        adata: AnnData 对象
        method: 相关性计算方法，可选 'pearson', 'spearman', 'kendall', 'pointbiserial', 'distance'

    返回:
        correlation_matrix: DataFrame，每行是一个基因，每列是一种细胞类型的相关性
    """
    from scipy.stats import pearsonr, spearmanr, kendalltau, pointbiserialr
    # from dcor import distance_correlation

    # 初始化相关性矩阵
    correlation_matrix = pd.DataFrame(index=adata.var_names, columns=neighbor_type_counts.columns)

    # 计算相关性
    for gene_idx, gene in enumerate(adata.var_names):
        for cell_type in neighbor_type_counts.columns:
            x = expression_matrix[:, gene_idx]
            y = neighbor_type_counts[cell_type]

            if method == 'pearson':
                correlation, _ = pearsonr(x, y)
            elif method == 'spearman':
                correlation, _ = spearmanr(x, y)
            elif method == 'kendall':
                correlation, _ = kendalltau(x, y)
            elif method == 'pointbiserial':
                correlation, _ = pointbiserialr(x, y)
            # elif method == 'distance':
            #     correlation = distance_correlation(x, y)
            else:
                raise ValueError(f"Unsupported correlation method: {method}")

            correlation_matrix.loc[gene, cell_type] = correlation

    return correlation_matrix

In [15]:
def plot_correlation_heatmap(correlation_matrix, target_type):
    """
    绘制基因表达量与邻居细胞类型数目的相关性热图。

    参数:
        correlation_matrix: DataFrame，相关性矩阵
        target_type: 目标细胞类型
    """
    import seaborn as sns
    import matplotlib.pyplot as plt

    plt.figure(figsize=(15, 10))
    sns.heatmap(correlation_matrix.astype(float), cmap='coolwarm', center=0)
    plt.title(f'Correlation between gene expression and neighbor cell type counts for {target_type}')
    plt.xlabel('Cell Types')
    plt.ylabel('Genes')
    plt.show()

In [16]:
def plot_correlation_heatmap(correlation_matrix, target_type):
    """
    绘制基因表达量与邻居细胞类型数目的相关性热图。

    参数:
        correlation_matrix: DataFrame，相关性矩阵
        target_type: 目标细胞类型
    """
    

### plot all gene types

In [18]:
target_type = 'CD8+'
obs_column = 'type'

for obs_column in ['type', 'subtype']:
    current_path = output_path / f"{obs_column}_neighbor"
    current_path.mkdir(exist_ok=True)
    for target_type in adata_STAGATE.obs[obs_column].unique():
        # 提取目标细胞的邻居细胞类型统计
        neighbor_type_counts = get_neighbor_type_counts(adata_STAGATE, target_type, obs_column)

        # 提取目标细胞的基因表达矩阵
        target_cell_indices = adata_STAGATE.obs[adata_STAGATE.obs[obs_column] == target_type].index
        expression_matrix = get_expression_matrix(adata_STAGATE, target_cell_indices)

        # 计算基因表达量与邻居细胞类型数目的相关性
        mat = calculate_correlations(expression_matrix, neighbor_type_counts, adata_STAGATE, method='spearman')
        mat = mat.astype(float)
        mat = mat.T
        try: order1 = get_order(mat)
        except: order1 = list(range(mat.shape[0]))
        try: order2 = get_order(mat.T)
        except: order2 = list(range(mat.shape[1]))
        if mat.shape[1] == 1:
        # order1 按照从大到小排序
            order1 = np.argsort(mat.iloc[:, 0].values)[::-1]
        if obs_column == 'type':
            plt.figure(figsize=(15, 8))
        else:
            plt.figure(figsize=(15, 15))
        sns.heatmap(mat.iloc[order1, order2], cmap="coolwarm_r", vmax=0.4, vmin=-0.4, annot=False, linewidths=0.5)
        plt.title(f'{target_type}')
        plt.xlabel('Cell Types')
        plt.ylabel('Genes')
        plt.savefig(current_path / f"{target_type}.png")
        plt.close()
        # plt.show()

### polot marker gene of each cell type