In [None]:
from pathlib import Path
import numpy as np
from scipy.ndimage import gaussian_filter1d
import flyvision
from flyvision.utils.activity_utils import LayerActivity
from flygym.examples.vision_connectome_model import (
    RetinaMapper,
    RealTimeVisionNetworkView,
)

retina_mapper = RetinaMapper()
flyvis_to_flygym = retina_mapper.flyvis_to_flygym

vision_network_dir = flyvision.results_dir / "opticflow/000/0000"
vision_network_view = RealTimeVisionNetworkView(vision_network_dir)
vision_network = vision_network_view.init_network(chkpt="best_chkpt")

paths = dict(
    sorted(
        [
            (float(i.stem.split("_")[-1]), i)
            for i in Path("./outputs/moving_bars_4deg").glob("nn_hist_*.npy")
        ]
    )
)

assert paths, "Run response_to_moving_bars.py first to generate the data."


def get_activity(x, neuron_type):
    layer_activity = LayerActivity(
        x,
        vision_network.connectome,
        keepref=True,
        use_central=False,
    )
    neuron_type_activity = getattr(layer_activity, neuron_type)
    right_center_ommatidia = flyvis_to_flygym(neuron_type_activity)[..., 1, 360]
    return right_center_ommatidia


neuron_type = "T4a"
nn_activities = {k: get_activity(np.load(v), neuron_type) for k, v in paths.items()}

In [None]:
split_indices = {
    k: gaussian_filter1d(y, 2, order=1).argmax() for k, y in nn_activities.items()
}

In [None]:
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "Arial"
plt.rcParams["pdf.fonttype"] = 42

In [None]:
n_cols = len(nn_activities)
fig, axs = plt.subplots(
    1, n_cols, figsize=(n_cols * 1.2, 1), sharex=False, sharey=True, dpi=300
)

vmax = 2

for j, (ax, (k, y)) in enumerate(zip(axs, nn_activities.items())):
    t = np.arange(len(y)) / 120
    ax.plot(t, y, color="k")
    split_idx = split_indices[k]
    t0 = t[y[:split_idx].argmin()]
    t1 = t[y[split_idx:].argmax() + split_idx]
    dur = t1 - t0
    ax.set_xlim(t0 - dur * 2, t1 + dur * 2)

    title = int(k) if int(k) == k else k
    ax.set_title(f"{title}°/s")

    ax.plot(
        [t0, t1],
        [0, 0],
        transform=ax.get_xaxis_transform(),
        color="k",
        clip_on=False,
    )
    ax.text(
        (t0 + t1) / 2,
        -0.05,
        f"{dur:.2f} s",
        color="k",
        ha="center",
        va="top",
        transform=ax.get_xaxis_transform(),
    )

    ax.set_xticks([])
    ax.set_ylim(-vmax, vmax)
    ax.set_yticks([-vmax, 0, vmax])

    for spine in ax.spines.values():
        spine.set_visible(False)

    if j == 0:
        ax.spines["left"].set_visible(True)
        ax.spines["left"].set_bounds(-vmax, vmax)
    else:
        ax.yaxis.set_tick_params(size=0)

axs[0].set_ylabel(f"{neuron_type} activity (a.u.)", labelpad=1)

Path("./outputs").mkdir(exist_ok=True)
plt.savefig(f"outputs/{neuron_type}_activity.pdf", bbox_inches="tight")

In [None]:
tuning_curve = {
    k: nn_activities[k][split_idx:].max() for k, split_idx in split_indices.items()
}

In [None]:
titles = [int(k) if int(k) == k else f"{k:0.1f}" for k in tuning_curve.keys()]
fig, ax = plt.subplots(1, 1, figsize=(3, 2), dpi=300)
ax.plot(
    list(tuning_curve.keys()),
    list(tuning_curve.values()),
    color="k",
    marker=".",
    markeredgewidth=0,
)
ax.set_xscale("log", base=2)
ax.set_xticks(list(tuning_curve.keys()))
ax.set_xticklabels(titles, rotation=45, ha="right", rotation_mode="anchor", va="top")
ax.tick_params(axis="both", pad=1)

ax.set_xlabel("Speed (°/s)")
ax.set_ylabel(f"{neuron_type} max. activity (a.u.)", labelpad=1)
ax.set_ylim(0, vmax)
ax.set_yticks([0, vmax])

for sides in ["top", "right"]:
    ax.spines[sides].set_visible(False)

plt.savefig(f"outputs/{neuron_type}_tuning_curve.pdf", bbox_inches="tight")