> **Note:** Run this notebook from the cloned `SummerSchool` repository using the `SummerSchool` conda environment described in the [README](README.md).

# Dimensionality Reduction & Clustering for Calcium Imaging

Modern calcium imaging experiments record population activity from hundreds of neurons at millisecond resolution. The raw data live in a very high-dimensional space where each neuron contributes one axis, and it is sometimes difficult to reason about population dynamics or coordinated patterns directly from the raw traces. In this notebook we will follow a practical workflow that reduces the dimensionality of the data, discovers structure through clustering, and interprets the resulting groups in neuroscientifically meaningful terms.

Along the way we will emphasize not just *how* to run the code, but *why* each step matters. Dimensionality reduction provides low-dimensional views that distill population activity into latent factors, while clustering reveals candidate functional ensembles. Both are indispensable tools when bridging raw fluorescence traces and hypotheses about circuit organization.

## Learning objectives
1. Motivate dimensionality reduction for neural population data and understand what information it preserves or discards.
2. Compare linear methods such as PCA with non-linear manifold learning techniques (t-SNE, UMAP) and know when to reach for each.
3. Apply correlation- and embedding-based clustering strategies to organize neurons into putative functional groups.
4. Evaluate cluster stability and biological plausibility so that downstream interpretations rest on solid evidence.

---

In [None]:
# Dimensionality Reduction & Clustering for Calcium Imaging
# Simplified Lecture Version

# --- Standard Library ---
import warnings

# --- Scientific Libraries ---
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns
from scipy.stats import zscore
from scipy.io import loadmat
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist, squareform
from umap import UMAP


# --- Machine Learning ---
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import AgglomerativeClustering
from hdbscan import HDBSCAN

warnings.filterwarnings("ignore")

# Set style
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman", "Times New Roman", "DejaVu Serif"],
    "text.usetex": False,
    "axes.labelsize": 13,
    "axes.titlesize": 15,
    "legend.fontsize": 11,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "axes.linewidth": 0.5,
    "lines.linewidth": 0.5,
})
sns.set_theme(style="white", context="talk", palette="Set2")

print("=== DIMENSIONALITY REDUCTION & CLUSTERING FOR CALCIUM IMAGING ===")


# 1. INTRODUCTION & DATA LOADING

### Why start with dimensionality reduction?
Calcium imaging datasets are matrices of shape `neurons × time`. Each neuron defines a dimension, so even a modest recording of 500 cells lives in a 500-D space. Visualizing or reasoning about such spaces directly is not possible, and the curse of dimensionality makes distances noisy and unintuitive. Dimensionality reduction seeks a low-dimensional latent space that captures the dominant co-activation patterns, providing both intuition and a staging ground for downstream analyses.

### Workflow for this notebook
1. **Load and align the data** – bring neural activity and metadata into a tidy format, handling units and time axes carefully.
2. **Preprocess** – detrend, normalize, or smooth the traces to reduce measurement artifacts that could dominate the variance.
3. **Reduce dimensionality** – extract latent axes that summarize the dominant shared dynamics across neurons.
4. **Cluster** – group neurons with similar signatures to propose functional assemblies or cell-type groupings.
5. **Interpret** – inspect the latent spaces and cluster assignments to generate hypotheses about circuit organization.

Each step feeds the next: poor preprocessing contaminates PCA, and noisy embeddings derail clustering. Treat the pipeline holistically rather than a sequence of independent buttons to press.

---

In [None]:
from scipy.io import loadmat

print("1. Loading Sample Data...")

np.random.seed(42)

# Selector for choosing the dataset
dataset_choice = "population"  # Change to "dendrites" to use the other dataset

if dataset_choice == "dendrites":
    # Load the .mat file for the tree dataset. these were variables directly stored in the .mat file
    mat_path = "Data/M1_dendritic_tree_data.mat"  # Update this path to the location of your .mat file
    mat_data = loadmat(mat_path)

    # Extract the variables
    behaviour = mat_data['behaviour'].squeeze()  # Squeeze to remove single-dimensional entries
    result = mat_data['result']
    tax = np.squeeze(mat_data['tax'].T)
    coords = mat_data['coords']

elif dataset_choice == "population":
    # Load the .mat file for the M1 dataset. # these were variables stored in a structure in the .mat file
    mat_path = "Data/M1_population_data.mat"  # Data exported from matlab
    mat_data = loadmat(mat_path)

    # Extract the variables
    mat_data = mat_data['mat_data']
    behaviour = mat_data['behaviour'][0, 0].squeeze()  # Indexing to access the nested structure and squeeze to remove single-dimensional entries
    result = mat_data['result'][0, 0]  # Indexing to access the nested structure
    tax = mat_data['tax'][0, 0].squeeze()  # Indexing to access the nested structure and squeeze to remove single-dimensional entries

