In [None]:
import os
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import sem, ttest_rel
from brainnetwork import load_data, preprocess_data, preprocess_spike_data
from brainnetwork import classify_by_timepoints, FI_by_timepoints_v2, FI_by_neuron_count
from brainnetwork import construct_correlation_network, compute_network_metrics_by_class
from brainnetwork.visualization import *


In [None]:
base_dir = "/beegfs_hdd/data/nfs_share/users/guiyun/nishome/Micedata/"
data_path =  ["M21_1107", "M71_1024","M77_1031","M78_1017","M91_1017"]

In [None]:
CLASS_NAMES = {1: "Convergent", 2: "Divergent", 3: "Random"}
CLASS_PAIRS = [(1, 2), (1, 3), (2, 3)]
NETWORK_METRICS = [
    "n_edges",
    "density",
    "mean_degree",
    "largest_component",
    "avg_clustering",
    "global_efficiency",
    "local_efficiency",
    "transitivity",
    "efficiency",
    "modularity",
]

def process_mouse(mouse_id):
    data_dir = os.path.join(base_dir, mouse_id)
    print(f"Processing {mouse_id} -> {data_dir}")
    neuron_data, neuron_pos, start_edges, stimulus_data = load_data(data_dir)
    segments_spi, labels_spi, neuron_pos_spi = preprocess_spike_data(
        neuron_data,
        neuron_pos,
        start_edges,
        stimulus_data,
    )
    neuron_data_flo, neuron_pos_flo, start_edges_flo, stimulus_data_flo = load_data(
        data_dir,
        data_type="fluorescence",
    )
    segments_flo, labels_flo, neuron_pos_flo = preprocess_data(
        neuron_data_flo,
        neuron_pos_flo,
        start_edges_flo,
        stimulus_data_flo,
    )
    fisher_mv, time_points_fi = FI_by_timepoints_v2(
        segments_flo,
        labels_flo,
        mode="multivariate",
        reduction=None,
    )
    fisher_uv, time_points_fi_uv = FI_by_timepoints_v2(
        segments_flo,
        labels_flo,
        mode="univariate",
        reduction="mean",
    )
    if not np.allclose(time_points_fi, time_points_fi_uv):
        raise ValueError("Time axes do not match between FI variants.")
    nx_result = compute_network_metrics_by_class(
        segments_spi,
        labels_spi,
        neuron_pos_spi,
        do_bootstrap=False,
    )
    return {
        "mouse_id": mouse_id,
        "fisher_mv": fisher_mv,
        "fisher_uv": fisher_uv,
        "time_points": time_points_fi,
        "nx_result": nx_result,
    }


def aggregate_curve_dict(curve_store):
    aggregated = {}
    for pair, curves in curve_store.items():
        stack = np.vstack([np.asarray(curve) for curve in curves])
        aggregated[pair] = {
            "mean": np.nanmean(stack, axis=0),
            "sem": sem(stack, axis=0, nan_policy="omit"),
            "n_mice": stack.shape[0],
        }
    return aggregated


def plot_average_fisher(curve_stats, time_points, title):
    plt.figure(figsize=(10, 4.5))
    for pair, stats_dict in sorted(curve_stats.items()):
        label = f"{CLASS_NAMES.get(pair[0], pair[0])} vs {CLASS_NAMES.get(pair[1], pair[1])}"
        mean_vals = stats_dict["mean"]
        err_vals = stats_dict["sem"]
        plt.plot(time_points, mean_vals, linewidth=2.2, label=label)
        plt.fill_between(time_points, mean_vals - err_vals, mean_vals + err_vals, alpha=0.2)
    plt.xlabel("Time (s relative to stimulus)")
    plt.ylabel("Fisher information")
    plt.title(title)
    plt.axvline(0, color="#bbbbbb", linestyle="--", linewidth=1)
    plt.grid(True, alpha=0.3)
    plt.legend(frameon=False)
    plt.tight_layout()


def flatten_network_summary(mouse_id, nx_result):
    rows = []
    for cls, info in nx_result.items():
        summary = info.get("summary", {}).copy()
        row = {
            "mouse_id": mouse_id,
            "class_label": cls,
            "class_name": CLASS_NAMES.get(cls, str(cls)),
        }
        for metric_name, value in summary.items():
            row[metric_name] = value
        row["efficiency"] = info.get("efficiency", np.nan)
        row["modularity"] = info.get("modularity", np.nan)
        rows.append(row)
    return rows


def summarize_network_metrics(df, metrics):
    summary_rows = []
    for cls, group in df.groupby("class_label"):
        entry = {
            "class_label": cls,
            "class_name": CLASS_NAMES.get(cls, str(cls)),
        }
        for metric in metrics:
            entry[f"{metric}_mean"] = group[metric].mean()
            entry[f"{metric}_sem"] = sem(group[metric], nan_policy="omit")
        summary_rows.append(entry)
    return pd.DataFrame(summary_rows).sort_values("class_label")


