In [None]:
%load_ext autoreload
%autoreload 2

# Cluster analysis based on naturalistic stimuli responses

This notebook illustrates how to cluster the models of an ensemble after nonlinear dimensionality reduction on their predicted responses to naturalistic stimuli. This can be done for any cell type. Here we provide a detailed example focusing on clustering based on T4c responses.

**Select GPU runtime**

To run the notebook on a GPU select Menu -> Runtime -> Change runtime type -> GPU.

In [None]:
# @markdown **Check access to GPU**

try:
    import google.colab

    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    import torch

    try:
        cuda_name = torch.cuda.get_device_name()
        print(f"Name of the assigned GPU / CUDA device: {cuda_name}")
    except RuntimeError:
        import warnings

        warnings.warn(
            "You have not selected Runtime Type: 'GPU' or Google could not assign you one. Please revisit the settings as described above or proceed on CPU (slow)."
        )

**Install Flyvis**

The notebook requires installing our package `flyvis`. You may need to restart your session after running the code block below with Menu -> Runtime -> Restart session. Then, imports from `flyvis` should succeed without issue.

In [None]:
if IN_COLAB:
    #@markdown **Install Flyvis**
    %%capture
    !git clone https://github.com/flyvis/flyvis-dev.git
    %cd /content/flyvis-dev
    !pip install -e .

In [None]:
# basic imports
import matplotlib.pyplot as plt
import numpy as np
import torch

plt.rcParams['figure.dpi'] = 200

# Naturalistic stimuli dataset (Sintel)
We load the dataset with our custom augmentations. The dataset contains movie sequences from the publicly available computer-animated movie Sintel rendered to the hexagonal lattice structure of the fly eye. For a more detailed introduction to the dataset class and parameters see the notebook on the optic flow task.

In [None]:
import flyvision
from flyvision.datasets.sintel import AugmentedSintel
import numpy as np

In [None]:
dt = 1 / 200  # can be changed for other temporal resolutions
dataset = AugmentedSintel(tasks=["lum"], dt=dt, temporal_split=True)

In [None]:
# view stimulus parameters
dataset.arg_df

In [None]:
sequence = dataset[0]["lum"]

In [None]:
# one sequence contains 80 frames with 721 hexals each
sequence.shape

In [None]:
animation = flyvision.animations.HexScatter(sequence[None], vmin=0, vmax=1)
animation.animate_in_notebook(frames=np.arange(5))

# Ensemble responses to a single sequence
We compute the responses of all models in the stored ensemble to the first sequence of the augmented Sintel dataset.

In [None]:
from flyvision import results_dir

In [None]:
# We load the ensemble trained on the optic flow task
ensemble = flyvision.ensemble.EnsembleView(results_dir / "flow/0000")

`ensemble.simulate` provides an efficient method to return responses of all networks within the ensemble.

In [None]:
responses = np.array(list(ensemble.simulate(sequence[None], dataset.dt, fade_in=True)))

In [None]:
responses.shape

`CentralActivity` is an interface to the response tensor of 45k cells that allows dict- and attribute-style access to the responses of the central cells of the different cell types.

In [None]:
from flyvision.utils.activity_utils import CentralActivity

central_responses = CentralActivity(responses, ensemble[0].connectome, keepref=True)

We visualize the central T4c responses for the whole ensemble

In [None]:
cell_type = "T4c"

In [None]:
n_frames = sequence.shape[0]
time = np.arange(0, n_frames * dataset.dt, dataset.dt)

In [None]:
colors = ensemble.task_error().colors

In [None]:
fig, ax = flyvision.plots.plt_utils.init_plot([2, 2], fontsize=5)
for model_id, response in enumerate(central_responses[cell_type]):
    r = response.squeeze()
    ax.plot(
        time,
        (r - r[0]) / np.abs(r).max(),
        c=colors[model_id],
        zorder=len(ensemble) - model_id,
    )