# Get the dimensions
T, N = len(behaviour), result.shape[0]

# Ensure result shape is (Individuals, Timepoints)
print(result.shape) # N x T
print(tax.shape) # T x 1
print(behaviour.shape) # T x 1

# Align with existing variable names in this notebook
behavior = behaviour
calcium_data = result
time_axis = tax

# Replace non-finite values with NaN and remove invalid entries
calcium_data = np.where(np.isfinite(calcium_data), calcium_data, np.nan)
behavior = np.where(np.isfinite(behavior), behavior, np.nan)
initial_n_neurons, initial_n_timepoints = calcium_data.shape
valid_neurons = ~np.all(np.isnan(calcium_data), axis=1)
calcium_data = calcium_data[valid_neurons]
removed_neurons = initial_n_neurons - calcium_data.shape[0]

valid_timepoints = ~np.isnan(behavior)
valid_timepoints &= ~np.any(np.isnan(calcium_data), axis=0)
calcium_data = calcium_data[:, valid_timepoints]
behavior = behavior[valid_timepoints]
time_axis = time_axis[valid_timepoints]
removed_timepoints = initial_n_timepoints - calcium_data.shape[1]

n_neurons, n_timepoints = calcium_data.shape
print(f"Removed {removed_neurons} neurons and {removed_timepoints} timepoints with NaN")
print(f"Dataset cleaned: {n_neurons} neurons, {n_timepoints} timepoints")

assert not np.isnan(calcium_data).any()
assert not np.isnan(behavior).any()

# 2. PRINCIPAL COMPONENT ANALYSIS (PCA)

### Conceptual overview
Principal Component Analysis (PCA) finds orthogonal directions (principal components) that capture maximal variance in the data. Projecting the neural activity matrix onto the first few components yields a low-dimensional representation that preserves global covariance structure. Because PCA is linear and deterministic, it is fast, interpretable, and often the first lens through which we inspect high-dimensional recordings.

### Why PCA helps with calcium imaging
- **Noise reduction**: By discarding components with tiny variance, PCA filters out noise that is uncorrelated across neurons, yielding cleaner latent activity traces.
- **Population motifs**: Components often correspond to co-activated neuronal ensembles or behavioral epochs, helping uncover coordinated circuit motifs.
- **Downstream convenience**: PCA embeddings serve as inputs for visualization, clustering, or regression models that prefer compact feature spaces.

### Interpreting PCA output
- **Explained variance ratio** quantifies the fraction of total variance captured by each component. A sharp drop indicates an intrinsic low-dimensional structure.
- **Loadings (eigenvectors)** tell you how strongly each neuron contributes to a component. Plotting loadings alongside anatomical locations can reveal spatial structure.
- **Scores (projected data)** are time courses in the latent space that you can relate to behavior or stimuli.

Remember that PCA assumes linear relationships and is sensitive to scaling. Standardizing or otherwise normalizing the traces before PCA is essential if neurons have widely different amplitudes.

---

In [None]:
if dataset_choice == "dendrites":
    print("\nDataset: Dendritic tree activity during behavior")
    # remove global median of all cells from each neuron
    calcium_norm = calcium_data - np.median(calcium_data, axis=0)
else:
    print("\nDataset: Population activity during behavior")
    # classic neuroanl standarization: ΔF/F0
    F0 = np.percentile(calcium_data, 20, axis=1, keepdims=True)
    calcium_norm = (calcium_data - F0) / F0 # ΔF/F0

# show media subtracted data
plt.figure(figsize=(12, 4))
plt.imshow(
    calcium_norm,
    aspect="auto",
    cmap="viridis",
    extent=[time_axis[0], time_axis[-1], 0, n_neurons],
    vmin=-2,
    vmax=2,
    interpolation='none',  # Prevent line blending/smoothing
    )
cbar = plt.colorbar(orientation="horizontal", pad=0.2, label="ΔF/F0 (median subtracted)", shrink=0.3)
cbar.outline.set_visible(False)
plt.xlabel("Time (s)")
plt.ylabel("Neuron #")
plt.title("Neural activity (median subtracted)")
sns.despine()
plt.grid(False)

print("\n3. Principal Component Analysis...")

# Apply PCA
pca = PCA(n_components=10)
pca_result = pca.fit_transform(calcium_norm)

# Figure 3: 3D PCA scatter plot
fig = plt.figure(figsize=(12, 5))

ax1 = fig.add_subplot(121, projection='3d')
scatter = ax1.scatter(pca_result[:, 0], pca_result[:, 1], pca_result[:, 2], 
                     c=range(n_neurons), cmap='Set2', s=40)
ax1.set_xlabel('PC1')
ax1.set_ylabel('PC2')
ax1.set_zlabel('PC3')
ax1.set_title('First 3 components capture main data variance')
sns.despine()
ax1.grid(False)

