# Moving edge responses

This notebook introduces moving edge responses and the direction selectivity index (DSI). The DSI measures motion selectivity of cells to visual input.

**Select GPU runtime**

To run the notebook on a GPU select Menu -> Runtime -> Change runtime type -> GPU.

In [None]:
# @markdown **Check access to GPU**

try:
    import google.colab

    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    import torch

    try:
        cuda_name = torch.cuda.get_device_name()
        print(f"Name of the assigned GPU / CUDA device: {cuda_name}")
    except RuntimeError:
        import warnings

        warnings.warn(
            "You have not selected Runtime Type: 'GPU' or Google could not assign you one. Please revisit the settings as described above or proceed on CPU (slow)."
        )

**Install Flyvis**

The notebook requires installing our package `flyvis`. You may need to restart your session after running the code block below with Menu -> Runtime -> Restart session. Then, imports from `flyvis` should succeed without issue.

In [None]:
if IN_COLAB:
    #@markdown **Install Flyvis**
    %%capture
    !git clone https://github.com/flyvis/flyvis-dev.git
    %cd /content/flyvis-dev
    !pip install -e .

In [None]:
# basic imports
import matplotlib.pyplot as plt
import numpy as np
import torch

plt.rcParams['figure.dpi'] = 200

## Moving edge stimuli

To elicit moving edge responses and characterise the motion selectivity of neurons, experimenters show an ON or OFF edge moving in different cardinal directions. We generate and render these stimuli with the `MovingEdge` dataset.

In [None]:
# import dataset and visualization helper
from flyvision.datasets.moving_bar import MovingEdge
from flyvision.animations.hexscatter import HexScatter

In [None]:
# initialize dataset
# make the dataset
dataset = MovingEdge(
    offsets=[-10, 11],  # offset of bar from center in 1 * radians(2.25) led size
    intensities=[0, 1],  # intensity of bar
    speeds=[19],  # speed of bar in 1 * radians(5.8) / s
    height=80,  # height of moving bar in 1 * radians(2.25) led size
    post_pad_mode="continue",  # for post-stimulus period, continue with the last frame of the stimulus
    t_pre=1.0,  # duration of pre-stimulus period
    t_post=1.0,  # duration of post-stimulus period
    dt=1 / 200,  # temporal resolution of rendered video
    angles=list(np.arange(0, 360, 30)),  # motion direction (orthogonal to edge)
)

In [None]:
# view stimulus parameters
dataset.arg_df
# the dataset has four samples, one corresponding to each row

In [None]:
# visualize single sample
# %#matplotlib notebook
animation = HexScatter(
    dataset[3][None, ::25, None], vmin=0, vmax=1
)  # intensity=1, radius=6
animation.animate_in_notebook()

## Moving edge response

Now that we have generated the stimulus, we can use it to drive a trained connectome-constrained network.

In [None]:
from flyvision import results_dir
from flyvision.network import NetworkView

# model are already sorted by task error
# we take the best task-performing model, model 0000
network_view = NetworkView("opticflow/000/0000")

In [None]:
# rebuild network from checkpoint
network = network_view.init_network()

In [None]:
responses = np.concatenate(
    [
        ret[1]
        for ret in network.stimulus_response(
            stim_dataset=dataset,
            dt=dataset.dt,
            t_pre=0.0,
        )
    ],
    axis=0,
)

### MovingEdgeResponseView
We've now computed network moving edge responses for all cells in the network. The `MovingEdgeResponseView` class allows us fast and flexible analysis and operations on the stored responses.

In [None]:
from flyvision.utils.nodes_edges_utils import CellTypeArray
from flyvision.analysis.moving_bar_responses import MovingEdgeResponseView

# extract cell responses in central column
central_responses = responses[:, :, network.connectome.central_cells_index[:]]
# wrap responses for easy access by cell type
responses_array = CellTypeArray(
    central_responses,
    cell_types=network.connectome.unique_cell_types[:].astype(str),
)

In [None]:
# initialize MovingEdgeResponseView
merv = MovingEdgeResponseView(
    arg_df=dataset.arg_df, responses=responses_array, config=dataset.config
)

### Response traces
We can plot single-cell response traces with `MovingEdgeResponseView.plot_traces()`. Here, we plot responses of T4c cells to edges with intensity 1 (ON edges).

In [None]:
# %#matplotlib inline
merv.plot_traces(
    cell_type="T4c",
    groupby=["angle"],
    intensity=1,
    t_start=-0.5,
    t_end=1.0,
    plot_kwargs=dict(
        figsize=(2.4, 1.8),
        fontsize=6,
        color=[plt.cm.hsv(x) for x in np.arange(0, 1, 1 / 12)],
    ),
)
plt.show()

### Direction selectivity index (DSI)

The **Direction Selectivity Index (DSI)** quantifies a cell's preference for stimuli moving in a particular direction.

