In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from itertools import combinations

# --- 0. 参数设置 ---
datasetlist = ['M-MG', 'R-MG', 'S-MG', 'R-AG', 'S-AG', 'R-CG']
input_dir = "D:/111/"
input_file_template = "{}_cleaned.h5ad"
output_base_dir = "D:/111/DEG_Analysis_Output/"
heatmap_output_dir = os.path.join(output_base_dir, "Individual_Heatmaps")
intersection_output_dir = os.path.join(output_base_dir, "Stage_Intersections")
os.makedirs(heatmap_output_dir, exist_ok=True)
os.makedirs(intersection_output_dir, exist_ok=True)

# DEG 参数
group_key = 'phase'
method    = 'wilcoxon'
top_n     = 20
pval_cutoff = 0.01
lfc_cutoff = 0.25
stage_to_phase={
    'stage0':'1-juvenile',
    'stage1':'1-juvenile',
    'stage2':'2-adult',
    'stage3':'3-ges-la',
    'stage4':'3-ges-la',  
}
all_samples_top_genes = {}
all_samples_notop_genes= {}
print("===== Stage 1: Processing samples =====")
for ds in datasetlist:
    print(f"--- Sample: {ds} ---")
    path = os.path.join(input_dir, input_file_template.format(ds))
    if not os.path.exists(path):
        print(f"File not found: {path}, skipping.")
        all_samples_top_genes[ds] = {}
        all_samples_notop_genes[ds]= {}
        continue
    adata = sc.read_h5ad(path)
    adata.obs['phase'] = adata.obs['stage'].map(stage_to_phase)
    if 'normalized' not in adata.layers:
        adata.layers['normalized'] = adata.X.copy()
    if group_key not in adata.obs.columns:
        all_samples_top_genes[ds] = {}
        continue
    adata.obs[group_key] = adata.obs[group_key].astype('category')
    stages = adata.obs[group_key].cat.categories.tolist()
    # Run DEG
    key = f"rank_{group_key}"
    sc.tl.rank_genes_groups(adata, groupby=group_key, method=method, pts=True, key_added=key)
    df1 = sc.get.rank_genes_groups_df(adata, group=stages,key = key)
    df1 = df1[(df1['pvals_adj'] < pval_cutoff) & (df1['logfoldchanges'] > lfc_cutoff)]
    out_csv = os.path.join(intersection_output_dir, f'{ds}_rank_genes_{group_key}.csv')
    df1.to_csv(out_csv, index=False)
    # Extract top genes for all stages
    stages = adata.obs[group_key].cat.categories.tolist()
    sample_genes = {}
    sample_genes_notop = {}
    for st in stages:
        df = sc.get.rank_genes_groups_df(adata, group=st, key=key)
        df = df[(df['pvals_adj'] < pval_cutoff) & (df['logfoldchanges'] > lfc_cutoff)]
        genes = df.sort_values('scores', ascending=False)['names'].head(top_n).tolist()
        genes_notop = df.sort_values('scores', ascending=False)['names'].tolist()
        sample_genes[st] = genes
        sample_genes_notop[st] = genes_notop
    all_samples_top_genes[ds] = sample_genes
    all_samples_notop_genes[ds] = sample_genes_notop

    # Heatmap for all stages
    print(f"Generating heatmap for sample {ds} ...")
    genes_union = pd.unique(np.concatenate(list(sample_genes.values())))
    if len(genes_union) == 0:
        print("No genes, skipping heatmap.")
    else:
        expr_df = pd.DataFrame(index=genes_union, columns=stages, dtype=float)
        idx_map = {g:i for i,g in enumerate(adata.var_names)}
        for st in stages:
            cells = adata.obs_names[adata.obs[group_key]==st]
            mat = adata[cells].layers['normalized']
            mean_vec = np.asarray(mat.mean(axis=0)).ravel()
            for g in genes_union:
                if g in idx_map:
                    expr_df.at[g, st] = mean_vec[idx_map[g]]
        expr_df = expr_df.fillna(0)
        # z-score by gene
        mu = expr_df.mean(axis=1)
        sigma = expr_df.std(axis=1)
        sigma[sigma==0] = 1
        expr_z = expr_df.sub(mu, axis=0).div(sigma, axis=0)
        # plot
        h = max(5, len(expr_z)*0.25+1)
        w = max(4, len(stages)*1.0+1)
        fig, ax = plt.subplots(figsize=(w, h))
        sns.heatmap(expr_z, cmap='RdBu_r', center=0, linewidths=0.5, ax=ax)
        ax.set_xlabel('Stage')
        ax.set_ylabel('Gene')
        ax.set_title(f'{ds} Marker Genes Heatmap')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        out_png = os.path.join(heatmap_output_dir, f"{ds}_heatmap.png")
        plt.savefig(out_png, dpi=300)
        plt.close(fig)

