In [1]:
# 🔧 Setup: Imports & Pfade
import sys
import os
sys.path.append("/home/karl-/liquidstatemachines")

import torch
import tonic.transforms as transforms
from data.dataloader import load_filtered_shd_dataloader
from models.sffnn_batched import Net


In [2]:
# 🖥️ Device konfigurieren
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [3]:
# 🔄 Testdaten vorbereiten
transform = transforms.Compose([
    transforms.Downsample(spatial_factor=0.5),
    transforms.ToFrame(sensor_size=(350,1,1), n_time_bins=250)
])

test_dataloader = load_filtered_shd_dataloader(label_range=range(10), transform=transform, train=False, batch_size=1)


In [10]:
from snntorch.import_nir import import_from_nir
import nir
# read from file
nir_network=nir.read("../model_export/my_model.nir")
snntorch_network = import_from_nir(nir_network)

replace rnn subgraph with nirgraph


In [8]:
snntorch_network.to(device)

GraphExecutor(
  (fc1): Linear(in_features=350, out_features=1000, bias=True)
  (fc2): Linear(in_features=1000, out_features=10, bias=True)
  (input): Identity()
  (lif1): Leaky()
  (lif2): Leaky()
  (output): Identity()
)

In [6]:
# 🧪 Eine Vorhersage machen
with torch.no_grad():
    for events, labels in test_dataloader:
        events = events.squeeze(2).to(device).float()  # [1, T, 1, 350] → [1, T, 350]
        labels = labels.to(device)
        
        spk_rec, _ = snntorch_network(events)
        spike_sums = spk_rec.sum(dim=1)  # → [B, num_outputs]
        pred = torch.argmax(spike_sums, dim=1)

        print(f"🔍 Predicted: {pred.item()}, ✅ Ground Truth: {labels.item()}")
        break  # nur ein Sample testen

🔍 Predicted: 2, ✅ Ground Truth: 2


In [7]:
from utils.metrics import *
print_full_dataloader_accuracy_batched(snntorch_network,test_dataloader)

Accuracy (Full Dataloader): 0.2753
