In [None]:
import os
import sys
from pathlib import Path
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors 
import matplotlib.patches as mpatches
import seaborn as sns

In [None]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42

In [None]:
plt.rcParams["figure.figsize"] = (4, 4)

In [None]:
plt.rcParams["figure.dpi"] = 500

In [None]:
#plt.rcParams["font.size"] = 24

## Import data

In [None]:
base_dir = Path("/path/to/tbi-seq")

## Input
data_dir = base_dir / "data/h5ad"
csv_dir = base_dir / "data/mapmycells"

## Output
output_dir = data_dir
outs = base_dir / "results"
fig_dir = outs / "figures"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(outs, exist_ok=True)
os.makedirs(fig_dir, exist_ok=True)

In [None]:
sc.settings.figdir = fig_dir

In [None]:
sc.settings.figdir = fig_dir

In [None]:
adata = sc.read_h5ad(os.path.join(data_dir, "03_neurons-clean-scvi.h5ad"))

In [None]:
metadata = {
    '1':  {'group': 'Sham-GFP', 'group_id': 'A', 'condition': 'Sham', 'treatment': 'GFP',   'side': 'Ipsilateral'},
    '3':  {'group': 'Sham-VEGFC', 'group_id': 'B', 'condition': 'Sham', 'treatment': 'VEGFC', 'side': 'Ipsilateral'},
    '5':  {'group': 'TBI-GFP', 'group_id': 'C', 'condition': 'TBI',  'treatment': 'GFP',   'side': 'Ipsilateral'},
    '6':  {'group': 'TBI-GFP', 'group_id': 'D', 'condition': 'TBI',  'treatment': 'GFP',   'side': 'Contralateral'},
    '7':  {'group': 'TBI-VEGFC', 'group_id': 'E', 'condition': 'TBI',  'treatment': 'VEGFC', 'side': 'Ipsilateral'},
    '8':  {'group': 'TBI-VEGFC', 'group_id': 'F', 'condition': 'TBI',  'treatment': 'VEGFC', 'side': 'Contralateral'},
}

for key in ['group', 'group_id', 'condition', 'treatment', 'side']:
    adata.obs[key] = adata.obs['sample_id'].map({k: v[key] for k, v in metadata.items()})

adata.obs.group.value_counts()

In [None]:
adata.obs.group_id.value_counts()

In [None]:
adata.layers['counts'] = adata.X.copy()
sc.pp.normalize_total(adata)
adata.layers['normalized'] = adata.X.copy()
sc.pp.log1p(adata)
adata.layers['log1p'] = adata.X.copy()
adata.raw = adata.copy()

In [None]:
adata.obs['cell_class'] = adata.obs['cell_type'].copy()

In [None]:
labels_df = pd.read_csv(
    os.path.join(csv_dir, '03_neurons-clean_10xWholeMouseBrain(CCN20230722)_HierarchicalMapping_UTC_1749244668001.csv'),
    skiprows=4
)

In [None]:
sc.pp.neighbors(adata, use_rep="X_scVI")
sc.tl.umap(adata)
sc.tl.leiden(adata, resolution=1, key_added="leiden")

In [None]:
sc.pl.umap(adata, color="leiden")

## Merge mapmycells annotations

In [None]:
print(adata.obs_names[:5])
print(labels_df['cell_id'].head())

In [None]:
labels_df = labels_df.set_index('cell_id')

columns_to_map = [
    'class_label', 'class_name', 'class_bootstrapping_probability',
    'subclass_label', 'subclass_name', 'subclass_bootstrapping_probability',
    'supertype_label', 'supertype_name', 'supertype_bootstrapping_probability',
    'cluster_label', 'cluster_name', 'cluster_alias', 'cluster_bootstrapping_probability'
]

for col in columns_to_map:
    adata.obs[col] = adata.obs_names.map(labels_df[col])

#print(adata.obs[columns_to_map].head())

In [None]:
adata

In [None]:
sc.pl.umap(adata, color=['supertype_bootstrapping_probability', 
                         'subclass_bootstrapping_probability', 
                         'class_bootstrapping_probability'], cmap='viridis', wspace=0.4)

In [None]:
adata.obs['subclass_name'].value_counts()

# assign labels to clusters 

In [None]:
# Compute majority subclass_label per Leiden cluster
majority_subclass_per_cluster = (
    adata.obs.groupby('leiden')['subclass_name']
    .agg(lambda x: x.value_counts().idxmax())
)

# Map from leiden cluster to majority subclass_label
adata.obs['cell_type'] = adata.obs['leiden'].map(majority_subclass_per_cluster)

# Check result
print(adata.obs[['leiden', 'cell_type']].head())

In [None]:
# Clean up the 'cell_type' column
adata.obs['cell_type'] = adata.obs['cell_type'].astype('category')
adata.obs['cell_type'] = adata.obs['cell_type'].cat.remove_unused_categories()

# Order categories by abundance
cell_type_counts = adata.obs['cell_type'].value_counts()
ordered_categories = cell_type_counts.index.tolist()

# Reorder the categories accordingly
adata.obs['cell_type'] = adata.obs['cell_type'].cat.reorder_categories(ordered_categories, ordered=True)

# Generate HUSL colors and assign to adata.uns
husl_colors = sns.color_palette('husl', n_colors=len(ordered_categories))
adata.uns['cell_type_colors'] = [mcolors.to_hex(c) for c in husl_colors]

# Plot UMAP with new colors
sc.pl.umap(adata, color=['leiden', 'Slc17a6', 'Slc17a7', 'Gad1', 'Gad2', 'cell_type'])