# Figure 4: Explained variance plot and cumulative explained variance
ax2 = fig.add_subplot(122)
ax2.plot(range(1, 11), pca.explained_variance_ratio_*100, 'bo-', linewidth=1, markersize=6)
ax2.set_xlabel('Principal Component')
ax2.set_ylabel('Explained Variance (%)')
ax2.set_title('How many components do we need?')

# Cumulative explained variance on right axis
ax2_right = ax2.twinx()
ax2_right.plot(range(1, 11), np.cumsum(pca.explained_variance_ratio_)*100, 'ro-', linewidth=1, markersize=6)
ax2_right.set_ylabel('Cumulative Explained Variance (%)', color='red')
ax2_right.tick_params(axis='y', labelcolor='red')

ax2.legend(['Explained Variance'], loc='upper left')
ax2_right.legend(['Cumulative Variance'], loc='upper right')
ax2.grid(False)
sns.despine(top=True,right=False)

# Cap to 100%
ax2_right.set_ylim(0, 105)

plt.tight_layout()
plt.show()

print(f"First 3 PCs explain {pca.explained_variance_ratio_[:3].sum()*100:.1f}% of variance")

# 3. NON-LINEAR DIMENSIONALITY REDUCTION

### When linear methods fall short
PCA excels when the data lie near a linear subspace, but neural activity often evolves on curved manifolds: think of trajectories that wrap around during cyclic behaviors or states that branch as animals switch tasks. Non-linear techniques attempt to preserve more nuanced geometry—local neighborhoods, geodesic distances, or global topology—that linear projections flatten.

### Two commonly used methods
- **t-SNE (t-distributed Stochastic Neighbor Embedding)** emphasizes local neighborhood preservation. It is excellent for visualizing clusters but can distort global distances. Perplexity controls the effective number of neighbors; values between 5 and 50 usually work, but tuning is encouraged.
- **UMAP (Uniform Manifold Approximation and Projection)** balances local and global structure, often running faster than t-SNE while producing stable embeddings. The `n_neighbors` parameter behaves similarly to perplexity, while `min_dist` adjusts how tightly points are packed.

### Practical guidance
- Always feed these methods a reasonably denoised representation (e.g., the top PCs) to avoid wasting effort on noise.
- Run multiple random seeds to check that qualitative structures persist; both methods involve randomness.
- Interpret relative positions, not absolute axes—rotations or reflections of the embedding are arbitrary.

Non-linear embeddings are powerful storytelling tools but should be complemented with quantitative analyses to avoid over-interpreting visualization artifacts.

---

In [None]:
# Improved t-SNE: run on first 10 PCA components, increase perplexity, and add explanation
print("\n4. Non-linear Dimensionality Reduction...")

# PCA preprocessing for t-SNE (denoising, speed)
pca_for_tsne = PCA(n_components=10)
pca_data = pca_for_tsne.fit_transform(calcium_norm)

# Try higher perplexity (20-40 typical for moderate datasets)
tsne = TSNE(n_components=3, perplexity=3, learning_rate=10, init='pca')
tsne_result = tsne.fit_transform(pca_data)

# Figure: Compare PCA vs improved t-SNE
plt.figure(figsize=(12, 5))
plt.subplot(1,2,1)
plt.scatter(pca_data[:, 0], pca_data[:, 1], c=range(n_neurons), cmap='Set2', s=40)
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('PCA (Linear)')
cbar = plt.colorbar()
cbar.outline.set_visible(False)
sns.despine()
plt.grid(False)

plt.subplot(1,2,2)
plt.scatter(tsne_result[:, 0], tsne_result[:, 1], c=range(n_neurons), cmap='Set2', s=40)
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.title('Improved t-SNE (Non-linear structure)')
cbar = plt.colorbar()
cbar.outline.set_visible(False)
sns.despine()
plt.grid(False)
plt.tight_layout()
plt.show()

# Comment: t-SNE may still show poor structure if the data is noisy or lacks clear clusters. Try different perplexity values or cluster numbers for better results.

Next, we build a “consensus” UMAP embedding by running UMAP multiple times with different random seeds, aligning the results, and aggregating them to emphasize structures that are stable across runs. The helper function consensus_umap_procrustes takes a feature matrix X (here, the PCA-preprocessed neural data), runs UMAP runs times and obtains an embedding Y each time. To make the embeddings comparable, each Y is centered and globally L2-normalized (subtract mean, divide by the Frobenius norm), removing translation and scale. The first embedding becomes the reference. Subsequent embeddings are orthogonally aligned to the reference via Procrustes: compute the SVD of $Y^T\,ref = U\Sigma V^T$ and rotate with $R = U V^T$, i.e., use $Y R$ to remove arbitrary rotations/reflections introduced by UMAP. After aligning all runs, the function stacks the embeddings and returns the pointwise median across runs, which is robust to outlier runs and highlights consistent geometry. The default n_neighbors=3 and min_dist=0.1 bias UMAP toward very local structure and tight clusters, which can be useful but may fragment data if the neighborhood is too small.

