In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr

##############################################################################
# 1. 读取三物种数据，并添加必要的 obs 信息
##############################################################################
path_m = "../../M-MG/1.subset/M-MG_cleaned.h5ad"
path_r = "../../R-MG/1.subset/R-MG_cleaned.h5ad"
path_s = "../../S-MG/1.subset/S-MG_cleaned.h5ad"

adata_m = sc.read_h5ad(path_m)
adata_r = sc.read_h5ad(path_r)
adata_s = sc.read_h5ad(path_s)

# 确保每个数据集都有 species 和 stage 信息
adata_m.obs['species'] = "M"
adata_r.obs['species'] = "R"
adata_s.obs['species'] = "S"
# 假设每个数据集的 obs['stage'] 已经包含 'stage1','stage2','stage3','stage4'

# 将三个数据集放入字典，方便后续调用
adata_dict = {"M": adata_m, "R": adata_r, "S": adata_s}

##############################################################################
# 2. 定义 DEG 导出函数（仅保留 p < 1e-5 的基因）
##############################################################################
def export_deg_result(adata_sub, pval_threshold=1e-5):
    result = adata_sub.uns['rank_genes_groups']
    groups = result['names'].dtype.names
    data = []
    for group in groups:
        genes = result['names'][group]
        logfoldchanges = result['logfoldchanges'][group]
        pvals = result['pvals'][group]
        pvals_adj = result['pvals_adj'][group]
        for gene, lfc, pval, pval_adj in zip(genes, logfoldchanges, pvals, pvals_adj):
            data.append([group, gene, lfc, pval, pval_adj])
    df = pd.DataFrame(data, columns=['group','gene','logfoldchange','pval','pval_adj'])
    df_filtered = df[df['pval'] < pval_threshold]
    return df_filtered

##############################################################################
# 3. 差异表达分析（DEG）：对每个物种对，先取各自 HVGs 的交集，再进行分析
##############################################################################
species_pairs = [("M", "R"), ("M", "S"), ("R", "S")]
stages = ["stage1", "stage2", "stage3", "stage4"]

deg_summary = []

for sp1, sp2 in species_pairs:
    # 分别获取两个物种各自被标记为 HVG 的基因
    hvgs_sp1 = adata_dict[sp1].var_names[adata_dict[sp1].var['highly_variable']]
    hvgs_sp2 = adata_dict[sp2].var_names[adata_dict[sp2].var['highly_variable']]
    
    # 取两者交集
    common_hvgs_pair = list(set(hvgs_sp1).intersection(set(hvgs_sp2)))
    print(f"{sp1} vs {sp2} 共有 HVGs 数量: {len(common_hvgs_pair)}")
    
    # 针对当前 pair，从原始数据中只保留共同 HVGs
    a1 = adata_dict[sp1][:, common_hvgs_pair].copy()
    a2 = adata_dict[sp2][:, common_hvgs_pair].copy()
    # 可选：更新 var['highly_variable']（表明这部分基因均为 HVG）
    a1.var['highly_variable'] = a1.var_names.isin(common_hvgs_pair)
    a2.var['highly_variable'] = a2.var_names.isin(common_hvgs_pair)
    
    # 合并两个数据集，利用 concatenate 保证批次区分
    adata_pair = a1.concatenate(
        a2,
        batch_key="species_batch",
        batch_categories=[sp1, sp2],
        join='inner'
    )
    
    # 针对每个 stage 进行差异表达分析
    for st in stages:
        mask = (adata_pair.obs['stage'] == st)
        adata_sub = adata_pair[mask].copy()
        
        if adata_sub.n_obs < 2:
            print(f"[警告] {sp1} vs {sp2}, {st}: 样本过少，跳过。")
            continue
        
        groups_in_data = adata_sub.obs['species_batch'].unique().tolist()
        if len(groups_in_data) < 2:
            print(f"[警告] {sp1} vs {sp2}, {st}: 只有一个物种数据，跳过。")
            continue
        
        # 使用 Wilcoxon 方法进行差异表达分析
        sc.tl.rank_genes_groups(
            adata_sub,
            groupby='species_batch',
            method='wilcoxon',
            use_raw=False
        )
        
        # 导出 p < 1e-5 的 DEG 结果
        deg_df = export_deg_result(adata_sub, pval_threshold=1e-5)
        out_csv = f"DEG_{sp1}_vs_{sp2}_{st}_pairHVGs_p1e-5.csv"
        deg_df.to_csv(out_csv, index=False)
        
        deg_count = deg_df.shape[0]
        deg_summary.append({
            'comparison': f"{sp1}_vs_{sp2}_{st}",
            'deg_count': deg_count
        })

