## Demo Notebook: Behavior clustering

In [None]:
import sys
sys.path.append('/Users/annateruel/socialhierarchy/code/')
from behavioral_clustering import load_features, sample_frames, run_umap, train_embedding_model, plot_umap_embedding, map_density, map_feature_by_cluster
import pandas as pd

In [None]:
sp_frames = sample_frames(directory='/Users/annateruel/Desktop/features',
                          file_format='csv',
                          frames_total = 10000
)

In [None]:
embedding = run_umap(sp_frames, min_dist=0.5)

In [None]:
train_embedding_model(embedding, sp_frames, save_model=True,
                      save_path='/Users/annateruel/Desktop/embedding_model.h5')

In [None]:
plot_umap_embedding(embedding, save=True, save_dir= '/Users/annateruel/Desktop', format='png')

In [None]:
map_density(embedding, 
            sigma= 4, 
            percentile=40, 
            cmap='OrRd', 
            save=True, 
            save_dir='/Users/annateruel/Desktop', 
            format='svg')

In [None]:
labeled_map, density_map, xe, ye = map_density(
    embedding,
    sigma=4,
    percentile=40,
    plot=False  # Set to False to avoid saving the density map
)
features = load_features(directory='/Users/annateruel/Desktop/features',
                        file_format='csv')
features_all = pd.concat(list(features.values()), ignore_index=False)
variable = features_all['avg_centroid_distance'].values
map_feature_by_cluster(
    embedding=embedding,
    variable=variable,
    labeled_map=labeled_map,
    xe=xe,
    ye=ye,
    cmap='plasma',
    save=True,
    save_dir='/Users/annateruel/Desktop',
    format='svg'
)

##### TESTING CODE


In [None]:
def hierarchical_clustering(embedding, labeled_map, xe, ye, method='ward', plot=True):
    """
    Perform hierarchical clustering on behavior clusters and optionally plot a dendrogram.

    Args:
        embedding (np.ndarray): UMAP 2D embedding (n_frames x 2).
        labeled_map (np.ndarray): Watershed cluster map from density-based clustering.
        xe (np.ndarray): x-axis bin edges from histogram2d.
        ye (np.ndarray): y-axis bin edges from histogram2d.
        method (str): Linkage method for hierarchical clustering (default 'ward').
        plot (bool): Whether to plot the dendrogram.

    Returns:
        Z (np.ndarray): The linkage matrix used to construct the dendrogram.
    """
    x_idx = np.digitize(embedding[:, 0], xe) - 1
    y_idx = np.digitize(embedding[:, 1], ye) - 1

    valid = (
        (x_idx >= 0) & (x_idx < labeled_map.shape[0]) &
        (y_idx >= 0) & (y_idx < labeled_map.shape[1])
    )
    x_idx = x_idx[valid]
    y_idx = y_idx[valid]
    embedding_valid = embedding[valid]

    cluster_labels = labeled_map[x_idx, y_idx]

    valid_points = ~np.isnan(cluster_labels) & (cluster_labels > 0)
    labels = cluster_labels[valid_points].astype(int)
    embedding_valid = embedding_valid[valid_points]
    unique_labels = np.unique(labels)
    centroids = np.array([
        embedding_valid[labels == lbl].mean(axis=0)
        for lbl in unique_labels
    ])

    Z = linkage(centroids, method=method)

    if plot:
        plt.figure(figsize=(20, 15))
        dendrogram(Z, labels=unique_labels)
        plt.title("Hierarchical Clustering of Behavioral Clusters")
        plt.xlabel("Cluster")
        plt.ylabel("Distance")
        plt.tight_layout()
        plt.show()

    return Z,  cluster_labels

In [None]:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list
from scipy.stats import binom_test
from matplotlib import gridspec

def compute_transition_matrix(cluster_labels, n_clusters):
    cluster_labels = cluster_labels.astype(int)
    transition_matrix = np.zeros((n_clusters, n_clusters))

    for (i, j) in zip(cluster_labels[:-1], cluster_labels[1:]):
        if i > 0 and j > 0:
            transition_matrix[i - 1, j - 1] += 1

    row_sums = transition_matrix.sum(axis=1, keepdims=True)
    transition_matrix = np.divide(
        transition_matrix,
        row_sums,
        where=row_sums != 0
    )
    return transition_matrix

def plot_transition_matrix_with_dendrogram(cluster_labels, n_clusters, method='ward', annotate_significance=True, alpha=0.05):
    T = compute_transition_matrix(cluster_labels, n_clusters)

    # Hierarchical clustering on rows
    linkage_matrix = linkage(T, method=method)
    order = leaves_list(linkage_matrix)
    T_sorted = T[order, :][:, order]

    # Bootstrap significance test
    if annotate_significance:
        n_per_row = np.sum(T * (T > 0), axis=1, keepdims=True)
        p_values = np.ones_like(T)
        for i in range(n_clusters):
            for j in range(n_clusters):
                count = T[i, j] * n_per_row[i]
                if n_per_row[i] > 0:
                    p_values[i, j] = binom_test(count, n_per_row[i], 1/n_clusters, alternative='greater')
        sig_mask = p_values < alpha
        T_display = np.where(sig_mask, T_sorted, np.nan)
    else:
        T_display = T_sorted

    # Layout with dendrogram
    fig = plt.figure(figsize=(12, 10))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 5], wspace=0.05)

    # Dendrogram
    ax_dendro = plt.subplot(gs[0])
    dendrogram(linkage_matrix, orientation='left', labels=None, no_labels=True, ax=ax_dendro, color_threshold=0)
    ax_dendro.invert_yaxis()
    ax_dendro.axis('off')

    # Heatmap
    ax = plt.subplot(gs[1])
    sns.heatmap(
        T_display,
        cmap='OrRd',
        xticklabels=order + 1,
        yticklabels=order + 1,
        cbar_kws={"label": "P(start → exit)"},
        square=True,
        annot=False,
        ax=ax,
        mask=np.isnan(T_display)
    )

    # Add significance markers
    if annotate_significance:
        for i in range(n_clusters):
            for j in range(n_clusters):
                if sig_mask[order[i], order[j]]:
                    ax.text(j + 0.5, i + 0.5, '*', color='white', ha='center', va='center', fontsize=10)

    ax.set_title("Behavioral Transition Structure")
    ax.set_xlabel("exit behavior")
    ax.set_ylabel("start behavior")
    plt.tight_layout()
    plt.show()

    return T, order, linkage_matrix

In [None]:
labeled_map, density_map, xe, ye = map_density(
    embedding, bins=200, sigma=3.5, percentile=30, plot=False
)

In [None]:
def clean_cluster_labels(cluster_labels):
    """
    Clean cluster labels by removing NaNs and background (label ≤ 0),
    and converting to integer array for valid points only.
    """
    mask = ~np.isnan(cluster_labels) & (cluster_labels > 0)
    return cluster_labels[mask].astype(int)

In [None]:
Z, cluster_labels = hierarchical_clustering(embedding, labeled_map, xe, ye)
clean_labels = clean_cluster_labels(cluster_labels)

# Use max based on clean labels, not the original
n_clusters = clean_labels.max()

T, order, Z = plot_transition_matrix_with_dendrogram(clean_labels, n_clusters=n_clusters)