In [None]:
print("\nConsensus UMAP: Multiple runs to highlight stable structure")
# Consensus UMAP: run UMAP multiple times and visualize stable structure

def consensus_umap_procrustes(X, runs=20, n_neighbors=3, min_dist=0.1, seed=0, n_components=2):
    embs, ref = [], None
    for i in range(runs):
        print(f"Run {i+1}/{runs}")
        Y = UMAP(n_components=n_components, n_neighbors=n_neighbors,
                 min_dist=min_dist, random_state=seed+i).fit_transform(X)
        Y -= Y.mean(0); Y /= np.sqrt((Y**2).sum())      # center/normalize
        if ref is None: ref = Y
        U, _, Vt = np.linalg.svd(Y.T @ ref, full_matrices=False)
        embs.append(Y @ (U @ Vt))                       # orthogonal alignment
    return np.median(np.stack(embs), axis=0)

consensus_emb = consensus_umap_procrustes(pca_data)  # -> (n_samples, 2)


# Figure: Compare PCA vs improved UMAP
plt.figure(figsize=(12, 6))

# PCA plot
plt.subplot(1, 2, 1)
plt.scatter(pca_data[:, 0], pca_data[:, 1], c=range(n_neurons), cmap='Set2', s=40)
plt.xlabel('PCA 1')
plt.ylabel('PCA 2')
plt.title('PCA')
sns.despine()
plt.grid(False)

# t-SNE plot
plt.subplot(1, 2, 2)
plt.scatter(consensus_emb[:, 0], consensus_emb[:, 1], c=range(n_neurons), cmap='Set2', s=40)
plt.xlabel('Consensus UMAP 1')
plt.ylabel('Consensus UMAP 2')
plt.title('Consensus UMAP (average of multiple runs)')
cbar = plt.colorbar(label="Neuron #")
cbar.outline.set_visible(False)
sns.despine()
plt.grid(False)
plt.tight_layout()
plt.show()

# 4. CORRELATION ANALYSIS & CLUSTERING

### Why cluster neurons?
Clustering transforms a sea of individual traces into a handful of representative groups, making it easier to reason about circuit motifs and to link them with behavior. For population imaging, clusters can correspond to neurons that fire together, neurons driven by the same stimulus, or anatomically co-located cells.

### Two complementary perspectives
1. **Correlation-based clustering**: Operates directly on the activity matrix by comparing full time courses (e.g., Pearson correlation). Hierarchical agglomerative clustering is a natural fit: it builds a dendrogram that records how neurons merge as you relax the similarity threshold, and you can “cut” the tree at different heights to explore candidate cluster counts.
2. **Embedding-based clustering**: Works on low-dimensional coordinates (e.g., PCA, t-SNE, UMAP). Density-based methods such as HDBSCAN identify tight groups while flagging sparse points as noise, which is valuable when populations contain both structured ensembles and idiosyncratic neurons.

### Good practices
- Standardize or z-score traces before computing correlations so that differences in variance do not dominate similarity.
- Compare multiple linkage criteria (ward, average, complete) in hierarchical clustering; each encodes a different notion of cluster compactness.
- For HDBSCAN, examine both the labels and the membership strength probabilities to decide which neurons belong confidently to each cluster.

Clustering is exploratory by nature. Use it to generate hypotheses, then verify them through anatomical overlays, stimulus alignment, or follow-up experiments.

---

In [None]:
# 5. Improved Correlation Analysis & Clustering
print("\n5. Improved Correlation Analysis & Clustering...")

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import squareform

# --- Config ---
corr_kind = 'pearson'       # 'pearson' or 'spearman'
linkage_method = 'average'  # default for non-Euclidean option
use_ward = True             # False → produce Z_abs & Z_signed; True → use Ward (Euclidean) and produce Z

# --- Correlation (neurons x neurons) ---
if corr_kind == 'pearson':
    corr_matrix = np.corrcoef(calcium_norm)
else:
    corr_matrix = spearmanr(calcium_norm.T).correlation

# Clean + enforce symmetry & unit diagonal
corr_matrix = np.nan_to_num(np.clip(corr_matrix, -1.0, 1.0), nan=0.0, posinf=0.0, neginf=0.0)
corr_matrix = 0.5 * (corr_matrix + corr_matrix.T)
np.fill_diagonal(corr_matrix, 1.0)

# --- Distances from correlation ---
# Sign-invariant (anti-correlated = near)
dist_abs = 1.0 - np.abs(corr_matrix)
np.fill_diagonal(dist_abs, 0.0)
dist_abs = 0.5 * (dist_abs + dist_abs.T)