ax.set_xlabel("time in s", fontsize=5)
ax.set_ylabel("response (a.u.)", fontsize=5)
ax.set_title(f"{cell_type} responses across the ensemble", fontsize=5)

We see that the across models of the ensemble the predictions for T4c vary. Our goal is to understand the underlying structure in those variations.

## Nonlinear dimensionality reduction (UMAP) and Gaussian Mixtures

In [None]:
from matplotlib.pyplot import subplot_mosaic

from flyvision.analysis.clustering import EnsembleEmbedding, get_cluster_to_indices

In [None]:
# specify parameters for umap embedding

embedding_kwargs = {
    "min_dist": 0.105,
    "spread": 9.0,
    "n_neighbors": 5,
    "random_state": 42,
    "n_epochs": 1500,
}

We compute the UMAP embedding of the ensemble based on the T4c responses of the single models to the single sequence for illustration.

In [None]:
central_responses[:].shape

In [None]:
embedding = EnsembleEmbedding(central_responses)
t4c_embedding = embedding("T4c", embedding_kwargs=embedding_kwargs)

In [None]:
task_error = ensemble.task_error()

In [None]:
fig, ax = t4c_embedding.plot(colors=task_error.colors)

Each of these scatterpoints in 2d represents a single time series plotted above.

We fit a Gaussian Mixture of 2 to 5 components to this embedding to label the clusters. We select the final number of Gaussian Mixture components that minimize the Bayesian Information Criterion (BIC).

In [None]:
# specifiy parameters for Gaussian Mixture

gm_kwargs = {
    "range_n_clusters": [1, 2, 3, 4, 5],
    "n_init": 100,
    "max_iter": 1000,
    "random_state": 42,
    "tol": 0.001,
}

In [None]:
gm_clustering = t4c_embedding.cluster.gaussian_mixture(**gm_kwargs)

In [None]:
embeddingplot = gm_clustering.plot(
    task_error=task_error.values, colors=task_error.colors
)

We can use the labels to disambiguate the time series data that we plotted above. We expect that these labels aggregate similar time series together and different time series separately.

In [None]:
cluster_to_indices = get_cluster_to_indices(
    embeddingplot.cluster.embedding.mask,
    embeddingplot.cluster.labels,
    ensemble.task_error(),
)

In [None]:
cluster_colors = {}
CMAPS = ["Blues_r", "Reds_r", "Greens_r", "Oranges_r", "Purples_r"]

for cluster_id in cluster_to_indices:
    cluster_colors[cluster_id] = ensemble.task_error(cmap=CMAPS[cluster_id]).colors

In [None]:
fig, ax = flyvision.plots.plt_utils.init_plot([2, 2], fontsize=5)
for cluster_id, model_ids in cluster_to_indices.items():
    for model_id, response in zip(
        model_ids, central_responses[cell_type][np.array(model_ids)]
    ):
        r = response.squeeze()
        ax.plot(
            time, (r - r[0]) / np.abs(r).max(), c=cluster_colors[cluster_id][model_id]
        )

ax.set_xlabel("time in s", fontsize=5)
ax.set_ylabel("response (a.u.)", fontsize=5)
ax.set_title(f"{cell_type} responses across the ensemble", fontsize=5)
ylim = ax.get_ylim()

In [None]:
fig, axes, _ = flyvision.plots.plt_utils.get_axis_grid(
    cluster_to_indices, fontsize=5, figsize=[5, 4], wspace=0.3, hspace=0.5
)
for cluster_id, model_ids in cluster_to_indices.items():
    ax = axes[cluster_id]
    for model_id, response in zip(
        model_ids, central_responses[cell_type][np.array(model_ids)]
    ):
        r = response.squeeze()
        ax.plot(
            time, (r - r[0]) / np.abs(r).max(), c=cluster_colors[cluster_id][model_id]
        )
    ax.set_xlabel("time in s", fontsize=5)
    ax.set_ylabel("response (a.u.)", fontsize=5)

