In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad

from sklearn.preprocessing import SplineTransformer, StandardScaler
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import hdbscan

import matplotlib.pyplot as plt
import seaborn as sns

from spida.utilities._ad_utils import normalize_adata

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 8
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path)
adata

# Functions

## downsample

In [None]:
def downsample_by_celltype_and_spatial(
    meta,
    cell_type_col="cell_type",
    spatial_col="spatial_score",
    max_cells_per_type=3000,
    n_bins=20,
    donor_col=None,
    random_state=0,
):
    """
    Returns: index of selected cells (subset of meta.index).
    
    meta: DataFrame with at least [cell_type_col, spatial_col]
    """
    rng = np.random.default_rng(random_state)
    selected_idx = []

    for ct, df_ct in meta.groupby(cell_type_col):
        n_ct = len(df_ct)
        if n_ct <= max_cells_per_type:
            selected_idx.append(df_ct.index)
            continue

        # Bin spatial score into quantile bins
        # Use rank to be robust to ties
        q = pd.qcut(df_ct[spatial_col].rank(method="first"),
                    q=n_bins, labels=False, duplicates="drop")
        df_ct = df_ct.assign(_bin=q)

        target_per_bin = max_cells_per_type // df_ct["_bin"].nunique()

        chosen = []
        for b, df_bin in df_ct.groupby("_bin"):
            if len(df_bin) <= target_per_bin:
                chosen.append(df_bin.index)
            else:
                # Optionally stratify by donor within each bin
                if donor_col is not None and donor_col in df_bin.columns:
                    # sample donors proportional to bin size
                    for donor, df_d in df_bin.groupby(donor_col):
                        n_d = len(df_d)
                        frac = n_d / len(df_bin)
                        n_sample = int(round(frac * target_per_bin))
                        if n_sample > 0:
                            chosen_idx = rng.choice(df_d.index, size=n_sample, replace=False)
                            chosen.append(chosen_idx)
                else:
                    chosen_idx = rng.choice(df_bin.index, size=target_per_bin, replace=False)
                    chosen.append(chosen_idx)

        chosen = np.concatenate(chosen)
        # in case we slightly overshoot/undershoot due to rounding
        if len(chosen) > max_cells_per_type:
            chosen = rng.choice(chosen, size=max_cells_per_type, replace=False)

        selected_idx.append(chosen)

    selected_idx = np.concatenate(selected_idx)
    return selected_idx


## bsplines + clusters

In [None]:
def compute_smoothed_trajectories(expr_ct, scores_ct,
                                  n_knots=6, degree=3,
                                  grid_points=50):
    """
    expr_ct: DataFrame (cells x genes) for one cell type
    scores_ct: 1D array of spatial scores for those cells
    Returns:
        traj: array (genes x grid_points) of smoothed expression
        grid: 1D array of grid positions along the spatial axis
    """
    # Ensure arrays
    x = np.asarray(scores_ct).reshape(-1, 1)  # (n_cells, 1)
    Y = expr_ct.values                         # (n_cells, n_genes)

    # Build spline basis on the cell scores
    spline = SplineTransformer(
        n_knots=n_knots,
        degree=degree,
        include_bias=True
    )
    B = spline.fit_transform(x)               # (n_cells, n_basis)

    # Fit regression Y ~ B for all genes: B * C ≈ Y
    # Solve for C (n_basis x n_genes) using least squares
    C, _, _, _ = np.linalg.lstsq(B, Y, rcond=None)

    # Evaluate spline basis on a regular grid along [-1, 1]
    grid = np.linspace(-1, 1, grid_points).reshape(-1, 1)  # (grid_points, 1)
    B_grid = spline.transform(grid)                        # (grid_points, n_basis)

    # Predicted smoothed trajectories on the grid
    Y_hat = B_grid @ C                                     # (grid_points, n_genes)

    # Reorient to (genes x grid_points)
    traj = Y_hat.T

    return traj, grid.ravel()


