In [None]:

# --- ONE-CELL DASHBOARD (AppLayout pane size fix) ---
from pathlib import Path
import numpy as np, pandas as pd
import ipywidgets as W
from ipywidgets import AppLayout, Layout
from IPython.display import display
import matplotlib.pyplot as plt

# ---------------- Paths & Safe Loading ----------------
HERE = Path(".").resolve()
DATA = HERE / "data"   # Expect meta.csv, tuning_curves.npz, psth.npz
PKG = str(DATA)        # Back-compat, if older code referenced PKG

def _demo():
    angles = np.linspace(0, 350, 36)
    time   = np.linspace(-0.5, 1.0, 151)
    rng = np.random.default_rng(0)
    meta = pd.DataFrame({
        "unit_id": np.arange(1, 31),
        "layer": rng.choice(["SG","G","IG"], size=30),
        "OSI": np.round(rng.uniform(0.1, 0.95, size=30), 3),
        "selectivity": np.round(rng.uniform(0.4, 1.1, size=30), 3),
    })
    tc = []
    for i in range(30):
        peak = rng.choice(angles)
        curve = 5 + 15*np.exp(-0.5*((angles-peak+360)%360/30)**2) + rng.normal(0,0.5,size=len(angles))
        tc.append(np.clip(curve, 0, None))
    tun_curves = np.stack(tc, axis=0)
    pc = []
    for i in range(30):
        mu = rng.uniform(0.2, 0.7)
        curve = 6 + 20*np.exp(-0.5*((time-mu)/0.12)**2) + rng.normal(0,0.8,size=len(time))
        pc.append(np.clip(curve, 0, None))
    psth_curves = np.stack(pc, axis=0)
    return meta, angles, tun_curves, time, psth_curves

def _load():
    try:
        meta = pd.read_csv(DATA / "meta.csv")
        tun  = np.load(DATA / "tuning_curves.npz")
        psth = np.load(DATA / "psth.npz")
        angles      = tun.get("angles", None)
        tun_curves  = tun.get("curves", None)
        time        = psth.get("t", None)
        psth_curves = psth.get("curves", None)
        if meta is None or angles is None or tun_curves is None or time is None or psth_curves is None:
            raise ValueError("Missing arrays in NPZ or meta is None.")
        if len(meta) == 0 or tun_curves.shape[0] == 0 or psth_curves.shape[0] == 0:
            raise ValueError("Empty arrays.")
        for col, default in [("layer", "G"), ("OSI", 0.5), ("selectivity", 0.8), ("unit_id", None)]:
            if col not in meta.columns:
                meta[col] = default if default is not None else np.arange(1, len(meta)+1)
        meta["OSI"] = pd.to_numeric(meta["OSI"], errors="coerce").fillna(0.5).clip(0, 1)
        meta["selectivity"] = pd.to_numeric(meta["selectivity"], errors="coerce").fillna(0.8).clip(0, 1.2)
        if len(meta) < 1:
            raise ValueError("Meta had no rows after coercion.")
        return meta, angles, tun_curves, time, psth_curves
    except Exception:
        return _demo()

meta, angles, tun_curves, time, psth_curves = _load()

# Align sizes if mismatch
n_units = len(meta)
if tun_curves.shape[0] != n_units or psth_curves.shape[0] != n_units:
    m = min(n_units, tun_curves.shape[0], psth_curves.shape[0])
    meta = meta.iloc[:m].reset_index(drop=True)
    tun_curves = tun_curves[:m]
    psth_curves = psth_curves[:m]

# ---------------- Widgets & Layout ----------------
layers = sorted(pd.Series(meta["layer"]).dropna().astype(str).unique().tolist())
if not layers:
    layers = ["SG", "G", "IG"]
osi_min, osi_max  = float(np.nanmin(meta["OSI"])), float(np.nanmax(meta["OSI"]))
sel_min, sel_max  = float(np.nanmin(meta["selectivity"])), float(np.nanmax(meta["selectivity"]))
if not np.isfinite(osi_min): osi_min = 0.0
if not np.isfinite(osi_max): osi_max = 1.0
if not np.isfinite(sel_min): sel_min = 0.0
if not np.isfinite(sel_max): sel_max = 1.2

