In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from dataset import get_trial_dataloaders
from model import LiquidSpikeFormer, SpikeEncoder
from loss import HybridSpikingLoss
from optimizer import get_optimizer, get_scheduler
from augmentation import (
    Compose,
    NormalizeTimestamps,
    RandomTemporalCrop,
    RandomSpatialJitter,
    RandomPolarityFlip,
    AddEventNoise,
    ToBinnedTensor
)

# --- Configuration ---
ROOT_DIR = "/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture"
BATCH_SIZE = 16
NUM_WORKERS = 12
PIN_MEMORY = True
NUM_EPOCHS = 90
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- Encoder ---
spike_encoder = SpikeEncoder(
    num_bins=256,
    height=128,
    width=128,
    poisson=False,
    learnable_bins=False,
    smooth_kernel_size=5
)

# --- Transforms ---
transform = Compose([
    NormalizeTimestamps(),
    RandomTemporalCrop(0.8),
    RandomSpatialJitter(max_jitter=1, height=128, width=128),
    RandomPolarityFlip(flip_prob=0.05),
    AddEventNoise(spatial_sigma=0.5, temporal_sigma=0.01, height=128, width=128),
    ToBinnedTensor(encoder=spike_encoder)
])

# --- DataLoaders ---
train_loader, test_loader = get_trial_dataloaders(
    root_dir=ROOT_DIR,
    transform=transform,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

# --- Model, Loss, Optimizer, Scheduler ---
model = LiquidSpikeFormer(
    in_channels=spike_encoder.num_bins,
    embed_dim=128,
    nhead=4,
    num_classes=11,
    encoder_bins=spike_encoder.num_bins,
    height=128,
    width=128,
    poisson=False,
    learnable_bins=False,
    smooth_kernel_size=5,
    dropout=0.1
).to(DEVICE)

criterion = HybridSpikingLoss(
    lambda_s=1.0,
    lambda_m=0.5,
    lambda_t=0.5,
    lambda_a=0.1,
    target_sparsity=0.1,
    threshold=0.5
)

optimizer = get_optimizer(
    model, optimizer_name='AdamW',
    lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

scheduler = get_scheduler(
    optimizer,
    scheduler_name='WarmupCosine',
    total_steps=len(train_loader) * NUM_EPOCHS,
    warmup_steps=500
)

# --- Training & Evaluation Loop ---
best_acc = 0.0
for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        events = batch['events'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        out = model(events)
        loss = criterion(
            out['logits'], labels,
            spikes=out['spikes'],
            membrane=out['membrane'],
            threshold_param=out['threshold']
        )
        loss.backward()
        optimizer.step()
        scheduler.step()
        running_loss += loss.item() * events.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # Evaluation
    model.eval()
    correct, total, test_loss = 0, 0, 0.0
    with torch.no_grad():
        for batch in test_loader:
            events = batch['events'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            out = model(events)
            loss = criterion(out['logits'], labels)
            test_loss += loss.item() * events.size(0)
            preds = out['logits'].argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    test_loss /= total
    test_acc = 100.0 * correct / total

    print(f"Epoch {epoch}/{NUM_EPOCHS} "
          f"Train Loss: {epoch_loss:.4f} "
          f"Test Loss: {test_loss:.4f} "
          f"Test Acc: {test_acc:.2f}%")

    if test_acc > best_acc:
        best_acc = test_acc
        ckpt_path = os.path.join(CHECKPOINT_DIR, f"best_model_epoch{epoch}.pth")
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'best_acc': best_acc
        }, ckpt_path)
        print(f"✅ Saved new best checkpoint to {ckpt_path}")

print(f"\n🎉 Training complete. Best Accuracy: {best_acc:.2f}%")



📂 Loading dataset from: /mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture
✅ Found 1164 training samples, 276 test samples




Epoch 1/90 Train Loss: 2.8912 Test Loss: 2.3688 Test Acc: 16.67%
✅ Saved new best checkpoint to ./checkpoints/best_model_epoch1.pth
Epoch 2/90 Train Loss: 3.2765 Test Loss: 2.3748 Test Acc: 19.20%
✅ Saved new best checkpoint to ./checkpoints/best_model_epoch2.pth
Epoch 3/90 Train Loss: 3.3172 Test Loss: 2.3727 Test Acc: 18.12%
Epoch 4/90 Train Loss: 3.3438 Test Loss: 2.4021 Test Acc: 17.39%
Epoch 5/90 Train Loss: 3.3488 Test Loss: 2.3903 Test Acc: 16.67%
Epoch 6/90 Train Loss: 3.3526 Test Loss: 2.4382 Test Acc: 16.67%
Epoch 7/90 Train Loss: 3.3448 Test Loss: 2.4178 Test Acc: 17.03%
Epoch 8/90 Train Loss: 3.3367 Test Loss: 2.3971 Test Acc: 17.03%
Epoch 9/90 Train Loss: 3.3409 Test Loss: 2.4176 Test Acc: 17.39%
Epoch 10/90 Train Loss: 3.3272 Test Loss: 2.4299 Test Acc: 17.03%
Epoch 11/90 Train Loss: 3.3311 Test Loss: 2.4328 Test Acc: 16.67%
Epoch 12/90 Train Loss: 3.3352 Test Loss: 2.4462 Test Acc: 16.67%
Epoch 13/90 Train Loss: 3.3240 Test Loss: 2.4647 Test Acc: 16.67%
Epoch 14/90 Train

KeyboardInterrupt: 

In [3]:
import os
import pandas as pd
import glob
from sklearn.model_selection import train_test_split

def generate_train_test_split(root_dir, test_size=0.2):
    label_files = sorted(glob.glob(os.path.join(root_dir, "*_labels.csv")))
    base_names = [os.path.basename(f).replace("_labels.csv", "") for f in label_files]

    train_files, test_files = train_test_split(base_names, test_size=test_size, random_state=42)

    pd.DataFrame(train_files).to_csv(os.path.join(root_dir, "train_gestures.csv"), index=False, header=False)
    pd.DataFrame(test_files).to_csv(os.path.join(root_dir, "test_gestures.csv"), index=False, header=False)

    print(f"✅ Split generated: {len(train_files)} train / {len(test_files)} test")

# Usage:
generate_train_test_split("/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture")


✅ Split generated: 97 train / 25 test


In [1]:
import os
import pandas as pd

ROOT_DIR = "/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture"

def check_dataset_integrity(root_dir):
    print(f"\n📁 Scanning directory: {root_dir}\n")

    all_files = os.listdir(root_dir)
    aedat_files = sorted([f for f in all_files if f.endswith(".aedat")])
    label_files = sorted([f for f in all_files if f.endswith("_labels.csv")])
    mapping_file = "gesture_mapping.csv"
    train_list_file = "trials_to_train.txt"
    test_list_file = "trials_to_test.txt"

    print(f"📦 Found {len(aedat_files)} AEDAT files")
    print(f"🗂️  Found {len(label_files)} label CSVs")
    print(f"📋 Found mapping file: {'✅' if mapping_file in all_files else '❌'}")
    print(f"📋 Found train/test splits: "
          f"{'✅' if train_list_file in all_files else '❌'} / "
          f"{'✅' if test_list_file in all_files else '❌'}")

    print("\n🔍 Verifying AEDAT + label pairs...\n")
    missing_label = []
    bad_csv = []

    for aedat in aedat_files:
        label_csv = aedat.replace(".aedat", "_labels.csv")
        print(f"🧪 {aedat}", end=" → ")

        if label_csv not in all_files:
            print("❌ Missing label file:", label_csv)
            missing_label.append((aedat, label_csv))
        else:
            print("✅ Found label file:", label_csv)
            # Check CSV content
            try:
                df = pd.read_csv(os.path.join(root_dir, label_csv))
                expected = {"class", "startTime_usec", "endTime_usec"}
                actual = set(df.columns.str.strip().tolist())
                if not expected.issubset(actual):
                    print(f"   ⚠️  Columns mismatch: {df.columns.tolist()}")
                    bad_csv.append(label_csv)
                else:
                    print(f"   📑 Columns OK: {df.columns.tolist()}")
            except Exception as e:
                print(f"   ❌ Failed to read CSV: {e}")
                bad_csv.append(label_csv)

    print("\n✅ Scan complete.")
    if missing_label:
        print(f"🚫 Missing label CSVs for {len(missing_label)} AEDAT files.")
    if bad_csv:
        print(f"⚠️ Found {len(bad_csv)} CSVs with incorrect format.")
    if not missing_label and not bad_csv:
        print("🎉 All files and formats look good!")

# Run the check
check_dataset_integrity(ROOT_DIR)



📁 Scanning directory: /mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture

📦 Found 122 AEDAT files
🗂️  Found 122 label CSVs
📋 Found mapping file: ✅
📋 Found train/test splits: ✅ / ✅

🔍 Verifying AEDAT + label pairs...

🧪 user01_fluorescent.aedat → ✅ Found label file: user01_fluorescent_labels.csv
   📑 Columns OK: ['class', 'startTime_usec', 'endTime_usec']
🧪 user01_fluorescent_led.aedat → ✅ Found label file: user01_fluorescent_led_labels.csv
   📑 Columns OK: ['class', 'startTime_usec', 'endTime_usec']
🧪 user01_lab.aedat → ✅ Found label file: user01_lab_labels.csv
   📑 Columns OK: ['class', 'startTime_usec', 'endTime_usec']
🧪 user01_led.aedat → ✅ Found label file: user01_led_labels.csv
   📑 Columns OK: ['class', 'startTime_usec', 'endTime_usec']
🧪 user01_natural.aedat → ✅ Found label file: user01_natural_labels.csv
   📑 Columns OK: ['class', 'startTime_usec', 'endTime_usec']
🧪 user02_fluorescent.aedat → ✅ Found label file: user02_fluorescent_labels.csv
   📑 Columns O

In [2]:
import os
print(os.listdir("/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture"))


['errata.txt', 'gesture_mapping.csv', 'LICENSE.txt', 'README.txt', 'trials_to_test.txt', 'trials_to_train.txt', 'user01_fluorescent.aedat', 'user01_fluorescent_labels.csv', 'user01_fluorescent_led.aedat', 'user01_fluorescent_led_labels.csv', 'user01_lab.aedat', 'user01_lab_labels.csv', 'user01_led.aedat', 'user01_led_labels.csv', 'user01_natural.aedat', 'user01_natural_labels.csv', 'user02_fluorescent.aedat', 'user02_fluorescent_labels.csv', 'user02_fluorescent_led.aedat', 'user02_fluorescent_led_labels.csv', 'user02_lab.aedat', 'user11_natural_labels.csv', 'user12_fluorescent_led.aedat', 'user12_fluorescent_led_labels.csv', 'user12_led.aedat', 'user12_led_labels.csv', 'user13_fluorescent.aedat', 'user13_fluorescent_labels.csv', 'user13_fluorescent_led.aedat', 'user13_fluorescent_led_labels.csv', 'user13_lab.aedat', 'user13_lab_labels.csv', 'user13_led.aedat', 'user13_led_labels.csv', 'user13_natural.aedat', 'user14_fluorescent.aedat', 'user14_fluorescent_labels.csv', 'user14_fluoresce

In [3]:
import pandas as pd
df = pd.read_csv("/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture/gesture_mapping.csv")
print(df.columns.tolist())


['action', 'label']


In [2]:
import pandas as pd

df = pd.read_csv("/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture/user01_lab_labels.csv")
print(df.columns.tolist())


['class', 'startTime_usec', 'endTime_usec']


In [2]:
with open("/mnt/m2ssd/research project/SNN/dataset/DVS  Gesture dataset/DvsGesture/trials_to_train.txt") as f:
    for _ in range(5):
        print(f.readline())


user01_fluorescent.aedat

user01_fluorescent_led.aedat

user01_lab.aedat

user01_led.aedat

user01_natural.aedat