In [None]:
def cluster_gene_trajectories(expr, meta,
                              cell_type,
                              n_clusters=6,
                              min_cells_pct=0.1,
                              n_knots=6,
                              degree=3,
                              grid_points=50,
                              cell_type_col = "cell_type",
                              axes = "spatial_score",
                              scale_trajectories=True,
                              random_state=0):
    """
    Clusters genes for a given cell type based on their smoothed trajectories.

    Returns:
        cluster_labels: pd.Series indexed by gene name, with cluster ID (0..K-1)
        grid: 1D np.array of grid positions
        traj: 2D np.array (genes x grid_points) of smoothed, (optionally) scaled trajectories
    """
    # Subset cells of this cell type
    idx = meta[cell_type_col] == cell_type
    expr_ct = expr.loc[idx]
    scores_ct = meta.loc[idx, axes]

    # Optionally filter out genes with too few nonzeros / low variance
    # Example: keep genes expressed in at least 10% of cells in this cell type
    detect = (expr_ct > 0).mean(axis=0)
    keep_genes = detect[detect >= min_cells_pct].index
    expr_ct = expr_ct[keep_genes]

    # Compute smoothed trajectories
    traj, grid = compute_smoothed_trajectories(
        expr_ct, scores_ct,
        n_knots=n_knots,
        degree=degree,
        grid_points=grid_points
    )

    # Optionally standardize each gene's trajectory (shape-based clustering)
    if scale_trajectories:
        scaler = StandardScaler(with_mean=True, with_std=True)
        traj_scaled = scaler.fit_transform(traj.T).T   # (genes x grid_points)
    else:
        traj_scaled = traj

    # Cluster trajectories
    km = KMeans(n_clusters=n_clusters,
                random_state=random_state,
                n_init='auto')
    labels = km.fit_predict(traj_scaled)

    cluster_labels = pd.Series(labels, index=expr_ct.columns, name='cluster')

    return cluster_labels, grid, traj_scaled


## Runners

In [None]:
def cluster_gene_trajectories_hdbscan(
    expr,
    meta,
    cell_type,
    cell_type_col="cell_type",
    axes="spatial_score",
    n_knots=6,
    degree=3,
    grid_points=50,
    scale_trajectories=True,
    detection_min_frac=0.05,
    n_pcs=10,
    min_cluster_size=20,
    min_samples=None,
    metric="euclidean",
    cluster_selection_epsilon=0.0,
    random_state=0,
):
    """
    Cluster gene trajectories for a given cell type using HDBSCAN.

    Parameters
    ----------
    expr : pd.DataFrame
        Cells x genes expression matrix (normalized/log-transformed).
    meta : pd.DataFrame
        Metadata with at least ['cell_type', 'spatial_score'] columns,
        index aligned to expr.
    cell_type : str
        Cell type to subset.
    n_knots, degree, grid_points : int
        Spline parameters as before (B-spline flexibility and grid density).
    scale_trajectories : bool
        If True, z-score each gene's trajectory across the grid
        (shape-based clustering).
    detection_min_frac : float
        Minimum fraction of cells in this cell type that must express a gene
        (expr > 0) for the gene to be kept.
    n_pcs : int or None
        Number of principal components to compute from trajectories before
        clustering. If None, cluster directly in trajectory space.
    min_cluster_size : int
        HDBSCAN minimum cluster size (controls granularity).
    min_samples : int or None
        HDBSCAN min_samples (if None, defaults to min_cluster_size internally).
    metric : str
        Distance metric for HDBSCAN (e.g. 'euclidean', 'manhattan', 'cosine').
    cluster_selection_epsilon : float
        Epsilon parameter for more fine-grained cluster selection (usually 0.0).
    random_state : int
        Reproducibility for PCA.

    Returns
    -------
    cluster_labels : pd.Series
        Index = gene names, values = cluster ID (int, with -1 for noise).
    grid : np.ndarray
        1D array of grid positions along the spatial axis.
    traj_scaled : np.ndarray
        (genes x grid_points) array of trajectories used for clustering
        (after optional scaling).
    clusterer : hdbscan.HDBSCAN
        The fitted HDBSCAN object (contains probabilities, outlier scores, etc.).
    """

    # 1. Subset this cell type
    idx = meta[cell_type_col] == cell_type
    expr_ct = expr.loc[idx]
    scores_ct = meta.loc[idx, axes]

    # 2. Filter low-detection genes (avoid genes with almost no signal)
    detect = (expr_ct > 0).mean(axis=0)
    keep_genes = detect[detect >= detection_min_frac].index
    expr_ct = expr_ct[keep_genes]

    # 3. Compute smoothed trajectories via B-splines
    traj, grid = compute_smoothed_trajectories(
        expr_ct,
        scores_ct,
        n_knots=n_knots,
        degree=degree,
        grid_points=grid_points,
    )
    # traj: (n_genes x grid_points)

    # 4. Optional per-gene z-scoring across the grid (shape-based clustering)
    if scale_trajectories:
        scaler = StandardScaler(with_mean=True, with_std=True)
        traj_scaled = scaler.fit_transform(traj.T).T
    else:
        traj_scaled = traj

    # 5. Optional dimensionality reduction before HDBSCAN
    if n_pcs is not None and n_pcs < traj_scaled.shape[1]:
        pca = PCA(n_components=n_pcs, random_state=random_state)
        X = pca.fit_transform(traj_scaled)  # (n_genes x n_pcs)
    else:
        X = traj_scaled  # (n_genes x grid_points)

    # 6. HDBSCAN clustering
    clusterer = hdbscan.HDBSCAN(
        min_cluster_size=min_cluster_size,
        min_samples=min_samples,
        metric=metric,
        cluster_selection_epsilon=cluster_selection_epsilon,
        core_dist_n_jobs=1,  # set >1 if you want multithreading
    )

    labels = clusterer.fit_predict(X)  # array (n_genes,), -1 = noise

    # 7. Wrap labels into a Series indexed by gene
    cluster_labels = pd.Series(labels, index=expr_ct.columns, name="cluster")

    return cluster_labels, grid, traj_scaled, clusterer


