# Cluster analyses

---

## 0. Environmental setup

In [None]:
from __future__ import annotations

import pickle
from typing import Any, Dict, Iterable, List, Union, Mapping
from typing import Sequence, Optional
from typing import Tuple
import matplotlib.patches as mpatches

import plotly.io as pio

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler
from umap import UMAP
from yellowbrick.cluster import KElbowVisualizer
import plotly.graph_objects as go

sns.set_style('white', rc={
    'xtick.bottom': True,
    'ytick.left': True,
})


In [None]:
bird_palette = {"g4r4": "tab:blue", "j8v8":"tab:orange", "o11y3":"tab:green", "r11n11":"tab:red", "r14n14":"tab:purple","r15v15":"tab:brown"}
cluster_palette = {0:"#e41a1c", 1:"#377eb8", 2:"#4daf4a", 3:"#984ea3", 4:"#ff7f00", 5:"#ffff33", 6:"#a65628"}

In [None]:
def _safe_mean(x: Any) -> float:
    if x is None:
        return np.nan

    if isinstance(x, (int, float, np.number)):
        return float(x)

    try:
        arr = np.asarray(x, dtype=float)
    except Exception:
        return np.nan

    if arr.size == 0:
        return np.nan

    return float(np.nanmean(arr))

def _safe_std(x: Any) -> float:
    if x is None:
        return np.nan

    if isinstance(x, (int, float, np.number)):
        return float(x)

    try:
        arr = np.asarray(x, dtype=float)
    except Exception:
        return np.nan

    if arr.size == 0:
        return np.nan

    return float(np.nanstd(arr))


def birdsong_list_to_acoustic_df(
    songs: List[Dict[str, Any]],
    meta_keys: Iterable[str] = ("bird", "dph", "n_frames"),
    acoustic_key: str = "acoustic",
    id_col: str = "song_id",
    prefix: str = "mean_",
) -> pd.DataFrame:
    rows = []
    all_acoustic_features = set()

    # Discover all acoustic feature names
    for song in songs:
        acoustic = song.get(acoustic_key, {}) or {}
        if isinstance(acoustic, dict):
            all_acoustic_features.update(acoustic.keys())

    # Build rows
    for song_id, song in enumerate(songs):
        row = {id_col: song_id}
        row.update({k: song.get(k, np.nan) for k in meta_keys})

        acoustic = song.get(acoustic_key, {}) or {}
        if not isinstance(acoustic, dict):
            acoustic = {}

        for feat in all_acoustic_features:
            row[f"mean_{feat}"] = _safe_mean(acoustic.get(feat))
            row[f"std_{feat}"] = _safe_std(acoustic.get(feat))

        rows.append(row)

    df = pd.DataFrame(rows)

    mean_cols = sorted(c for c in df.columns if c.startswith("mean_") or c.startswith("std_"))
    df = df[[id_col] + list(meta_keys) + mean_cols]

    return df


def birdsong_list_to_neural_df(
    songs: List[Dict[str, Any]],
    meta_keys: Iterable[str] = ("bird", "dph", "n_frames"),
    neural_key: str = "neural",
    id_col: str = "song_id",
    lag_col: str = "time_lag",
) -> pd.DataFrame:
    rows = []
    all_feature_names = set()

    # Discover all feature names
    for song in songs:
        neural = song.get(neural_key, {}) or {}
        if not isinstance(neural, dict):
            continue
        for payload in neural.values():
            if isinstance(payload, dict):
                fnames = payload.get("fnames", []) or []
                all_feature_names.update(fnames)

    def _coerce_feats(feats: Any, n: int) -> np.ndarray:
        if feats is None:
            return np.full(n, np.nan)
        arr = np.asarray(feats)
        out = np.full(n, np.nan)
        m = min(n, arr.size)
        for i in range(m):
            try:
                out[i] = float(arr[i])
            except Exception:
                out[i] = np.nan
        return out

    # Build rows
    for song_id, song in enumerate(songs):
        meta = {k: song.get(k, np.nan) for k in meta_keys}
        neural = song.get(neural_key, {}) or {}

        if not isinstance(neural, dict):
            continue

        for lag, payload in neural.items():
            row = {id_col: song_id, **meta, lag_col: lag}

            # initialize all neural features as NaN
            for fname in all_feature_names:
                row[fname] = np.nan

            if isinstance(payload, dict):
                fnames = payload.get("fnames", []) or []
                feats = payload.get("feats", None)
                feats_arr = _coerce_feats(feats, len(fnames))

                for i, fname in enumerate(fnames):
                    row[fname] = feats_arr[i]

            rows.append(row)

    df = pd.DataFrame(rows)

    feature_cols = sorted(
        c for c in df.columns
        if c not in [id_col, lag_col, *meta_keys]
    )

    df = df[[id_col] + list(meta_keys) + [lag_col] + feature_cols]
    df = df.sort_values([id_col, lag_col], kind="mergesort").reset_index(drop=True)

    return df


In [None]:
def _build_unrolled_acoustic_matrix(
    songs: List[Dict],
    acoustic_key: str = "acoustic",
    n_frames_key: str = "n_frames",
    feature_order: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str], int]:
    if not songs:
        return np.zeros((0, 0), dtype=np.float32), [], 0

    if feature_order is None:
        acoustic0 = songs[0].get(acoustic_key, {})
        if not isinstance(acoustic0, dict):
            raise ValueError(f"Expected '{acoustic_key}' to be a dict.")
        feature_order = list(acoustic0.keys())

    max_len = max(int(s[n_frames_key]) for s in songs)
    n_songs = len(songs)
    n_feats = len(feature_order)

    X = np.zeros((n_songs, n_feats * max_len), dtype=np.float32)

    for i, s in enumerate(songs):
        blocks = []
        acoustic = s.get(acoustic_key, {})

        for f in feature_order:
            vals = np.asarray(acoustic.get(f, []), dtype=np.float32)
            if vals.ndim != 1:
                vals = vals.ravel()

            pad = max_len - len(vals)
            if pad < 0:
                vals = vals[:max_len]
                pad = 0

            blocks.append(np.pad(vals, (0, pad)))

        X[i] = np.concatenate(blocks)

    return X, feature_order, max_len


def add_pca_umap_from_unrolled_acoustics(
    songs: List[Dict],
    acoustic_df: pd.DataFrame,
    acoustic_key: str = "acoustic",
    n_frames_key: str = "n_frames",
    n_pcs: int = 50,
    umap_min_dist: float = 0.1,
    umap_n_neighbors: int = 15,
    random_state: int = 0,
):
    if acoustic_df.empty or not songs:
        return acoustic_df.copy(), None, None, None

    # ---------- Unrolled acoustic matrix ----------
    X, feature_order, max_len = _build_unrolled_acoustic_matrix(
        songs,
        acoustic_key=acoustic_key,
        n_frames_key=n_frames_key,
        feature_order=None,
    )
    print(X.shape)

    # ---------- Z-score ----------
    scaler = StandardScaler()
    Xz = scaler.fit_transform(X)

    # ---------- PCA ----------
    max_pcs = min(Xz.shape[0], Xz.shape[1])
    n_pcs_eff = min(n_pcs, max_pcs)

    pca = PCA(n_components=n_pcs_eff, random_state=random_state)
    pcs = pca.fit_transform(Xz)

    # ---------- Build PCA columns ----------
    pca_cols = {
        f"pc_{j+1}": pcs[:, j]
        for j in range(n_pcs_eff)
    }

    evr = pca.explained_variance_ratio_
    cum = np.cumsum(evr)

    for j in range(n_pcs_eff):
        pca_cols[f"pca_evr_{j+1}"] = float(evr[j])
        pca_cols[f"pca_cum_evr_{j+1}"] = float(cum[j])

    pca_cols["pca_total_evr"] = float(cum[-1]) if cum.size else np.nan

    pca_df = pd.DataFrame(pca_cols, index=acoustic_df.index)

    # ---------- UMAP ----------
    if umap_n_neighbors is None:
        umap_n_neighbors = min(max(2, int(np.sqrt(len(pca_df)))), len(pca_df)-1)
    reducer = UMAP(
        n_components=2,
        min_dist=umap_min_dist,
        n_neighbors=umap_n_neighbors,
        random_state=random_state,
        init="pca",
    )
    umap_emb = reducer.fit_transform(pcs)

    umap_df = pd.DataFrame(
        umap_emb,
        columns=["umap_1", "umap_2"],
        index=acoustic_df.index,
    )

    # ---------- Single concat (no fragmentation) ----------
    df_out = pd.concat([acoustic_df, pca_df, umap_df], axis=1)

    return df_out, pca, reducer, scaler


In [None]:
def filter_songs_by_n_frames(
    songs: List[Dict],
    max_frames: int,
    n_frames_key: str = "n_frames",
) -> List[Dict]:
    return [
        s for s in songs
        if s.get(n_frames_key) is not None
        and isinstance(s.get(n_frames_key), (int, float))
        and s[n_frames_key] <= max_frames
    ]


In [None]:
def add_mean_umap(data, min_dist=0.1, n_neighbors = 15, random_state=1234):
    mean_cols = [c for c in full_acoustic_data.columns if c.startswith("mean_")]
    X_mean = data[mean_cols].to_numpy()
    X_mean = StandardScaler().fit_transform(X_mean)
    umap_mean = UMAP(
    n_neighbors=n_neighbors,
    min_dist=min_dist,
    random_state=random_state
).fit_transform(X_mean)
    data["umap_mean_1"] = umap_mean[:, 0]
    data["umap_mean_2"] = umap_mean[:, 1]
    return data

In [None]:
def add_neural_umap(data, feature_columns, lag_column, min_dist=0.1, n_neighbors = 15, random_state=1234):
    for l in data.loc[:, lag_column].unique():
        lag_data = data.loc[data.loc[:, lag_column] == l, feature_columns]
        X = lag_data.to_numpy()
        X = StandardScaler().fit_transform(X)
        umap_embs = UMAP(
    n_neighbors=n_neighbors,
    min_dist=min_dist,
    random_state=random_state
).fit_transform(X)
        data.loc[lag_data.index, "umap_lag{}_1".format(l)] = umap_embs[:, 0]
        data.loc[lag_data.index, "umap_lag{}_2".format(l)] = umap_embs[:, 1]
    return data

In [None]:
def seaborn_scatter_with_colorbar(
    data,
    x,
    y,
    hue,
    *,
    percentiles=(2, 98),
    cmap="viridis",
    s=4,
    figsize=(9, 6),
    ax=None,
    cbar_label=None,
    title=None,
    vmin=None,
    vmax=None,
):
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    vals = data[hue].values
    if vmin is None or vmax is None:
        vmin, vmax = np.percentile(vals, percentiles)

    sns.scatterplot(
        data=data,
        x=x,
        y=y,
        hue=hue,
        palette=cmap,
        s=s,
        ax=ax,
        legend=False,
        hue_norm=(vmin, vmax),
    )

    ax.set_xlabel(x)
    ax.set_ylabel(y)

    if title is not None:
        ax.set_title(title)

    # colorbar
    norm = Normalize(vmin=vmin, vmax=vmax)
    sm = ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])

    cbar = fig.colorbar(sm, ax=ax)
    cbar.set_label(cbar_label or hue.replace("_", " ").replace("-", " ").capitalize())

    return fig, ax, vmin, vmax


