# Feed Forward SNN mit Batched Input für effizienteres Training

In [11]:
# dataloading 

import sys
import os
sys.path.append("/home/karl-/liquidstatemachines")
from data.dataloader import *
from data.dummy_event_based_classification_dataset import *
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from tonic import MemoryCachedDataset
from utils.spike_plots import *

transform = transforms.ToFrame(
    sensor_size=tonic.datasets.SHD.sensor_size,  # = (700,),
    n_time_bins=250
    
)
train_dataset=DummySpikeDataset(2000,10000)
cached_train_dataset = MemoryCachedDataset(train_dataset, transform=transform)
train_dataloader=DataLoader(cached_train_dataset,batch_size=128)

test_dataset=DummySpikeDataset(2000,10000,21)
cached_test_dataset = MemoryCachedDataset(test_dataset, transform=transform)
test_dataloader=DataLoader(cached_test_dataset,batch_size=128)

events, labels = next(iter(train_dataloader))

In [12]:
import torch.nn as nn
from models.sffnn_batched import *

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
ffn = Net(num_inputs=700,num_hidden=1000,num_outputs=2,num_steps=250,beta=0.65).to(device)

In [13]:
events, labels = next(iter(train_dataloader))  # z. B. events: [B, T, 1, 700]
events = events.squeeze(2).to(device).float()  # → [B, T, 700]
labels = labels.to(device)

spk_rec, mem_rec = ffn(events)  # spk_rec: [B, T, 2]
spike_sums = spk_rec.sum(dim=1)  # über Zeit → [B, 2]
pred_labels = torch.argmax(spike_sums, dim=1)  # [B]
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(spike_sums, labels)

In [14]:
optimizer = torch.optim.Adam(ffn.parameters(), lr=5e-4)

def train_one_epoch(net, dataloader, optimizer, loss_fn, device):
    net.train()
    total_loss = 0
    correct = 0
    total = 0

    for events, labels in dataloader:
        # events: [B, T, 1, 700] → [B, T, 700]
        events = events.reshape(events.size(0), events.size(1), events.size(3))  # [B, T, 700]

        events = events.to(device).float()
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        spk_rec, _ = net(events)              # spk_rec: [B, T, num_outputs]
        spike_sums = spk_rec.sum(dim=1)       # Rate Coding: [B, num_outputs]

        # Compute loss
        loss = loss_fn(spike_sums, labels)    # labels: [B]
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)  # multipliziere für Batch-Summen

        # Accuracy
        preds = torch.argmax(spike_sums, dim=1)     # [B]
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    print(f"Train Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy


for epoch in range(1, 10 + 1):
    print(f"\nEpoch {epoch}")
    train_one_epoch(ffn, train_dataloader, optimizer, loss_fn, device)



Epoch 1


IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

In [None]:
ffn.eval()
correct = 0
total = 0
with torch.no_grad():
    for events, labels in test_dataloader:
        events = events.squeeze(2).to(device).float()
        labels = labels.to(device)

        spk_rec, _ = ffn(events)
        spike_sums = spk_rec.sum(dim=1)
        preds = torch.argmax(spike_sums, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {correct / total:.4f}")


Test Accuracy: 0.4967


In [None]:
print(spk_rec.mean().item())  # z. B. ≈ 0.0?
print(spk_rec.sum(dim=1)[:5])  # [B, 2]


0.0
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