In [None]:
def plot_cluster_trajectories_kmeans(cluster_labels, grid, traj,
                              max_genes_per_cluster=50):
    """
    Quick visualization: show smoothed trajectories and cluster averages.
    """
    n_clusters = cluster_labels.nunique()
    genes = cluster_labels.index
    labels = cluster_labels.values

    fig, axes = plt.subplots(n_clusters, 1, figsize=(6, 2*n_clusters), sharex=True)
    if n_clusters == 1:
        axes = [axes]

    for k in range(n_clusters):
        ax = axes[k]
        mask = labels == k
        gene_idxs = np.where(mask)[0]

        # Optionally subsample genes if too many
        if len(gene_idxs) > max_genes_per_cluster:
            gene_idxs = np.random.choice(gene_idxs, max_genes_per_cluster, replace=False)

        # Plot individual trajectories (light)
        for idx in gene_idxs:
            ax.plot(grid, traj[idx, :], alpha=0.2)

        # Plot cluster mean
        mean_traj = traj[mask, :].mean(axis=0)
        ax.plot(grid, mean_traj, linewidth=2)

        ax.set_title(f"Cluster {k} (n={mask.sum()})")

    axes[-1].set_xlabel("Spatial score")
    plt.tight_layout()
    plt.show()


## Plotting

In [None]:
def plot_cluster_trajectories(cluster_labels, grid, traj,
                              max_genes_per_cluster=50):
    """
    Quick visualization: show smoothed trajectories and cluster averages.
    """
    n_clusters = cluster_labels.nunique()
    genes = cluster_labels.index
    labels = cluster_labels.values

    fig, axes = plt.subplots(n_clusters, 1, figsize=(6, 2*n_clusters), sharex=True)
    if n_clusters == 1:
        axes = [axes]

    for k in range(n_clusters):
        ax = axes[k]
        mask = labels == k
        gene_idxs = np.where(mask)[0]

        # Optionally subsample genes if too many
        if len(gene_idxs) > max_genes_per_cluster:
            gene_idxs = np.random.choice(gene_idxs, max_genes_per_cluster, replace=False)

        # Plot individual trajectories (light)
        for idx in gene_idxs:
            ax.plot(grid, traj[idx, :], alpha=0.2)

        # Plot cluster mean
        mean_traj = traj[mask, :].mean(axis=0)
        ax.plot(grid, mean_traj, linewidth=2)

        ax.set_title(f"Cluster {k} (n={mask.sum()})")

    axes[-1].set_xlabel("Spatial score")
    plt.tight_layout()
    plt.show()


In [None]:
def plot_cluster_mean_sd(cluster_labels, grid, traj, 
                         figsize=(6, 3), 
                         alpha_fill=0.25,
                         linewidth=2):
    """
    Plot mean ± SD trajectories for each cluster.

    Parameters
    ----------
    cluster_labels : pd.Series
        Index = genes, values = cluster ID (int)
    grid : np.ndarray, shape (grid_points,)
        Spatial grid on which trajectories are evaluated
    traj : np.ndarray, shape (genes, grid_points)
        Smoothed trajectories for each gene
    figsize : tuple
        Size of each subplot
    alpha_fill : float
        Transparency for SD ribbon
    linewidth : float
        Thickness of mean line
    """

    cluster_ids = np.sort(cluster_labels.unique())
    n_clusters = len(cluster_ids)

    fig, ax = plt.subplots(1, 1, figsize=(figsize[0], figsize[1]),
                             sharex=True)

    for i, cid in enumerate(cluster_ids):

        mask = (cluster_labels == cid).values
        cluster_traj = traj[mask, :]  # subset (n_genes_in_cluster x grid_points)

        # Compute mean and SD
        mean_traj = cluster_traj.mean(axis=0)
        sd_traj = cluster_traj.std(axis=0)

        # Plot mean
        ax.plot(grid, mean_traj, linewidth=linewidth)

        # Plot SD band
        ax.fill_between(grid,
                        mean_traj - sd_traj,
                        mean_traj + sd_traj,
                        alpha=alpha_fill)

        ax.set_title(f"Cluster {cid} (n = {mask.sum()} genes)")
        ax.set_ylabel("Expression")

    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    ax.set_xlabel("Spatial axis")
    plt.tight_layout()
    plt.show()


