# Figure 3a-b
Characterization of spatial and transcriptional state of IEC in response to acute nociceptor activation.

In [None]:
from pathlib import Path
import sys
import os
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import matplotlib.colors as clr
import seaborn as sns

In [None]:
# plt.rcParams['figure.figsize'] = (4,4)
plt.rcParams["figure.dpi"] = 150

In [None]:
sys.path.append(str(Path.cwd().resolve().parents[1]))

from config.paths import BASE_DIR

input_dir = BASE_DIR / "data/h5ad/export_11"
output_dir = BASE_DIR / "figures/Figure_3"

output_dir.mkdir(parents=True, exist_ok=True)

input_file = input_dir / "adata-v3-annotations.h5ad"

In [None]:
adata = sc.read_h5ad(input_file)
adata

In [None]:
palette = [
    "red",
    "blue",
    "green",
    "orange",
    "cyan",
    "magenta",
    "blueviolet",
    "darkturquoise",
    "chartreuse",
    "black",
    "sienna",
    "navy",
    "tomato",
]

## Fig3a

### UMAP

In [None]:
def plot_umap_groups(
    adata, label_key, groups, size=1, unlabeled="lightgray", figsize=(5, 5)
):
    """Plots a UMAP with selected groups highlighted"""

    selected_palette = palette[: len(groups)]
    color_dict = {group: color for group, color in zip(groups, selected_palette)}
    color_dict.update(
        {
            label: unlabeled
            for label in adata.obs[label_key].astype(str).unique()
            if label not in groups
        }
    )

    fig, ax = plt.subplots(figsize=figsize)

    sc.pl.umap(
        adata,
        color=label_key,
        palette=color_dict,
        frameon=False,
        title="",
        size=size,
        ax=ax,  
        show=False,  
    )

    return fig

In [None]:
adata

In [None]:
epithelial = ['Enterocyte_1', 'Enterocyte_2', 'Early_enterocyte', 'ISC',  'Mature_goblet', 'Immature_goblet', 'Paneth', 'Transit_Amplifying', 'Tuft_cell', 'Enteroendocrine']

In [None]:
fig = plot_umap_groups(adata, "cell_type", epithelial, size=1, figsize=(5, 5))
fig.tight_layout()
fig.savefig(f"{output_dir}/Fig3a-left.pdf", dpi=300, bbox_inches="tight")

### Spatial scatter

In [None]:
def plot_spatial_groups(
    adata, basis, label_key, groups, fov=None, size=50, unlabeled="lightgray"
):
    """Plots a spatial embedding with selected groups highlighted"""

    selected_palette = palette[: len(groups)]

    color_dict = {group: color for group, color in zip(groups, selected_palette)}
    color_dict.update(
        {
            label: unlabeled
            for label in adata.obs[label_key].astype(str).unique()
            if label not in groups
        }
    )

    fig, ax = plt.subplots(figsize=(6, 2))
    sc.pl.embedding(
        adata,
        basis=basis,
        color=label_key,
        palette=color_dict,
        ax=ax,
        show=False,
        size=size,
        frameon=False,
        title="",
    )

    if fov:
        xmin, xmax, ymin, ymax = fov
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)

    return fig

In [None]:
bdata = adata[adata.obs['sample_id'] == 'TIS09472_Control'].copy()

In [None]:
fig = plot_spatial_groups(
    bdata,
    basis="spatial",
    label_key="cell_type",
    groups=epithelial,
    fov=(2916, 4321, 5000, 5700),
    size=30,
)

fig.savefig(f"{output_dir}/Fig3a-right.pdf", dpi=300, bbox_inches="tight")

### Fig3b - Composition heatmap

In [None]:
input_dir = BASE_DIR / "data/h5ad/export_10"
input_file = input_dir / "resolvi-corrected-prepped.h5ad"


adata = sc.read_h5ad(input_dir / "iec-subset-resolvi-cc-v2.h5ad") # epithelial subset with cellcharter annotations

In [None]:
epsilon = 1e-5
min_fraction = 0.01

# Zone mapping
zone_mapping = {0: "Stem/Progenitor", 1: "Early", 2: "Late"}
adata.obs["cellcharter_zones"] = adata.obs["epithelial_cc_3"].map(zone_mapping)

zone_order = ["Stem/Progenitor", "Early", "Late"]
adata.obs["cellcharter_zones"] = pd.Categorical(
    adata.obs["cellcharter_zones"], categories=zone_order, ordered=True
)

# Rename groups (safe for categorical dtype)
if pd.api.types.is_categorical_dtype(adata.obs["group"]):
    adata.obs["group_label"] = adata.obs["group"].cat.rename_categories(
        {"Control": "hM3Dq", "Trpv1-cre": "TRPV1-hM3Dq"}
    )
else:
    adata.obs["group_label"] = adata.obs["group"].replace(
        {"Control": "hM3Dq", "Trpv1-cre": "TRPV1-hM3Dq"}
    )