The clustering has led us to 5 qualitatively distinct predictions from the ensemble for this cell and sequence. This is a first lead for an underlying structure in these predictions. We will get an even better estimate once we use more sequences for the clustering.

# Clustering based on the ensemble responses to the whole dataset

In [None]:
from flyvision.utils.activity_utils import StimulusResponseIndexer
from flyvision.utils.activity_utils import CellTypeArray

Because this analysis is costly, we randomly select a subset of samples from the dataset of 2268 sequences to illustrate how it scales (one may set 'indices' to None to compute all responses). We can also include only the best x-% of models if we wanted. Skip ahead to the next section to download the precomputed clusterings.

In [None]:
np.random.seed(42)
indices = np.random.choice(np.arange(len(dataset)), replace=False, size=64)


with ensemble.ratio(best=1.0):
    responses = np.stack(
        list(
            ensemble.simulate_from_dataset(
                dataset,
                dt=1 / 200,
                indices=indices,
                batch_size=4,
                central_cell_only=True,
            )
        )
    )

In [None]:
arg_df = dataset.arg_df.loc[indices].reset_index(drop=True)
sri = StimulusResponseIndexer(
    arg_df,
    CellTypeArray(responses, ensemble[0].connectome),
    dt=dataset.dt,
    t_pre=0,
    temporal_dim=2,
    stim_sample_dim=1,
)

centered = sri - sri.responses.array[:, :, [0]]
centered /= sri.abs().max(dims=(1, 2), keepdims=True)

centered.plot_traces("T4c", plot_kwargs=dict(legend=[]))

When we plot the centered T4c responses across all models and stimuli we see again lots of structure. The amount of data would now make it difficult to disambiguate them all at this scale manually. It would also be easier to interpret differences in responses to simple stimuli rather than to naturalistic stimuli.

That's why we again first compute a non-linear dimensionality reduction of these traces to 2d and then we cluster to understand the structure in the dataset. The dim. reduction just pretends that traces from individual movie sequences are a single long trace that needs to be embedded. Afterwards we interpret differences in these clusters in responses to simple stimuli.

In [None]:
central_responses = CentralActivity(responses, ensemble[0].connectome, keepref=True)

In [None]:
embedding = EnsembleEmbedding(central_responses)
t4c_embedding = embedding("T4c", embedding_kwargs=embedding_kwargs)

In [None]:
gm_kwargs = {'range_n_clusters': [2, 3, 3, 4, 5],
 'n_init': 100,
 'max_iter': 1000,
 'random_state': 42,
 'tol': 0.001}

In [None]:
with ensemble.ratio(best=1.0):
    task_error = ensemble.task_error()

In [None]:
embeddingplot = t4c_embedding.cluster.gaussian_mixture(**gm_kwargs).plot(
    task_error=task_error.values, colors=task_error.colors
)

This clustering looks already very close to the result in the paper! Note though that this is still based on a small subset of the Sintel dataset.

# Using the clustering to discover tuning predictions in responses to simple stimuli

We expect that the clustering based on naturalistic stimuli will also disambiguate the different tuning predictions from different models for simple stimuli.

In [None]:
cluster_to_indices = get_cluster_to_indices(
    embeddingplot.cluster.embedding.mask,
    embeddingplot.cluster.labels,
    ensemble.task_error(),
)

In [None]:
# define different colormaps for clusters
cluster_colors = {}
CMAPS = ["Blues_r", "Reds_r", "Greens_r", "Oranges_r", "Purples_r"]

for cluster_id in cluster_to_indices:
    cluster_colors[cluster_id] = ensemble.task_error(cmap=CMAPS[cluster_id]).colors

## Clustered voltage responses to moving edges

In [None]:
from flyvision.datasets.moving_bar import MovingEdge

