In [None]:
import pickle
from neuron_visualization.NeuronVisualizer import NeuronVisualizer
import pandas as pd

env_type = "triangle"

with open(f"models/rnn_{env_type[0:2]}_model.pkl", "rb") as f:
    model = pickle.load(f)
model.set_device("cpu")
visualizer = NeuronVisualizer(model)
data = pd.read_csv(f"data/{env_type}_1000traj_50steps.csv")
visualizer.retrieve_activations(data, use_predicted=True)

In [None]:
from matplotlib import pyplot as plt
import numpy as np
from utils.plots import smart_subplots
maps = visualizer.get_spatial_maps(absolute=False)
cos_sims = visualizer.compute_cos_sim_on_maps(maps)

import seaborn as sns
import scipy.cluster.hierarchy as sch

linkage = sch.linkage(1 - cos_sims, method='ward')
sns.clustermap(cos_sims, row_linkage=linkage, col_linkage=linkage, cmap='coolwarm', center=0)

# cluster the resulting cells into 5 clusters
from sklearn.cluster import AgglomerativeClustering
num_clusters = 5
clustering = AgglomerativeClustering(n_clusters=num_clusters, linkage='average')
labels = clustering.fit_predict(1 - cos_sims)
for cluster_id in range(num_clusters):
    cluster_indices = np.where(labels == cluster_id)[0]
    print(f"Cluster {cluster_id}: {cluster_indices}")
    fig, axes = smart_subplots(len(cluster_indices), 7)
    axes = np.array(axes).ravel()
    for ax in axes:
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xticklabels([])
        ax.set_yticklabels([])
    for i, neuron_idx in enumerate(cluster_indices):
        axes[i].imshow(maps[neuron_idx], cmap='viridis', origin='lower')
        axes[i].set_title(f'Neuron {neuron_idx}')
    plt.suptitle(f'Spatial Maps for Cluster {cluster_id}')
    plt.show()

In [None]:
import matplotlib.pyplot as plt
from utils.plots import smart_subplots
maps = visualizer.get_spatial_maps(absolute=False)
fig, axes = smart_subplots(100, 10)
for i, ax in enumerate(axes):
    if i < maps.shape[0]:
        ax.imshow(maps[i], cmap='viridis', origin='lower')
    ax.axis('off')
plt.savefig(f'tex/figures/rnn_{env_type}_neuron_spatial_maps.png', bbox_inches='tight')
plt.show()