In [1]:
# 🔧 Setup: Imports & Pfade
import sys
import os
sys.path.append("/home/karl-/liquidstatemachines")

import torch
import tonic.transforms as transforms
from data.dataloader import load_filtered_shd_dataloader
from models.sffnn_batched import Net


In [2]:
# 🖥️ Device konfigurieren
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [3]:
# 🔄 Testdaten vorbereiten
transform = transforms.Compose([
    transforms.Downsample(spatial_factor=0.5),
    transforms.ToFrame(sensor_size=(350,1,1), n_time_bins=250)
])

test_dataloader = load_filtered_shd_dataloader(label_range=range(10), transform=transform, train=False, batch_size=1)


In [4]:
from snntorch.import_nir import import_from_nir
import nir
# read from file
nir_network=nir.read("../model_export/my_model.nir")
snntorch_network = import_from_nir(nir_network)

replace rnn subgraph with nirgraph


In [5]:
from models.sffnn_batched import Net

model = Net(num_inputs=350, num_hidden=1000, num_outputs=10, num_steps=250, beta=0.9)

In [6]:
nir_state_dict = dict(snntorch_network.state_dict())

In [7]:
model.load_state_dict(nir_state_dict, strict=False)

<All keys matched successfully>

In [8]:
model.to(device)
model.eval()
# 🧪 Eine Vorhersage machen
with torch.no_grad():
    for events, labels in test_dataloader:
        events = events.squeeze(2).to(device).float()  # [1, T, 1, 350] → [1, T, 350]
        labels = labels.to(device)
        print(events.shape)
        
        spk_rec, _ = model(events)
        spike_sums = spk_rec.sum(dim=1)  # → [B, num_outputs]
        pred = torch.argmax(spike_sums, dim=1)

        print(f"🔍 Predicted: {pred.item()}, ✅ Ground Truth: {labels.item()}")
        accuracy = (pred == labels).sum().item() / len(labels)
        print(f"Accuracy:{accuracy}")
        break  # nur ein Sample testen


torch.Size([1, 250, 350])
🔍 Predicted: tensor([4], device='cuda:0'), ✅ Ground Truth: tensor([4], device='cuda:0')
Accuracy:1.0


In [None]:
from utils.metrics import *
model.to(device)
print_full_dataloader_accuracy_batched(model,test_dataloader)