In [None]:
mer_dataset = MovingEdge(
    offsets=[-10, 11],  # offset of bar from center in 1 * radians(2.25) led size
    intensities=[0, 1],  # intensity of bar
    speeds=[19],  # speed of bar in 1 * radians(5.8) / s
    height=80,  # height of moving bar in 1 * radians(2.25) led size
    post_pad_mode="continue",  # for post-stimulus period, continue with the last frame of the stimulus
    t_pre=1.0,  # duration of pre-stimulus period
    t_post=1.0,  # duration of post-stimulus period
    dt=1 / 200,  # temporal resolution of rendered video
    angles=list(np.arange(0, 360, 30)),  # motion direction (orthogonal to edge)
)

In [None]:
central_cells_index = ensemble[0].connectome.central_cells_index[:]
with ensemble.ratio(best=1.0):  # take only top 20% (10 in this case) of models
    mer = np.stack(
        list(
            ensemble.simulate_from_dataset(
                mer_dataset,
                dt=mer_dataset.dt,
                batch_size=4,
                central_cell_only=True,
            )
        )
    )

In [None]:
responses_array = CellTypeArray(
    mer,
    cell_types=ensemble[0].connectome.unique_cell_types[:].astype(str),
)

In [None]:
from flyvision.analysis.moving_bar_responses import MovingEdgeResponseView

In [None]:
merv = MovingEdgeResponseView(
    arg_df=mer_dataset.arg_df,
    responses=responses_array,
    config=mer_dataset.config,
    stim_sample_dim=1,
    temporal_dim=2,
)

In [None]:
centered = (
    merv.between_seconds(-0.5, 1.0)
    - merv.between_seconds(-0.5, 1.0).responses.array[:, :, [0]]
)
centered /= centered.abs().max(dims=(1, 2), keepdims=True)
centered.plot_traces(
    cell_type="T4c",
    angle=90,
    intensity=1,
    plot_kwargs=dict(figsize=(2.4, 1.8), fontsize=6),
)
plt.show()

In [None]:
fig, axes, _ = flyvision.plots.plt_utils.get_axis_grid(
    cluster_to_indices, fontsize=5, figsize=[5, 4], wspace=0.3, hspace=0.5
)
for cluster_id, model_ids in cluster_to_indices.items():
    ax = axes[cluster_id]
    centered[model_ids, :, :].plot_traces(
        cell_type="T4c",
        angle=90,
        intensity=1,
        plot_kwargs=dict(
            figsize=(2.4, 1.8),
            fontsize=6,
            fig=fig,
            ax=ax,
            title=f"cluster {cluster_id}",
            color=cluster_colors[cluster_id][model_ids]
        ),
    )
    ax.set_xlabel("time in s", fontsize=5)
    ax.set_ylabel("response (a.u.)", fontsize=5)

## Clustered peak voltage responses to moving edges

In [None]:
merv.plot_angular_tuning(cell_type="T4c", intensity=1)

In [None]:
fig, axes, _ = flyvision.plots.plt_utils.get_axis_grid(
    cluster_to_indices,
    fontsize=5,
    figsize=[5, 4],
    wspace=0.3,
    hspace=0.5,
    projection="polar",
)
for cluster_id, model_ids in cluster_to_indices.items():
    ax = axes[cluster_id]
    merv[model_ids, :, :].plot_angular_tuning(
        cell_type="T4c", intensity=1, colors=cluster_colors[cluster_id][model_ids],
        fig=fig, ax=ax
    )

As we can see here, the models quite nicely predict clustered neural responses. We discovered all of these clusters simply by using UMAP and Gaussian Mixtures.

# Load precomputed umap and clustering

Due to the computational requirement of recording and embedding all responses and for consistency we also show how to use the precomputed embeddings and clusterings from the paper.

In [None]:
cell_type = "T4c"
clustering = ensemble.clustering(cell_type)

In [None]:
task_error = ensemble.task_error()

In [None]:
embeddingplot = clustering.plot(task_error=task_error.values,
                                colors=task_error.colors)

With this embedding and clustering one can proceed in the same way as above to plot the tunings.