In [None]:
def plot_clusters_mean_sem(
    cluster_labels,
    grid,
    traj,
    figsize=(7, 5),
    alpha_fill=0.25,
    linewidth=2,
    cmap="tab10"
):
    """
    Plot all cluster mean trajectories on one combined axis,
    with mean +/- SEM shading and automatically assigned colors.

    Parameters
    ----------
    cluster_labels : pd.Series
        Index = genes, values = cluster ID
    grid : np.ndarray, shape (grid_points,)
        Spatial grid for trajectory evaluation
    traj : np.ndarray, shape (genes, grid_points)
        Trajectories for each gene
    figsize : tuple
        Figure size
    alpha_fill : float
        Transparency for the SEM band
    linewidth : float
        Thickness of the mean line
    """

    fig, ax = plt.subplots(figsize=figsize)

    cluster_ids = np.sort(cluster_labels.unique())
    cmap = plt.get_cmap(cmap)   # Automatically provides distinct colors

    for i, cid in enumerate(cluster_ids):
        mask = (cluster_labels == cid).values
        cluster_traj = traj[mask, :]

        # Mean and standard error
        mean_traj = cluster_traj.mean(axis=0)
        sem_traj = cluster_traj.std(axis=0) / np.sqrt(cluster_traj.shape[0])

        color = cmap(i % 10)

        # Plot mean
        ax.plot(grid, mean_traj, color=color, linewidth=linewidth, label=f"Cluster {cid}")

        # Plot SEM band
        ax.fill_between(
            grid,
            mean_traj - sem_traj,
            mean_traj + sem_traj,
            color=color,
            alpha=alpha_fill
        )

    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    ax.set_xlabel("Spatial axis")
    ax.set_ylabel("Expression (smoothed)")
    ax.set_title("Cluster Trajectories (Mean ± SEM)")
    ax.legend(title="Clusters", loc="best")

    plt.tight_layout()
    plt.show()


# Run

## HDBSCAN

In [None]:
use_layer = "counts"
cell_type = "CN ST18 GABA"  # example
use_br = ['CAH', 'CAB', 'NAC', 'PU']
at = adata[(adata.obs['Subclass'] == cell_type) & (~adata.obs['MS_NORM'].isna()) & (adata.obs['brain_region'].isin(use_br))]
at.X = at.layers[use_layer].copy()
normalize_adata(at)
expr = pd.DataFrame(at.X.toarray(), columns=at.var_names, index=at.obs_names)
meta = at.obs.copy()

sel = downsample_by_celltype_and_spatial(meta, max_cells_per_type=3000, cell_type_col="Subclass", spatial_col="MS_NORM")
expr_ds = expr.loc[sel]
meta_ds = meta.loc[sel]

In [None]:
cluster_labels, grid, traj_scaled, clusterer = cluster_gene_trajectories_hdbscan(
    expr_ds,
    meta_ds,
    cell_type=cell_type,
    cell_type_col="Subclass",
    axes="MS_NORM",
    n_knots=4,
    degree=3,
    grid_points=50,
    scale_trajectories=True,
    detection_min_frac=0.1,
    n_pcs=10,
    min_cluster_size=10,
    min_samples=None,   # defaults to min_cluster_size internally
    metric="manhattan",
)

In [None]:
cluster_labels.value_counts()

In [None]:
# Drop noise genes (label = -1) for plotting
# non_noise = cluster_labels != -1
# cluster_labels_nn = cluster_labels[non_noise]
# traj_nn = traj_scaled[non_noise.values, :]

plot_clusters_mean_sem(cluster_labels, grid, traj_scaled)

In [None]:
# MSNs
cluster_labels['PDYN'], cluster_labels['KIRREL3'], cluster_labels['BACH2']

In [None]:
cluster_labels['PDYN'], cluster_labels['KIRREL3'], cluster_labels['GALNT17'], cluster_labels['GLP1R'],  cluster_labels['RASGRF2']

In [None]:
plot_cluster_mean_sd(cluster_labels, grid, traj_scaled)

In [None]:
cluster_labels + 1

In [None]:
plot_cluster_trajectories(cluster_labels+1, grid, traj_scaled)

## KMEANS

In [None]:
labels, grid, traj = cluster_gene_trajectories(
    expr_ds, meta_ds, cell_type, cell_type_col="Subclass", axes="MS_NORM",
    n_clusters=6, min_cells_pct=0.2, n_knots=4, degree=3, grid_points=50,
    scale_trajectories=True
)

In [None]:
plot_cluster_mean_sd(labels, grid, traj)

In [None]:
plot_clusters_mean_sem(labels, grid, traj)

In [None]:
labels[labels == 1].index

In [None]:
labels[labels == 5].index

In [None]:
labels['PDYN'], labels['KIRREL3'], labels['GALNT17'], labels['GLP1R'],  labels['RASGRF2']

In [None]:
plot_cluster_trajectories(labels, grid, traj)