The DSI is derived from the following steps:
1. Obtain the neuron's peak responses to stimuli moving in different directions $\theta$ and at different speeds $S$.
2. Rectify these peak responses to ensure they are non-negative.
3. Compute the DSI using the equation:

$$
DSI_{t_i}(I) = \frac{1}{\lvert S \rvert} \sum_{S \in S} \left\lvert \frac{\sum_{\theta \in \Theta} r^{peak}_{t_{central}}(I, S, \theta) e^{i\theta}}{\max_{I \in I} \left\lvert \sum_{\theta \in \Theta} r^{peak}_{t_{central}}(I, S, \theta) \right\rvert} \right\rvert
$$

Where:
- $DSI_{t_i}(I)$ is the Direction Selectivity Index for cell type $t_i$ at stimulus intensity $I$.
- $\lvert S \rvert$ is the number of different speeds at which stimuli are moved.
- $r^{peak}_{t_{central}}(I, S, \theta)$ represents the rectified peak response of the central cell in hexagonal space of a cell type, for a given stimulus intensity $I$, speed $S$, and direction $\theta$.
- $\theta$ is varied across all tested directions $\Theta$.
- $e^{i\theta}$ introduces the directional component by weighting the response by the complex exponential of the angle of movement.
- The denominator normalizes the responses, ensuring that DSI values range from 0 to 1.

The DSI values range from 0 to 1. A DSI of 0 indicates no directional preference, while a DSI of 1 indicates a strong preference for a specific direction.

For the T4c cell plotted before, we can see that it preferentially responds to ON edges moving at an angle of 60 degrees, so we expect to see a large DSI. We can compute the DSI with `MovingEdgeResponseView.dsi()`.

In [None]:
# get DSI for T4c cell
dsi_T4c = merv.where_stim_args(intensity=1).cell_type("T4c").dsi()[:].squeeze()
print(f"T4c DSI: {dsi_T4c}")

We can also compute the preferred direction of the cell with `MovingEdgeResponseView.preferred_direction()` (this is the direction that the tuning lobe points towards). We would expect the preferred direction to be around 60 degrees based on the response traces.

In [None]:
preferred_direction = (
    merv.where_stim_args(intensity=1)
    .cell_type("T4c")
    .preferred_direction()[:]
    .squeeze()
)
print(f"T4c preferred direction: {preferred_direction / np.pi * 180} degrees")

We can also inspect the direction selecity of a cell type visually, by plotting the angular tuning with `MovingEdgeResponseView.plot_angular_tuning()`. Here we see clearly how the cell is tuned to stimuli moving at a 60 degree angle.

In [None]:
merv.plot_angular_tuning(cell_type="T4c", intensity=1)

### DSI  and tuning curve correlation

With the `dsi()` function we can also compute DSIs for every cell type at once. Since the selectivity of some cell types have been determined experimentally, we can then compare our model to experimental findings by computing the correlation between the model DSIs for known cell types with their expected motion selectivity.

In [None]:
from flyvision.analysis.moving_bar_responses import dsi_correlation_to_known

# compute DSIs for all cell types
dsi_all = merv.dsi()
# get DSI values and corresponding cell type
off_dsis = dsi_all.where_stim_args(intensity=0)
on_dsis = dsi_all.where_stim_args(intensity=1)
dsis = np.stack([off_dsis[:], on_dsis[:]], axis=0)[:, :, 0]  # remove temproal dim
cell_types = dsi_all.responses.cell_types
# compute correlation
dsi_corr = dsi_correlation_to_known(dsis, cell_types, respect_contrast=True).squeeze()
print(f"DSI correlation = {dsi_corr}")

Further, for certain cell types, their actual tuning curves have also been measured experimentally, so we can correlate our model cell's tuning to the true values. For T4c, the cell is known to tune to stimuli moving at 90 degrees, so the correlation should be relatively high.

In [None]:
from flyvision.analysis.moving_bar_responses import tuning_curve_correlation_to_known

# compute DSIs for all cell types
tuning_curve, (angles, intensities) = merv.tuning_curves()
# get DSI values and corresponding cell type
tc_corr = tuning_curve_correlation_to_known(
    tuning=tuning_curve, angles=angles, intensities=intensities
)
print(f"T4c tuning curve correlation = {tc_corr['T4c']}")

In fact, tuning curves for all T4 and T5 cells have been measured, so we can compute the correlation for all 8 cell types.

In [None]:
import pprint

print(f"Tuning curve correlations: \n{pprint.pformat(tc_corr)}")

So, the model yields accurate predictions for all T4 and T5 cell types.

## Ensemble responses

Now we can compare motion selectivity properties across an ensemble of trained models. First we need to again simulate the network responses.

In [None]:
from flyvision.ensemble import EnsembleView

ensemble = EnsembleView("opticflow/000")

