In [5]:
# Cell A  ───────────────────────────────────────────────────────────────
from pathlib import Path
import torch, numpy as np
from torch.utils.data import Dataset, DataLoader

# ---------- load the spikes ----------
dat  = torch.load(Path.cwd() / "data" / "spikes_run4.pt")
xs   = dat["spikes"]                       # list of [T,C]
ys   = torch.tensor(dat["labels"])         # dummy 0s

class DS(Dataset):
    def __len__(self):  return len(xs)
    def __getitem__(self,i):  return xs[i], ys[i]

def collate(batch):
    x, y = zip(*batch)                     # tuple of [T,C]
    x = torch.stack(x, dim=0)              # -> [B,T,C]
    x = x.permute(1,0,2).contiguous()      # -> [T,B,C]
    return x.float(), torch.tensor(y)

dl  = DataLoader(DS(), batch_size=2, shuffle=False, collate_fn=collate)
xb, yb = next(iter(dl))                    # grab one batch

print("xb shape (T,B,C) =", xb.shape)      # should be (320,2,64)
print("first time-step  =", xb[0].shape)   # should be (2,64)


xb shape (T,B,C) = torch.Size([320, 2, 64])
first time-step  = torch.Size([2, 64])


  dat  = torch.load(Path.cwd() / "data" / "spikes_run4.pt")


In [6]:
# Cell B  ───────────────────────────────────────────────────────────────
import torch.nn as nn, snntorch as snn
from snntorch import surrogate

C, HIDDEN, NCLASS = xb.shape[2], 128, 2    # expect C=64

class DebugNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(C, HIDDEN)          # 64 → 128
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())
        self.fc2  = nn.Linear(HIDDEN, NCLASS)     # 128 → 2
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())
    def forward(self, x):                         # x [T,B,C]
        mem1 = self.lif1.init_leaky(); mem2 = self.lif2.init_leaky()
        for t, step in enumerate(x):              # step [B,C]
            print(f"t{t}  step   :", step.shape)   # (B,64)
            cur1 = self.fc1(step)
            print(f"t{t}  fc1 out:", cur1.shape)   # (B,128)
            mem1, spk1 = self.lif1(mem1, cur1)
            print(f"t{t}  spk1   :", spk1.shape)   # (B,128)
            cur2 = self.fc2(spk1)
            print(f"t{t}  fc2 out:", cur2.shape)   # (B,2)
            break                                 # only first timestep
        return cur2

net = DebugNet()
net(xb)        # just to trigger the prints


t0  step   : torch.Size([2, 64])
t0  fc1 out: torch.Size([2, 128])
t0  spk1   : torch.Size([0])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x0 and 128x2)