In [20]:
import torch # type: ignore
import numpy as np # type: ignore
import torch.nn as nn # type: ignore
import matplotlib.pyplot as plt # type: ignore
from torch.utils.data import DataLoader # type: ignore
import torchvision.datasets as datasets # type: ignore
import torchvision.transforms as transforms # type: ignore

from snn_model import SCNN_CIFAR10_TTFS, TTFS_Encoder, TemporalWeightingDecoder

In [28]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=False)
test_dataset  = datasets.CIFAR10(root='./data', train=False, transform=transform, download=False)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=512, shuffle=True, num_workers=2)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = SCNN_CIFAR10_TTFS(T=8).to('cpu')
state = torch.load("ttfs_based_scnn_model_weights.pth", map_location='cpu')
net.load_state_dict(state, strict=False)

  return torch._C._cuda_getDeviceCount() > 0


In [36]:
img, label = test_dataset[6]   # returns (tensor[C,32,32], label)
# Example indices: [14, 59, 90, 120, 123, 124, 134, 164, 165, 189]

In [38]:
# img, label = test_dataset[19]   # returns (tensor[C,32,32], label)
# Example indices: [14, 59, 90, 120, 123, 124, 134, 164, 165, 189]

# 0 → airplane
# 1 → automobile
# 2 → bird
# 3 → cat
# 4 → deer
# 5 → dog
# 6 → frog
# 7 → horse
# 8 → ship
# 9 → truck
print(f"Label: {label}")

# img_npp = img.permute(1, 2, 0).cpu().numpy()
# plt.imshow(img_npp, interpolation='nearest')
# plt.title("CIFAR-10 Image")
# plt.axis('off')
# plt.show()

img_unsq = img.unsqueeze(0)
spike_seq = TTFS_Encoder(T=8)(img_unsq)
B, T, C, H, W = spike_seq.shape
layer_outputs = {
        'conv1': [], 'neuron1': [],
        'conv2': [], 'neuron2': [],
        'fc1': [], 'neuron3': [],
        'fc2': [], 'neuron4': []
    }
# ---- Reset all spiking neurons ----
for m in net.modules():
    if hasattr(m, "reset"):
        m.reset()

# ---- Temporal forward pass (TTFS) ----
with torch.no_grad():
    for t in range(T):
        # shape of cur: [B, C, H, W]
        cur = spike_seq[:, t].to(device)

        # ----- Layer 1 -----
        x = net.conv1(cur)
        layer_outputs['conv1'].append(x.clone())

        x = net.pool1(x)
        x = net.neuron1(x)
        layer_outputs['neuron1'].append(x.clone())

        # ----- Layer 2 -----
        x = net.conv2(x)
        layer_outputs['conv2'].append(x.clone())

        x = net.pool2(x)
        x = net.neuron2(x)
        layer_outputs['neuron2'].append(x.clone())

        # ----- FC1 -----
        x = x.flatten(1)
        x = net.fc1(x)
        layer_outputs['fc1'].append(x.clone())

        x = net.neuron3(x)
        layer_outputs['neuron3'].append(x.clone())

        # ----- FC2 -----
        x = net.fc2(x)
        layer_outputs['fc2'].append(x.clone())

        x = net.neuron4(x)
        layer_outputs['neuron4'].append(x.clone())

for k in layer_outputs:
    layer_outputs[k] = torch.stack(layer_outputs[k], dim=1)

final_spikes = layer_outputs['neuron4']
final_spikes

Label: 1


tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [32]:
decoder = TemporalWeightingDecoder(T=net.T, gamma=0.1, mode='linear').to(device)

In [18]:
correct = 0
total = 0

for images, labels in train_loader:
    images, labels = images.to('cpu'), labels.to('cpu')
    out_spikes = net(images)          # [B, T, C]
    logits = decoder(out_spikes)      # [B, C]
    preds = logits.argmax(dim=1)
    correct += (preds == labels).sum().item()
    total += labels.size(0)

print("Batch-wise inference accuracy:", 100 * correct / total)