In [None]:
def plot_k_elbow(
    df: pd.DataFrame,
    feature_cols: Sequence[str],
    model,
    k_range: Tuple[int, int] = (4, 12),
    metric: str = "calinski_harabasz",
    standardize: bool = True,
    timings: bool = False,
):
    missing = [c for c in feature_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing feature columns: {missing}")

    X = df[list(feature_cols)].to_numpy()
    if standardize:
        X = StandardScaler().fit_transform(X)

    viz = KElbowVisualizer(model, k=k_range, metric=metric, timings=timings)
    viz.fit(X)
    viz.show()
    return viz, viz.elbow_value_


In [None]:
def add_cluster_labels(
    df: pd.DataFrame,
    feature_cols: Sequence[str],
    model,
    k: int,
    k_param: str = "n_clusters",   # use "n_components" for GaussianMixture
    cluster_col: str = "cluster",
    standardize: bool = True,
):
    missing = [c for c in feature_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing feature columns: {missing}")

    params = model.get_params(deep=True)
    if k_param not in params:
        raise ValueError(f"Model does not have parameter '{k_param}'.")

    X = df[list(feature_cols)].to_numpy()
    if standardize:
        X = StandardScaler().fit_transform(X)

    model.set_params(**{k_param: int(k)})
    model.fit(X)

    # labels from labels_ if available, otherwise use predict()
    labels = getattr(model, "labels_", None)
    if labels is None:
        if not hasattr(model, "predict"):
            raise RuntimeError("Model has neither labels_ nor predict().")
        labels = model.predict(X)

    out = df.copy()
    out[cluster_col] = labels
    return out, model


In [None]:
def select_gmm_k_bic(
    df: pd.DataFrame,
    feature_cols: Sequence[str],
    k_range: Tuple[int, int] = (2, 12),
    covariance_type: str = "full",
    standardize: bool = True,
    random_state: int = 0,
    n_init: int = 3,
    reg_covar: float = 1e-6,
    plot: bool = True,
):
    missing = [c for c in feature_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing feature columns: {missing}")

    X = df[list(feature_cols)].to_numpy()
    if standardize:
        X = StandardScaler().fit_transform(X)

    ks = list(range(int(k_range[0]), int(k_range[1]) + 1))
    bics = []

    for k in ks:
        gmm = GaussianMixture(
            n_components=k,
            covariance_type=covariance_type,
            random_state=random_state,
            n_init=n_init,
            reg_covar=reg_covar,
        )
        gmm.fit(X)
        bics.append(gmm.bic(X))

    bic_table = pd.DataFrame({"k": ks, "bic": bics})
    best_k = int(bic_table.loc[bic_table["bic"].idxmin(), "k"])

    if plot:
        plt.figure(figsize=(6, 4))
        plt.plot(bic_table["k"], bic_table["bic"], marker="o")
        plt.xlabel("n_components (k)")
        plt.ylabel("BIC (lower is better)")
        plt.title(f"GMM BIC selection (best k = {best_k})")
        plt.tight_layout()
        plt.show()

    return best_k, bic_table


In [None]:
def add_gmm_labels(
    df: pd.DataFrame,
    feature_cols: Sequence[str],
    k: int,
    cluster_col: str = "gmm_cluster",
    covariance_type: str = "full",
    standardize: bool = True,
    random_state: int = 0,
    n_init: int = 3,
    reg_covar: float = 1e-6,
    add_probs: bool = False,
    prob_prefix: str = "gmm_prob_",
):
    missing = [c for c in feature_cols if c not in df.columns]
    if missing:
        raise ValueError(f"Missing feature columns: {missing}")

    X = df[list(feature_cols)].to_numpy()
    if standardize:
        X = StandardScaler().fit_transform(X)

    gmm = GaussianMixture(
        n_components=int(k),
        covariance_type=covariance_type,
        random_state=random_state,
        n_init=n_init,
        reg_covar=reg_covar,
    )
    gmm.fit(X)

    out = df.copy()
    out[cluster_col] = gmm.predict(X)

    if add_probs:
        probs = gmm.predict_proba(X)
        for j in range(probs.shape[1]):
            out[f"{prob_prefix}{j}"] = probs[:, j]

    return out, gmm


In [None]:
def matrixplot_median_by_cluster(
    df: pd.DataFrame,
    cluster_col: str,
    feature_cols: list[str],
    *,
    sort_clusters: bool = True,
    sort_features: bool = False,
    normalize: str | None = None,   # None | "rows" | "columns"
    cmap: str = "viridis",
    figsize: tuple[float, float] | None = None,
    vmin: float | None = None,
    vmax: float | None = None,
    cbar_label: str | None = None,
    annot: bool = False,
    annot_fmt: str = ".2f",
):
    missing = [c for c in [cluster_col, *feature_cols] if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns: {missing}")

    # --- compute medians ---
    med = df.groupby(cluster_col)[feature_cols].median()

    # --- sorting ---
    if sort_clusters:
        med = med.sort_index(axis=0)
    if sort_features:
        med = med.reindex(sorted(med.columns), axis=1)

    plot_mat = med.copy()

    # --- min–max normalization (safe pandas broadcasting) ---
    if normalize is not None:
        if normalize not in {"rows", "columns"}:
            raise ValueError("normalize must be one of: None, 'rows', 'columns'")

        if normalize == "columns":
            col_min = plot_mat.min(axis=0)
            col_max = plot_mat.max(axis=0)
            denom = (col_max - col_min).replace(0, np.nan)  # avoid divide-by-zero
            plot_mat = plot_mat.sub(col_min, axis=1).div(denom, axis=1)
            cbar_label = cbar_label or "min–max (per feature)"
            if vmin is None and vmax is None:
                vmin, vmax = 0, 1

        elif normalize == "rows":
            row_min = plot_mat.min(axis=1)
            row_max = plot_mat.max(axis=1)
            denom = (row_max - row_min).replace(0, np.nan)  # avoid divide-by-zero
            plot_mat = plot_mat.sub(row_min, axis=0).div(denom, axis=0)
            cbar_label = cbar_label or "min–max (per cluster)"
            if vmin is None and vmax is None:
                vmin, vmax = 0, 1

    else:
        cbar_label = cbar_label or "median value"

    # --- figure size heuristic ---
    if figsize is None:
        figsize = (max(6, 0.28 * plot_mat.shape[1]), max(3, 0.35 * plot_mat.shape[0]))

    fig, ax = plt.subplots(figsize=figsize)

    sns.heatmap(
        plot_mat,
        ax=ax,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        cbar_kws={"label": cbar_label},
        annot=annot,
        fmt=annot_fmt,
        linewidths=0.5,
        linecolor="white",
    )

    ax.set_xlabel("Features")
    ax.set_ylabel(cluster_col)
    ax.set_title(f"Median feature values by {cluster_col}")

    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

    plt.tight_layout()
    return fig, ax, med


In [None]:
def plot_cluster_percent_by_bird(
    df: pd.DataFrame,
    bird_col: str,
    cluster_col: str,
    ax: plt.Axes | None = None,
    palette: str | dict = "Set1",
    label_fmt: str = "{:.0f}%",
    min_label_pct: float = 5.0,
    sort_birds: bool = True,
    sort_clusters: bool = True,
):
    if bird_col not in df.columns:
        raise ValueError(f"'{bird_col}' not found in dataframe.")
    if cluster_col not in df.columns:
        raise ValueError(f"'{cluster_col}' not found in dataframe.")

    d = df[[bird_col, cluster_col]].dropna()

    # counts per bird x cluster
    ct = pd.crosstab(d[bird_col], d[cluster_col])

    # percentages per bird
    pct = ct.div(ct.sum(axis=1), axis=0) * 100

    if sort_birds:
        pct = pct.sort_index(axis=0)
    if sort_clusters:
        pct = pct.sort_index(axis=1)

    if ax is None:
        fig, ax = plt.subplots(figsize=(10, max(3, 0.35 * pct.shape[0])))

    clusters = list(pct.columns)

    # resolve colors
    if isinstance(palette, dict):
        missing = [c for c in clusters if c not in palette]
        if missing:
            raise ValueError(f"Palette dict missing colors for clusters: {missing}")
        colors = [palette[c] for c in clusters]
    else:
        cmap = plt.get_cmap(palette)
        colors = cmap(np.linspace(0, 1, len(clusters)))

    left = np.zeros(pct.shape[0])
    y = np.arange(pct.shape[0])

    for cluster, color in zip(clusters, colors):
        vals = pct[cluster].to_numpy()
        ax.barh(
            y,
            vals,
            left=left,
            color=color,
            edgecolor="white",
            label=str(cluster),
        )

        # --- annotation: ALWAYS black ---
        for i, v in enumerate(vals):
            if v >= min_label_pct:
                ax.text(
                    left[i] + v / 2,
                    y[i],
                    label_fmt.format(v),
                    ha="center",
                    va="center",
                    fontsize=9,
                    color="black",
                )

        left += vals

    ax.set_yticks(y)
    ax.set_yticklabels(pct.index.astype(str))
    ax.set_xlim(0, 100)
    ax.set_xlabel("Percentage of songs")
    ax.set_ylabel(bird_col)
    ax.set_title(f"Cluster composition by {bird_col} ({cluster_col})")

    ax.legend(
        title="Cluster",
        bbox_to_anchor=(1.02, 1),
        loc="upper left",
        frameon=False,
    )

    return ax


In [None]:
def plot_sankey_clusters(
    df,
    bird_col="bird",
    cluster1_col="ward_clusters_lag0",
    cluster2_col="ward_clusters",
    unit_col=None,
    unit_how="mode",  # "mode" or "first"
    cluster1_label="neural",
    cluster2_label="acoustic",
    bird_colors=None,        # dict: bird -> hex
    cluster1_colors=None,    # dict: cluster1 -> hex
    cluster2_colors=None,    # dict: cluster2 -> hex
    default_node_color="#d9d9d9",
    link_alpha=0.35,         # opacity for links (0..1)
    title=None,
    height=650,
):
    # --- helpers ---
    def _norm_map(m):
        return None if m is None else {str(k): v for k, v in m.items()}

    def _hex_to_rgba(hex_color, alpha):
        if hex_color is None:
            return f"rgba(0,0,0,{alpha})"
        hc = hex_color.strip()
        if hc.startswith("rgba(") or hc.startswith("rgb("):
            # If user passed rgba, respect it (but if rgb, we can't inject alpha safely here)
            return hc
        if hc.startswith("#") and len(hc) == 7:
            r = int(hc[1:3], 16)
            g = int(hc[3:5], 16)
            b = int(hc[5:7], 16)
            return f"rgba({r},{g},{b},{alpha})"
        # fallback
        return f"rgba(0,0,0,{alpha})"

    bird_colors = _norm_map(bird_colors)
    cluster1_colors = _norm_map(cluster1_colors)
    cluster2_colors = _norm_map(cluster2_colors)

    # --- prep data ---
    cols = [bird_col, cluster1_col, cluster2_col] + ([unit_col] if unit_col else [])
    d = df[cols].dropna().copy()
    for c in cols:
        d[c] = d[c].astype(str)

    # collapse to one row per unit if requested (e.g., if df is per-frame/per-window)
    if unit_col:
        def _mode(s):
            return s.value_counts().index[0]  # deterministic tie-break

        if unit_how == "mode":
            d = (d.groupby([bird_col, unit_col], as_index=False)
                   .agg({cluster1_col: _mode, cluster2_col: _mode}))
        elif unit_how == "first":
            d = (d.sort_values([bird_col, unit_col])
                   .drop_duplicates([bird_col, unit_col], keep="first"))
        else:
            raise ValueError("unit_how must be 'mode' or 'first'")

    # --- aggregate flows ---
    b2c1 = d.groupby([bird_col, cluster1_col]).size().reset_index(name="value")
    c1c2 = d.groupby([cluster1_col, cluster2_col]).size().reset_index(name="value")

    # --- nodes (namespaced) ---
    bird_nodes = [f"bird: {b}" for b in sorted(d[bird_col].unique())]
    c1_nodes = [f"{cluster1_label}: {c}" for c in sorted(d[cluster1_col].unique())]
    c2_nodes = [f"{cluster2_label}: {c}" for c in sorted(d[cluster2_col].unique())]

    labels = bird_nodes + c1_nodes + c2_nodes
    idx = {lab: i for i, lab in enumerate(labels)}

    # --- node colors from dicts ---
    node_colors = []
    for lab in labels:
        if lab.startswith("bird: "):
            key = lab.replace("bird: ", "")
            node_colors.append(bird_colors.get(key, default_node_color) if bird_colors else default_node_color)
        elif lab.startswith(f"{cluster1_label}: "):
            key = lab.replace(f"{cluster1_label}: ", "")
            node_colors.append(cluster1_colors.get(key, default_node_color) if cluster1_colors else default_node_color)
        elif lab.startswith(f"{cluster2_label}: "):
            key = lab.replace(f"{cluster2_label}: ", "")
            node_colors.append(cluster2_colors.get(key, default_node_color) if cluster2_colors else default_node_color)
        else:
            node_colors.append(default_node_color)

    # convenience maps for link coloring
    bird_to_color = (bird_colors or {})
    c1_to_color = (cluster1_colors or {})

    # --- links + link colors = source color ---
    sources, targets, values, link_colors = [], [], [], []

    # bird → cluster1
    for _, r in b2c1.iterrows():
        b = r[bird_col]
        c1 = r[cluster1_col]
        sources.append(idx[f"bird: {b}"])
        targets.append(idx[f"{cluster1_label}: {c1}"])
        values.append(int(r["value"]))
        link_colors.append(_hex_to_rgba(bird_to_color.get(b, default_node_color), link_alpha))

    # cluster1 → cluster2
    for _, r in c1c2.iterrows():
        c1 = r[cluster1_col]
        c2 = r[cluster2_col]
        sources.append(idx[f"{cluster1_label}: {c1}"])
        targets.append(idx[f"{cluster2_label}: {c2}"])
        values.append(int(r["value"]))
        link_colors.append(_hex_to_rgba(c1_to_color.get(c1, default_node_color), link_alpha))

    if title is None:
        title = f"Agreement: bird → {cluster1_label} → {cluster2_label}"

    fig = go.Figure(go.Sankey(
        node=dict(
            label=labels,
            color=node_colors,
            pad=12,
            thickness=14,
            line=dict(color="rgba(0,0,0,0.25)", width=0.5),
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors,
        ),
    ))

    fig.update_layout(title=title, font_size=11, height=height)
    fig.show()
    return fig, d


In [None]:
def corr_by_bird(
    df: pd.DataFrame,
    features: list[str],
    group_col: str = "bird",
    target_col: str = "neural_change_statistic",
    method: str = "pearson",
    min_n: int = 3
) -> pd.DataFrame:
    rows = []

    for bird, g in df.groupby(group_col, sort=False):
        g = g[g[target_col] != 0]

        if len(g) < min_n:
            rows.append(pd.Series(index=features, dtype=float, name=bird))
            continue

        corr = g[features].corrwith(g[target_col], method=method)
        corr.name = bird
        rows.append(corr)

    return pd.DataFrame(rows)


In [None]:
def merged_correlation_matrices(
    acoustic: pd.DataFrame,
    neural: pd.DataFrame,
    id_col: str = "song_id",
    how: str = "inner",
    suffixes: Tuple[str, str] = ("_acoustic", "_neural"),
    numeric_only: bool = True,
    drop_constant: bool = True,
    min_periods: int = 2,
) -> Dict[str, pd.DataFrame]:
    # Basic validation
    if id_col not in acoustic.columns:
        raise KeyError(f"`{id_col}` not found in `acoustic` columns.")
    if id_col not in neural.columns:
        raise KeyError(f"`{id_col}` not found in `neural` columns.")
    if not acoustic[id_col].is_unique:
        raise ValueError(f"`{id_col}` is not unique in `acoustic`.")
    if not neural[id_col].is_unique:
        raise ValueError(f"`{id_col}` is not unique in `neural`.")

    merged = acoustic.merge(neural, on=id_col, how=how, suffixes=suffixes)

    # Drop the id column
    X = merged.drop(columns=[id_col])

    # Keep only numeric columns if requested
    if numeric_only:
        X = X.select_dtypes(include=[np.number])

    # Optionally drop constant columns (correlation would be NaN)
    if drop_constant and X.shape[1] > 0:
        nunique = X.nunique(dropna=True)
        X = X.loc[:, nunique > 1]

    if X.shape[1] == 0:
        raise ValueError(
            "No usable feature columns remain after dropping id / filtering numeric / dropping constants."
        )

    pearson = X.corr(method="pearson", min_periods=min_periods)
    spearman = X.corr(method="spearman", min_periods=min_periods)

    return {"pearson": pearson, "spearman": spearman}

In [None]:
Color = Any  # matplotlib accepts many "color-like" specs
def clustered_corr_heatmap(
    corr: pd.DataFrame,
    *,
    method: str = "average",
    metric: str = "euclidean",
    center: float = 0.0,
    cmap: Union[str, sns.palettes._ColorPalette] = "vlag",
    vmin: Optional[float] = -1.0,
    vmax: Optional[float] = 1.0,
    figsize: Tuple[float, float] = (10, 10),
    linewidths: float = 0.0,
    linecolor: str = "white",
    annot: bool = False,
    fmt: str = ".2f",
    annot_kws: Optional[dict] = None,
    dendrogram_ratio: Tuple[float, float] = (0.15, 0.15),
    cbar_pos: Tuple[float, float, float, float] = (0.02, 0.8, 0.05, 0.18),
    cbar_kws: Optional[dict] = None,
    xticklabels: Union[bool, int] = True,
    yticklabels: Union[bool, int] = True,
    row_cluster: bool = True,
    col_cluster: bool = True,
    robust: bool = False,
    rasterized: bool = False,
    # feature category annotations
    feature_to_category: Optional[Mapping[str, str]] = None,
    unknown_category: str = "Unknown",
    # NEW: palette can be str/list OR dict(category->color)
    category_palette: Union[str, list, Mapping[str, Color]] = "tab20",
    category_order: Optional[list] = None,
    unknown_color: Color = "lightgray",
    # legend
    show_category_legend: bool = True,
    legend_title: str = "Feature category",
    legend_loc: str = "upper right",
    legend_bbox_to_anchor: Tuple[float, float] = (1.02, 1.0),
    # title
    title: Optional[str] = None,
    title_kws: Optional[dict] = None,
    title_top: float = 0.92,
    title_y: float = 0.98,
) -> sns.matrix.ClusterGrid:
    if corr.shape[0] != corr.shape[1]:
        raise ValueError("`corr` must be a square matrix.")
    if list(corr.index) != list(corr.columns):
        raise ValueError("`corr` index and columns should match and be in the same order.")

    row_colors = col_colors = None
    cat2color: Optional[Dict[str, Color]] = None

    if feature_to_category is not None:
        categories = pd.Series(
            [feature_to_category.get(col, unknown_category) for col in corr.columns],
            index=corr.columns,
            name="category",
        )

        # Determine the set/order of categories present in this matrix
        if category_order is None:
            seen = []
            for c in categories.tolist():
                if c not in seen:
                    seen.append(c)
            category_order = seen

        # Build cat2color depending on palette type
        if isinstance(category_palette, Mapping):
            # User-supplied explicit mapping: {category: color}
            cat2color = dict(category_palette)

            # Handle unknowns
            if unknown_category in category_order and unknown_category not in cat2color:
                cat2color[unknown_category] = unknown_color

            # Validate: every category in the matrix has a color
            missing = [c for c in category_order if c not in cat2color]
            if missing:
                raise ValueError(
                    "category_palette dict is missing colors for categories: "
                    + ", ".join(map(str, missing))
                )
        else:
            # Old behavior: palette name or list of colors
            pal = sns.color_palette(category_palette, n_colors=len(category_order))
            cat2color = {cat: pal[i] for i, cat in enumerate(category_order)}
            if unknown_category in category_order:
                cat2color[unknown_category] = unknown_color

        feature_colors = categories.map(cat2color)
        row_colors = feature_colors.reindex(corr.index)
        col_colors = feature_colors.reindex(corr.columns)

    cg = sns.clustermap(
        corr,
        method=method,
        metric=metric,
        cmap=cmap,
        center=center,
        vmin=vmin,
        vmax=vmax,
        figsize=figsize,
        linewidths=linewidths,
        linecolor=linecolor,
        annot=annot,
        fmt=fmt,
        annot_kws=annot_kws,
        dendrogram_ratio=dendrogram_ratio,
        cbar_pos=cbar_pos,
        cbar_kws=cbar_kws,
        xticklabels=xticklabels,
        yticklabels=yticklabels,
        row_cluster=row_cluster,
        col_cluster=col_cluster,
        robust=robust,
        rasterized=rasterized,
        row_colors=row_colors,
        col_colors=col_colors,
    )

    cg.ax_heatmap.set_xlabel("")
    cg.ax_heatmap.set_ylabel("")
    plt.setp(cg.ax_heatmap.get_xticklabels(), rotation=90)
    plt.setp(cg.ax_heatmap.get_yticklabels(), rotation=0)

    if title is not None:
        if title_kws is None:
            title_kws = {"fontsize": 14, "fontweight": "bold"}
        cg.fig.subplots_adjust(top=title_top)
        cg.fig.suptitle(title, y=title_y, **title_kws)

    if show_category_legend and cat2color is not None:
        # Only show categories that are actually present (in order)
        present_cats = [c for c in category_order if c in categories.values]
        handles = [mpatches.Patch(color=cat2color[c], label=str(c)) for c in present_cats]
        cg.ax_heatmap.legend(
            handles=handles,
            title=legend_title,
            loc=legend_loc,
            bbox_to_anchor=legend_bbox_to_anchor,
            borderaxespad=0.0,
            frameon=True,
        )

    return cg


In [None]:
def abs_feature_change_from_baseline_day(
    df: pd.DataFrame,
    feature_cols: List[str],
    bird_col: str = "bird",
    day_col: str = "day",
    suffix: str = "_abs_change_in_pct",
) -> pd.DataFrame:
    # 1) mean per bird/day
    out = (
        df[[bird_col, day_col, *feature_cols]]
        .groupby([bird_col, day_col], as_index=False)
        .mean()
    )

    # 2) sort so "first" = lowest day per bird
    out = out.sort_values([bird_col, day_col])

    # 3) baseline = first day per bird (broadcast to all rows)
    baseline = out.groupby(bird_col)[feature_cols].transform("first")

    # 4) absolute change vs baseline
    out[[f"{c}{suffix}" for c in feature_cols]] = (out[feature_cols] - baseline).abs() /(baseline.abs() + 1e-8)

    return out.reset_index(drop=True)


In [None]:
from pathlib import Path
def save_2d_kde_plots(
    df: pd.DataFrame,
    features: Iterable[str],
    out_dir: str ,
    x_col: str = "neural_change_statistic",
    cmap: str = "inferno",
    dpi: int = 300,
) -> List:
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    saved = []

    for y_col in features:
        d = (
            df[[x_col, y_col]]
            .replace([np.inf, -np.inf], np.nan)
            .dropna()
        )

        if len(d) < 10:
            continue

        fig, ax = plt.subplots(figsize=(6, 5))

        sns.kdeplot(
            data=d,
            x=x_col,
            y=y_col,
            fill=True,
            cmap=cmap,
            levels=60,
            thresh=0,
            ax=ax,
        )

        ax.set_title(f"{x_col} vs {y_col}")
        ax.set_xlabel(x_col)
        ax.set_ylabel(y_col)

        fig.tight_layout()

        safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in y_col)
        out_path = out_dir / f"kde2d_{x_col}_vs_{safe_name}.png"
        fig.savefig(out_path, dpi=dpi)
        plt.close(fig)

        saved.append(out_path)

    return saved


---

## 1. Read in data

In [None]:
fpath = "all_birds_all_features_final.pkl"
with open(fpath, "rb") as f:
    song_dicts = pickle.load(f)

In [None]:
acoustic_data = birdsong_list_to_acoustic_df(song_dicts)
acoustic_data.head()

In [None]:
neural_data = birdsong_list_to_neural_df(song_dicts)
neural_data.head()

---

## 2. Exploratory data analysis

In the following, we will quickly explore the data by assessing the following:

- distribution of songs across birds
- distribution of ages across birds
- distribution of n_frames across birds

### 2.1. Number of songs

In [None]:
bird_order = acoustic_data["bird"].value_counts(ascending=False).index

fig, ax = plt.subplots(figsize=(4,4))
ax = sns.countplot(
    data=acoustic_data,
    x="bird",
    hue="bird",
    palette=bird_palette,
    order=bird_order,
    ax=ax
)

# Add counts on top of bars
for container in ax.containers:
    ax.bar_label(container, padding=3)

ax.set_ylabel("Number of songs")
ax.set_xlabel("Bird")

plt.tight_layout()
plt.show()
fig.savefig("plots/metadata/n_songs_pre_filtering.png", dpi=300)
plt.close()


### 2.2. Days post hatch (age)

In [None]:
unique_df = acoustic_data.drop_duplicates(subset=["bird", "dph"])
fig, ax = plt.subplots(figsize=(4, 4))

# Boxplot
sns.boxplot(
    data=unique_df,
    x="bird",
    y="dph",
    hue="bird",
    dodge=False,
    order=bird_order,
    palette=bird_palette,
    ax=ax,
    showfliers=False
)

# Jittered points
sns.stripplot(
    data=unique_df,
    x="bird",
    y="dph",
    color="black",
    order=bird_order,
    alpha=0.6,
    jitter=0.15,
    size=4,
    ax=ax
)

ax.set_ylabel("Days post hatch")
ax.set_xlabel("Bird")
plt.tight_layout()
plt.show()
fig.savefig("plots/metadata/days_post_hatch_pre_filtering.png", dpi=300)

In [None]:
unique_df.groupby("bird").dph.describe()

### 2.3. Song length (n_frames)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
sns.boxplot(
    data=acoustic_data,
    x="bird",
    y="n_frames",
    hue="bird",
    dodge=False,
    order=bird_order,
    palette=bird_palette,
    ax=ax,
    showfliers=False  # optional: hide outlier dots (since we add jitter)
)

# Jittered points


ax.set_ylim(0, 500)
ax.set_ylabel("Number of frames")
ax.set_xlabel("Bird")
plt.tight_layout()
plt.show()
fig.savefig("plots/metadata/n_frames_pre_filtering.png", dpi=300)
plt.close()

print("Data for {}/{} songs outside of the plotted range are not shown.".format(np.sum(acoustic_data.n_frames > 500), len(acoustic_data)))


In [None]:
acoustic_data.groupby("bird").n_frames.describe()

In [None]:
acoustic_data.groupby("bird").n_frames.quantile(0.95)

---

## 4. Dimensionality reduction

To further study the data structure, we will perform a dimensionality reduction of the song bird data using both the acoustic and neural features. In particular, we will compute and visualize the following:

- Principal component analysis of the concatenated acoustic features representing a single song
- UMAP embedding of the resulting principal component representation
- UMAP embedding of the neural features for each song



### 4.1. Filtering out outliers (large songs)

Importantly, however given the above plots, we will filter out any songs that contain more than 443 frames, which corresponds to the maximum of the 99%tiles of the number of frames across birds.

In [None]:
filtered_song_dicts = filter_songs_by_n_frames(song_dicts, max_frames=380)
print("Dropped {}/{} songs with more than 380 frames...".format(len(song_dicts)-len(filtered_song_dicts), len(song_dicts)))

This removes 73 songs and we will from now on only use the filtered data for the consecutive analyses.

In [None]:
acoustic_data = birdsong_list_to_acoustic_df(filtered_song_dicts)
neural_data = birdsong_list_to_neural_df(filtered_song_dicts)

In [None]:
acoustic_data.columns

We will now briefly recreate the plots from above now using the filtered data

In [None]:
bird_order = acoustic_data["bird"].value_counts(ascending=False).index

fig, ax = plt.subplots(figsize=(4,4))
ax = sns.countplot(
    data=acoustic_data,
    x="bird",
    hue="bird",
    palette=bird_palette,
    order=bird_order,
    ax=ax
)

# Add counts on top of bars
for container in ax.containers:
    ax.bar_label(container, padding=3)

ax.set_ylabel("Number of songs (filtered)")
ax.set_xlabel("Bird")
ax.set_ylim(0, 10000)

plt.tight_layout()
plt.show()
fig.savefig("plots/metadata/n_songs_post_filtering.png", dpi=300)
plt.close()


In [None]:
unique_df = acoustic_data.drop_duplicates(subset=["bird", "dph"])
fig, ax = plt.subplots(figsize=(4, 4))

# Boxplot
sns.boxplot(
    data=unique_df,
    x="bird",
    y="dph",
    hue="bird",
    dodge=False,
    order=bird_order,
    palette=bird_palette,
    ax=ax,
    showfliers=False
)

# Jittered points
sns.stripplot(
    data=unique_df,
    x="bird",
    y="dph",
    color="black",
    order=bird_order,
    alpha=0.6,
    jitter=0.15,
    size=4,
    ax=ax
)

ax.set_ylabel("Days post hatch (filtered)")
ax.set_xlabel("Bird")
plt.tight_layout()
fig.savefig("plots/metadata/days_post_hatch_post_filtering.png", dpi=300)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
sns.boxplot(
    data=acoustic_data,
    x="bird",
    y="n_frames",
    hue="bird",
    dodge=False,
    order=bird_order,
    palette=bird_palette,
    ax=ax,
    showfliers=False  # optional: hide outlier dots (since we add jitter)
)

# Jittered points


ax.set_ylabel("Number of frames (filtered)")
ax.set_xlabel("Bird")
plt.tight_layout()
plt.show()
fig.savefig("plots/metadata/n_frames_post_filtering.png", dpi=300)
plt.close()


----

### 4.2. Compute PCA and UMAP embeddings

We will now use the expand the acoustic dataset by adding the first 50 principal components and the corresponding UMAP embeddings of the unrolled acoustic feature representation as well as a UMAP computed directly on the average features.

In [None]:
full_acoustic_data, *_ = add_pca_umap_from_unrolled_acoustics(
    songs=filtered_song_dicts,
    acoustic_df=acoustic_data,
    n_pcs=30,
    umap_min_dist=0.3,
    umap_n_neighbors=15,
)
full_acoustic_data = add_mean_umap(full_acoustic_data, min_dist=0.3, n_neighbors=15)

full_acoustic_data.head()



In [None]:
neural_data.columns

In [None]:
neural_features = filtered_song_dicts[0]["neural"][-100]["fnames"]
full_neural_data = add_neural_umap(neural_data, feature_columns=neural_features, lag_column="time_lag", min_dist=0.1, n_neighbors=15, random_state=1234)

---

### 4.3. Visualizations

We will now visualize the data set, by highlighting each individual feature independently.

In [None]:
all_feature_columns = ["dph", "n_frames"] +[c for c in full_acoustic_data.columns if c.startswith("mean_") or c.startswith("std_")]
all_feature_columns

#### 4.3.1. Acoustic features

##### 4.3.1a. Unrolled features

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

sns.scatterplot(
    data=full_acoustic_data,
    x="umap_1",
    y="umap_2",
    hue="bird",
    s=4,
    ax=ax,
    palette=bird_palette,
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Unrolled acoustic data")
legend = ax.legend(
    title="Bird",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)
fig.savefig("plots/acoustic/unrolled/unrolled_umap_acoustic_bird.png", dpi=300)
plt.show()
plt.close()


In [None]:
for f in all_feature_columns:
    fig, ax, vmin, vmax = seaborn_scatter_with_colorbar(
    data=full_acoustic_data,
    x="umap_1",
    y="umap_2",
    hue=f,
    title="Unrolled acoustic data",
    figsize=[8, 6],
)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    fig.savefig("plots/acoustic/unrolled/unrolled_umap_{}_all_birds.png".format(f), dpi=300)
    plt.show()
    plt.close()

    # for b in full_acoustic_data.bird.unique():
    #     fig, ax, *_ = seaborn_scatter_with_colorbar(
    #     data=full_acoustic_data.loc[full_acoustic_data.bird == b],
    #     x="umap_1",
    #     y="umap_2",
    #     hue=f,
    #     title="Unrolled acoustic data ({})".format(b),
    #     figsize=[8, 6],
    #     vmin=vmin,
    #     vmax=vmax
    # )
    #     ax.set_xlabel("UMAP 1")
    #     ax.set_ylabel("UMAP 2")
    #     fig.savefig("plots/acoustic/unrolled/unrolled_umap_{}_{}_bird.png".format(f,b), dpi=300)
    #     #plt.show()
    #     plt.close()

    fig, ax = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
    ax = ax.flatten()
    i = 0
    for b in full_acoustic_data.bird.unique():
        _, ax[i], *_ = seaborn_scatter_with_colorbar(
            data=full_acoustic_data.loc[full_acoustic_data.bird == b],
            x="umap_1",
            y="umap_2",
            hue=f,
            title="Mean acoustic data ({})".format(b),
            figsize=[8, 6],
            vmin=vmin,
            vmax=vmax,
            ax=ax[i],
            s=3
        )
        ax[i].set_xlabel("UMAP 1")
        ax[i].set_ylabel("UMAP 2")
        i += 1
    fig.tight_layout()
    fig.savefig("plots/acoustic/unrolled/unrolled_umap_{}_by_bird.png".format(f), dpi=300)
    plt.close()

----

##### 4.3.1b. Mean features

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

sns.scatterplot(
    data=full_acoustic_data,
    x="umap_mean_1",
    y="umap_mean_2",
    hue="bird",
    s=4,
    ax=ax,
    palette=bird_palette,
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Mean acoustic data")
legend = ax.legend(
    title="Bird",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)
fig.savefig("plots/acoustic/mean/mean_umap_acoustic_bird.png", dpi=300)
plt.show()
plt.close()


In [None]:
for f in all_feature_columns:
    fig, ax, vmin, vmax = seaborn_scatter_with_colorbar(
    data=full_acoustic_data,
    x="umap_mean_1",
    y="umap_mean_2",
    hue=f,
    title="Mean acoustic data",
    figsize=[8, 6],
)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    fig.savefig("plots/acoustic/mean/mean_umap_{}_all_birds.png".format(f), dpi=300)
    plt.show()
    plt.close()

    # for b in full_acoustic_data.bird.unique():
    #     fig, ax, *_ = seaborn_scatter_with_colorbar(
    #     data=full_acoustic_data.loc[full_acoustic_data.bird == b],
    #     x="umap_mean_1",
    #     y="umap_mean_2",
    #     hue=f,
    #     title="Mean acoustic data ({})".format(b),
    #     figsize=[8, 6],
    #     vmin=vmin,
    #     vmax=vmax
    # )
    #     ax.set_xlabel("UMAP 1")
    #     ax.set_ylabel("UMAP 2")
    #     fig.savefig("plots/acoustic/mean/mean_umap_{}_{}_bird.png".format(f,b), dpi=300)
    #     #plt.show()
    #     plt.close()

    fig, ax = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
    ax = ax.flatten()
    i = 0
    for b in full_acoustic_data.bird.unique():
        _, ax[i], *_ = seaborn_scatter_with_colorbar(
            data=full_acoustic_data.loc[full_acoustic_data.bird == b],
            x="umap_mean_1",
            y="umap_mean_2",
            hue=f,
            title="Mean acoustic data ({})".format(b),
            figsize=[8, 6],
            vmin=vmin,
            vmax=vmax,
            ax=ax[i],
            s=3
        )
        ax[i].set_xlabel("UMAP 1")
        ax[i].set_ylabel("UMAP 2")
        i += 1
    fig.tight_layout()
    fig.savefig("plots/acoustic/mean/mean_umap_{}_by_bird.png".format(f), dpi=300)
    plt.close()

---

#### 4.3.2. Neural features

In [None]:
neural_features

In [None]:
full_neural_data.columns

##### 4.3.2a. Time lag -100

In [None]:
timelag = -100

fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue="bird",
    s=4,
    ax=ax,
    palette=bird_palette,
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Neural data (lag={})".format(timelag))
legend = ax.legend(
    title="Bird",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)
fig.savefig("plots/neural/lag{}/umap_neural_lag{}_bird.png".format(timelag, timelag), dpi=300)
plt.show()
plt.close()


plot_neural_features = ["dph", "n_frames"] + neural_features
for f in plot_neural_features:
    fig, ax, vmin, vmax = seaborn_scatter_with_colorbar(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue=f,
    title="Neural data (lag={})".format(timelag),
    figsize=[8, 6],
)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_all_birds.png".format(timelag, f, timelag), dpi=300)
    plt.show()
    plt.close()

    # for b in full_acoustic_data.bird.unique():
    #     fig, ax, *_ = seaborn_scatter_with_colorbar(
    #     data=full_acoustic_data.loc[full_acoustic_data.bird == b],
    #     x="umap_mean_1",
    #     y="umap_mean_2",
    #     hue=f,
    #     title="Mean acoustic data ({})".format(b),
    #     figsize=[8, 6],
    #     vmin=vmin,
    #     vmax=vmax
    # )
    #     ax.set_xlabel("UMAP 1")
    #     ax.set_ylabel("UMAP 2")
    #     fig.savefig("plots/acoustic/mean/mean_umap_{}_{}_bird.png".format(f,b), dpi=300)
    #     #plt.show()
    #     plt.close()

    fig, ax = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
    ax = ax.flatten()
    i = 0
    for b in full_acoustic_data.bird.unique():
        _, ax[i], *_ = seaborn_scatter_with_colorbar(
            data=full_neural_data.loc[(full_neural_data.time_lag == timelag) & (full_neural_data.bird == b)],
            x="umap_lag{}_1".format(timelag),
            y="umap_lag{}_2".format(timelag),
            hue=f,
            title="Neural data ({}, lag={})".format(b, timelag),
            figsize=[8, 6],
            vmin=vmin,
            vmax=vmax,
            ax=ax[i],
            s=3
        )
        ax[i].set_xlabel("UMAP 1")
        ax[i].set_ylabel("UMAP 2")
        i += 1
    fig.tight_layout()
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_by_bird.png".format(timelag, f, timelag), dpi=300)
    plt.close()

----

##### 4.3.2b. Time lag -50

In [None]:
timelag = -50

fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue="bird",
    s=4,
    ax=ax,
    palette=bird_palette,
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Neural data (lag={})".format(timelag))
legend = ax.legend(
    title="Bird",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)
fig.savefig("plots/neural/lag{}/umap_neural_lag{}_bird.png".format(timelag, timelag), dpi=300)
plt.show()
plt.close()


plot_neural_features = ["dph", "n_frames"] + neural_features
for f in plot_neural_features:
    fig, ax, vmin, vmax = seaborn_scatter_with_colorbar(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue=f,
    title="Neural data (lag={})".format(timelag),
    figsize=[8, 6],
)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_all_birds.png".format(timelag, f, timelag), dpi=300)
    plt.show()
    plt.close()

    # for b in full_acoustic_data.bird.unique():
    #     fig, ax, *_ = seaborn_scatter_with_colorbar(
    #     data=full_acoustic_data.loc[full_acoustic_data.bird == b],
    #     x="umap_mean_1",
    #     y="umap_mean_2",
    #     hue=f,
    #     title="Mean acoustic data ({})".format(b),
    #     figsize=[8, 6],
    #     vmin=vmin,
    #     vmax=vmax
    # )
    #     ax.set_xlabel("UMAP 1")
    #     ax.set_ylabel("UMAP 2")
    #     fig.savefig("plots/acoustic/mean/mean_umap_{}_{}_bird.png".format(f,b), dpi=300)
    #     #plt.show()
    #     plt.close()

    fig, ax = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
    ax = ax.flatten()
    i = 0
    for b in full_acoustic_data.bird.unique():
        _, ax[i], *_ = seaborn_scatter_with_colorbar(
            data=full_neural_data.loc[(full_neural_data.time_lag == timelag) & (full_neural_data.bird == b)],
            x="umap_lag{}_1".format(timelag),
            y="umap_lag{}_2".format(timelag),
            hue=f,
            title="Neural data ({}, lag={})".format(b, timelag),
            figsize=[8, 6],
            vmin=vmin,
            vmax=vmax,
            ax=ax[i],
            s=3
        )
        ax[i].set_xlabel("UMAP 1")
        ax[i].set_ylabel("UMAP 2")
        i += 1
    fig.tight_layout()
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_by_bird.png".format(timelag, f, timelag), dpi=300)
    plt.close()

---

##### 4.3.2c. Time lag 0

In [None]:
timelag = 0

fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue="bird",
    s=4,
    ax=ax,
    palette=bird_palette,
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Neural data (lag={})".format(timelag))
legend = ax.legend(
    title="Bird",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)
fig.savefig("plots/neural/lag{}/umap_neural_lag{}_bird.png".format(timelag, timelag), dpi=300)
plt.show()
plt.close()


plot_neural_features = ["dph", "n_frames"] + neural_features
for f in plot_neural_features:
    fig, ax, vmin, vmax = seaborn_scatter_with_colorbar(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue=f,
    title="Neural data (lag={})".format(timelag),
    figsize=[8, 6],
)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_all_birds.png".format(timelag, f, timelag), dpi=300)
    plt.show()
    plt.close()

    # for b in full_acoustic_data.bird.unique():
    #     fig, ax, *_ = seaborn_scatter_with_colorbar(
    #     data=full_acoustic_data.loc[full_acoustic_data.bird == b],
    #     x="umap_mean_1",
    #     y="umap_mean_2",
    #     hue=f,
    #     title="Mean acoustic data ({})".format(b),
    #     figsize=[8, 6],
    #     vmin=vmin,
    #     vmax=vmax
    # )
    #     ax.set_xlabel("UMAP 1")
    #     ax.set_ylabel("UMAP 2")
    #     fig.savefig("plots/acoustic/mean/mean_umap_{}_{}_bird.png".format(f,b), dpi=300)
    #     #plt.show()
    #     plt.close()

    fig, ax = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
    ax = ax.flatten()
    i = 0
    for b in full_acoustic_data.bird.unique():
        _, ax[i], *_ = seaborn_scatter_with_colorbar(
            data=full_neural_data.loc[(full_neural_data.time_lag == timelag) & (full_neural_data.bird == b)],
            x="umap_lag{}_1".format(timelag),
            y="umap_lag{}_2".format(timelag),
            hue=f,
            title="Neural data ({}, lag={})".format(b, timelag),
            figsize=[8, 6],
            vmin=vmin,
            vmax=vmax,
            ax=ax[i],
            s=3
        )
        ax[i].set_xlabel("UMAP 1")
        ax[i].set_ylabel("UMAP 2")
        i += 1
    fig.tight_layout()
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_by_bird.png".format(timelag, f, timelag), dpi=300)
    plt.close()

In [None]:
fig, ax = plt.subplots(figsize=(8, 8), ncols=2, nrows=2)
ax = ax.flatten()
spike_areas = ['spikes_AreaX',
 'spikes_LMAN',
 'spikes_Pallium',
 'spikes_Striatum']
for i, f in enumerate(spike_areas):
    ax[i] = sns.boxplot(full_neural_data, x="bird", y =f, hue="bird", dodge=False, showfliers=False, order = bird_order, palette=bird_palette, ax = ax[i])
    ax[i].set_xlabel("Bird")
    ax[i].set_ylabel("Spike count")
    ax[i].set_title("{}".format(f.split("_")[1]))
fig.tight_layout()
plt.show()
fig.savefig("plots/neural/boxplot_spikecount_by_region_by_bird.png", dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(8, 8), ncols=2, nrows=2)
ax = ax.flatten()
spike_areas = ['spikes_AreaX',
 'spikes_LMAN',
 'spikes_Pallium',
 'spikes_Striatum']
for i, f in enumerate(spike_areas):
    ax[i] = sns.boxplot(full_neural_data, x="bird", y =full_neural_data[f]/full_neural_data["n_frames"], hue="bird", dodge=False, showfliers=False, order = bird_order, palette=bird_palette, ax = ax[i])
    ax[i].set_xlabel("Bird")
    ax[i].set_ylabel("Spike count per frame")
    ax[i].set_title("{}".format(f.split("_")[1]))
fig.tight_layout()
plt.show()
fig.savefig("plots/neural/boxplot_spikecount_by_region_by_bird_by_frame.png", dpi=300)

---

##### 4.3.2d. Time lag 50

In [None]:
timelag = 50

fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue="bird",
    s=4,
    ax=ax,
    palette=bird_palette,
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Neural data (lag={})".format(timelag))
legend = ax.legend(
    title="Bird",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)
fig.savefig("plots/neural/lag{}/umap_neural_lag{}_bird.png".format(timelag, timelag), dpi=300)
plt.show()
plt.close()


plot_neural_features = ["dph", "n_frames"] + neural_features
for f in plot_neural_features:
    fig, ax, vmin, vmax = seaborn_scatter_with_colorbar(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue=f,
    title="Neural data (lag={})".format(timelag),
    figsize=[8, 6],
)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_all_birds.png".format(timelag, f, timelag), dpi=300)
    plt.show()
    plt.close()

    # for b in full_acoustic_data.bird.unique():
    #     fig, ax, *_ = seaborn_scatter_with_colorbar(
    #     data=full_acoustic_data.loc[full_acoustic_data.bird == b],
    #     x="umap_mean_1",
    #     y="umap_mean_2",
    #     hue=f,
    #     title="Mean acoustic data ({})".format(b),
    #     figsize=[8, 6],
    #     vmin=vmin,
    #     vmax=vmax
    # )
    #     ax.set_xlabel("UMAP 1")
    #     ax.set_ylabel("UMAP 2")
    #     fig.savefig("plots/acoustic/mean/mean_umap_{}_{}_bird.png".format(f,b), dpi=300)
    #     #plt.show()
    #     plt.close()

    fig, ax = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
    ax = ax.flatten()
    i = 0
    for b in full_acoustic_data.bird.unique():
        _, ax[i], *_ = seaborn_scatter_with_colorbar(
            data=full_neural_data.loc[(full_neural_data.time_lag == timelag) & (full_neural_data.bird == b)],
            x="umap_lag{}_1".format(timelag),
            y="umap_lag{}_2".format(timelag),
            hue=f,
            title="Neural data ({}, lag={})".format(b, timelag),
            figsize=[8, 6],
            vmin=vmin,
            vmax=vmax,
            ax=ax[i],
            s=3
        )
        ax[i].set_xlabel("UMAP 1")
        ax[i].set_ylabel("UMAP 2")
        i += 1
    fig.tight_layout()
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_by_bird.png".format(timelag, f, timelag), dpi=300)
    plt.close()

---

##### 4.3.2e. Time lag 100

In [None]:
timelag = 100

fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue="bird",
    s=4,
    ax=ax,
    palette=bird_palette,
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Neural data (lag={})".format(timelag))
legend = ax.legend(
    title="Bird",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)
fig.savefig("plots/neural/lag{}/umap_neural_lag{}_bird.png".format(timelag, timelag), dpi=300)
plt.show()
plt.close()


plot_neural_features = ["dph", "n_frames"] + neural_features
for f in plot_neural_features:
    fig, ax, vmin, vmax = seaborn_scatter_with_colorbar(
    data=full_neural_data.loc[full_neural_data.time_lag == timelag],
    x="umap_lag{}_1".format(timelag),
    y="umap_lag{}_2".format(timelag),
    hue=f,
    title="Neural data (lag={})".format(timelag),
    figsize=[8, 6],
)
    ax.set_xlabel("UMAP 1")
    ax.set_ylabel("UMAP 2")
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_all_birds.png".format(timelag, f, timelag), dpi=300)
    plt.show()
    plt.close()

    # for b in full_acoustic_data.bird.unique():
    #     fig, ax, *_ = seaborn_scatter_with_colorbar(
    #     data=full_acoustic_data.loc[full_acoustic_data.bird == b],
    #     x="umap_mean_1",
    #     y="umap_mean_2",
    #     hue=f,
    #     title="Mean acoustic data ({})".format(b),
    #     figsize=[8, 6],
    #     vmin=vmin,
    #     vmax=vmax
    # )
    #     ax.set_xlabel("UMAP 1")
    #     ax.set_ylabel("UMAP 2")
    #     fig.savefig("plots/acoustic/mean/mean_umap_{}_{}_bird.png".format(f,b), dpi=300)
    #     #plt.show()
    #     plt.close()

    fig, ax = plt.subplots(figsize=(12, 6), ncols=3, nrows=2, sharex=True, sharey=True)
    ax = ax.flatten()
    i = 0
    for b in full_acoustic_data.bird.unique():
        _, ax[i], *_ = seaborn_scatter_with_colorbar(
            data=full_neural_data.loc[(full_neural_data.time_lag == timelag) & (full_neural_data.bird == b)],
            x="umap_lag{}_1".format(timelag),
            y="umap_lag{}_2".format(timelag),
            hue=f,
            title="Neural data ({}, lag={})".format(b, timelag),
            figsize=[8, 6],
            vmin=vmin,
            vmax=vmax,
            ax=ax[i],
            s=3
        )
        ax[i].set_xlabel("UMAP 1")
        ax[i].set_ylabel("UMAP 2")
        i += 1
    fig.tight_layout()
    fig.savefig("plots/neural/lag{}/umap_{}_lag{}_by_bird.png".format(timelag, f, timelag), dpi=300)
    plt.close()

----

## 5. Correlation analyses

Next, we will check if the acoustic feature measurements averaged across songs show correlation with any or a subset of the neural features. To this end, we will compute the Pearson correlation matrix based on the z-scored feature representations across all songs and visualize it.

In [None]:
list(full_acoustic_data.columns)

In [None]:
selected_acoustic_columns = ["song_id",'dph',
 'n_frames',
 'mean_amplitude_rms',
 'mean_frequency_modulation',
 'mean_pitch_hz',
 'mean_spectral_centroid',
 'mean_spectral_rolloff',
 'mean_wiener_entropy',
 'mean_zero_crossing_rate',
 'std_amplitude_rms',
 'std_frequency_modulation',
 'std_pitch_hz',
 'std_spectral_centroid',
 'std_spectral_rolloff',
 'std_wiener_entropy',
 'std_zero_crossing_rate']
acoustic_feature_data = full_acoustic_data.loc[:, selected_acoustic_columns]

In [None]:
selected_neural_columns = ["song_id", 'CO_Embed2_Dist_tau_d_expfit_meandiff', 'CO_FirstMin_ac',
       'CO_HistogramAMI_even_2_5', 'CO_f1ecac', 'CO_trev_1_num',
       'DN_HistogramMode_10', 'DN_HistogramMode_5',
       'DN_OutlierInclude_n_001_mdrmd', 'DN_OutlierInclude_p_001_mdrmd',
       'FC_LocalSimple_mean1_tauresrat', 'FC_LocalSimple_mean3_stderr',
       'IN_AutoMutualInfoStats_40_gaussian_fmmi', 'MD_hrv_classic_pnn40',
       'PD_PeriodicityWang_th0_01', 'SB_BinaryStats_diff_longstretch0',
       'SB_BinaryStats_mean_longstretch1', 'SB_MotifThree_quantile_hh',
       'SB_TransitionMatrix_3ac_sumdiagcov',
       'SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1',
       'SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1',
       'SP_Summaries_welch_rect_area_5_1', 'SP_Summaries_welch_rect_centroid',
       'spikes_AreaX', 'spikes_LMAN', 'spikes_Pallium', 'spikes_Striatum']
neural_feature_data_lag_neg100 = full_neural_data.loc[full_neural_data.time_lag == -100, selected_neural_columns]
neural_feature_data_lag_neg50 = full_neural_data.loc[full_neural_data.time_lag == -50, selected_neural_columns]
neural_feature_data_lag0 = full_neural_data.loc[full_neural_data.time_lag == 0, selected_neural_columns]
neural_feature_data_lag_pos50 = full_neural_data.loc[full_neural_data.time_lag == 50, selected_neural_columns]
neural_feature_data_lag_pos100 = full_neural_data.loc[full_neural_data.time_lag == 100, selected_neural_columns]

In [None]:
feature_category_dict = {}
for c in acoustic_feature_data.columns:
    if c in ["song_id","n_frames", "dph"]:
        feature_category_dict[c] = "Metadata"
    else:
        feature_category_dict[c] = "Acoustic"
for c in neural_feature_data_lag0.columns:
    if c in ["song_id","n_frames", "dph"]:
        feature_category_dict[c] = "Metadata"
    else:
        feature_category_dict[c] = "Neural"
feature_category_dict

In [None]:
corrs_neg100 = merged_correlation_matrices(acoustic_feature_data, neural_feature_data_lag_neg100, id_col="song_id", how="inner")

cg = clustered_corr_heatmap(
    corrs_neg100["pearson"],
    feature_to_category=feature_category_dict,
    dendrogram_ratio=(0.06, 0.06),
    title="Pearson correlations (lag=-100)",
    category_palette={"Metadata":"k", "Neural":"tab:green", "Acoustic":"tab:orange"},
    show_category_legend=False,
    cmap="seismic",
    method="average"
)
pos = cg.cax.get_position()
cg.cax.set_position([
    pos.x0,
    pos.y0,
    0.02,
    pos.height *1.2
])
cg.savefig("plots/correlation_maps/pearson_correlation_acoustic_neural_features_lag-100.png", dpi=300)

In [None]:
corrs_neg50 = merged_correlation_matrices(acoustic_feature_data, neural_feature_data_lag_neg50, id_col="song_id", how="inner")

cg = clustered_corr_heatmap(
    corrs_neg50["pearson"],
    feature_to_category=feature_category_dict,
    dendrogram_ratio=(0.06, 0.06),
    title="Pearson correlations (lag=-50)",
    category_palette={"Metadata":"k", "Neural":"tab:green", "Acoustic":"tab:orange"},
    show_category_legend=False,
    cmap="seismic",
    method="average"
)
pos = cg.cax.get_position()
cg.cax.set_position([
    pos.x0,
    pos.y0,
    0.02,
    pos.height *1.2
])
cg.savefig("plots/correlation_maps/pearson_correlation_acoustic_neural_features_lag-50.png", dpi=300)

In [None]:
corrs_0 = merged_correlation_matrices(acoustic_feature_data, neural_feature_data_lag0, id_col="song_id", how="inner")

cg = clustered_corr_heatmap(
    corrs_0["pearson"],
    feature_to_category=feature_category_dict,
    dendrogram_ratio=(0.06, 0.06),
    title="Pearson correlations (lag=0)",
    category_palette={"Metadata":"k", "Neural":"tab:green", "Acoustic":"tab:orange"},
    show_category_legend=False,
    cmap="seismic",
    method="average"
)
pos = cg.cax.get_position()
cg.cax.set_position([
    pos.x0,
    pos.y0,
    0.02,
    pos.height *1.2
])
cg.savefig("plots/correlation_maps/pearson_correlation_acoustic_neural_features_lag0.png", dpi=300)

In [None]:
corrs_pos50 = merged_correlation_matrices(acoustic_feature_data, neural_feature_data_lag_pos50, id_col="song_id", how="inner")

cg = clustered_corr_heatmap(
    corrs_pos50["pearson"],
    feature_to_category=feature_category_dict,
    dendrogram_ratio=(0.06, 0.06),
    title="Pearson correlations (lag=50)",
    category_palette={"Metadata":"k", "Neural":"tab:green", "Acoustic":"tab:orange"},
    show_category_legend=False,
    cmap="seismic",
    method="average"
)
pos = cg.cax.get_position()
cg.cax.set_position([
    pos.x0,
    pos.y0,
    0.02,
    pos.height *1.2
])
cg.savefig("plots/correlation_maps/pearson_correlation_acoustic_neural_features_lag50.png", dpi=300)

In [None]:
corrs_pos100 = merged_correlation_matrices(acoustic_feature_data, neural_feature_data_lag_pos100, id_col="song_id", how="inner")

cg = clustered_corr_heatmap(
    corrs_pos100["pearson"],
    feature_to_category=feature_category_dict,
    dendrogram_ratio=(0.06, 0.06),
    title="Pearson correlations (lag=100)",
    category_palette={"Metadata":"k", "Neural":"tab:green", "Acoustic":"tab:orange"},
    show_category_legend=False,
    cmap="seismic",
    method="average"
)
pos = cg.cax.get_position()
cg.cax.set_position([
    pos.x0,
    pos.y0,
    0.02,
    pos.height *1.2
])
cg.savefig("plots/correlation_maps/pearson_correlation_acoustic_neural_features_lag100.png", dpi=300)

----

## 6. Cluster analyses

Next, we will use a simple Gaussian mixture model to group songs into specific groups based on i) their acoustic, ii) neural and iii) both their acoustic and neural features together. To this end, we will use the following feature representations:

- Acoustic features of the songs represented by their averaged values
- Neural features are represented as they are.

Again, it goes without saying that we will normalize the data before fitting the GMMs.

Using the cluster results, we will assess the distribution of the individual features by cluster to identify which features determine the clustering and assess the co-clustering of the songs by neural and acoustic features using a number of different cluster metrics.

### 6.1. Acoustic data

In [None]:
mean_cols = [c for c in full_acoustic_data.columns if c.startswith("mean_")]
viz, best_k = plot_k_elbow(full_acoustic_data, mean_cols, AgglomerativeClustering(), k_range=(4, 12))
print("Suggested k:", best_k)

Given that we achieve a maximum Calinski-Harabasz score at 7 within the range from 4 to 12 clusters, we will choose this number of clusters for the further analyses.

In [None]:
full_acoustic_data, _ = add_cluster_labels(
    full_acoustic_data,
    feature_cols=mean_cols,
    model=AgglomerativeClustering(),
    k=7,
    k_param="n_clusters",
    cluster_col="ward_clusters",
)


In [None]:
sns.set_style("ticks")
fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=full_acoustic_data,
    x="umap_mean_1",
    y="umap_mean_2",
    hue="ward_clusters",
    s=4,
    ax=ax,
    palette="Set1",
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Mean acoustic data")
legend = ax.legend(
    title="Clusters",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)

plt.show()
fig.savefig("plots/cluster_analyses/mean_umap_acoustic_clusters.png", dpi=300)
plt.close()


In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=full_acoustic_data,
    x="umap_1",
    y="umap_2",
    hue="ward_clusters",
    s=4,
    ax=ax,
    palette="Set1",
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Mean acoustic data")
legend = ax.legend(
    title="Clusters",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)

plt.show()
fig.savefig("plots/cluster_analyses/umap_unrolled_all_birds_cluster.png", dpi=300)
plt.close()


We now visualize the distribution of songs per cluster for the individual birds.

In [None]:
with sns.axes_style("whitegrid"):
    plot_cluster_percent_by_bird(
    full_acoustic_data,
    bird_col="bird",
    cluster_col="ward_clusters",
    palette=cluster_palette,
    min_label_pct=5,
)

    plt.tight_layout()
    plt.show()
    plt.gcf().savefig("plots/cluster_analyses/acoustic_cluster_composition.png")


In [None]:
all_acoustic_feature_cols = mean_cols + [c for c in full_acoustic_data.columns if c.startswith("std_")]
sns.set_style("whitegrid")
fig, ax, med = matrixplot_median_by_cluster(
    full_acoustic_data,
    cluster_col="ward_clusters",
    feature_cols=["dph", "n_frames"] + all_acoustic_feature_cols,
    normalize="columns",
    cmap="coolwarm",
    figsize=[8, 5],
    cbar_label="Scaled median",
)
ax.set_xlabel("Features")
ax.set_ylabel("Clusters")
plt.xticks(rotation=90)
ax.set_title("Feature profiles of the acoustic clusters")
plt.show()
fig.savefig("plots/cluster_analyses/feature_profiles_acoustic_cluster.png", dpi=300)


In [None]:
sns.set_style("ticks")
sns.boxplot(data=full_acoustic_data,x = "ward_clusters", y = "std_pitch_hz")
plt.show()

---

### 6.2. Neural features

Since we previously saw little difference between the time lags, we will only add the clusters for lag 0 here.

In [None]:
viz, best_k = plot_k_elbow(full_neural_data.loc[full_neural_data.time_lag == 0], neural_features, AgglomerativeClustering(), k_range=(4, 12))
print("Suggested k:", best_k)

In [None]:
full_neural_data_lag0, _ = add_cluster_labels(
    full_neural_data.loc[full_neural_data.time_lag == 0],
    feature_cols=neural_features,
    model=AgglomerativeClustering(),
    k=4,
    k_param="n_clusters",
    cluster_col="ward_clusters_lag0",
)

In [None]:
sns.set_style("ticks")
fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=full_neural_data_lag0,
    x="umap_lag0_1",
    y="umap_lag0_2",
    hue="ward_clusters_lag0",
    s=4,
    ax=ax,
    palette="Set2",
)

ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title("Neural data (lag=0)")
legend = ax.legend(
    title="Clusters",
    fontsize=12,        # label font size
    title_fontsize=12,  # title font size
    markerscale=3,
    frameon = False# scales up legend marker size
)

plt.show()

fig.savefig("plots/cluster_analyses/umap_neural_features_lag0_cluster.png")
plt.close()


In [None]:
neural_cluster_palette = {0:"#66c2a5",1:"#fc8d62",2:"#8da0cb",3:"#e78ac3"}

In [None]:
with sns.axes_style("whitegrid"):
    plot_cluster_percent_by_bird(
    full_neural_data_lag0,
    bird_col="bird",
    cluster_col="ward_clusters_lag0",
    palette=neural_cluster_palette,
    min_label_pct=5,
)

    plt.tight_layout()
    plt.show()
    plt.gcf().savefig("plots/cluster_analyses/neural_cluster_composition_by_bird.png")


We again see less bird-specific neural patterns, except for r14n14 which has the lowest absolute number of songs but a lot of them fall into cluster 1. We will now briefly characterize the feature profiles of these clusters as we had for the acoustic clusters.

In [None]:
fig, ax, med = matrixplot_median_by_cluster(
    full_neural_data_lag0,
    cluster_col="ward_clusters_lag0",
    feature_cols=["dph", "n_frames"] + neural_features,
    normalize="columns",
    cmap="coolwarm",
    figsize=[8, 5],
    cbar_label="Scaled median",
)
ax.set_xlabel("Features")
ax.set_ylabel("Clusters")
plt.xticks(rotation=90)
ax.set_title("Feature profiles of the neural clusters")
plt.show()
fig.savefig("plots/cluster_analyses/feature_profile_neural_clusters_lag0.png", dpi=300)
plt.close()


---

### 6.3. Cluster agreement

In [None]:
cluster_summary = (full_neural_data_lag0[["song_id",'bird', 'ward_clusters_lag0']]
      .merge(full_acoustic_data[["song_id", 'ward_clusters']], on='song_id', how='inner')
      .copy())

# make them categorical for nicer ordering/labels
cluster_summary['neural_clusters_lag0'] = cluster_summary['ward_clusters_lag0'].astype(str)
cluster_summary['acoustic_clusters'] = cluster_summary['ward_clusters'].astype(str)
cluster_summary['bird'] = cluster_summary['bird'].astype(str)
cluster_summary = cluster_summary.drop(columns = ["ward_clusters_lag0", "ward_clusters"])

cluster_summary.head()

In [None]:
hex_bird_palette = {
    'g4r4':   '#1f77b4',  # tab:blue
    'j8v8':   '#ff7f0e',  # tab:orange
    'o11y3':  '#2ca02c',  # tab:green
    'r11n11': '#d62728',  # tab:red
    'r14n14': '#9467bd',  # tab:purple
    'r15v15': '#8c564b',  # tab:brown
}


fig, song_level = plot_sankey_clusters(
    cluster_summary,
    bird_col="bird",
    cluster1_col="acoustic_clusters",
    cluster2_col="neural_clusters_lag0",
    cluster1_label="Acoustic",
    cluster2_label="Neural",
    bird_colors=hex_bird_palette,
    cluster1_colors = cluster_palette,
    cluster2_colors = neural_cluster_palette,
    unit_col="song_id",        # <-- IMPORTANT: one row per song
    unit_how="mode",
    title="Cluster agreement (bird - acoustic - neural (lag=0))",
)
img = pio.to_image(fig, format="png", engine="kaleido")  # engine explicit
with open("plots/sankey_plot_clustering_birds.png", "wb") as f:
    f.write(img)


---

## 7. Analyses of the daily feature differences

Finally, we will now investigate if the average differences of the average neural and accoustic features across day per bird show similar patterns. To this end, we first compute the mean acoustic song features averaged per bird and day and the neural features averaged per bird and day.

### 7.1. Across birds

In [None]:
full_acoustic_data.groupby(["bird", "dph"]).mean(numeric_only=True).reset_index().head()

In [None]:
all_acoustic_feature_cols

In [None]:
acoustic_feature_change_per_daybird = abs_feature_change_from_baseline_day(
    full_acoustic_data,
    feature_cols=mean_cols,
    bird_col="bird",
    day_col="dph",
)
acoustic_feature_change_per_daybird.head()

In [None]:
neural_features

In [None]:
neural_feature_change_per_daybird = abs_feature_change_from_baseline_day(
    full_neural_data_lag0,
    feature_cols=neural_features,
    bird_col="bird",
    day_col="dph",
)
neural_feature_change_per_daybird.head()

We have now computed also the absolute difference the day/bird averages with respect to the small day (post hatch). To account for the fact that all features have different scales, we normalized them with respect to the absolute value of the baseline, such we now have for each feature the relative change (with respect to the baseline) value of the neural and acoustic features, where for the acoustic features we used not the features for each frame but already those aggregated by song via averaging as described before.

We will now compute one summary statistic for the change in neural activity, which is simply given by the mean of the relative changes per feature at each day with respect to the value at the baseline day.

We will now assess the correlation of each of these neural change summary statistics with the change of the acoustic features individually to see if the variation of any of the acoustic feature can be explained by an overall change in neural activity.

In [None]:
combined_feature_change = pd.merge(neural_feature_change_per_daybird, acoustic_feature_change_per_daybird, on=["bird", "dph"], how="inner")
combined_feature_change["neural_change_statistic"] = combined_feature_change.loc[:, ["{}_abs_change_in_pct".format(f) for f in neural_features]].mean(axis=1)
combined_feature_change = combined_feature_change.loc[combined_feature_change.neural_change_statistic > 0]
combined_feature_change = combined_feature_change.loc[:, ["bird", "dph", "neural_change_statistic"] + ["{}_abs_change_in_pct".format(f) for f in mean_cols]]
combined_feature_change.head()

In [None]:
features = [
    "mean_amplitude_rms_abs_change_in_pct",
    "mean_frequency_modulation_abs_change_in_pct",
    "mean_pitch_hz_abs_change_in_pct",
    "mean_spectral_centroid_abs_change_in_pct",
    "mean_spectral_rolloff_abs_change_in_pct",
    "mean_wiener_entropy_abs_change_in_pct",
    "mean_zero_crossing_rate_abs_change_in_pct",
]

save_2d_kde_plots(
    combined_feature_change,
    features=features,
    out_dir="plots/kde_plots",
)

In addition to the summarized statistic we also assess the individual correlation.

In [None]:
full_combine_feature_change = pd.merge(neural_feature_change_per_daybird, acoustic_feature_change_per_daybird, on=["bird", "dph"], how="inner")
full_combine_feature_change = full_combine_feature_change.loc[: , [c for c in full_combine_feature_change.columns if "_in_pct" in c]]
change_corr_mtx = full_combine_feature_change.corr()



In [None]:
change_feature_categories = []
for c in change_corr_mtx.columns:
    if c in neural_feature_change_per_daybird.columns:
        change_feature_categories.append("Neural")
    else:
        change_feature_categories.append("Acoustic")


In [None]:
cg = clustered_corr_heatmap(
    change_corr_mtx,
    feature_to_category=dict(zip(list(change_corr_mtx.columns), change_feature_categories)),
    dendrogram_ratio=(0.06, 0.06),
    title="Pearson correlations abs. rel. day feature change (lag=0)",
    category_palette={"Metadata":"k", "Neural":"tab:green", "Acoustic":"tab:orange"},
    show_category_legend=False,
    cmap="seismic",
    method="average"
)
pos = cg.cax.get_position()
cg.cax.set_position([
    pos.x0,
    pos.y0,
    0.02,
    pos.height *1.2
])
cg.savefig("plots/correlation_maps/neural_acoustic_feature_change_correlation.png", dpi=300)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

sns.kdeplot(
            data=full_combine_feature_change,
            x="mean_wiener_entropy_abs_change_in_pct",
            y="DN_OutlierInclude_p_001_mdrmd_abs_change_in_pct",
            fill=True,
            cmap="inferno",
            levels=60,
            thresh=0,
            ax=ax,
        )

ax.set_title("")
ax.set_xlabel("Change in Wiener Entropy")
ax.set_ylabel("Change in DN Outliers (p001_mdrmd)")
ax.set_xlim([-1,2])
ax.set_ylim([-4.5,10])
fig.savefig("plots/kde_plots/wiener_entropy_vs_dn_outlier_change_inferno.png", dpi=300)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

sns.kdeplot(
            data=full_combine_feature_change,
            x="mean_wiener_entropy_abs_change_in_pct",
            y="DN_OutlierInclude_p_001_mdrmd_abs_change_in_pct",
            ax=ax,
        )

ax.set_title("")
ax.set_xlabel("Change in Wiener Entropy")
ax.set_ylabel("Change in DN Outliers (p001_mdrmd)")
fig.tight_layout()
fig.savefig("plots/kde_plots/wiener_entropy_vs_dn_outlier_change.png", dpi=300)
plt.show()

---
### 7.2. By-bird-by-lag-view

In [None]:
selected_neural_ext_columns = ["song_id", "bird", "dph",'CO_Embed2_Dist_tau_d_expfit_meandiff', 'CO_FirstMin_ac',
       'CO_HistogramAMI_even_2_5', 'CO_f1ecac', 'CO_trev_1_num',
       'DN_HistogramMode_10', 'DN_HistogramMode_5',
       'DN_OutlierInclude_n_001_mdrmd', 'DN_OutlierInclude_p_001_mdrmd',
       'FC_LocalSimple_mean1_tauresrat', 'FC_LocalSimple_mean3_stderr',
       'IN_AutoMutualInfoStats_40_gaussian_fmmi', 'MD_hrv_classic_pnn40',
       'PD_PeriodicityWang_th0_01', 'SB_BinaryStats_diff_longstretch0',
       'SB_BinaryStats_mean_longstretch1', 'SB_MotifThree_quantile_hh',
       'SB_TransitionMatrix_3ac_sumdiagcov',
       'SC_FluctAnal_2_dfa_50_1_2_logi_prop_r1',
       'SC_FluctAnal_2_rsrangefit_50_1_logi_prop_r1',
       'SP_Summaries_welch_rect_area_5_1', 'SP_Summaries_welch_rect_centroid',
       'spikes_AreaX', 'spikes_LMAN', 'spikes_Pallium', 'spikes_Striatum']


target_acoustic_features = [
    "mean_amplitude_rms_abs_change_in_pct",
    "mean_frequency_modulation_abs_change_in_pct",
    "mean_pitch_hz_abs_change_in_pct",
    "mean_spectral_centroid_abs_change_in_pct",
    "mean_spectral_rolloff_abs_change_in_pct",
    "mean_wiener_entropy_abs_change_in_pct",
    "mean_zero_crossing_rate_abs_change_in_pct",
]

In [None]:
neural_data_lag_neg100 = full_neural_data.loc[full_neural_data.time_lag == -100, selected_neural_ext_columns]
neural_data_lag_neg50 = full_neural_data.loc[full_neural_data.time_lag == -50, selected_neural_ext_columns]
neural_data_lag0 = full_neural_data.loc[full_neural_data.time_lag == 0, selected_neural_ext_columns]
neural_data_lag_pos50 = full_neural_data.loc[full_neural_data.time_lag == 50, selected_neural_ext_columns]
neural_data_lag_pos100 = full_neural_data.loc[full_neural_data.time_lag == 100, selected_neural_ext_columns]

#### 7.2.1. Lag -100

In [None]:
neural_feature_change_lagneg100 = abs_feature_change_from_baseline_day(
    neural_data_lag_neg100,
    feature_cols=neural_features,
    bird_col="bird",
    day_col="dph",
)

combined_feature_change_lagneg100 = pd.merge(neural_feature_change_lagneg100, acoustic_feature_change_per_daybird, on=["bird", "dph"], how="inner")
combined_feature_change_lagneg100["neural_change_statistic"] = combined_feature_change_lagneg100.loc[:, ["{}_abs_change_in_pct".format(f) for f in neural_features]].mean(axis=1)

corr_lagneg100_by_bird = corr_by_bird(combined_feature_change_lagneg100, target_acoustic_features, target_col="neural_change_statistic")

fig, ax = plt.subplots(figsize=[6,8])
ax = sns.heatmap(corr_lagneg100_by_bird, annot=True, cmap="seismic", fmt = ".2f")
ax.set_title("Correlation neural activity change (lag=-100)")
fig.tight_layout()
fig.savefig("plots/kde_plots/correlation_neural_activity_change_by_bird_lag-100.png", dpi=300)
plt.show()

#### 7.2.2. Lag -50

In [None]:
neural_feature_change_lagneg50 = abs_feature_change_from_baseline_day(
    neural_data_lag_neg50,
    feature_cols=neural_features,
    bird_col="bird",
    day_col="dph",
)

combined_feature_change_lagneg50 = pd.merge(neural_feature_change_lagneg50, acoustic_feature_change_per_daybird, on=["bird", "dph"], how="inner")
combined_feature_change_lagneg50["neural_change_statistic"] = combined_feature_change_lagneg50.loc[:, ["{}_abs_change_in_pct".format(f) for f in neural_features]].mean(axis=1)

corr_lagneg50_by_bird = corr_by_bird(combined_feature_change_lagneg50, target_acoustic_features, target_col="neural_change_statistic")

fig, ax = plt.subplots(figsize=[6,8])
ax = sns.heatmap(corr_lagneg50_by_bird, annot=True, cmap="seismic", fmt = ".2f")
ax.set_title("Correlation neural activity change (lag=-50)")
fig.tight_layout()
fig.savefig("plots/kde_plots/correlation_neural_activity_change_by_bird_lag-50.png", dpi=300)
plt.show()

#### 7.2.3. Lag 0

In [None]:
neural_feature_change_lagneg0 = abs_feature_change_from_baseline_day(
    neural_data_lag0,
    feature_cols=neural_features,
    bird_col="bird",
    day_col="dph",
)

combined_feature_change_lagneg0 = pd.merge(neural_feature_change_lagneg0, acoustic_feature_change_per_daybird, on=["bird", "dph"], how="inner")
combined_feature_change_lagneg0["neural_change_statistic"] = combined_feature_change_lagneg0.loc[:, ["{}_abs_change_in_pct".format(f) for f in neural_features]].mean(axis=1)

corr_lagneg0_by_bird = corr_by_bird(combined_feature_change_lagneg0, target_acoustic_features, target_col="neural_change_statistic")

fig, ax = plt.subplots(figsize=[6,8])
ax = sns.heatmap(corr_lagneg0_by_bird, annot=True, cmap="seismic", fmt = ".2f")
ax.set_title("Correlation neural activity change (lag=0)")
fig.tight_layout()
fig.savefig("plots/kde_plots/correlation_neural_activity_change_by_bird_lag0.png", dpi=300)
plt.show()

#### Lag 50

In [None]:
neural_feature_change_lagpos50 = abs_feature_change_from_baseline_day(
    neural_data_lag_pos50,
    feature_cols=neural_features,
    bird_col="bird",
    day_col="dph",
)

combined_feature_change_lagpos50 = pd.merge(neural_feature_change_lagpos50, acoustic_feature_change_per_daybird, on=["bird", "dph"], how="inner")
combined_feature_change_lagpos50["neural_change_statistic"] = combined_feature_change_lagpos50.loc[:, ["{}_abs_change_in_pct".format(f) for f in neural_features]].mean(axis=1)

corr_lagpos50_by_bird = corr_by_bird(combined_feature_change_lagpos50, target_acoustic_features, target_col="neural_change_statistic")

fig, ax = plt.subplots(figsize=[6,8])
ax = sns.heatmap(corr_lagpos50_by_bird, annot=True, cmap="seismic", fmt = ".2f")
ax.set_title("Correlation neural activity change (lag=50)")
fig.tight_layout()
fig.savefig("plots/kde_plots/correlation_neural_activity_change_by_bird_lag_50.png", dpi=300)
plt.show()

In [None]:
neural_feature_change_lagpos100 = abs_feature_change_from_baseline_day(
    neural_data_lag_pos100,
    feature_cols=neural_features,
    bird_col="bird",
    day_col="dph",
)

combined_feature_change_lagpos100 = pd.merge(neural_feature_change_lagpos100, acoustic_feature_change_per_daybird, on=["bird", "dph"], how="inner")
combined_feature_change_lagpos100["neural_change_statistic"] = combined_feature_change_lagpos100.loc[:, ["{}_abs_change_in_pct".format(f) for f in neural_features]].mean(axis=1)

corr_lagpos100_by_bird = corr_by_bird(combined_feature_change_lagpos100, target_acoustic_features, target_col="neural_change_statistic")

fig, ax = plt.subplots(figsize=[6,8])
ax = sns.heatmap(corr_lagpos100_by_bird, annot=True, cmap="seismic", fmt = ".2f")
ax.set_title("Correlation neural activity change (lag=100)")
fig.tight_layout()
fig.savefig("plots/kde_plots/correlation_neural_activity_change_by_bird_lag_100.png", dpi=300)
plt.show()

---

## 8. Session information

In [None]:
import session_info


session_info.show()