def plot_network_metric_bars(df, metric):
    pivot = df.pivot_table(index="mouse_id", columns="class_label", values=metric)
    class_order = [cls for cls in sorted(pivot.columns) if pivot[cls].notna().any()]
    means, errors, labels = [], [], []
    colors = ["#1F77B4", "#D55E00", "#009E73", "#9467BD"]
    for cls in class_order:
        col = pivot[cls].dropna()
        if col.empty:
            continue
        means.append(col.mean())
        errors.append(sem(col, nan_policy="omit"))
        labels.append(CLASS_NAMES.get(cls, str(cls)))
    plt.figure(figsize=(6.2, 4.2))
    plt.bar(labels, means, yerr=errors, capsize=4, color=colors[: len(labels)])
    plt.ylabel(metric.replace("_", " ").title())
    plt.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()


def paired_ttests(df, metrics, class_pairs=CLASS_PAIRS):
    records = []
    for metric in metrics:
        pivot = df.pivot_table(index="mouse_id", columns="class_label", values=metric)
        for cls_a, cls_b in class_pairs:
            if cls_a not in pivot.columns or cls_b not in pivot.columns:
                continue
            paired = pivot[[cls_a, cls_b]].dropna()
            if paired.shape[0] < 2:
                continue
            t_stat, p_value = ttest_rel(paired[cls_a], paired[cls_b])
            records.append(
                {
                    "metric": metric,
                    "class_a": CLASS_NAMES.get(cls_a, str(cls_a)),
                    "class_b": CLASS_NAMES.get(cls_b, str(cls_b)),
                    "n_mice": paired.shape[0],
                    "t_stat": float(t_stat),
                    "p_value": float(p_value),
                    "mean_diff": float((paired[cls_a] - paired[cls_b]).mean()),
                }
            )
    return pd.DataFrame(records).sort_values("p_value")


In [None]:
mouse_results = []
failed_mice = []

for mouse_id in data_path:
    try:
        mouse_results.append(process_mouse(mouse_id))
    except Exception as exc:
        print(f"[!] Failed to process {mouse_id}: {exc}")
        failed_mice.append({"mouse_id": mouse_id, "error": repr(exc)})

print(f"Processed {len(mouse_results)} mice (failed: {len(failed_mice)}).")
failed_mice


In [None]:
fisher_mv_store = defaultdict(list)
fisher_uv_store = defaultdict(list)
network_rows = []
time_axis = None

for result in mouse_results:
    if time_axis is None:
        time_axis = result["time_points"]
    elif not np.allclose(time_axis, result["time_points"]):
        raise ValueError("Mismatched FI time axes across mice.")
    for pair, curve in result["fisher_mv"].items():
        fisher_mv_store[pair].append(np.asarray(curve))
    for pair, curve in result["fisher_uv"].items():
        fisher_uv_store[pair].append(np.asarray(curve))
    network_rows.extend(flatten_network_summary(result["mouse_id"], result["nx_result"]))

print(f"Collected {len(fisher_mv_store)} Fisher curve pairs and {len(network_rows)} network entries.")


In [None]:
if time_axis is not None and fisher_mv_store:
    fisher_mv_stats = aggregate_curve_dict(fisher_mv_store)
    plot_average_fisher(fisher_mv_stats, time_axis, "Multivariate Fisher information (mean ? SEM)")
else:
    print("Multivariate Fisher information curves are not available.")

if time_axis is not None and fisher_uv_store:
    fisher_uv_stats = aggregate_curve_dict(fisher_uv_store)
    plot_average_fisher(fisher_uv_stats, time_axis, "Univariate Fisher information (mean ? SEM)")
else:
    print("Univariate Fisher information curves are not available.")


In [None]:
if network_rows:
    network_df = pd.DataFrame(network_rows)
    network_df
else:
    network_df = pd.DataFrame()
    print("No network metrics were collected.")


In [None]:
if not network_df.empty:
    network_summary_df = summarize_network_metrics(network_df, NETWORK_METRICS)
    network_summary_df
else:
    network_summary_df = pd.DataFrame()
    print("Summary table is empty because network_df is empty.")


In [None]:
if not network_df.empty:
    for metric in ["avg_clustering", "global_efficiency", "modularity"]:
        plot_network_metric_bars(network_df, metric)
else:
    print("Skip plotting network comparisons because network_df is empty.")


In [None]:
if not network_df.empty:
    network_ttest_df = paired_ttests(network_df, NETWORK_METRICS)
    network_ttest_df
else:
    network_ttest_df = pd.DataFrame()
    print("No paired t-tests computed because network_df is empty.")