Batch-wise inference accuracy: 93.844


In [33]:
T = 8
confidence_margin = 0.2

net.eval()
decoder.eval()
torch.set_grad_enabled(False)

stats = {
    "correct": 0,
    "wrong": 0,
    "early_correct": 0,
    "late_correct": 0,
    "early_times": []
}

for idx in range(1000):
    img, label = train_dataset[idx]
    img = img.unsqueeze(0).to(device)

    # ---- Reset neurons ----
    for m in net.modules():
        if hasattr(m, "reset"):
            m.reset()

    # ---- Forward (MATCHES TRAINING) ----
    out_spikes = net(img)        # [1, T, C]
    final_spikes = out_spikes[0] # [T, C]

    logits = decoder(out_spikes)      # [B, C]
    preds = logits.argmax(dim=1)
    pred = preds[0]

    # # ---- Progressive temporal decode ----
    # acc = torch.zeros_like(final_spikes)

    # decided = False
    # for t in range(T):
    #     acc[t] = final_spikes[t]
    #     logits = decoder(acc.unsqueeze(0))[0]

    #     top2 = torch.topk(logits, 2)
    #     gap = top2.values[0] - top2.values[1]

    #     if gap >= confidence_margin:
    #         pred = top2.indices[0].item()
    #         decided = True
    #         break

    # if not decided:
    #     logits = decoder(final_spikes.unsqueeze(0))[0]
    #     pred = logits.argmax().item()
    #     t = T - 1

    # ---- Stats ----
    if pred == label:
        stats["correct"] += 1
        # if t == 0:
        #     stats["early_correct"] += 1
        # else:
        #     stats["late_correct"] += 1
        #     stats["early_times"].append(t)
    else:
        stats["wrong"] += 1
        print("idx = ", idx, "pred = ", pred, "label = ", label)


idx =  2 pred =  tensor(9) label =  8
idx =  6 pred =  tensor(5) label =  1
idx =  7 pred =  tensor(0) label =  6
idx =  9 pred =  tensor(3) label =  1
idx =  10 pred =  tensor(3) label =  0
idx =  11 pred =  tensor(0) label =  9
idx =  14 pred =  tensor(2) label =  9
idx =  15 pred =  tensor(5) label =  8
idx =  17 pred =  tensor(3) label =  7
idx =  20 pred =  tensor(3) label =  7
idx =  22 pred =  tensor(3) label =  4
idx =  24 pred =  tensor(4) label =  5
idx =  25 pred =  tensor(6) label =  2
idx =  26 pred =  tensor(5) label =  4
idx =  27 pred =  tensor(2) label =  0
idx =  28 pred =  tensor(1) label =  9
idx =  31 pred =  tensor(2) label =  5
idx =  32 pred =  tensor(7) label =  4
idx =  33 pred =  tensor(3) label =  5
idx =  35 pred =  tensor(3) label =  2
idx =  36 pred =  tensor(2) label =  4
idx =  37 pred =  tensor(8) label =  1
idx =  39 pred =  tensor(3) label =  5
idx =  40 pred =  tensor(0) label =  4
idx =  41 pred =  tensor(4) label =  6
idx =  42 pred =  tensor(3) l

KeyboardInterrupt: 

In [11]:
total = 1000

print("\n===== INFERENCE SUMMARY =====")
print(f"Total samples     : {total}")
print(f"Correct           : {stats['correct']} ({100*stats['correct']/total:.2f}%)")
print(f"Wrong             : {stats['wrong']}")

print(f"\nEarly correct     : {stats['early_correct']}")
print(f"Late correct      : {stats['late_correct']}")

if stats["early_times"]:
    avg_t = sum(stats["early_times"]) / len(stats["early_times"])
    print(f"Avg early time    : {avg_t:.2f}")



===== INFERENCE SUMMARY =====
Total samples     : 1000
Correct           : 411 (41.10%)
Wrong             : 589

Early correct     : 0
Late correct      : 0


In [8]:
print(stats['wrong'])

418
