# Feed Forward SNN mit Batched Input für effizienteres Training

In [1]:
# dataloading 

import sys
import os
sys.path.append("/home/karl-/liquidstatemachines")
from data.dataloader import *
from data.dummy_event_based_classification_dataset import *
import tonic.transforms as transforms
from torch.utils.data import DataLoader
from tonic import MemoryCachedDataset
from utils.spike_plots import *

transform = transforms.ToFrame(
    sensor_size=tonic.datasets.SHD.sensor_size,  # = (700,),
    n_time_bins=250
    
)
train_dataset=DummySpikeDataset(2000,10000)
cached_train_dataset = MemoryCachedDataset(train_dataset, transform=transform)
train_dataloader=DataLoader(cached_train_dataset,batch_size=128)

test_dataset=DummySpikeDataset(2000,10000,21)
cached_test_dataset = MemoryCachedDataset(test_dataset, transform=transform)
test_dataloader=DataLoader(cached_test_dataset,batch_size=128)

events, labels = next(iter(train_dataloader))

In [2]:
import torch.nn as nn
from models.sffnn_batched import *

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
ffn = Net(num_inputs=700,num_hidden=1000,num_outputs=2,num_steps=250,beta=0.95).to(device)

In [3]:
events, labels = next(iter(train_dataloader))  # z. B. events: [B, T, 1, 700]
print(f"Original events shape: {events.shape}")

events = events.squeeze(2).to(device).float()  # → [B, T, 700]
print(f"After squeeze shape: {events.shape}")
labels = labels.to(device)
print(f"Labels: {labels[:10]}")
print(f"Unique labels: {torch.unique(labels)}")
spk_rec, mem_rec = ffn(events)  # spk_rec: [B, T, 2]
spike_sums = spk_rec.sum(dim=1)  # über Zeit → [B, 2]
pred_labels = torch.argmax(spike_sums, dim=1)  # [B]
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(spike_sums, labels)

Original events shape: torch.Size([128, 250, 1, 700])
After squeeze shape: torch.Size([128, 250, 700])
Labels: tensor([0, 1, 0, 0, 1, 1, 0, 0, 0, 0], device='cuda:0')
Unique labels: tensor([0, 1], device='cuda:0')


In [4]:
optimizer = torch.optim.Adam(ffn.parameters(), lr=5e-4)

def train_one_epoch(net, dataloader, optimizer, loss_fn, device):
    net.train()
    total_loss = 0
    correct = 0
    total = 0

    for events, labels in dataloader:
        # events: [B, T, 1, 700] → [B, T, 700]
        events = events.squeeze(2).to(device).float()  # Entferne reshape, nutze squeeze
        
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        spk_rec, _ = net(events)              # spk_rec: [B, T, num_outputs]
        spike_sums = spk_rec.sum(dim=1)       # Rate Coding: [B, num_outputs]
        print(f"Spike mean: {spk_rec.mean().item()}")
        print(f"Spike max: {spk_rec.max().item()}")
        print(f"Spike sum per sample: {spk_rec.sum(dim=1)[:5]}")
        print(f"Memory mean: {mem_rec.mean().item()}")
        # Compute loss
        loss = loss_fn(spike_sums, labels)    # labels: [B]
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)

        # Accuracy
        preds = torch.argmax(spike_sums, dim=1)     # [B]
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total
    print(f"Train Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    return avg_loss, accuracy


for epoch in range(1, 10 + 1):
    print(f"\nEpoch {epoch}")
    train_one_epoch(ffn, train_dataloader, optimizer, loss_fn, device)



Epoch 1
Spike mean: 0.00015625001105945557
Spike max: 1.0
Spike sum per sample: tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0', grad_fn=<SliceBackward0>)
Memory mean: -0.2765536904335022
Spike mean: 0.053578127175569534
Spike max: 1.0
Spike sum per sample: tensor([[27.,  0.],
        [27.,  0.],
        [27.,  0.],
        [27.,  0.],
        [27.,  0.]], device='cuda:0', grad_fn=<SliceBackward0>)
Memory mean: -0.2765536904335022
Spike mean: 0.018765626475214958
Spike max: 1.0
Spike sum per sample: tensor([[8., 0.],
        [9., 0.],
        [9., 0.],
        [9., 0.],
        [9., 0.]], device='cuda:0', grad_fn=<SliceBackward0>)
Memory mean: -0.2765536904335022
Spike mean: 0.0
Spike max: 0.0
Spike sum per sample: tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0', grad_fn=<SliceBackward0>)
Memory mean: -0.2765536904335022
Spike mean: 0.0
Spike max: 0.0
Spike sum per sam

In [5]:
ffn.eval()
correct = 0
total = 0
with torch.no_grad():
    for events, labels in test_dataloader:
        events = events.squeeze(2).to(device).float()
        labels = labels.to(device)

        spk_rec, _ = ffn(events)
        spike_sums = spk_rec.sum(dim=1)
        preds = torch.argmax(spike_sums, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {correct / total:.4f}")


Test Accuracy: 1.0000


In [6]:
print(spk_rec.mean().item())  # z. B. ≈ 0.0?
print(spk_rec.sum(dim=1)[:5])  # [B, 2]


0.11187499761581421
tensor([[27., 32.],
        [29., 24.],
        [29., 23.],
        [29., 23.],
        [28., 32.]], device='cuda:0')
