In [None]:

from pathlib import Path
import numpy as np, pandas as pd
import ipywidgets as W
from ipywidgets import Layout
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
from io import BytesIO

HERE = Path(".").resolve()
DATA = HERE / "data"

def _load():
    meta = pd.read_csv(DATA / "meta.csv")
    tz = np.load(DATA / "tuning_curves.npz")
    pz = np.load(DATA / "psth.npz")
    return meta, tz["angles"], tz["curves"], pz["t"], pz["curves"]

meta, angles, tuning, t, psth = _load()

layers = sorted(meta["layer"].unique().tolist())
w_layers = W.SelectMultiple(options=layers, value=tuple(layers), description="Layers",
                            layout=Layout(width="200px", height="100px"))
w_osi = W.FloatRangeSlider(value=[float(meta.OSI.min()), float(meta.OSI.max())], min=0, max=1, step=0.01,
                           description="OSI", readout_format=".2f", layout=Layout(width="300px"))
w_sel = W.FloatRangeSlider(value=[float(meta.selectivity.min()), float(meta.selectivity.max())], min=0, max=1, step=0.01,
                           description="Select.", readout_format=".2f", layout=Layout(width="300px"))
w_nmax = W.IntSlider(value=200, min=1, max=1000, step=1, description="Max units", layout=Layout(width="300px"))
w_show_tuning = W.Checkbox(value=True, description="Plot Tuning")
w_show_psth = W.Checkbox(value=False, description="Plot PSTH")
w_reset = W.Button(description="Reset", button_style="warning")

img = W.Image(format="png", layout=Layout(width="100%", height="auto", border="1px solid #eee"))
status = W.HTML()

def _plot():
    m = meta["layer"].isin(w_layers.value)
    m &= meta["OSI"].between(w_osi.value[0], w_osi.value[1])
    m &= meta["selectivity"].between(w_sel.value[0], w_sel.value[1])
    idx = np.where(m.values)[0][: w_nmax.value]

    fig, ax = plt.subplots(figsize=(7,4), dpi=150)
    if w_show_tuning.value:
        x = angles
        for i in idx:
            ax.plot(x, tuning[i], alpha=0.25, lw=1)
        ax.set_xlabel("Orientation (deg)")
        ax.set_ylabel("Firing rate (a.u.)")
        ax.set_title(f"Tuning curves (n={len(idx)})")
        ax.set_xlim(x.min(), x.max())
    elif w_show_psth.value:
        x = t
        for i in idx:
            ax.plot(x, psth[i], alpha=0.25, lw=1)
        ax.set_xlabel("Time (s)")
        ax.set_ylabel("Firing rate (a.u.)")
        ax.set_title(f"PSTH (n={len(idx)})")
        ax.set_xlim(x.min(), x.max())
    else:
        ax.text(0.5, 0.5, "Enable Plot Tuning or Plot PSTH", ha="center", va="center", transform=ax.transAxes)
        ax.axis("off")

    fig.tight_layout()
    buf = BytesIO()
    fig.savefig(buf, format="png")
    plt.close(fig)
    img.value = buf.getvalue()
    status.value = (f"<b>Selected:</b> {len(idx)} / {len(meta)} &nbsp;"
                    f"<b>Layers:</b> {', '.join(w_layers.value)} &nbsp;"
                    f"<b>OSI:</b> {w_osi.value[0]:.2f}–{w_osi.value[1]:.2f} &nbsp;"
                    f"<b>Select:</b> {w_sel.value[0]:.2f}–{w_sel.value[1]:.2f}")

def _on_change(change):
    _plot()

def _on_reset(_):
    w_layers.value = tuple(layers)
    w_osi.value = [float(meta.OSI.min()), float(meta.OSI.max())]
    w_sel.value = [float(meta.selectivity.min()), float(meta.selectivity.max())]
    w_nmax.value = min(200, len(meta))
    w_show_tuning.value = True
    w_show_psth.value = False
    _plot()

for w in [w_layers, w_osi, w_sel, w_nmax, w_show_tuning, w_show_psth]:
    w.observe(_on_change, names="value")
w_reset.on_click(_on_reset)

controls = W.VBox([W.HBox([w_layers, W.VBox([w_osi, w_sel, w_nmax])]), W.HBox([w_show_tuning, w_show_psth, w_reset])])
ui = W.VBox([controls, img, status])
display(ui)

_plot()