In [None]:
central_cells_index = ensemble[ensemble.names[0]].connectome.central_cells_index[:]
with ensemble.ratio(
    best=0.20
):  # take only top 20% (10 in this case) of models
    responses = np.stack(
        [
            np.concatenate(
                [
                    ret[1][:, :, central_cells_index]
                    for ret in net.stimulus_response(
                        stim_dataset=dataset,
                        dt=dataset.dt,
                        t_pre=0.0,
                    )
                ],
                axis=0,
            )
            for net in ensemble.yield_networks()
        ],
        axis=0,
    )

We again use `MovingEdgeResponseView` to wrap around the network responses to moving edges.

In [None]:
responses_array = CellTypeArray(
    responses,
    cell_types=network.connectome.unique_cell_types[:].astype(str),
)

In [None]:
merv = MovingEdgeResponseView(
    arg_df=dataset.arg_df,
    responses=responses_array,
    config=dataset.config,
    stim_sample_dim=1,
    temporal_dim=2,
)

### Response traces

We can once again plot response traces for a single cell type. We subtract the initial value of each trace and rescale by the maximum value before plotting, as the network neuron activities are in arbitrary units. We plot only T4c responses to ON edges moving at a 90-degree angle.

In [None]:
centered = (
    merv.between_seconds(-0.5, 1.0)
    - merv.between_seconds(-0.5, 1.0).responses.array[:, :, [0]]
)
centered /= centered.abs().max(dims=(1, 2), keepdims=True)
centered.plot_traces(
    cell_type="T4c",
    angle=90,
    intensity=1,
    plot_kwargs=dict(figsize=(2.4, 1.8), fontsize=6),
)
plt.show()

Though for most networks T4c responses are correctly predicted to the stimuli, there are some networks in the ensemble with different tuning.

### Direction selectivity index (DSI)

We can also compute direction selectivity indices for each network in the ensemble.

In [None]:
# get DSI for T4c cell
dsi_l1 = merv.where_stim_args(intensity=1).cell_type("T4c").dsi()[:].squeeze().tolist()
print(f"T4c DSIs: {pprint.pformat(dsi_l1)}")

Most networks in this group recover some direction selectivity for T4c. We can also plot the distribution of DSIs per cell type for both ON and OFF-edge stimuli across the ensemble.

In [None]:
# compute FRIs for all cell types
dsi_all = merv.dsi()
# get FRI values and corresponding cell types
dsis = dsi_all.responses.array.squeeze()
cell_types = dsi_all.responses.cell_types

In [None]:
from flyvision.analysis.moving_bar_responses import plot_dsis

fig, ax = plot_dsis(
    dsis,
    cell_types,
    bold_output_type_labels=True,
    output_cell_types=ensemble[ensemble.names[0]]
    .connectome.output_cell_types[:]
    .astype(str),
    figsize=[10, 1.2],
    color_known_types=True,
    fontsize=6,
    scatter_best_index=0,
    scatter_best_color=plt.get_cmap("Blues")(1.0),
)
fig.show()

### DSI correlation

Lastly, we look at the correlations to ground-truth DSIs and tuning curves across the ensemble. This provides us with a high-level understanding of the accuracy of known motion tuning predictions. 

In [None]:
off_dsis = dsi_all.where_stim_args(intensity=0)
on_dsis = dsi_all.where_stim_args(intensity=1)
dsis = np.stack([off_dsis[:], on_dsis[:]], axis=0)[:, :, 0]  # remove temproal dim
cell_types = dsi_all.responses.cell_types
dsi_corr = dsi_correlation_to_known(dsis, cell_types, respect_contrast=True)

tuning_curve, (angles, intensities) = merv.tuning_curves()
tc_corr = tuning_curve_correlation_to_known(
    tuning_curve, angles, intensities, aggregate_dims=3
)

t4_corr = np.median(
    [tc_corr[cell_type].squeeze() for cell_type in ["T4a", "T4b", "T4c", "T4d"]], axis=0
)
t5_corr = np.median(
    [tc_corr[cell_type].squeeze() for cell_type in ["T5a", "T5b", "T5c", "T5d"]], axis=0
)

In [None]:
from flyvision.plots.plots import violin_groups

fig, ax, *_ = violin_groups(
    np.stack([dsi_corr, t4_corr, t5_corr], axis=0)[:, None, :],
    ["DSI", "T4 tuning", "T5 tuning"],
    ylabel="correlation",
    figsize=(1.8, 1.5),
    ylim=(-1, 1),
    colors=[
        plt.get_cmap("Dark2")(0.125),
        plt.get_cmap("Dark2")(0),
        plt.get_cmap("Dark2")(0.25),
    ],
    color_by="experiments",
    scatter_edge_color="gray",
    scatter_radius=5,
    violin_alpha=0.8,
)

<!-- ... Models in general have very good match to known single-neuron tuning properties, with median correlation around $0.8$. -->