In [None]:
# --- Setup & data loading (works in JupyterLite/Voici) ---
from pathlib import Path
import numpy as np, pandas as pd
import ipywidgets as W
import matplotlib.pyplot as plt

# Folder assumptions:
# - This notebook sits in /content
# - Data are in ./data/
HERE = Path(".").resolve()
DATA = HERE / "data"

# Backwards-compat with any code using PKG
PKG = str(DATA)

def must_exist(p: Path):
    if not p.exists():
        raise FileNotFoundError(f"Required file not found: {p}")
    return p

# Try to load real data; if missing, use small demo data shipped with repo
try:
    meta = pd.read_csv(must_exist(DATA / "meta.csv"))
    tun  = np.load(must_exist(DATA / "tuning_curves.npz"))
    psth = np.load(must_exist(DATA / "psth.npz"))
except Exception as e:
    # Generate fallback demo data (should not happen when files are present)
    angles = np.linspace(0, 350, 36)
    t = np.linspace(-0.5, 1.0, 151)

    rng = np.random.default_rng(0)
    meta = pd.DataFrame({
        "unit_id": np.arange(1, 31),
        "layer": rng.choice(["SG", "G", "IG"], size=30),
        "OSI": np.round(rng.uniform(0.1, 0.95, size=30), 3),
        "selectivity": np.round(rng.uniform(0.4, 1.1, size=30), 3),
    })
    tc = []
    for i in range(30):
        peak = rng.choice(angles)
        curve = 5 + 15*np.exp(-0.5*((angles-peak+360)%360/30)**2) + rng.normal(0, 0.5, size=len(angles))
        tc.append(np.clip(curve, 0, None))
    tc = np.stack(tc, axis=0)
    psth_curves = []
    for i in range(30):
        mu = rng.uniform(0.2, 0.7)
        curve = 6 + 20*np.exp(-0.5*((t-mu)/0.12)**2) + rng.normal(0, 0.8, size=len(t))
        psth_curves.append(np.clip(curve, 0, None))
    psth_curves = np.stack(psth_curves, axis=0)

    tun = {"angles": angles, "curves": tc}
    psth = {"t": t, "curves": psth_curves}

# Normalize structures whether loaded from npz or dict
angles = tun["angles"]
tun_curves = tun["curves"]
time = psth["t"]
psth_curves = psth["curves"]

# Utility: nice figure
def newfig():
    plt.figure(figsize=(7,4), dpi=110)
    plt.grid(True, alpha=0.25)


In [None]:
# --- Widgets ---
layers = sorted(meta["layer"].dropna().unique().tolist())
osi_min, osi_max = float(meta["OSI"].min()), float(meta["OSI"].max())
sel_min, sel_max = float(meta["selectivity"].min()), float(meta["selectivity"].max())

layer_w = W.SelectMultiple(options=layers, value=tuple(layers),
                           description="Layer", layout=W.Layout(width="220px", height="120px"))
osi_w   = W.FloatRangeSlider(description="OSI range", min=0, max=1, step=0.01,
                             value=[max(0, osi_min), min(1, osi_max)], readout_format=".2f")
sel_w   = W.FloatRangeSlider(description="Sel range", min=0, max=1.2, step=0.01,
                             value=[max(0, sel_min), min(1.2, sel_max)], readout_format=".2f")
plot_w  = W.ToggleButtons(options=["Tuning", "PSTH"], value="Tuning", description="Plot")
idx_w   = W.IntSlider(description="Unit index", min=0, max=len(meta)-1, step=1, value=0,
                      continuous_update=False)

info_out = W.HTML()
plot_out = W.Output()

controls = W.VBox([layer_w, osi_w, sel_w, plot_w, idx_w])
ui = W.HBox([controls, plot_out], layout=W.Layout(align_items="flex-start"))
display(ui)


In [None]:
# --- Render logic ---
def filtered_indices():
    df = meta.copy()
    if layer_w.value:
        df = df[df["layer"].isin(list(layer_w.value))]
    lo, hi = osi_w.value
    df = df[(df["OSI"] >= lo) & (df["OSI"] <= hi)]
    lo2, hi2 = sel_w.value
    df = df[(df["selectivity"] >= lo2) & (df["selectivity"] <= hi2)]
    return df.index.to_numpy(), df

def render(*_):
    idxs, df = filtered_indices()
    n = len(df)
    with plot_out:
        plot_out.clear_output(wait=True)
        if n == 0:
            display(W.HTML("<b>No units match the filters.</b>"))
            return
        # Map slider to filtered set
        pos = min(idx_w.value, n-1)
        unit_global_idx = idxs[pos]
        unit_row = meta.iloc[unit_global_idx]
        # Plot
        if plot_w.value == "Tuning":
            newfig()
            plt.plot(angles, tun_curves[unit_global_idx], marker="o")
            plt.title(f"Tuning — unit {int(unit_row['unit_id'])} | layer={unit_row['layer']} | OSI={unit_row['OSI']:.2f}")
            plt.xlabel("Orientation (deg)"); plt.ylabel("Response (a.u.)")
            plt.tight_layout(); plt.show()
        else:
            newfig()
            plt.plot(time, psth_curves[unit_global_idx])
            plt.title(f"PSTH — unit {int(unit_row['unit_id'])} | layer={unit_row['layer']} | OSI={unit_row['OSI']:.2f}")
            plt.xlabel("Time (s)"); plt.ylabel("Rate (Hz)")
            plt.tight_layout(); plt.show()
        # Info
        info_html = f"<b>{n}</b> units match filters. Showing {pos+1}/{n} (global index {unit_global_idx})."
        display(W.HTML(info_html))

# Wire events
for w in (layer_w, osi_w, sel_w, plot_w, idx_w):
    w.observe(render, names="value")
render()