print("\n===== Stage 2: Calculating intersections =====")
def intersect(lists):
    sets = [set(x) for x in lists if x]
    if not sets: return []
    return sorted(set.intersection(*sets))

# Stages to consider from any sample
all_stages = sorted({st for genes in all_samples_top_genes.values() for st in genes.keys()})
for st in all_stages:
    rows = []
    # collect valid lists
    valid = {ds: all_samples_top_genes[ds][st] for ds in datasetlist if st in all_samples_top_genes[ds] and all_samples_top_genes[ds][st]}
    if len(valid) < 2:
        print(f"Stage {st}: fewer than 2 valid datasets, skipping.")
        continue
    # pairwise
    for a,b in combinations(valid.keys(), 2):
        inter = intersect([valid[a], valid[b]])
        for gene in inter:
            rows.append({'Combination': f'{a}&{b}', 'Gene': gene})
    # triplet for MG group
    mg = ['M-MG','R-MG','S-MG']
    if all(x in valid for x in mg):
        inter = intersect([valid[x] for x in mg])
        for gene in inter:
            rows.append({'Combination': 'M-MG&R-MG&S-MG', 'Gene': gene})
    # all valid
    inter_all = intersect(list(valid.values()))
    for gene in inter_all:
        rows.append({'Combination': 'ALL', 'Gene': gene})
    # save one CSV per stage
    df_out = pd.DataFrame(rows)
    out_csv = os.path.join(intersection_output_dir, f"intersection_stage_{st}.csv")
    df_out.to_csv(out_csv, index=False)
    print(f"Saved intersections for stage {st} to {out_csv}")

print("Script completed.")


===== Stage 1: Processing samples =====
--- Sample: M-MG ---
Generating heatmap for sample M-MG ...
--- Sample: R-MG ---
Generating heatmap for sample R-MG ...
--- Sample: S-MG ---
Generating heatmap for sample S-MG ...
--- Sample: R-AG ---
Generating heatmap for sample R-AG ...
--- Sample: S-AG ---
Generating heatmap for sample S-AG ...
--- Sample: R-CG ---
Generating heatmap for sample R-CG ...

===== Stage 2: Calculating intersections =====
Saved intersections for stage 1-juvenile to D:/111/DEG_Analysis_Output/Stage_Intersections\intersection_stage_1-juvenile.csv
Saved intersections for stage 2-adult to D:/111/DEG_Analysis_Output/Stage_Intersections\intersection_stage_2-adult.csv
Saved intersections for stage 3-ges-la to D:/111/DEG_Analysis_Output/Stage_Intersections\intersection_stage_3-ges-la.csv
Script completed.


In [6]:
key

'rank_phase'

In [2]:
print("\n===== Stage 3: Calculating intersections for DEG result =====")
def intersect(lists):
    sets = [set(x) for x in lists if x]
    if not sets: return []
    return sorted(set.intersection(*sets))

# Stages to consider from any sample
all_stages = sorted({st for genes in all_samples_notop_genes.values() for st in genes.keys()})
for st in all_stages:
    rows = []
    # collect valid lists
    valid = {ds: all_samples_notop_genes[ds][st] for ds in datasetlist if st in all_samples_notop_genes[ds] and all_samples_notop_genes[ds][st]}
    if len(valid) < 2:
        print(f"Stage {st}: fewer than 2 valid datasets, skipping.")
        continue
    # pairwise
    for a,b in combinations(valid.keys(), 2):
        inter = intersect([valid[a], valid[b]])
        for gene in inter:
            rows.append({'Combination': f'{a}&{b}', 'Gene': gene})
    # triplet for MG group
    mg = ['M-MG','R-MG','S-MG']
    if all(x in valid for x in mg):
        inter = intersect([valid[x] for x in mg])
        for gene in inter:
            rows.append({'Combination': 'M-MG&R-MG&S-MG', 'Gene': gene})
    # all valid
    inter_all = intersect(list(valid.values()))
    for gene in inter_all:
        rows.append({'Combination': 'ALL', 'Gene': gene})
    # save one CSV per stage
    df_out = pd.DataFrame(rows)
    out_csv = os.path.join(intersection_output_dir, f"intersection_stage_{st}_all.csv")
    df_out.to_csv(out_csv, index=False)
    print(f"Saved intersections for stage {st} to {out_csv}")
print("Script completed again.")



===== Stage 3: Calculating intersections for DEG result =====
Saved intersections for stage 1-juvenile to D:/111/DEG_Analysis_Output/Stage_Intersections\intersection_stage_1-juvenile_all.csv
Saved intersections for stage 2-adult to D:/111/DEG_Analysis_Output/Stage_Intersections\intersection_stage_2-adult_all.csv
Saved intersections for stage 3-ges-la to D:/111/DEG_Analysis_Output/Stage_Intersections\intersection_stage_3-ges-la_all.csv
Script completed again.