# Sign-sensitive (anti-correlated = far)
dist_signed = 1.0 - corr_matrix
np.fill_diagonal(dist_signed, 0.0)
dist_signed = 0.5 * (dist_signed + dist_signed.T)

# Euclidean-compatible (for 'ward'/'centroid'/'median')
dist_chord = np.sqrt(2.0 * (1.0 - corr_matrix))
np.fill_diagonal(dist_chord, 0.0)
dist_chord = 0.5 * (dist_chord + dist_chord.T)

# --- Linkage (if/else option) ---
if use_ward:
    # If you switch to Ward/centroid/median, use dist_chord:
    linkage_method = 'ward'
    Z = linkage(squareform(dist_chord, checks=True), method='ward')
else:
    Z_abs    = linkage(squareform(dist_abs,    checks=True), method=linkage_method)
    Z_signed = linkage(squareform(dist_signed, checks=True), method=linkage_method)



In [None]:
from sklearn.metrics import silhouette_score, silhouette_samples
from scipy.cluster.hierarchy import fcluster
import matplotlib.pyplot as plt
import matplotlib.cm as cm

def find_optimal_clusters_simple(Z, corr_matrix, max_k=15):
    """Find optimal clusters using silhouette analysis - most robust single method."""
    # Convert correlation to distance matrix (diagonal = 0)
    dist_matrix = 1 - np.abs(corr_matrix)
    np.fill_diagonal(dist_matrix, 0)
    
    # Calculate silhouette scores
    k_range = range(2, min(max_k+1, corr_matrix.shape[0]))
    scores = [silhouette_score(dist_matrix, fcluster(Z, k, criterion='maxclust'), metric='precomputed') 
              for k in k_range]
    
    optimal_k = k_range[np.argmax(scores)]
    
    # Create two plots: score vs k, and detailed silhouette plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot 1: Silhouette score vs number of clusters
    ax1.plot(k_range, scores, 'bo-', linewidth=2, markersize=8)
    ax1.axvline(optimal_k, color='red', linestyle='--', alpha=0.7, linewidth=2)
    ax1.set_title('Silhouette Analysis', fontsize=14)
    ax1.set_xlabel('Number of Clusters')
    ax1.set_ylabel('Average Silhouette Score')
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Detailed silhouette plot for optimal k
    optimal_clusters = fcluster(Z, optimal_k, criterion='maxclust')
    silhouette_vals = silhouette_samples(dist_matrix, optimal_clusters, metric='precomputed')
    
    y_lower = 10
    colors = cm.nipy_spectral(np.linspace(0, 1, optimal_k))
    
    for i, color in zip(range(1, optimal_k + 1), colors):
        cluster_silhouette_vals = silhouette_vals[optimal_clusters == i]
        cluster_silhouette_vals.sort()
        
        size_cluster_i = cluster_silhouette_vals.shape[0]
        y_upper = y_lower + size_cluster_i
        
        ax2.fill_betweenx(np.arange(y_lower, y_upper), 0, cluster_silhouette_vals,
                         facecolor=color, edgecolor=color, alpha=0.7)
        
        # Label clusters at center
        ax2.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
        y_lower = y_upper + 10
    
    avg_score = np.mean(silhouette_vals)
    ax2.axvline(x=avg_score, color="red", linestyle="--", linewidth=2)
    ax2.set_xlabel('Silhouette Coefficient Values')
    ax2.set_ylabel('Cluster Label')
    ax2.set_title(f'Silhouette Plot for {optimal_k} Clusters\n(Avg Score: {avg_score:.3f})')
    
    plt.tight_layout()
    plt.show()
    
    # Find corresponding dendrogram threshold
    merge_idx = len(Z) - optimal_k + 1
    threshold = Z[merge_idx, 2] if merge_idx < len(Z) else Z[-1, 2]
    
    return optimal_k, threshold, max(scores)

# Usage:
optimal_k, threshold, best_score = find_optimal_clusters_simple(Z, corr_matrix)
print(f"Optimal: {optimal_k} clusters | Threshold: {threshold:.3f} | Silhouette: {best_score:.3f}")

In [None]:

# --- Dendrogram(s) ---
if use_ward:
    # Single dendrogram for Ward
    color_threshold = threshold #* np.max(Z[:, 2])
    plt.figure(figsize=(6, 4))
    dendrogram(Z, no_labels=True, leaf_rotation=0, color_threshold=color_threshold)
    plt.title('Ward linkage (Euclidean chord distance)', fontsize=12, pad=10)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.tight_layout()
    plt.show()
