In [None]:

# --- ONE-CELL DASHBOARD (overlay all filtered tuning curves) ---
from pathlib import Path
import io
import numpy as np, pandas as pd
import ipywidgets as W
from ipywidgets import Layout
from IPython.display import display
import matplotlib
matplotlib.use("agg")  # offscreen backend for JupyterLite/Voilà
import matplotlib.pyplot as plt

# ---------------- Paths & Safe Loading ----------------
HERE = Path(".").resolve()
DATA = HERE / "data"   # Expect meta.csv, tuning_curves.npz, psth.npz

def _demo():
    # Fallback demo data (30 units, 36 angles, synthetic PSTHs)
    angles = np.arange(0, 360, 10, dtype=float)
    time   = np.linspace(0, 1.5, 151)
    rng = np.random.default_rng(0)
    meta = pd.DataFrame({
        "unit_id": np.arange(1, 31),
        "layer": rng.choice(["SG","G","IG"], size=30),
        "OSI": np.round(rng.uniform(0.1, 0.95, size=30), 3),
        "selectivity": np.round(rng.uniform(0.4, 1.1, size=30), 3),
    })
    tc = []
    for i in range(30):
        peak = rng.choice(angles)
        curve = 5 + 15*np.exp(-0.5*((angles-peak+360)%360/30)**2) + rng.normal(0,0.5,size=len(angles))
        tc.append(np.clip(curve, 0, None))
    tun_curves = np.stack(tc, axis=0)
    pc = []
    for i in range(30):
        mu = rng.uniform(0.2, 0.7)
        curve = 6 + 20*np.exp(-0.5*((time-mu)/0.12)**2) + rng.normal(0,0.8,size=len(time))
        pc.append(np.clip(curve, 0, None))
    psth_curves = np.stack(pc, axis=0)
    return meta, angles, tun_curves, time, psth_curves

def _load():
    try:
        meta = pd.read_csv(DATA / "meta.csv")
        tun  = np.load(DATA / "tuning_curves.npz")
        psth = np.load(DATA / "psth.npz")
        angles      = tun["angles"]
        tun_curves  = tun["curves"]
        time        = psth["t"]
        psth_curves = psth["curves"]
        # sanity
        assert len(meta) > 0 and tun_curves.ndim == 2 and psth_curves.ndim == 2
        # missing columns handling
        if "unit_id" not in meta.columns:
            meta["unit_id"] = np.arange(1, len(meta)+1)
        if "layer" not in meta.columns:
            meta["layer"] = "G"
        for c,lo,hi,fill in [("OSI",0,1,0.5), ("selectivity",0,1.2,0.8)]:
            if c not in meta.columns:
                meta[c] = fill
            meta[c] = pd.to_numeric(meta[c], errors="coerce").fillna(fill).clip(lo,hi)
        return meta.reset_index(drop=True), angles, tun_curves, time, psth_curves
    except Exception as e:
        # print(f"Falling back to demo due to: {e}")
        return _demo()

meta, angles, tun_curves, time, psth_curves = _load()

# Align sizes if mismatch
m = min(len(meta), tun_curves.shape[0], psth_curves.shape[0])
meta = meta.iloc[:m].reset_index(drop=True)
tun_curves = tun_curves[:m]
psth_curves = psth_curves[:m]

# ---------------- Widgets ----------------
layers = sorted(list(pd.unique(meta["layer"].astype(str))))
osi_min, osi_max = float(np.nanmin(meta["OSI"])), float(np.nanmax(meta["OSI"]))
sel_min, sel_max = float(np.nanmin(meta["selectivity"])), float(np.nanmax(meta["selectivity"]))

layer_w = W.SelectMultiple(description="Layers", options=layers, value=tuple(layers),
                           layout=Layout(width="260px", height="100px"))
osi_w   = W.FloatRangeSlider(description="OSI range", min=0, max=1, step=0.01,
                             value=[max(0, osi_min), min(1, osi_max)], readout_format=".2f",
                             layout=Layout(width="260px"))
