In [6]:
from pathlib import Path
import torch, numpy as np, matplotlib.pyplot as plt

DATA     = Path.cwd() / "data"            # same folder as notebook
DAT_FILE = DATA / "spikes_run4.pt"        # made in 02_preprocess.ipynb
assert DAT_FILE.exists(), f"{DAT_FILE} missing – run preprocessing first"

d            = torch.load(DAT_FILE)
spike_tensors, labels = d["spikes"], d["labels"]

print("First tensor:", spike_tensors[0].shape)      # (320, 64)
print("Unique labels:", np.unique(labels))

First tensor: torch.Size([320, 64])
Unique labels: [0]


  d            = torch.load(DAT_FILE)


In [7]:
from torch.utils.data import Dataset, DataLoader

class SpikeDS(Dataset):
    def __init__(self, xs, ys):
        self.x = xs
        self.y = torch.tensor(ys, dtype=torch.long)
    def __len__(self):       return len(self.x)
    def __getitem__(self, i): return self.x[i], self.y[i]

def collate(batch):
    xs, ys = zip(*batch)           # xs tuple of [T,C]
    xs = torch.stack(xs, 0)        # [B, T, C]
    xs = xs.permute(1, 0, 2)       # [T, B, C]  (time-major)
    ys = torch.tensor(ys)
    return xs.float(), ys

dl = DataLoader(SpikeDS(spike_tensors, labels),
                batch_size=4, shuffle=True, collate_fn=collate)

xb, yb = next(iter(dl))
print("Batch xb:", xb.shape, " yb:", yb.shape)       # expect [320,4,64]  [4]
assert xb.shape[2] != 0, "Channel dimension collapsed!"

Batch xb: torch.Size([320, 4, 64])  yb: torch.Size([4])


In [8]:
import torch.nn as nn, snntorch as snn
from snntorch import surrogate

C       = xb.shape[2]        # 64
HIDDEN  = 128
N_CLASS = len(np.unique(labels)) or 2

class FC_SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(C, HIDDEN)
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())
        self.fc2  = nn.Linear(HIDDEN, N_CLASS)
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())

    def forward(self, x):                 # x [T, B, C]
        B = x.size(1)
        mem1 = torch.zeros(B, HIDDEN, device=x.device)
        mem2 = torch.zeros(B, N_CLASS, device=x.device)
        out  = 0
        for step in x:                    # step [B, C]
            mem1, spk1 = self.lif1(mem1, self.fc1(step))
            mem2, spk2 = self.lif2(mem2, self.fc2(spk1))
            out += spk2
        return out / x.size(0)            # logits [B, N_CLASS]

net = FC_SNN()

# ── one-step probe ─────────────────────────────────────────────────────
with torch.no_grad():
    probe = net(xb)                       # should run with no error
print("Probe OK — logits shape:", probe.shape)   # e.g. (4, 2)

Probe OK — logits shape: torch.Size([4, 1])


In [9]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
criterion  = nn.CrossEntropyLoss()
EPOCHS = 5
for ep in range(EPOCHS):
    total = correct = loss_sum = 0
    for xb, yb in dl:
        optimizer.zero_grad()
        logits = net(xb)
        loss   = criterion(logits, yb)
        loss.backward(); optimizer.step()
        loss_sum += loss.item()*yb.size(0)
        correct  += (logits.argmax(1) == yb).sum().item()
        total    += yb.size(0)
    print(f"epoch {ep+1}: loss {loss_sum/total:.4f}  acc {correct/total:.2%}")

epoch 1: loss 0.0000  acc 100.00%
epoch 2: loss 0.0000  acc 100.00%
epoch 3: loss 0.0000  acc 100.00%
epoch 4: loss 0.0000  acc 100.00%
epoch 5: loss 0.0000  acc 100.00%


In [11]:
# save inside current project dir
EXP_DIR = Path.cwd() / "experiments"      # not parent/
EXP_DIR.mkdir(parents=True, exist_ok=True)

ckpt_file = EXP_DIR / "snn_run01.pt"
torch.save(net.state_dict(), ckpt_file)
print("Model saved to", ckpt_file)

Model saved to /Users/grantmckenzie/experiments/snn_run01.pt