layer_w = W.SelectMultiple(options=layers, value=tuple(layers), description="Layer",
                           layout=Layout(width="260px", height="120px"))
osi_w   = W.FloatRangeSlider(description="OSI range", min=0, max=1, step=0.01,
                             value=[max(0, osi_min), min(1, osi_max)], readout_format=".2f",
                             layout=Layout(width="260px"))
sel_w   = W.FloatRangeSlider(description="Sel range", min=0, max=1.2, step=0.01,
                             value=[max(0, sel_min), min(1.2, sel_max)], readout_format=".2f",
                             layout=Layout(width="260px"))
plot_w  = W.ToggleButtons(options=["Tuning","PSTH"], value="Tuning", description="Plot",
                          layout=Layout(width="260px"))

max_idx = max(len(meta)-1, 0)
idx_w   = W.IntSlider(description="Unit index", min=0, max=max_idx, step=1, value=0,
                      continuous_update=False, layout=Layout(width="260px"))

controls = W.VBox([layer_w, osi_w, sel_w, plot_w, idx_w],
                  layout=Layout(width="280px", min_width="280px", flex="0 0 280px"))

plot_out   = W.Output(layout=Layout(width="100%", min_height="640px", overflow="auto",
                                    border="1px solid #eee", padding="6px"))
status_out = W.HTML(layout=Layout(margin="6px 0 0 0"))

# IMPORTANT: AppLayout requires px/fr/% (no 'auto'); use '1fr' for flexible rows.
app = AppLayout(
    header=None,
    left_sidebar=controls,
    center=plot_out,
    right_sidebar=None,
    footer=status_out,
    pane_widths=["300px", "1fr", "0px"],   # left, center, right
    pane_heights=["0px", "1fr", "40px"]    # header, center, footer
)
display(app)

# ---------------- Rendering ----------------
def filtered_indices():
    df = meta.copy()
    if layer_w.value:
        df = df[df["layer"].astype(str).isin(list(layer_w.value))]
    lo, hi = osi_w.value
    df = df[(df["OSI"] >= lo) & (df["OSI"] <= hi)]
    lo2, hi2 = sel_w.value
    df = df[(df["selectivity"] >= lo2) & (df["selectivity"] <= hi2)]
    return df.index.to_numpy(), df

def newfig():
    plt.close("all")
    fig = plt.figure(figsize=(8.5,5.0), dpi=110)
    plt.grid(True, alpha=0.25)
    return fig

def render(*_):
    idxs, df = filtered_indices()
    n = len(df)
    with plot_out:
        plot_out.clear_output(wait=True)
        if n == 0:
            display(W.HTML("<b>No units match the filters.</b>"))
            status_out.value = "No matching units."
            return
        pos = int(np.clip(idx_w.value, 0, n-1))
        unit_global_idx = int(idxs[pos])
        row = meta.iloc[unit_global_idx]
        newfig()
        if plot_w.value == "Tuning":
            plt.plot(angles, tun_curves[unit_global_idx], marker="o")
            plt.title(f"Tuning — unit {int(row.get('unit_id', unit_global_idx+1))} | layer={row['layer']} | OSI={float(row['OSI']):.2f}")
            plt.xlabel("Orientation (deg)"); plt.ylabel("Response (a.u.)")
        else:
            plt.plot(time, psth_curves[unit_global_idx])
            plt.title(f"PSTH — unit {int(row.get('unit_id', unit_global_idx+1))} | layer={row['layer']} | OSI={float(row['OSI']):.2f}")
            plt.xlabel("Time (s)"); plt.ylabel("Rate (Hz)")
        plt.tight_layout(); plt.show()
    status_out.value = f"<span>{n} units match filters. Showing {pos+1}/{n} (global idx {unit_global_idx}).</span>"

for w in (layer_w, osi_w, sel_w, plot_w, idx_w):
    w.observe(render, names="value")
render()