sel_w   = W.FloatRangeSlider(description="Sel range", min=0, max=1.2, step=0.01,
                             value=[max(0, sel_min), min(1.2, sel_max)], readout_format=".2f",
                             layout=Layout(width="260px"))
plot_w  = W.ToggleButtons(options=["Tuning","PSTH"], value="Tuning", description="Plot",
                          layout=Layout(width="260px"))

idx_w   = W.IntSlider(description="Unit index", min=0, max=max(len(meta)-1,0), step=1, value=0,
                      continuous_update=False, layout=Layout(width="260px"))

controls = W.VBox([layer_w, osi_w, sel_w, plot_w, idx_w],
                  layout=Layout(width="280px", min_width="280px"))

# Image widget to render PNG plots (faster than interactive backends on Lite)
img = W.Image(format='png', layout=Layout(width="100%", height="640px", border="1px solid #eee"))
status = W.HTML(layout=Layout(margin="6px 0 0 6px"))
right = W.VBox([img, status], layout=Layout(width="100%"))

ui = W.HBox([controls, right], layout=Layout(align_items="flex-start", width="100%"))
display(ui)

# ---------------- Helpers ----------------
def draw_png(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
    plt.close(fig)
    buf.seek(0)
    return buf.getvalue()

def filtered_indices():
    df = meta.copy()
    if layer_w.value:
        df = df[df["layer"].astype(str).isin(list(layer_w.value))]
    lo, hi = osi_w.value
    df = df[(df["OSI"] >= lo) & (df["OSI"] <= hi)]
    lo2, hi2 = sel_w.value
    df = df[(df["selectivity"] >= lo2) & (df["selectivity"] <= hi2)]
    return df.index.to_numpy()

# ---------------- Rendering ----------------
def render(*_):
    idxs = filtered_indices()
    n = len(idxs)
    # keep slider in-bounds with current filter
    idx_w.max = max(n-1, 0)
    pos = min(max(idx_w.value, 0), max(n-1, 0)) if n>0 else 0

    if n == 0:
        fig = plt.figure(figsize=(8,4))
        plt.text(0.5, 0.5, "No units match filters", ha='center', va='center', fontsize=14)
        plt.axis("off")
        img.value = draw_png(fig)
        status.value = "0 units match filters."
        return

    unit_global_idx = int(idxs[pos])
    row = meta.iloc[unit_global_idx]

    if plot_w.value == "Tuning":
        fig = plt.figure(figsize=(8,4.8))
        # --- NEW: overlay ALL filtered tuning curves in the background ---
        # thin, semi-transparent lines for context
        for i in idxs:
            plt.plot(angles, tun_curves[i], linewidth=0.8, alpha=0.25)
        # highlight the selected unit on top
        plt.plot(angles, tun_curves[unit_global_idx], linewidth=2.4)
        plt.title(f"Tuning — {n} units overlaid | showing unit_id={int(row.get('unit_id', unit_global_idx+1))} "
                  f"(global idx {unit_global_idx}) | layer={row['layer']} | OSI={float(row['OSI']):.2f}")
        plt.xlabel("Orientation (deg)"); plt.ylabel("Response (a.u.)")
        plt.xlim(float(np.min(angles)), float(np.max(angles)))
    else:
        fig = plt.figure(figsize=(8,4.8))
        plt.plot(time, psth_curves[unit_global_idx], linewidth=2.0)
        plt.title(f"PSTH — unit_id={int(row.get('unit_id', unit_global_idx+1))} "
                  f"(global idx {unit_global_idx}) | layer={row['layer']} | OSI={float(row['OSI']):.2f}")
        plt.xlabel("Time (s)"); plt.ylabel("Rate (Hz)")

    plt.tight_layout()
    img.value = draw_png(fig)
    status.value = f"{n} units match filters. Showing {pos+1}/{n} (global idx {unit_global_idx})."

# Wire events
for w in (layer_w, osi_w, sel_w, plot_w, idx_w):
    w.observe(render, names="value")

# Initial draw
render()