In [None]:
adata.obs.cell_type.value_counts()

In [None]:
# Count the number of cells per cell_type
cell_type_counts = adata.obs['cell_type'].value_counts().reset_index()
cell_type_counts.columns = ['cell_type', 'count']

# Optional: Sort by count (descending)
cell_type_counts = cell_type_counts.sort_values('count', ascending=False)

# Plot
plt.figure(figsize=(10, 8))
sns.barplot(data=cell_type_counts, y='cell_type', x='count', palette='viridis')

# Aesthetics
plt.title('Cell type composition (MapMyCells subclass labels)', fontsize=16)
plt.xlabel('Number of cells', fontsize=14)
plt.ylabel('Cell type', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.tight_layout()
plt.show()

## Annotate hippocampal vs border/input neurons

In [None]:
# Categorize cells into Hippocampus / Border-Input groups

# Hippocampal core subclasses (reviewer specified: CA1, CA2, CA3, DG, SUB, CR)
hippocampus_core = [
    "037 DG Glut",
    "016 CA1-ProS Glut",
    "025 CA2-FC-IG Glut",
    "017 CA3 Glut",
    "023 SUB-ProS Glut",
    "033 NP SUB Glut",
    "HPF CR Glut",   # CR neurons, if present
]

# Hippocampal interneurons
hippocampus_interneurons = [
    "053 Sst Gaba",
    "052 Pvalb Gaba",
    "050 Lamp5 Lhx6 Gaba",
    "046 Vip Gaba",
    "049 Lamp5 Gaba",
]

# Assign cells to region categories
def assign_region(cell_type):
    if cell_type in hippocampus_core or cell_type in hippocampus_interneurons:
        return "Hippocampal neurons"
    else:
        return "Border/Input neurons"

adata.obs["region_assignment"] = adata.obs["cell_type"].apply(assign_region)
adata.obs["region_assignment"] = adata.obs["region_assignment"].astype("category")

# Subset hippocampus only
adata_hippocampus_only = adata[adata.obs["region_assignment"] == "Hippocampal neurons"].copy()

# PLOT 1: UMAP of hippocampus only
sc.pl.umap(
    adata_hippocampus_only,
    color="cell_type",
    legend_loc="on data",
    title="Hippocampus - Cell Types"
)

# PLOT 2: UMAP of all cells colored by region assignment
region_palette = {
    "Hippocampal neurons": "#1f77b4",   # blue
    "Border/Input neurons": "#d62728",  # red
}

sc.pl.umap(
    adata,
    color="region_assignment",
    palette=region_palette,
    title="Region Assignment (All Cells)"
)

# PLOT 3: Barplot of region_assignment counts
region_counts = adata.obs["region_assignment"].value_counts().reset_index()
region_counts.columns = ["region_assignment", "count"]

plt.figure(figsize=(7, 5))
sns.barplot(
    data=region_counts,
    x="region_assignment",
    y="count",
    palette=region_palette
)

plt.title("Region Assignment - Cell Counts")
plt.xlabel("Region Assignment")
plt.ylabel("Number of Cells")
plt.xticks(rotation=20)
plt.tight_layout()
plt.show()

In [None]:
# Inner ring
region_counts = adata.obs['region_assignment'].value_counts()
region_palette = {
    'Hippocampal neurons': '#4d4d4d',   # now dark gray
    'Border/Input neurons': '#d9d9d9',  # now light gray
}
region_colors = [region_palette[r] for r in region_counts.index]

# cell_type counts by region
ct = pd.crosstab(adata.obs['region_assignment'], adata.obs['cell_type'])
hip_cell_types = [c for c in ct.columns if ct.loc['Hippocampal neurons', c] > 0]
bdr_cell_types = [c for c in ct.columns if ct.loc['Border/Input neurons', c] > 0]

# Build divergent palettes
full_rocket = sns.color_palette("rocket_r", 256)
idx = np.linspace(20, 220, len(hip_cell_types), dtype=int)
hip_pal = [full_rocket[i] for i in idx]
bdr_pal = sns.color_palette("mako", n_colors=len(bdr_cell_types))[::-1]

# Map each cell_type to color
cell_type_to_color = {}
for name, col in zip(hip_cell_types, hip_pal):
    cell_type_to_color[name] = matplotlib.colors.to_hex(col)
for name, col in zip(bdr_cell_types, bdr_pal):
    cell_type_to_color[name] = matplotlib.colors.to_hex(col)

# Build outer_counts/colors
outer_counts, outer_colors = [], []
for region in region_counts.index:
    for name, cnt in ct.loc[region].items():
        if cnt > 0:
            outer_counts.append(cnt)
            outer_colors.append(cell_type_to_color[name])

# Plot multi‐layer donut
fig, ax = plt.subplots(figsize=(6, 6))

ax.pie(
    region_counts,
    radius=0.7,
    colors=region_colors,
    startangle=160,
    wedgeprops=dict(width=0.3, edgecolor='white'),
    labels=None,
    autopct='%1.0f%%',
    pctdistance=0.30,
    textprops={'color': 'black', 'fontsize': 25}
)

# Outer ring
ax.pie(
    outer_counts,
    radius=1.0,
    colors=outer_colors,
    startangle=160,
    wedgeprops=dict(width=0.3, edgecolor='white')
)

# White center circle
centre = plt.Circle((0, 0), 0.4, fc='white')
ax.add_artist(centre)

ax.set_title('Anatomical Assignment and Cell Type Composition', fontsize=16)
plt.tight_layout()

export_path = fig_dir / "anatomical_assignment_donut.pdf"
fig.savefig(export_path, bbox_inches='tight')
plt.show()

In [None]:
# Compute top 5 cell types in each anatomical category
top5 = {}
for region in ['Hippocampal neurons','Border/Input neurons']:
    counts = (
        adata.obs[adata.obs['region_assignment'] == region]
        .groupby('cell_type')
        .size()
        .sort_values(ascending=False)
    )
    top5[region] = counts.head(5).index.tolist()

print("Top 5 cell types for Border/Input neurons:")
print(top5['Border/Input neurons'])
print()
print("Top 5 cell types for Hippocampal neurons:")
print(top5['Hippocampal neurons'])

In [None]:
# recompute palette
ct = pd.crosstab(adata.obs['region_assignment'], adata.obs['cell_type'])
hip_cell_types = [c for c in ct.columns if ct.loc['Hippocampal neurons', c] > 0]
bdr_cell_types = [c for c in ct.columns if ct.loc['Border/Input neurons', c] > 0]

full_rocket = sns.color_palette("rocket_r", 256)
idx = np.linspace(20, 220, len(hip_cell_types), dtype=int)
hip_pal = [full_rocket[i] for i in idx]
bdr_pal = sns.color_palette("mako", n_colors=len(bdr_cell_types))[::-1]

cell_type_to_color = {}
for name, col in zip(hip_cell_types, hip_pal):
    cell_type_to_color[name] = matplotlib.colors.to_hex(col)
for name, col in zip(bdr_cell_types, bdr_pal):
    cell_type_to_color[name] = matplotlib.colors.to_hex(col)

top5 = {
    'Hippocampal neurons': ['037 DG Glut', '016 CA1-ProS Glut',
                             '017 CA3 Glut', '053 Sst Gaba',
                             '052 Pvalb Gaba'],
    'Border/Input neurons': ['028 L6b/CT ENT Glut', '019 L2/3 IT PPP Glut',
                             '113 MEA-COA-BMA Ccdc42 Glut',
                             '070 LSX Prdm12 Slit2 Gaba',
                             '071 LSX Prdm12 Zeb2 Gaba']
}

hip_handles = [mpatches.Patch(color=cell_type_to_color[l], label=l)
               for l in top5['Hippocampal neurons']]
bdr_handles = [mpatches.Patch(color=cell_type_to_color[l], label=l)
               for l in top5['Border/Input neurons']]

fig, ax = plt.subplots(figsize=(12, 4))  # slightly narrower
ax.axis('off')

legend_kwargs = dict(
    frameon=False,
    fontsize=18,
    title_fontsize=20,
    handlelength=2.5,
    handleheight=1.0,
    labelspacing=1.2
)

# Move legend
legend1 = ax.legend(
    handles=hip_handles,
    title='Hippocampal neurons\n(top 5)',
    loc='center left',
    bbox_to_anchor=(0.10, 0.5),
    **legend_kwargs
)
ax.add_artist(legend1)

legend2 = ax.legend(
    handles=bdr_handles,
    title='Border/Input neurons\n(top 5)',
    loc='center right',
    bbox_to_anchor=(0.90, 0.5),
    **legend_kwargs
)

plt.tight_layout()
export_path = fig_dir / "anatomical_assignment_donut_legend.pdf"
fig.savefig(export_path, bbox_inches='tight')
plt.show()

In [None]:
region_palette = {
    'Hippocampal neurons': '#4d4d4d',   # dark gray
    'Border/Input neurons': '#d9d9d9',  # light gray
}

# Build legend handles
legend_handles_region = [
    mpatches.Patch(color=color, label=label)
    for label, color in region_palette.items()
]

# Legend style kwargs
legend_kwargs = dict(
    frameon=False,
    fontsize=20,
    title_fontsize=24,
    handlelength=2.5,
    handleheight=1.0,
    labelspacing=1.2
)

# Plot standalone legend
fig, ax = plt.subplots(figsize=(4, 3))
ax.axis('off')

ax.legend(
    handles=legend_handles_region,
    title='Anatomical Assignment',
    loc='center',
    **legend_kwargs
)

plt.tight_layout()
export_path = fig_dir / "anatomical_assignment_umap_legend.pdf"
fig.savefig(export_path, bbox_inches='tight')
plt.show()

In [None]:
sc.pl.umap(adata, 
           color='region_assignment', 
           palette=region_palette, title='', 
           frameon = False,
          save = "anatomical_assignment_umap_legend.pdf")

In [None]:
sc.pl.umap(adata, color='cell_type', legend_loc=None, frameon = False, title = '',
          save = 'cell_type.pdf') 

In [None]:
#Plot legend
categories = adata.obs['cell_type'].cat.categories
colors = adata.uns['cell_type_colors']

handles = [mpatches.Patch(color=col, label=cat) for col, cat in zip(colors, categories)]

fig, ax = plt.subplots(figsize=(6, len(categories) * 0.3)) 
ax.legend(handles=handles, loc='center', frameon=False, ncol=2)
ax.axis('off')
plt.tight_layout()
export_path = fig_dir / "cell_type_legend.pdf"
fig.savefig(export_path, bbox_inches='tight')
plt.show()

In [None]:
# plot legend
categories = adata.obs['region_assignment'].cat.categories
colors = adata.uns['region_assignment_colors']

handles = [mpatches.Patch(color=col, label=cat) for col, cat in zip(colors, categories)]

fig, ax = plt.subplots(figsize=(6, len(categories) * 0.3)) 
ax.legend(handles=handles, loc='center', frameon=False, ncol=1)
ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
border_input_counts = (
    adata.obs[adata.obs['region_assignment'] == 'Input / border']
    .assign(cell_type_str = adata.obs['cell_type'].astype(str))  # convert to string for clean grouping
    .groupby('cell_type_str')
    .size()
    .reset_index(name='count')
    .sort_values('count', ascending=False)
)

#border_input_counts = border_input_counts[border_input_counts['count'] > 0]

plt.figure(figsize=(5, 6))
sns.barplot(data=border_input_counts, y='cell_type_str', x='count', palette='viridis_r')

plt.title('Neuron Subtypes', fontsize=16)
plt.xlabel('Number of cells', fontsize=14)
plt.ylabel('Cell type', fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.tight_layout()
plt.show()

In [None]:
border_input_counts = (
    adata.obs[adata.obs['region_assignment'] == 'Input / border']
    .groupby('cell_type')
    .size()
    .sort_values(ascending=False)
)

print(border_input_counts)

In [None]:
adata.obs.columns

# markers

In [None]:
adata_subset = adata[adata.obs['region_assignment'] == 'Hippocampal neurons'].copy()

print("Cells in hippocampus subset:", adata_subset.n_obs)
print(adata_subset.obs['region_assignment'].value_counts())

# Top 10 hippocampal subtypes
cell_type_counts = adata_subset.obs['cell_type'].value_counts()
top10_cell_types = cell_type_counts.index[:10].tolist()

print("Top10 cell types in hippocampus:", top10_cell_types)
print(cell_type_counts.head(10))

In [None]:
adata.obs.columns

In [None]:
sc.pl.umap(adata_subset, color = ['cell_type'])

In [None]:
adata_subset

In [None]:
subclass_markers = {
    "037 DG Glut":       ["Prox1", "Itpka",  "C1ql3", "Glis3", "Egr3"],
    "016 CA1-ProS Glut": ["Fibcd1","Spink8"],
    "017 CA3 Glut":      ["Tspan18","Slc17a7","Hes1"],
    "028 L6b/CT ENT Glut":["Cplx3","Nxph3","Npr3"],
    "019 L2/3 IT PPP Glut":["Cdc14a","Satb2","Meis2"],
    "113 MEA-COA-BMA Ccdc42 Glut":["Kcnh3","Scn5a","Adcyap1","Csmd3","Trhr","B130024G19Rik"],
    "053 Sst Gaba":      ["Rbp4","Lhx6","Reln"],
    "070 LSX Prdm12 Slit2 Gaba":["Prdm12","Hs3st2","Shisa9","Runx1t1"],
    "052 Pvalb Gaba":    ["Rbp4","Lhx6","Zfp804b"],
    "071 LSX Prdm12 Zeb2 Gaba":["Ano1","Myo5b","Zeb2","Slc18a2"],
}

subclass_tf_markers = {
    "037 DG Glut":       ["Prox1","Glis3","Egr3"],
    "016 CA1-ProS Glut": ["Zfhx4","Neurod6","Fezf2","Bcl6","Satb2"],
    "017 CA3 Glut":      ["Lhx9","Neurod6","Foxg1","Hopx","Hes1"],
    "028 L6b/CT ENT Glut":["Foxp2","Satb2","Nr2f2","Zeb2"],
    "019 L2/3 IT PPP Glut":["Cux2","Satb2","Lhx2","Tshz2","Tox","Sox8","Tead1"],
    "113 MEA-COA-BMA Ccdc42 Glut":["Egr3","Zbtb7c","Nr2f2","Zim1","Id2","Myt1l","Meis2","Foxg1"],
    "053 Sst Gaba":      ["Mafb","Npas1","Sox6","Zeb2"],
    "070 LSX Prdm12 Slit2 Gaba":["Prdm12","Myt1l","Prdm16","Isl1","Egr3"],
    "052 Pvalb Gaba":    ["Creb5","Lhx6","Klf5"],
    "071 LSX Prdm12 Zeb2 Gaba":["Prdm12","Myt1l","Zeb2","Isl1"],
}

subclass_id_markers = list(subclass_markers.values())
subclass_id_markers

In [None]:
genes = [
    'Prox1', 'Itpka', 'C1ql3', "Glis3", "Egr3",
    'Fibcd1', 'Spink8', 'Zfhx4', 'Neurod6', 'Bcl6', # Fezf2, Satb2
    'Tspan18', 'Slc17a7', 'Hes1', 'Lhx9', 'Neurod6', # Foxg1,Hopx,Hes1
    #'Cplx3', 'Nxph3', 'Npr3', 'Foxp2', 'Fezf2',
    #'Cdc14a', 'Satb2', 'Meis2', 'Cux2', 'Satb2',
    #'Kcnh3', 'Scn5a', 'Adcyap1', 'Csmd3', 'Trhr', 'B130024G19Rik',
    'Rbp4', 'Lhx6', 'Reln',
    #'Prdm12', 'Hs3st2', 'Shisa9', 'Runx1t1',
    'Rbp4', 'Lhx6', 'Zfp804b',
    #'Ano1', 'Myo5b', 'Zeb2', 'Slc18a2'
] 

In [None]:
#sc.tl.dendrogram(adata_subset, use_rep = "X_scVI", groupby = 'cell_type')
sc.pl.matrixplot(adata_subset, var_names=['Slc17a7', 'Slc17a6', 'Gad1', 'Gad2'], groupby="cell_type", dendrogram = True)

In [None]:
marker_genes_dict = {
    "Excitatory neuron": ["Slc17a6", "Slc17a7"],
    "Inhibitory neuron": ["Gad1", "Gad2"],
}

In [None]:
ax = sc.pl.heatmap(
    adata,
    marker_genes_dict,
    groupby="cluster",
    cmap="turbo",
    dendrogram=False,
    swap_axes = True,
    vmax=1.5
)

In [None]:
genes = [
    'Prox1', 'Itpka', 'C1ql3', "Glis3", "Egr3",
    'Fibcd1', 'Spink8', 'Zfhx4', 'Neurod6', 'Bcl6', # Fezf2, Satb2
    'Tspan18', 'Slc17a7', 'Hes1', 'Lhx9', 'Neurod6', # Foxg1,Hopx,Hes1
    #'Cplx3', 'Nxph3', 'Npr3', 'Foxp2', 'Fezf2',
    #'Cdc14a', 'Satb2', 'Meis2', 'Cux2', 'Satb2',
    #'Kcnh3', 'Scn5a', 'Adcyap1', 'Csmd3', 'Trhr', 'B130024G19Rik',
    'Rbp4', 'Lhx6', 'Reln',
    #'Prdm12', 'Hs3st2', 'Shisa9', 'Runx1t1',
    'Rbp4', 'Lhx6', 'Zfp804b',
    #'Ano1', 'Myo5b', 'Zeb2', 'Slc18a2'
] 

In [None]:
marker_genes_dict = {
    "Excitatory neuron": ["Slc17a6", "Slc17a7"],
    "Inhibitory neuron": ["Gad1", "Gad2"],
    "037 DG Glut": ['Prox1', 'Itpka', 'C1ql3', "Glis3", "Egr3",],
    "016 CA1-ProS Glut": ['Neurod6', 'Fibcd1', 'Spink8', 'Satb2', 'Bcl6'], #Zfhx4
    "017 CA3 Glut": ['Tspan18', 'Hes1', 'Lhx9', 'Hopx'], #'Slc17a7', 
    "053 Sst Gaba": ['Lhx6', 'Reln', 'Mafb', 'Rbp4', 'Npas1'], # Rbp4,Lhx6,Reln	Mafb,Npas1,Sox6,Zeb2
    "052 Pvalb Gaba": ['Zfp804b', 'Sox6', 'Zeb2'],
    "050 Lamp5 Lhx6 Gaba": [],
    "046 Vip Gaba": [],

}

In [None]:
sc.pl.heatmap(adata_subset, var_names=marker_genes_dict, 
              groupby="cell_type", standard_scale="var",
             vmax = 0.5, cmap = 'turbo')

In [None]:
sc.tl.rank_genes_groups(
    adata_subset,
    groupby='cell_type',
    method='wilcoxon',
    key_added='rank_genes_celltype'  # so we don't overwrite default
)

adata_subset.uns['rank_genes_celltype']

In [None]:
marker_genes_dict = {}

result = adata_subset.uns['rank_genes_celltype']
groups = result['names'].dtype.names

for group in groups:
    top_genes = result['names'][group][:10].tolist()
    marker_genes_dict[group] = top_genes

In [None]:
all_marker_genes = sum(marker_genes_dict.values(), [])

top_n = 50
top_marker_genes = all_marker_genes[:top_n]

sc.pl.heatmap(adata_subset, var_names=top_marker_genes, 
              groupby="cell_type", standard_scale="var",
              vmax=1, cmap='turbo',
             save = 'subtype_marker_genes.pdf')

In [None]:
sc.pl.heatmap(adata_subset, var_names=['Slc17a6', 'Slc17a7', 'Gad1', 'Gad2'],
              groupby="cell_type", standard_scale="var",
              vmax=1, cmap='turbo',
             save = 'excitatory_inhibitory_markers_by_subtype.pdf')

In [None]:
adata.X[:5, :5]

# Hippocampal donut plot

In [None]:
cell_type_counts = adata_subset.obs['cell_type'].value_counts()
cell_types = cell_type_counts.index.tolist()

palette = sns.color_palette("Paired", n_colors=len(cell_types))
cell_type_colors = [matplotlib.colors.to_hex(c) for c in palette]

# Function to show % only if > 5%
def autopct_format(pct):
    return ('%1.0f%%' % pct) if pct > 3 else ''

fig, ax = plt.subplots(figsize=(6, 6))

# Outer ring
ax.pie(
    cell_type_counts,
    radius=1.0,
    colors=cell_type_colors,
    startangle=160,
    wedgeprops=dict(width=0.3, edgecolor='white'),
    labels = cell_types[:7] + [''] * (len(cell_types) - 7),
    labeldistance=1.05,
    autopct=autopct_format,    # use custom function
    pctdistance=0.85,
    textprops={'fontsize': 11}
)

# White center circle
centre = plt.Circle((0, 0), 0.4, fc='white')
ax.add_artist(centre)

ax.set_title('Cell Type Composition', fontsize=16)
plt.tight_layout()
export_path = fig_dir / "hippocampal_only_subtype_donut.pdf"
fig.savefig(export_path, bbox_inches='tight')
plt.show()

In [None]:
handles = [mpatches.Patch(color=col, label=cat) for col, cat in zip(cell_type_colors, cell_types)]

# Plot the legend only
fig, ax = plt.subplots(figsize=(6, len(cell_types) * 0.3)) 
ax.legend(handles=handles, loc='center', frameon=False, ncol=1)
ax.axis('off')
plt.tight_layout()
export_path = fig_dir / "hippocampal_only_subtype_donut_legend.pdf"
fig.savefig(export_path, bbox_inches='tight')
plt.show()

In [None]:
sc.pl.umap(adata, color = 'Slc17a6', frameon = False)

In [None]:
sc.pl.umap(adata, color = 'Slc17a7', frameon = False)

In [None]:
sc.pl.umap(adata, color = 'Gad1', frameon = False)

In [None]:
sc.pl.umap(adata, color = 'Gad2', frameon = False)

In [None]:
sc.pl.umap(adata, color = 'cell_class', frameon = False)

# Hippocampus only gene markers by group

In [None]:
genes = ['Prox1', 'Synpr', 'C1ql2', 'C1ql3', 'Camk2a', 'Camk2b', 'Tmem108', 'Ppfia2', 'Rfx3', 'Lrrtm4', 'Btbd9', 'Cntnap5a', 'Erc2']

In [None]:
sc.pl.matrixplot(adata_hippocampus_only, var_names = genes, groupby = 'group', standard_scale = 'var', 
                 colorbar_title = 'Scaled to gene',
                cmap = 'rocket')

In [None]:
sc.pl.matrixplot(adata_hippocampus_only, var_names = genes, groupby = 'cell_type', standard_scale = 'var', 
                 colorbar_title = 'Scaled to gene',
                cmap = 'viridis',
                swap_axes = True)

# DE

In [None]:
# WILCOXON DE FUNCTION — TEST ON ALL GENES

def run_wilcoxon_de_analysis(adata, cell_type, reference_group, comparison_group, min_cells_per_group=40, max_imbalance_ratio=0.9):
    # Filter data for the specific cell type and groups
    subset_data = adata[(adata.obs['cell_type'] == cell_type) & 
                        (adata.obs['group'].isin([reference_group, comparison_group]))].copy()

    if subset_data.shape[0] == 0:
        return f"Skipped: {reference_group} vs {comparison_group} (no cells of this type)"

    group_counts = subset_data.obs['group'].value_counts()
    group_counts = group_counts.reindex([reference_group, comparison_group]).fillna(0)

    if any(group_counts == 0):
        return f"Skipped: {reference_group} vs {comparison_group} (no cells in groups)"

    # Check for minimum n
    if any(group_counts < min_cells_per_group):
        return f"Skipped: {reference_group} vs {comparison_group} (n is too low)"

    # Check for imbalance
    if max(group_counts) / sum(group_counts) > max_imbalance_ratio:
        return f"Skipped: {reference_group} vs {comparison_group} (imbalance)"

    n_genes = subset_data.shape[1]

    sc.tl.rank_genes_groups(
        subset_data,
        groupby='group',
        groups=[comparison_group],
        reference=reference_group,
        method='wilcoxon',
        n_genes=n_genes
    )

    result = subset_data.uns['rank_genes_groups']
    
    gene_names = pd.DataFrame(result['names'])[comparison_group]
    pvals_adj = pd.DataFrame(result['pvals_adj'])[comparison_group]
    logfc = pd.DataFrame(result['logfoldchanges'])[comparison_group]

    df_result = pd.DataFrame({
        'gene': gene_names.values,
        'pvals_adj': pvals_adj.values,
        'log2fc': logfc.values,
        'cell_type': cell_type,
        'comparison': f"{reference_group}_vs_{comparison_group}"
    })

    return df_result

In [None]:
adata.obs.side.value_counts()

In [None]:
adata.obs.group.value_counts()

In [None]:
adata = adata[adata.obs['side'] == 'Ipsilateral'].copy()
adata.obs.group.value_counts()

In [None]:
# RUN DE WITH WILCOXON

cell_types = [
    '037 DG Glut', '016 CA1-ProS Glut', '017 CA3 Glut', 
    '025 CA2-FC-IG Glut', '023 SUB-ProS Glut', '033 NP SUB Glut', 
    '053 Sst Gaba', '052 Pvalb Gaba', '050 Lamp5 Lhx6 Gaba', '046 Vip Gaba'
]

comparisons = [('Sham-GFP', 'TBI-GFP'), ('TBI-GFP', 'TBI-VEGFC'), ('Sham-GFP', 'Sham-VEGFC')]

all_wilcoxon_results = []

for cell_type in cell_types:
    for reference_group, comparison_group in comparisons:
        try:
            result = run_wilcoxon_de_analysis(
                adata=adata,
                cell_type=cell_type,
                reference_group=reference_group,
                comparison_group=comparison_group
            )
            
            # Safely handle result
            if isinstance(result, pd.DataFrame):
                all_wilcoxon_results.append(result)
            else:
                print(f"Skipping {cell_type} {reference_group} vs {comparison_group}: {result}")
        
        except ValueError as e:
            print(f"Error in {cell_type} {reference_group} vs {comparison_group}: {e}")
            print("Skipping this comparison.")

if len(all_wilcoxon_results) > 0:
    final_wilcoxon_results = pd.concat(all_wilcoxon_results, ignore_index=True)
    print("Final DE results shape:", final_wilcoxon_results.shape)
else:
    print("No DE results were generated.")

In [None]:
adata.obs.group.value_counts()

In [None]:
# Group by side + group + group_id, and count
grouping = adata.obs.groupby(['side', 'group', 'group_id']).size().reset_index(name='n_cells')

# Show result
display(grouping)

In [None]:
#final_wilcoxon_results.to_csv(os.path.join(outs, '2025-06-07_neuron-wilcoxon-DE.csv'))

In [None]:
padj_threshold = 0.05
log2fc_threshold = 0.25

filtered_wilcoxon_results = final_wilcoxon_results[
    (final_wilcoxon_results['pvals_adj'] < padj_threshold) &
    (abs(final_wilcoxon_results['log2fc']) >= log2fc_threshold)
]

print(filtered_wilcoxon_results['comparison'].value_counts())  # How many DEGs per comparison

In [None]:
final_wilcoxon_results.comparison.value_counts()

In [None]:
from functools import reduce

comparisons = ['Sham-GFP_vs_TBI-GFP', 'TBI-GFP_vs_TBI-VEGFC', 'Sham-GFP_vs_Sham-VEGFC']

# Specify order of cell types for heatmap — adjust based on your adata.obs['cell_type'] categories!
order = adata.obs['cell_type'].cat.categories.tolist()

# Create a reference DataFrame for cell types
cell_type_reference = pd.DataFrame(order, columns=['cell_type'])

# Initialize an empty list to store DEG counts per comparison
data_list = []

# Define thresholds
padj_threshold = 0.05
log2fc_threshold = 0.25

# Loop over each comparison
for comp in comparisons:
    # Filter for this comparison
    df_comp = final_wilcoxon_results[final_wilcoxon_results['comparison'] == comp]

    # Optional: filter on 'mean' if available
    if 'mean' in df_comp.columns:
        df_comp = df_comp[df_comp['mean'] > mean_expression_threshold]

    # Filter for significant DEGs
    df_significant = df_comp[
        (df_comp['pvals_adj'] < padj_threshold) &
        (abs(df_comp['log2fc']) > log2fc_threshold)
    ]

    # Count DEGs per cell type
    deg_counts = df_significant.groupby('cell_type').size().reset_index(name=f'{comp}')

    # Merge with reference to ensure all cell types appear
    deg_counts = pd.merge(cell_type_reference, deg_counts, on='cell_type', how='left').fillna(0)

    # Append to list
    data_list.append(deg_counts)

# Merge all comparison dataframes on 'cell_type'
combined_df = reduce(lambda left, right: pd.merge(left, right, on='cell_type', how='outer'), data_list)
combined_df.fillna(0, inplace=True)  # Replace NaN with 0 where no DEGs

# Reorder dataframe according to desired order
combined_df['cell_type'] = pd.Categorical(combined_df['cell_type'], categories=order, ordered=True)
combined_df = combined_df.sort_values('cell_type')

# Plot heatmap
# Plot heatmap — top 10 rows only
plt.figure(figsize=(6, 5))
sns.heatmap(combined_df.set_index('cell_type').head(10), annot=True, cmap='viridis', fmt="g")

plt.title('')
plt.ylabel('')
plt.xlabel('')
plt.grid(False)
plt.tight_layout()
plt.show()

In [None]:
import pandas as pd

# Filter & prepare (exactly as in your plotting code)
padj_thresh = 0.05
lfc_thresh  = 0.25
cluster     = "016 CA1-ProS Glut"
comps = [
    'Sham-GFP_vs_TBI-GFP',
    'TBI-GFP_vs_TBI-VEGFC',
    'Sham-GFP_vs_Sham-VEGFC'
]

df = (
    final_wilcoxon_results
    .query("cell_type == @cluster and comparison in @comps")
    .copy()
)
df['neglog10_p'] = -np.log10(df['pvals_adj'] + 1e-300)

# Mark significance
df['sig'] = np.where(
    (df['pvals_adj'] < padj_thresh) & (df['log2fc'].abs() > lfc_thresh),
    np.where(df['log2fc'] > 0, 'up', 'down'),
    'ns'
)

# Select columns and sort
table_cols = ['comparison', 'gene', 'log2fc', 'pvals_adj', 'sig']
table = (
    df
    .sort_values(['comparison', 'sig', 'neglog10_p'], ascending=[True, False, False])
    [table_cols]
)

# Show the top 10 rows per comparison
top_table = (
    table
    .groupby('comparison')
    .head(20)
    .reset_index(drop=True)
)

# Display
from IPython.display import display
display(top_table)

In [None]:
genes = ['Arpp21', 'R3hdm1', 'Rorb', 'Cux1', 'Cux2', 'Brinp3', 'Mef2c', 'Zbtb20']

sc.pl.matrixplot(adata_subset, 
                 var_names = genes, 
                 groupby = 'group', 
                 standard_scale = 'var')

In [None]:
sc.pl.dotplot(adata_subset, 
                 var_names = genes, 
                 groupby = 'group', 
                 standard_scale = 'var')

In [None]:
final_wilcoxon_results

In [None]:
#df_check[(df_check['log2fc'] >= log2fc_threshold)]

In [None]:
# PARAMETERS
log2fc_threshold = 0.25
qval_threshold = 0.05

# List of comparisons
comparisons = final_wilcoxon_results['comparison'].unique()
cell_types = final_wilcoxon_results['cell_type'].unique()

# Initialize list to collect tallies
tally_list = []

# Loop over comparisons and cell types
for comp in comparisons:
    for cell_type in cell_types:
        df_sub = final_wilcoxon_results[
            (final_wilcoxon_results['comparison'] == comp) &
            (final_wilcoxon_results['cell_type'] == cell_type) &
            (final_wilcoxon_results['pvals_adj'] < qval_threshold)
        ]
        
        # Tally up/down based on log2fc
        n_up = (df_sub['log2fc'] > log2fc_threshold).sum()
        n_down = (df_sub['log2fc'] < -log2fc_threshold).sum()
        
        # Save result
        tally_list.append({
            'comparison': comp,
            'cell_type': cell_type,
            'upregulated': n_up,
            'downregulated': n_down
        })

# Convert to dataframe
tally_df = pd.DataFrame(tally_list)

# Display nicely
display(tally_df)

In [None]:
# Get all unique comparisons and cell_types:
all_comparisons = tally_df['comparison'].unique()
all_cell_types = tally_df['cell_type'].unique()

# Build complete index
import itertools
full_index = pd.MultiIndex.from_product(
    [all_comparisons, all_cell_types],
    names=['comparison', 'cell_type']
)

# Reindex the df
tally_df_complete = tally_df.set_index(['comparison', 'cell_type']).reindex(full_index).reset_index()

# Fill missing values with 0
tally_df_complete['upregulated'] = tally_df_complete['upregulated'].fillna(0).astype(int)
tally_df_complete['downregulated'] = tally_df_complete['downregulated'].fillna(0).astype(int)

# plot heatmaps

column_order = [
    'Sham-GFP_vs_TBI-GFP',
    'TBI-GFP_vs_TBI-VEGFC',
    'Sham-GFP_vs_Sham-VEGFC'
]

# up
heatmap_up = tally_df_complete.pivot(index='cell_type', columns='comparison', values='upregulated')
heatmap_up = heatmap_up.reindex(columns=column_order)

plt.figure(figsize=(5, 5))
sns.heatmap(
    heatmap_up,
    annot=True,
    cmap='rocket_r',
    fmt='g'
)
plt.title('Upregulated genes')
plt.ylabel('')
plt.xlabel('')
plt.tight_layout()
plt.show()

# down
heatmap_down = tally_df_complete.pivot(index='cell_type', columns='comparison', values='downregulated')
heatmap_down = heatmap_down.reindex(columns=column_order)

plt.figure(figsize=(5, 5))
sns.heatmap(
    heatmap_down,
    annot=True,
    cmap='rocket_r',
    fmt='g'
)
plt.title('Downregulated genes')
plt.ylabel('')
plt.xlabel('')
plt.tight_layout()
plt.show()

# up + down
tally_df_complete['total_DE_genes'] = tally_df_complete['upregulated'] + tally_df_complete['downregulated']

heatmap_total = tally_df_complete.pivot(index='cell_type', columns='comparison', values='total_DE_genes')
heatmap_total = heatmap_total.reindex(columns=column_order)

plt.figure(figsize=(5, 5))
sns.heatmap(
    heatmap_total,
    annot=True,
    cmap='rocket_r',
    fmt='g'
)
plt.title('Total DE genes (Up + Down)')
plt.ylabel('')
plt.xlabel('')
plt.tight_layout()
plt.show()

In [None]:
# up
heatmap_up = tally_df_complete.pivot(index='cell_type', columns='comparison', values='upregulated')
heatmap_up = heatmap_up.reindex(columns=column_order)

plt.figure(figsize=(5, 5))
sns.heatmap(
    heatmap_up,
    annot=True,
    cmap='rocket_r',
    fmt='g'
)
plt.title('Upregulated genes')
plt.ylabel('')
plt.xlabel('')
plt.tight_layout()
plt.savefig(os.path.join(fig_dir, "heatmap_upregulated.pdf"))
plt.show()

# Down
heatmap_down = tally_df_complete.pivot(index='cell_type', columns='comparison', values='downregulated')
heatmap_down = heatmap_down.reindex(columns=column_order)

plt.figure(figsize=(5, 5))
sns.heatmap(
    heatmap_down,
    annot=True,
    cmap='rocket_r',
    fmt='g'
)
plt.title('Downregulated genes')
plt.ylabel('')
plt.xlabel('')
plt.tight_layout()
plt.savefig(os.path.join(fig_dir, "heatmap_downregulated.pdf"))
plt.show()

# Up + down
tally_df_complete['total_DE_genes'] = tally_df_complete['upregulated'] + tally_df_complete['downregulated']

heatmap_total = tally_df_complete.pivot(index='cell_type', columns='comparison', values='total_DE_genes')
heatmap_total = heatmap_total.reindex(columns=column_order)

plt.figure(figsize=(5, 5))
sns.heatmap(
    heatmap_total,
    annot=True,
    cmap='rocket_r',
    fmt='g'
)
plt.title('Total DE genes (Up + Down)')
plt.ylabel('')
plt.xlabel('')
plt.tight_layout()
plt.savefig(os.path.join(fig_dir, "heatmap_total_DEGs.pdf"))
plt.show()

In [None]:
barplot_total_df = tally_df_complete[['comparison', 'cell_type', 'total_DE_genes']].copy()

# 2. Plot total DEGs
plt.figure(figsize=(3, 2))

sns.barplot(
    data=barplot_total_df,
    y='cell_type',
    x='total_DE_genes',
    hue='comparison',
    dodge=True,
    ci=None,
    palette='husl' 
)

plt.xlabel('Total # of DE Genes (Up + Down)')
plt.ylabel('')
plt.title('# DEGs per Cell Type Comparison')

plt.legend(title='Comparison', frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()
export_path = fig_dir / "DEGs_barplot.pdf"
fig.savefig(export_path, bbox_inches='tight')
plt.show()