# 绘制 DEG 结果柱状图
deg_summary_df = pd.DataFrame(deg_summary)
plt.figure(figsize=(8,5))
plt.bar(deg_summary_df['comparison'], deg_summary_df['deg_count'], color='lightblue')
plt.xticks(rotation=45, ha='right')
plt.xlabel("Species Pair & Stage")
plt.ylabel("Number of DEGs (p < 1e-5, pair HVGs)")
plt.title("DEG Counts (Pairwise HVGs)")
plt.tight_layout()
plt.savefig("DEG_barplot_pairHVGs_p1e-5.png", dpi=300)
plt.show()

##############################################################################
# 4. Spearman 相关性分析：对每个物种对，使用当前 pair 的共同 HVGs 计算每个 stage 的相关性
##############################################################################
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12,4), sharey=True)

for idx, (sp1, sp2) in enumerate(species_pairs):
    ax = axes[idx]
    # 获取当前 pair 的共同 HVGs
    hvgs_sp1 = adata_dict[sp1].var_names[adata_dict[sp1].var['highly_variable']]
    hvgs_sp2 = adata_dict[sp2].var_names[adata_dict[sp2].var['highly_variable']]
    common_hvgs_pair = list(set(hvgs_sp1).intersection(set(hvgs_sp2)))
    print(f"{sp1} vs {sp2} 共有 HVGs 数量（相关性分析）: {len(common_hvgs_pair)}")
    
    # 针对当前 pair，切割数据
    a1 = adata_dict[sp1][:, common_hvgs_pair].copy()
    a2 = adata_dict[sp2][:, common_hvgs_pair].copy()
    # 合并数据
    adata_pair = a1.concatenate(
        a2,
        batch_key="species_batch",
        batch_categories=[sp1, sp2],
        join='inner'
    )
    
    corr_vals = []
    for st in stages:
        mask_sp1 = (adata_pair.obs['species_batch'] == sp1) & (adata_pair.obs['stage'] == st)
        mask_sp2 = (adata_pair.obs['species_batch'] == sp2) & (adata_pair.obs['stage'] == st)
        adata_sp1 = adata_pair[mask_sp1]
        adata_sp2 = adata_pair[mask_sp2]
        
        if adata_sp1.n_obs == 0 or adata_sp2.n_obs == 0:
            corr_vals.append(np.nan)
        else:
            mean_sp1 = np.asarray(adata_sp1.X.mean(axis=0)).flatten()
            mean_sp2 = np.asarray(adata_sp2.X.mean(axis=0)).flatten()
            mean_sp1 = np.nan_to_num(mean_sp1, nan=np.nanmean(mean_sp1))
            mean_sp2 = np.nan_to_num(mean_sp2, nan=np.nanmean(mean_sp2))
            corr_val, _ = spearmanr(mean_sp1, mean_sp2)
            corr_vals.append(corr_val)
    
    # 构造 1×4 的 DataFrame，列对应 4 个 stage
    corr_df = pd.DataFrame([corr_vals], columns=stages)
    
    # 绘制“长条” heatmap（1 行 4 列）
    sns.heatmap(corr_df, cmap='RdBu_r', annot=True, fmt=".2f", vmin=-1, vmax=1,
                ax=ax, cbar=(idx == 2))
    ax.set_xticklabels(stages, rotation=45, ha='center')
    ax.set_yticks([])  # 1 行的 heatmap 可隐藏 y 轴
    ax.set_title(f"{sp1} vs {sp2}", fontsize=12)

fig.suptitle("Spearman Correlation (Pairwise HVGs) Across 4 Stages", fontsize=14)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("Spearman_correlation_pairHVGs.png", dpi=300)
plt.show()
