# Figure 7 Analysis Notebook (PE-Aged)

Generates panels 7A–7H for the PE-Aged sample: spatial scatter and final_cell_type distribution bars for each gene under four conditions.

## 0. Setup & Imports

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pandas import CategoricalDtype
from matplotlib.lines import Line2D
import scanpy as sc

# Directories
DATA_DIR = "./data"
RESULTS_DIR = "./results/figure7_PE_Aged"
os.makedirs(RESULTS_DIR, exist_ok=True)

# Load integrated AnnData
adata_all = sc.read_h5ad(f"{DATA_DIR}/integrated_PE_Aged.h5ad")  # adjust path if needed
adata = adata_all.copy()
adata.obs["condition"] = adata.obs["condition"].replace({
    "Healthy_young": "Control-Young",
    "Healthy_old":   "Control-Aged",
    "PE_young":      "PE-Young",
    "PE_old":        "PE-Aged"
})

## 1. Define Conditions, Categories, and Colors

In [None]:
conditions = ["Control-Young", "Control-Aged", "PE-Young", "PE-Aged"]

# final_cell_type categories
dtype = adata.obs["final_cell_type"].dtype
if isinstance(dtype, CategoricalDtype):
    final_categories = list(dtype.categories)
else:
    final_categories = sorted(adata.obs["final_cell_type"].unique())

# Color mapping
custom_colors = {
    "Cardiomyocyte": "#e41a1c",
    "Endothelial":   "#377eb8",
    "Fibroblast":    "#4daf4a",
    "Smooth muscle": "#e65100"
}
base_colors = plt.cm.tab20.colors
non_custom = [ct for ct in final_categories if ct not in custom_colors]
color_dict = {ct: custom_colors[ct] for ct in custom_colors if ct in final_categories}
for i, ct in enumerate(non_custom):
    color_dict[ct] = base_colors[i % len(base_colors)]

## 2. Plot Loop for Genes

In [None]:
genes = ["C1qtnf3","Comp","Cthrc1","H19","Crlf1","Spp1","Runx1","Ptn"]

for gene in genes:
    # 2A: Spatial scatter 2x2
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    axes = axes.flatten()
    for ax, cond in zip(axes, conditions):
        sub = adata[adata.obs["condition"] == cond]
        coords = sub.obsm["spatial"]
        expr = sub[:, gene].X
        expr = expr.toarray().flatten() if hasattr(expr, "toarray") else expr.flatten()
        zero = expr == 0
        pos  = expr > 0

        if zero.any():
            ax.scatter(coords[zero, 0], coords[zero, 1],
                       s=0.1, c="lightgray", alpha=0.8)
        if pos.any():
            ax.scatter(coords[pos, 0], coords[pos, 1],
                       s=0.1, c="red", alpha=0.8)

        ax.set_title(cond, fontsize=20)
        ax.set_xlabel("Spatial 1", fontsize=16)
        ax.set_ylabel("Spatial 2", fontsize=16)
        ax.tick_params(labelsize=14)
        ax.set_aspect("equal")
        for spine in ax.spines.values():
            spine.set_visible(False)

    legend_e = [
        Line2D([0], [0], marker="o", color="w", markerfacecolor="lightgray", markersize=8, label="No Expr"),
        Line2D([0], [0], marker="o", color="w", markerfacecolor="red", markersize=8, label=gene)
    ]
    fig.legend(handles=legend_e,
               title=f"{gene} expression",
               loc="upper center", bbox_to_anchor=(0.5, 1.02),
               ncol=2, fontsize=16, title_fontsize=18)

    plt.subplots_adjust(top=0.88, bottom=0.06, left=0.02,
                        right=0.98, wspace=0.05, hspace=0.35)
    fig.savefig(f"{RESULTS_DIR}/{gene}_spatial_scatter.png", dpi=300, bbox_inches="tight")
    plt.close(fig)

    # 2B: final_cell_type distribution bar 2x2
    fig, axes = plt.subplots(2, 2, figsize=(12, 7))
    axes = axes.flatten()
    for i, (ax, cond) in enumerate(zip(axes, conditions)):
        sub = adata[adata.obs["condition"] == cond]
        expr = sub[:, gene].X
        expr = expr.toarray().flatten() if hasattr(expr, "toarray") else expr.flatten()
        mask = expr > 0

        counts = (
            sub.obs.loc[mask, "final_cell_type"]
            .value_counts(normalize=True)
            .reindex(final_categories)
            .fillna(0)
        )

        left = 0.0
        for ft in final_categories:
            w = counts[ft]
            ax.barh(0, w, left=left, height=0.6,
                    color=color_dict[ft], edgecolor="white")
            left += w

        ax.set_xlim(0, 1)
        ax.set_yticks([])
        ax.set_title(cond, fontsize=20)
        if i // 2 == 1:
            ax.set_xlabel("Proportion", fontsize=18)
        ax.tick_params(axis="x", labelsize=16)

    handles = [
        Line2D([0],[0], marker="s", color="w",
               markerfacecolor=color_dict[ft], markersize=10, label=ft)
        for ft in final_categories
    ]
    fig.legend(handles=handles,
               title="final_cell_type",
               ncol=4, fontsize=14, title_fontsize=16,
               loc="upper center", bbox_to_anchor=(0.5, 1.05))

    plt.subplots_adjust(top=0.80, bottom=0.06, left=0.02,
                        right=0.90, wspace=0.2, hspace=0.35)
    fig.savefig(f"{RESULTS_DIR}/{gene}_final_celltype_distribution.png", dpi=300, bbox_inches="tight")
    plt.close(fig)

print("All Figure 7 panels saved to:", RESULTS_DIR)