else:
    # Two dendrograms: abs vs signed
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    color_threshold_abs = threshold #* np.max(Z_abs[:, 2])
    color_threshold_signed = threshold #* np.max(Z_signed[:, 2])

    dendrogram(Z_abs, no_labels=True, ax=axes[0], leaf_rotation=0, color_threshold=color_threshold_abs)
    axes[0].set_title('Absolute Distance (1 - |r|)', fontsize=12, pad=10)
    axes[0].spines['top'].set_visible(False); axes[0].spines['right'].set_visible(False)
    axes[0].tick_params(labelsize=10)

    dendrogram(Z_signed, no_labels=True, ax=axes[1], leaf_rotation=0, color_threshold=color_threshold_signed)
    axes[1].set_title('Signed Distance (1 - r)', fontsize=12, pad=10)
    axes[1].spines['top'].set_visible(False); axes[1].spines['right'].set_visible(False)
    axes[1].tick_params(labelsize=10)

    plt.tight_layout()
    plt.show()


In [None]:
# 6. Sorted matrices, cluster extraction, and overlays

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.cluster.hierarchy import fcluster, leaves_list

def contig_blocks(labels):
    """Yield (start, end, value) for contiguous runs in a 1D label array."""
    n = len(labels)
    if n == 0: 
        return
    start = 0
    cur = labels[0]
    for i in range(1, n):
        if labels[i] != cur:
            yield start, i, cur
            start = i
            cur = labels[i]
    yield start, n, cur

def add_cluster_overlays(ax, clusters_sorted, alpha=0.15):
    """Add translucent squares for contiguous cluster blocks on a heatmap."""
    unique_clusters = np.unique(clusters_sorted)
    palette = ['red','blue','green','orange','purple','brown','pink','gray','olive','cyan']
    color_map = {c: palette[i % len(palette)] for i, c in enumerate(unique_clusters)}
    for s, e, c in contig_blocks(clusters_sorted):
        size = e - s
        # outline
        rect = patches.Rectangle((s-0.5, s-0.5), size, size, 
                                 linewidth=2, edgecolor=color_map[c], facecolor='none', alpha=1.0)
        ax.add_patch(rect)
        # fill
        rect_fill = patches.Rectangle((s-0.5, s-0.5), size, size, 
                                      linewidth=0, facecolor=color_map[c], alpha=alpha)
        ax.add_patch(rect_fill)

if use_ward:
    # ----- Ward mode: one Z (Euclidean), one sorting, one cluster vector -----
    # Leaves / order
    sort_idx = leaves_list(Z)
    corr_sorted = corr_matrix[sort_idx][:, sort_idx]

    # Threshold & clusters (match first-cell convention: 70% of max merge height)
    color_threshold = threshold #* np.max(Z[:, 2])
    clusters = fcluster(Z, color_threshold, criterion='distance')
    clusters_sorted = clusters[sort_idx]

    # Plot original vs Ward-sorted with overlays
    fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))
    sns.heatmap(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, square=True,
                cbar_kws={"shrink": 0.8}, ax=axes[0], xticklabels=False, yticklabels=False)
    axes[0].set_title("Original Matrix", fontsize=12, pad=10)

    sns.heatmap(corr_sorted, cmap="RdBu_r", vmin=-1, vmax=1, square=True,
                cbar_kws={"shrink": 0.8}, ax=axes[1], xticklabels=False, yticklabels=False)
    add_cluster_overlays(axes[1], clusters_sorted, alpha=0.15)
    axes[1].set_title("Ward-sorted (chord distance) + Clusters", fontsize=12, pad=10)

    for ax in axes:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    plt.tight_layout()
    plt.show()

    # Stats & quick visuals
    print(f"\nUsing {corr_kind} correlation. Ward linkage.")
    print(f"Color threshold (Ward): {color_threshold:.3f}")
    uniq, counts = np.unique(clusters_sorted, return_counts=True)
    print(f"Ward clustering: {len(uniq)} clusters")
    print(f"Cluster sizes: {counts.tolist()}")
