In [None]:
from glob import glob
import json
import pandas as pd

files = glob('./output/experiment_different_surrogate_approximation/*.json')

all_dicts = []

for file in files:
    with open(file, "r") as f:
        data = json.load(f)
        all_dicts.append(data)

df = pd.DataFrame(all_dicts)

df

In [None]:
from constants import TIME_STEPS
from different_surrogate_approximation_experiment import best_grid_search_model_atan, best_grid_search_model_spike_rate_escape

In [None]:
import torch
import copy

atan_model = copy.deepcopy(best_grid_search_model_atan)
atan_model.load_state_dict(torch.load('./models/experiment_different_surrogate_approximation/atan.pth'))

spike_rate_escape_model = copy.deepcopy(best_grid_search_model_spike_rate_escape)
spike_rate_escape_model.load_state_dict(torch.load('./models/experiment_different_surrogate_approximation/spike_rate_escape.pth'))

In [None]:
import torch
from util.utils import get_device
from torch.utils.data import DataLoader
from tonic import datasets, transforms

selection_index = 2
device = get_device()

frame_transform = transforms.ToFrame(
    sensor_size=datasets.SHD.sensor_size,  
    n_time_bins=TIME_STEPS
)

test_data = datasets.SHD("./data", transform=frame_transform, train=False)

test_data_loader = DataLoader(test_data, shuffle=False, batch_size=32)

data, target = list(test_data_loader)[0]
data = data.to_dense().to(torch.float32).squeeze().permute(1, 0, 2).to(device)

x_selected = data[:, selection_index, :]
y_selected = target[selection_index]


In [None]:
def get_spk_matrices(data, model, selection_index):
    spk_recs, _ = model(data)

    output_spk_rec = spk_recs[-1][:, selection_index, :]
    hidden_spk_rec = [hidden_spk_rec[:, selection_index, :].detach() for hidden_spk_rec in spk_recs[:-1]]

    return [x_selected, *hidden_spk_rec, output_spk_rec.detach()]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

spike_matrices_atan = get_spk_matrices(data, atan_model, selection_index)
spike_matrices_rate_escape = get_spk_matrices(data, spike_rate_escape_model, selection_index)

fig, axes = plt.subplots(2, len(spike_matrices_atan), figsize=(20, 10))

fig.suptitle(f"Different Surrogates. Label {y_selected}", fontsize=16)

for index in range(len(spike_matrices_atan)):
    spike_matrix_np = spike_matrices_atan[index].numpy()
    times, neurons = np.where(spike_matrix_np == 1)
    ax_untrained = axes[0, index]
    ax_untrained.scatter(times, neurons, s=1, color='black')
    ax_untrained.set_title(f"Atan - Layer {index}")
    ax_untrained.set_xlabel("Time step")
    ax_untrained.set_ylim(-1, spike_matrix_np.shape[1])
    if index == 0:
        ax_untrained.set_ylabel("Neuron index")

    spike_matrix_np = spike_matrices_rate_escape[index].numpy()
    times, neurons = np.where(spike_matrix_np == 1)
    ax_trained = axes[1, index]
    ax_trained.scatter(times, neurons, s=1, color='black')
    ax_trained.set_title(f"Spike Rate Escape - Layer {index}")
    ax_trained.set_xlabel("Time step")
    ax_trained.set_ylim(-1, spike_matrix_np.shape[1])
    if index == 0:
        ax_trained.set_ylabel("Neuron index")