exclude = []
all_dfs = []

# Loop over replicates
for sid in adata.obs["sample_id"].unique():
    df = adata.obs[adata.obs["sample_id"] == sid].copy()
    group = df["group_label"].unique()[0]

    # Crosstab: counts per zone and cell type
    comp = pd.crosstab(df["cellcharter_zones"], df["cell_type"], dropna=False)
    comp = comp.drop(columns=[c for c in exclude if c in comp.columns], errors="ignore")
    comp = comp.reindex(zone_order, fill_value=0)

    # Normalize per zone
    comp_norm = comp.div(comp.sum(axis=1), axis=0)

    # Melt for tidy DataFrame
    melted = comp_norm.reset_index().melt(
        id_vars="cellcharter_zones",
        var_name="Cell Type",
        value_name="Normalized Proportion",
    )
    melted["Sample ID"] = sid
    melted["Experimental Group"] = group

    all_dfs.append(melted)

# Combine all samples
all_data_df = pd.concat(all_dfs, ignore_index=True)
all_data_df = all_data_df.rename(columns={"cellcharter_zones": "Zone"})

# Group-wise mean normalized proportions
group_means = (
    all_data_df.groupby(["Zone", "Cell Type", "Experimental Group"])["Normalized Proportion"]
    .mean()
    .reset_index()
)

# Pivot to wide format
group_means_pivot = (
    group_means.pivot(
        index=["Zone", "Cell Type"],
        columns="Experimental Group",
        values="Normalized Proportion",
    )
    .fillna(0)
    .reset_index()
)
group_means_pivot.columns.name = None

# Rename columns directly (no conditional logic needed)
group_means_pivot = group_means_pivot.rename(
    columns={
        "hM3Dq": "Mean Proportion in hM3Dq",
        "TRPV1-hM3Dq": "Mean Proportion in TRPV1-hM3Dq",
    }
)

# Compute log2 Fold Change and low abundance filter
group_means_pivot["log2 Fold Change (TRPV1-hM3Dq vs hM3Dq)"] = np.log2(
    (group_means_pivot["Mean Proportion in TRPV1-hM3Dq"] + epsilon)
    / (group_means_pivot["Mean Proportion in hM3Dq"] + epsilon)
)

group_means_pivot["Low Abundance Filter (<1% in both groups)"] = (
    (group_means_pivot["Mean Proportion in hM3Dq"] < min_fraction)
    & (group_means_pivot["Mean Proportion in TRPV1-hM3Dq"] < min_fraction)
)

# Final DataFrame
final_df = group_means_pivot.sort_values(by=["Zone", "Cell Type"]).reset_index(drop=True)
final_df

In [None]:
#final_df.to_csv(os.path.join(output_dir, "iec-zone-comp.csv"), index=False)

In [None]:
zone_order = ["Late", "Early", "Stem/Progenitor"]

cell_type_order = [
    "ISC",
    "Transit_Amplifying",
    "Early_enterocyte",
    "Enterocyte_1",
    "Enterocyte_2",
]

plot_df = final_df[final_df["Cell Type"].isin(cell_type_order)].copy()


# Pivot heatmap data
heatmap_data = plot_df.pivot_table(
    index="Zone",
    columns="Cell Type",
    values="log2 Fold Change (TRPV1-hM3Dq vs hM3Dq)",
    fill_value=0,
).reindex(index=zone_order, columns=cell_type_order)

# Calculate mask
control_mean = plot_df.pivot_table(
    index="Zone", columns="Cell Type", values="Mean Proportion in hM3Dq", fill_value=0
).reindex(index=zone_order, columns=cell_type_order)

trpv1_mean = plot_df.pivot_table(
    index="Zone",
    columns="Cell Type",
    values="Mean Proportion in TRPV1-hM3Dq",
    fill_value=0,
).reindex(index=zone_order, columns=cell_type_order)

min_fraction = 0.01
epsilon = 1e-5

# Mask low-abundance cell types (<1% in both groups)
low_mask = (control_mean < min_fraction) & (trpv1_mean < min_fraction)

log2fc = np.log2((trpv1_mean + epsilon) / (control_mean + epsilon))
log2fc = log2fc.mask(low_mask)

cmap = plt.get_cmap("seismic")
cmap.set_bad("lightgray")

In [None]:
plt.figure(figsize=(8, 2))

sns.heatmap(
    log2fc,
    cmap=cmap,
    center=0,
    linewidths=0.5,
    annot=False, 
    fmt=".2f",
    cbar_kws={"label": r"log$_2$FC (TRPV1$^{\mathregular{hM3Dq}}$/hM3Dq)"},
)

plt.title("Cell type relative abundance per zone (≥1% of Zone)", fontsize=14)
plt.xlabel(" ")
plt.ylabel(" ")
plt.xticks(rotation=90, ha="right", fontsize=10)
plt.yticks(fontsize=10)
plt.tight_layout()

output_path = output_dir / "Fig3b.pdf"
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.show()