else:
    # ----- Non-Euclidean mode: abs & signed -----
    # Leaves / orders
    sort_idx_abs = leaves_list(Z_abs)
    sort_idx_signed = leaves_list(Z_signed)

    # Sorted matrices
    corr_sorted_abs = corr_matrix[sort_idx_abs][:, sort_idx_abs]
    corr_sorted_signed = corr_matrix[sort_idx_signed][:, sort_idx_signed]

    # Thresholds & clusters (70% of max height)
    color_threshold_abs = threshold #* np.max(Z_abs[:, 2])
    color_threshold_signed = threshold #* np.max(Z_signed[:, 2])
    clusters_abs = fcluster(Z_abs, color_threshold_abs, criterion='distance')[sort_idx_abs]
    clusters_signed = fcluster(Z_signed, color_threshold_signed, criterion='distance')[sort_idx_signed]

    # Heatmaps with overlays
    fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))

    sns.heatmap(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, square=True,
                cbar_kws={"shrink": 0.8}, ax=axes[0], xticklabels=False, yticklabels=False)
    axes[0].set_title("Original Matrix", fontsize=12, pad=10)

    sns.heatmap(corr_sorted_abs, cmap="RdBu_r", vmin=-1, vmax=1, square=True,
                cbar_kws={"shrink": 0.8}, ax=axes[1], xticklabels=False, yticklabels=False)
    add_cluster_overlays(axes[1], clusters_abs, alpha=0.15)
    axes[1].set_title("Sorted by |Correlation| + Clusters", fontsize=12, pad=10)

    sns.heatmap(corr_sorted_signed, cmap="RdBu_r", vmin=-1, vmax=1, square=True,
                cbar_kws={"shrink": 0.8}, ax=axes[2], xticklabels=False, yticklabels=False)
    add_cluster_overlays(axes[2], clusters_signed, alpha=0.15)
    axes[2].set_title("Sorted by Correlation + Clusters", fontsize=12, pad=10)

    for ax in axes:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    plt.tight_layout(); plt.show()

    # Stats
    print(f"\nUsing {corr_kind} correlation.")
    print(f"Color threshold (abs): {color_threshold_abs:.3f}")
    uniq_abs, cnt_abs = np.unique(clusters_abs, return_counts=True)
    print(f"Absolute distance clustering: {len(uniq_abs)} clusters")
    print(f"Cluster sizes (abs): {cnt_abs.tolist()}")

    print(f"\nColor threshold (signed): {color_threshold_signed:.3f}")
    uniq_sig, cnt_sig = np.unique(clusters_signed, return_counts=True)
    print(f"Signed distance clustering: {len(uniq_sig)} clusters")
    print(f"Cluster sizes (signed): {cnt_sig.tolist()}")


In [None]:
# HDBSCAN on UMAP coordinates + correlation reordered by UMAP clusters
print("\n5. Clustering (HDBSCAN on UMAP) and corr matrix reordered by clusters...")

# Run HDBSCAN on the consensus UMAP embedding
clusterer = HDBSCAN(min_cluster_size=5)
cluster_labels = clusterer.fit_predict(consensus_emb)

# Visualization: UMAP colored by HDBSCAN clusters
plt.figure(figsize=(8, 6))
scatter = plt.scatter(
    consensus_emb[:, 0], consensus_emb[:, 1],
    c=cluster_labels, cmap='tab10', alpha=0.8, s=30
)
plt.colorbar(scatter, label="Cluster", shrink=0.8)
plt.title('UMAP with HDBSCAN Clusters', fontsize=14, pad=15)
plt.xlabel('UMAP 1', fontsize=12)
plt.ylabel('UMAP 2', fontsize=12)
plt.tick_params(labelsize=10)
plt.tight_layout()
plt.show()

# Build correlation matrix and reorder by clusters
corr_matrix = np.corrcoef(calcium_norm)

# Reorder indices by cluster label (noise = -1 at end)
unique_labels = [c for c in np.unique(cluster_labels) if c != -1]
ordered_indices = []
for c in unique_labels:
    ordered_indices.append(np.where(cluster_labels == c)[0])
noise_indices = np.where(cluster_labels == -1)[0]
if noise_indices.size > 0:
    ordered_indices.append(noise_indices)
sort_idx = np.concatenate(ordered_indices) if len(ordered_indices) else np.arange(corr_matrix.shape[0])

corr_sorted = corr_matrix[sort_idx][:, sort_idx]

# Comparison: original vs cluster-sorted correlation matrices
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

sns.heatmap(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, square=True,
            cbar_kws={"shrink": 0.8}, ax=axes[0], 
            xticklabels=False, yticklabels=False)
axes[0].set_title("Original Matrix", fontsize=12, pad=10)

sns.heatmap(corr_sorted, cmap="RdBu_r", vmin=-1, vmax=1, square=True,
            cbar_kws={"shrink": 0.8}, ax=axes[1],
            xticklabels=False, yticklabels=False)  
axes[1].set_title("Sorted by HDBSCAN Clusters", fontsize=12, pad=10)

plt.tight_layout()
plt.show()

print(f"Found {len(unique_labels)} clusters, {sum(cluster_labels == -1)} noise points")

# 5. CLUSTER VALIDATION & INTERPRETATION

### Guarding against over-interpretation
Clustering algorithms will always partition the data—even pure noise can yield seemingly neat groups. Validation is therefore critical before we claim the existence of functional ensembles. Ask whether clusters are reproducible across methods, stable under resampling, and interpretable within the biological context.

### Validation toolbox
1. **Visual diagnostics**: Inspect embeddings colored by cluster labels. Coherent, well-separated clouds inspire more confidence than fuzzy gradients.
2. **Cross-method agreement**: Compare labels from correlation-based and embedding-based approaches (e.g., via adjusted Rand index). Convergence across methods suggests signal rather than artifacts.
3. **Temporal and behavioral profiles**: Compute average traces or stimulus-triggered responses per cluster. Distinct dynamics imply functional specialization.
4. **Null models**: Randomly shuffle time or neuron identities and rerun clustering to ensure the observed structure exceeds what noise would produce.

