In [None]:
import os
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from dataset import get_trial_dataloaders
from model import LiquidSpikeFormer
from loss import HybridSpikingLoss
from optimizer import get_optimizer, get_scheduler
from augmentation import Compose, NormalizeTimestamps, RandomTemporalCrop, RandomSpatialJitter, RandomPolarityFlip, AddEventNoise, ToBinnedTensor

# --- Configuration ---
ROOT_DIR = "/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture"  # Path to DVS Gesture root
BATCH_SIZE = 16
NUM_WORKERS = 4
PIN_MEMORY = True
NUM_EPOCHS = 90
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- Transforms ---
spike_encoder = ...  # e.g., SpikeEncoder(num_bins=256, height=128, width=128)
transform = Compose([
    NormalizeTimestamps(),
    RandomTemporalCrop(0.8),
    RandomSpatialJitter(max_jitter=1, height=128, width=128),
    RandomPolarityFlip(flip_prob=0.05),
    AddEventNoise(spatial_sigma=0.5, temporal_sigma=0.01, height=128, width=128),
    ToBinnedTensor(encoder=spike_encoder)
])



# --- DataLoaders ---
train_loader, test_loader = get_trial_dataloaders(
    root_dir=ROOT_DIR,
    transform=transform,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

# --- Model, Loss, Optimizer, Scheduler ---
model = LiquidSpikeFormer(
    in_channels=spike_encoder.num_bins,
    embed_dim=128,
    nhead=4,
    num_classes=11,
    encoder_bins=spike_encoder.num_bins,
    height=128,
    width=128,
    poisson=False,
    learnable_bins=False,
    smooth_kernel_size=5,
    dropout=0.1
).to(DEVICE)

criterion = HybridSpikingLoss(lambda_s=1.0, lambda_m=0.5, lambda_t=0.5, lambda_a=0.1,
                              target_sparsity=0.1, threshold=0.5)
optimizer = get_optimizer(model, optimizer_name='AdamW', lr=LEARNING_RATE,
                          weight_decay=WEIGHT_DECAY)
scheduler = get_scheduler(optimizer,
                          scheduler_name='WarmupCosine',
                          total_steps=len(train_loader)*NUM_EPOCHS,
                          warmup_steps=500)

# --- Training & Evaluation Loop ---
best_acc = 0.0
for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        events = batch['events'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        out = model(events)
        loss = criterion(
            out['logits'], labels,
            spikes=out['spikes'],
            membrane=out['membrane'],
            threshold_param=out['threshold']
        )
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item() * events.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            events = batch['events'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            out = model(events)
            loss = criterion(out['logits'], labels)
            test_loss += loss.item() * events.size(0)

            preds = out['logits'].argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    test_loss = test_loss / total
    test_acc = 100.0 * correct / total

    print(f"Epoch {epoch}/{NUM_EPOCHS} "
          f"Train Loss: {epoch_loss:.4f} "
          f"Test Loss: {test_loss:.4f} "
          f"Test Acc: {test_acc:.2f}%")

    if test_acc > best_acc:
        best_acc = test_acc
        ckpt_path = os.path.join(CHECKPOINT_DIR, f"best_model_epoch{epoch}.pth")
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'best_acc': best_acc
        }, ckpt_path)
        print(f"Saved new best checkpoint to {ckpt_path}")

print(f"Training complete. Best Acc: {best_acc:.2f}%")


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [2]:
import os
print(os.listdir("/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture"))


['errata.txt', 'gesture_mapping.csv', 'LICENSE.txt', 'README.txt', 'trials_to_test.txt', 'trials_to_train.txt', 'user01_fluorescent.aedat', 'user01_fluorescent_labels.csv', 'user01_fluorescent_led.aedat', 'user01_fluorescent_led_labels.csv', 'user01_lab.aedat', 'user01_lab_labels.csv', 'user01_led.aedat', 'user01_led_labels.csv', 'user01_natural.aedat', 'user01_natural_labels.csv', 'user02_fluorescent.aedat', 'user02_fluorescent_labels.csv', 'user02_fluorescent_led.aedat', 'user02_fluorescent_led_labels.csv', 'user02_lab.aedat', 'user11_natural_labels.csv', 'user12_fluorescent_led.aedat', 'user12_fluorescent_led_labels.csv', 'user12_led.aedat', 'user12_led_labels.csv', 'user13_fluorescent.aedat', 'user13_fluorescent_labels.csv', 'user13_fluorescent_led.aedat', 'user13_fluorescent_led_labels.csv', 'user13_lab.aedat', 'user13_lab_labels.csv', 'user13_led.aedat', 'user13_led_labels.csv', 'user13_natural.aedat', 'user14_fluorescent.aedat', 'user14_fluorescent_labels.csv', 'user14_fluoresce

In [3]:
import pandas as pd
df = pd.read_csv("/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture/gesture_mapping.csv")
print(df.columns.tolist())


['action', 'label']


In [2]:
import pandas as pd

df = pd.read_csv("/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture/user01_lab_labels.csv")
print(df.columns.tolist())


['class', 'startTime_usec', 'endTime_usec']
