# Orientation Tuning Dashboard
Interactive filters + **overlay of all units that pass the filter**.

- Use the controls to select layers, OSI range, and selectivity range.
- The tuning plot overlays all passing units as thin lines; the mean (±SEM) is shown if enabled.
- Optionally highlight a focus unit.


In [None]:
import numpy as np
import pandas as pd
import ipywidgets as W
import matplotlib.pyplot as plt

# --- Load data relative to notebook location ---
meta = pd.read_csv("data/meta.csv")
npz = np.load("data/tuning_curves.npz")
angles = npz["angles"]              # shape (A,)
curves = npz["curves"]              # shape (N, A)
N, A = curves.shape
assert len(meta) == N, "meta rows must match number of units in curves"

# Widgets
layers_widget = W.SelectMultiple(
    options=sorted(meta["layer"].unique().tolist()),
    value=tuple(sorted(meta["layer"].unique().tolist())),
    description="Layers",
    rows=3,
    layout=W.Layout(width="200px"),
)

osi_widget  = W.FloatRangeSlider(
    value=(0.0, 1.0), min=0.0, max=1.0, step=0.01,
    description="OSI", readout_format=".2f",
    layout=W.Layout(width="380px")
)

sel_widget  = W.FloatRangeSlider(
    value=(0.0, 1.0), min=0.0, max=1.0, step=0.01,
    description="Select.", readout_format=".2f",
    layout=W.Layout(width="380px")
)

overlay_widget = W.Checkbox(value=True, description="Overlay all filtered units")
show_mean_widget = W.Checkbox(value=True, description="Show mean ± SEM")
normalize_widget = W.Dropdown(
    options=[("Raw","raw"), ("Z-score per unit","zscore"), ("Peak normalize (max=1)","peak")],
    value="raw", description="Normalize"
)

# Focus unit dropdown (by unit_id as in meta)
focus_widget = W.Dropdown(
    options=[(f"Unit {uid}", uid) for uid in meta["unit_id"].tolist()],
    value=meta["unit_id"].iloc[0],
    description="Focus unit"
)

# Status readout
status = W.HTML()


In [None]:
def _normalize(y, mode):
    y = np.asarray(y, float)
    if mode == "raw":
        return y
    if mode == "zscore":
        m, s = np.nanmean(y), np.nanstd(y)
        return (y - m) / (s + 1e-12)
    if mode == "peak":
        m = np.nanmax(np.abs(y))
        return y / (m + 1e-12)
    return y

def get_filtered_indices(layers, osi_rng, sel_rng):
    if not layers:
        return np.array([], dtype=int)
    mask = (
        meta["layer"].isin(layers)
        & (meta["OSI"] >= osi_rng[0]) & (meta["OSI"] <= osi_rng[1])
        & (meta["selectivity"] >= sel_rng[0]) & (meta["selectivity"] <= sel_rng[1])
    )
    return np.flatnonzero(mask.values)

def plot_tuning(layers, osi_rng, sel_rng, overlay_all, show_mean, normalize, focus_uid):
    idx = get_filtered_indices(layers, osi_rng, sel_rng)

    # Prepare figure
    fig = plt.figure(figsize=(7.5, 4.5), dpi=140)
    ax  = plt.gca()

    if idx.size == 0:
        ax.text(0.5, 0.5, "No units match the filter", ha="center", va="center", transform=ax.transAxes)
        ax.set_axis_off()
        status.value = f"<b>Matched units:</b> 0"
        plt.show()
        return

    # Overlay
    if overlay_all:
        for ii in idx:
            y = _normalize(curves[ii, :], normalize)
            ax.plot(angles, y, lw=0.8, alpha=0.25)

    # Focus unit (if it's in data)
    if focus_uid in meta["unit_id"].values:
        # meta index is unit_id-1 if unit_id is 1..N; fall back to locating by row
        row = meta.index[meta["unit_id"] == focus_uid][0]
        y = _normalize(curves[row, :], normalize)
        ax.plot(angles, y, lw=2.2)  # thicker line for focus

    # Mean ± SEM
    if show_mean and idx.size >= 2:
        Y = np.vstack([_normalize(curves[ii, :], normalize) for ii in idx])
        m = np.nanmean(Y, axis=0)
        s = np.nanstd(Y, axis=0, ddof=1) / np.sqrt(Y.shape[0])
        ax.plot(angles, m, lw=2.5)
        ax.fill_between(angles, m - s, m + s, alpha=0.15)

    ax.set_xlabel("Orientation (deg)")
    ax.set_ylabel("Response (a.u.)" if normalize == "raw" else f"Response ({normalize})")
    ax.set_title("Orientation tuning (filtered overlay)")
    ax.grid(True, alpha=0.25)
    status.value = f"<b>Matched units:</b> {idx.size} | Layers: {', '.join(layers)} | OSI: {osi_rng[0]:.2f}–{osi_rng[1]:.2f} | Select.: {sel_rng[0]:.2f}–{sel_rng[1]:.2f}"
    plt.tight_layout()
    plt.show()

# Link controls -> plot
out = W.interactive_output(
    plot_tuning,
    {
        "layers": layers_widget,
        "osi_rng": osi_widget,
        "sel_rng": sel_widget,
        "overlay_all": overlay_widget,
        "show_mean": show_mean_widget,
        "normalize": normalize_widget,
        "focus_uid": focus_widget,
    },
)

controls_left = W.VBox([layers_widget, overlay_widget, show_mean_widget, normalize_widget])
controls_right = W.VBox([osi_widget, sel_widget, focus_widget])

ui = W.HBox([controls_left, W.Box([controls_right], layout=W.Layout(margin="0 0 0 20px"))])
display(ui, status, out)