### From clusters to neuroscience insight
Once confident in cluster validity, map clusters back onto anatomy, cell-type markers, or behavioral epochs. Look for enrichment of known cell classes or alignment with behaviorally relevant times. Treat clusters as hypotheses: design follow-up experiments to test whether the grouped neurons indeed share synaptic inputs, genetic identity, or behavioral roles.

---

In [None]:
print("\n6. Cluster Validation...")

# Prepare both clustering labelings
labels_umap = cluster_labels  # HDBSCAN on UMAP
labels_hier = None

# Recover hierarchical labels by cutting dendrogram
try:
    from scipy.cluster.hierarchy import fcluster
    if 'linkage_matrix' not in globals():
        corr_matrix = np.corrcoef(calcium_norm)
        linkage_matrix = linkage(corr_matrix, method='average')
    
    k_umap = len([c for c in np.unique(labels_umap) if c != -1])
    k_hier = max(k_umap, 2) if k_umap > 0 else 2
    labels_hier = fcluster(linkage_matrix, k_hier, criterion='maxclust') - 1
except Exception as e:
    print(f"Warning: hierarchical labels could not be computed: {e}")
    labels_hier = np.zeros_like(labels_umap)

# Compute cluster-average traces
valid_idx_umap = labels_umap != -1
unique_umap = [c for c in np.unique(labels_umap) if c != -1]
unique_hier = np.unique(labels_hier)

means_umap = np.array([calcium_norm[labels_umap == c].mean(axis=0) for c in unique_umap])
means_hier = np.array([calcium_norm[labels_hier == c].mean(axis=0) for c in unique_hier])

# Plot cluster traces
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# UMAP-HDBSCAN traces
for j, c in enumerate(unique_umap):
    axes[0].plot(time_axis, means_umap[j] + j*2.0, linewidth=0.5, alpha=0.8,
                label=f'C{c} (n={sum(labels_umap==c)})')
axes[0].set_ylabel('Activity (offset)', fontsize=12)
axes[0].set_title('UMAP-HDBSCAN Cluster Traces', fontsize=13, pad=15)
axes[0].legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=10)
axes[0].tick_params(labelsize=10)

# Hierarchical traces  
for j, c in enumerate(unique_hier):
    axes[1].plot(time_axis, means_hier[j] + j*5, linewidth=0.5, alpha=0.8,
                label=f'C{c} (n={sum(labels_hier==c)})')
axes[1].set_ylabel('Activity (offset)', fontsize=12)
axes[1].set_xlabel('Time', fontsize=12)
axes[1].set_title('Hierarchical Cluster Traces', fontsize=13, pad=15)
axes[1].legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=10)
axes[1].tick_params(labelsize=10)

plt.setp([ax.spines.values() for ax in axes], linewidth=0.5)
for ax in axes:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()

# Confusion matrix between methods
from sklearn.metrics import confusion_matrix

mask = labels_umap != -1
if mask.sum() > 0:
    relabel_map = {c: i for i, c in enumerate(unique_umap)}
    labels_umap_relab = np.array([relabel_map[c] for c in labels_umap[mask]])
    labels_hier_sel = labels_hier[mask]
    conf_mat = confusion_matrix(labels_umap_relab, labels_hier_sel)
else:
    conf_mat = np.zeros((1, 1), dtype=int)

plt.figure(figsize=(7, 5))
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', square=True,
            cbar_kws={'shrink': 0.8}, annot_kws={'fontsize': 11})
plt.title('UMAP-HDBSCAN vs Hierarchical Clustering', fontsize=13, pad=15)
plt.xlabel('Hierarchical Cluster', fontsize=12)
plt.ylabel('UMAP-HDBSCAN Cluster', fontsize=12)
plt.tick_params(labelsize=10)
plt.tight_layout()
plt.show()

# Summary
print(f"\n=== RESULTS ===")
print(f"Dataset: {n_neurons} neurons, {n_timepoints} timepoints")
print(f"PCA: First 3 components = {pca.explained_variance_ratio_[:3].sum()*100:.1f}% variance")
print(f"UMAP-HDBSCAN: {len(unique_umap)} clusters + {sum(labels_umap==-1)} noise points")
print(f"Hierarchical: {len(unique_hier)} clusters")

print(f"\n=== CLUSTER SIZES ===")
for c in unique_umap:
    print(f"  UMAP C{c}: {sum(labels_umap==c)} neurons")
for c in unique_hier:
    print(f"  Hier C{c}: {sum(labels_hier==c)} neurons")