In [None]:
!pip install --upgrade -q git+https://github.com/AlexIK3404/spikingjelly

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for spikingjelly (setup.py) ... [?25l[?25hdone


Путь к библиотеке - /usr/local/lib/python3.12/dist-packages/spikingjelly

Для изменения метода интегрирования вносятся модификации в метод neuronal_charge классов LIFNode и IzhikevichNode

# 1. Импорт библиотек

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import numpy as np
import pandas as pd

from torchvision import datasets, transforms

from spikingjelly.activation_based import surrogate, functional, encoding
from spikingjelly.activation_based import neuron as neuron

import random
import time
import csv
import math
from statistics import mean, stdev
from dataclasses import dataclass

# 2. Установка сида

In [None]:
def set_seed(seed: int):
    # Python-level RNG (используется модулем random: random.random(), random.shuffle() и т.д.)
    random.seed(seed)

    # NumPy RNG (np.random.*)
    np.random.seed(seed)

    # PyTorch CPU RNG — влияет на все операции, генерирующие случайные числа на CPU
    # (torch.rand, torch.randn инициализация весов на CPU и т.д.)
    torch.manual_seed(seed)

    # PyTorch CUDA RNG — задаёт seed для всех GPU-устройств, если есть CUDA.
    # Это влияет на всё случайное, выполняемое на GPU (torch.rand на cuda и т.д.)
    torch.cuda.manual_seed_all(seed)

    # Настройки cuDNN:
    # - deterministic=True просит cuDNN использовать детерминированные алгоритмы, когда это возможно.
    # - benchmark=False выключает автоподбор (autotuner), который подбирает «лучший» (но иногда недетерминированный)
    #   алгоритм для текущего ввода/устройства.
    # Это помогает воспроизводимости, но может снизить производительность и/или выбросить некоторые ускоренные реализации.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# 3. Параметры эксперимента

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('DEVICE =', DEVICE)

#T_bio = 5            # условное время симуляции, число временных шагов T = T_bio / dt
T = 50                # число временных шагов
dt = 0.1              # шаг интегрирования
BATCH_SIZE = 128
LR = 1e-3
EPOCHS = 20
SEEDS = [0]           # список сидов для усреднения


DEVICE = cuda


# 4. Датасет MNIST + Poisson-кодирование

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_set = datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)

test_set = datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

encoder = encoding.PoissonEncoderWithDt(dt=dt, method='exact', seed=42)


100%|██████████| 9.91M/9.91M [00:00<00:00, 60.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 2.01MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 15.1MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.08MB/s]


# 5. SNN с нейроном Ижикевича

In [None]:
class IzhSNN(nn.Module):
    def __init__(self, integrator='euler', dt=1.0, T=50):
        super().__init__()

        self.fc1 = nn.Linear(28 * 28, 256, bias=False)

        # передаём integrator в нейрон
        self.neuron1 = neuron.IzhikevichNode(
            tau=2.0,
            v_c=0.8,
            a0=1.0,
            v_threshold=1.0,
            v_reset=0.0,
            v_rest=-0.1,
            w_rest=0.0,
            tau_w=2.0,
            a=0.02,
            b=0.2,
            surrogate_function=surrogate.Sigmoid(),
            integrator=integrator,
            dt=dt
        )

        self.fc_out = nn.Linear(256, 10, bias=False)

    def forward(self, x):
        # x: [B, 1, 28, 28]
        out_max = None

        # для совместимости с encoder, используем плоскую форму
        x_flat = x.view(x.size(0), -1)

        # T = int(T_bio / dt)

        for t in range(T):
            # encoder принимает плоский вход
            x_t = encoder(x_flat)            # shape [B, 28*28]
            h = self.fc1(x_t)
            s = self.neuron1(h)              # спайки
            out = self.fc_out(s)
            out_max = out if out_max is None else torch.maximum(out_max, out)

        return out

In [None]:
class LIFSNN(nn.Module):
    def __init__(self, integrator='euler', dt=1.0, T=50):
        super().__init__()

        self.fc1 = nn.Linear(28 * 28, 256, bias=False)

        # LIF node (интегратор можно задавать: 'euler' или 'exp')
        self.neuron1 = neuron.LIFNode(
            tau=2.0,
            decay_input=True,
            v_threshold=1.0,
            v_reset=0.0,
            surrogate_function=surrogate.Sigmoid(),
            integrator=integrator,
            dt=dt
        )

        self.fc_out = nn.Linear(256, 10, bias=False)

    def forward(self, x):
        out_max = None
        x_flat = x.view(x.size(0), -1)
        # T = int(T_bio / dt)

        for t in range(T):
            x_t = encoder(x_flat)
            h = self.fc1(x_t)
            s = self.neuron1(h)
            out = self.fc_out(s)
            out_max = out if out_max is None else torch.maximum(out_max, out)

        return out

# 6. Обучение

In [None]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    correct = 0
    total = 0
    total_loss = 0.0

    for x, y in loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        optimizer.zero_grad()
        functional.reset_net(model)

        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    return total_loss / len(loader), 100. * correct / total


# 7. Тестирование

In [None]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0

    for x, y in loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        functional.reset_net(model)
        out = model(x)
        pred = out.argmax(dim=1)

        correct += (pred == y).sum().item()
        total += y.size(0)

    return 100. * correct / total


# 8. Запуск одного эксперимента (один seed)

In [None]:
def run_experiment(model_cls, integrator, EPOCHS, seed: int, dt: float, T:float):
    print(f'\n=== Model {model_cls.__name__} integrator={integrator} seed={seed} dt={dt} ===')

    set_seed(seed)

    model = model_cls(integrator=integrator, dt=dt, T=T).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    t0 = time.time()
    for epoch in range(EPOCHS):
        encoder.reset(seed=seed + epoch, device=torch.device(DEVICE))
        train_loss, train_acc = train_one_epoch(
            model, train_loader, optimizer, criterion
        )
        encoder.reset(seed=seed + 10_000 + epoch, device=torch.device(DEVICE))
        test_acc = evaluate(model, test_loader)

        elapsed = time.time() - t0
        print(
            f'Epoch {epoch:02d}: loss={train_loss:.4f}, train_acc={train_acc:.2f}%, test_acc={test_acc:.2f}%, elapsed={elapsed/60:.2f}min'
        )

    return test_acc

# 9. Запуск серии экспериментов

In [None]:
def run_grid_experiments(SEEDS, dt, T, EPOCHS=20):
    results_table = []
    # конфигурации: (модель, список интеграторов)
    grid = [
        (IzhSNN, ['euler', 'symplectic']),
        (LIFSNN, ['euler', 'exp']),
    ]

    for model_cls, integrators in grid:
        for integrator in integrators:
            accs = []
            for seed in SEEDS:
                acc = run_experiment(model_cls, integrator, EPOCHS, seed, dt, T)
                accs.append(acc)
            mean_acc = mean(accs)
            std_acc = stdev(accs) if len(accs) > 1 else 0.0
            print(f'--> {model_cls.__name__} integrator={integrator}: mean={mean_acc:.2f}% std={std_acc:.2f}% over seeds={SEEDS}')
            results_table.append({
                'model': model_cls.__name__,
                'integrator': integrator,
                'seeds': SEEDS,
                'accs': accs,
                'mean_acc': mean_acc,
                'std_acc': std_acc
            })

    # сохранить в CSV
    csv_file = 'results.csv'
    with open(csv_file, mode='w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['model', 'integrator', 'seeds', 'accs', 'mean_acc', 'std_acc'])
        for r in results_table:
            writer.writerow([r['model'], r['integrator'], r['seeds'], r['accs'], r['mean_acc'], r['std_acc']])

    print(f'\nAll experiments finished. Results saved to {csv_file}')
    return results_table

# 10. Эксперименты с фиксированным T_bio

In [None]:
if __name__ == "__main__":
    dt=0.5
    encoder.reset(dt=dt)
    results = run_grid_experiments()
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=0.5 ===
Epoch 00: loss=1.0698, train_acc=71.59%, test_acc=81.20%, elapsed=0.24min
Epoch 01: loss=0.5992, train_acc=82.49%, test_acc=84.74%, elapsed=0.47min
Epoch 02: loss=0.5316, train_acc=84.24%, test_acc=85.21%, elapsed=0.71min
Epoch 03: loss=0.4926, train_acc=85.37%, test_acc=86.50%, elapsed=0.95min
Epoch 04: loss=0.4565, train_acc=86.45%, test_acc=87.59%, elapsed=1.18min
Epoch 05: loss=0.4331, train_acc=87.23%, test_acc=87.66%, elapsed=1.41min
Epoch 06: loss=0.4149, train_acc=87.64%, test_acc=88.28%, elapsed=1.64min
Epoch 07: loss=0.4011, train_acc=88.17%, test_acc=88.85%, elapsed=1.88min
Epoch 08: loss=0.3840, train_acc=88.57%, test_acc=88.97%, elapsed=2.12min
Epoch 09: loss=0.3763, train_acc=88.80%, test_acc=89.61%, elapsed=2.35min
Epoch 10: loss=0.3649, train_acc=89.27%, test_acc=89.46%, elapsed=2.58min
Epoch 11: loss=0.3518, train_acc=89.60%, test_acc=90.06%, elapsed=2.82min
Epoch 12: loss=0.3442, train_acc=89.90%, test_acc=89.96%, e

In [None]:
if __name__ == "__main__":
    dt=1.0
    encoder.reset(dt=dt)
    results = run_grid_experiments()
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=1.0 ===
Epoch 00: loss=0.8681, train_acc=77.85%, test_acc=85.57%, elapsed=0.19min
Epoch 01: loss=0.4520, train_acc=86.99%, test_acc=88.30%, elapsed=0.37min
Epoch 02: loss=0.3831, train_acc=88.77%, test_acc=90.28%, elapsed=0.56min
Epoch 03: loss=0.3364, train_acc=89.94%, test_acc=90.64%, elapsed=0.74min
Epoch 04: loss=0.3051, train_acc=91.01%, test_acc=91.73%, elapsed=0.92min
Epoch 05: loss=0.2746, train_acc=91.83%, test_acc=92.11%, elapsed=1.11min
Epoch 06: loss=0.2583, train_acc=92.17%, test_acc=92.87%, elapsed=1.29min
Epoch 07: loss=0.2476, train_acc=92.51%, test_acc=92.94%, elapsed=1.48min
Epoch 08: loss=0.2333, train_acc=93.03%, test_acc=93.21%, elapsed=1.67min
Epoch 09: loss=0.2256, train_acc=93.22%, test_acc=93.43%, elapsed=1.85min
Epoch 10: loss=0.2174, train_acc=93.36%, test_acc=93.24%, elapsed=2.04min
Epoch 11: loss=0.2086, train_acc=93.61%, test_acc=93.83%, elapsed=2.22min
Epoch 12: loss=0.2020, train_acc=93.78%, test_acc=93.62%, e

In [None]:
if __name__ == "__main__":
    dt=2.0
    encoder.reset(dt=dt)
    results = run_grid_experiments()
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=2.0 ===
Epoch 00: loss=0.4597, train_acc=87.48%, test_acc=92.26%, elapsed=0.16min
Epoch 01: loss=0.2232, train_acc=93.35%, test_acc=93.96%, elapsed=0.31min
Epoch 02: loss=0.1732, train_acc=94.77%, test_acc=94.89%, elapsed=0.47min
Epoch 03: loss=0.1430, train_acc=95.71%, test_acc=95.75%, elapsed=0.64min
Epoch 04: loss=0.1284, train_acc=96.02%, test_acc=96.05%, elapsed=0.80min
Epoch 05: loss=0.1162, train_acc=96.36%, test_acc=95.99%, elapsed=0.95min
Epoch 06: loss=0.1056, train_acc=96.66%, test_acc=96.36%, elapsed=1.11min
Epoch 07: loss=0.0991, train_acc=96.84%, test_acc=96.21%, elapsed=1.27min
Epoch 08: loss=0.0929, train_acc=97.08%, test_acc=96.67%, elapsed=1.43min
Epoch 09: loss=0.0903, train_acc=97.11%, test_acc=96.62%, elapsed=1.60min
Epoch 10: loss=0.0833, train_acc=97.29%, test_acc=96.49%, elapsed=1.75min
Epoch 11: loss=0.0810, train_acc=97.39%, test_acc=96.69%, elapsed=1.91min
Epoch 12: loss=0.0783, train_acc=97.49%, test_acc=96.71%, e

In [None]:
if __name__ == "__main__":
    dt=5.0
    encoder.reset(dt=dt)
    results = run_grid_experiments()
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=5.0 ===
Epoch 00: loss=0.3494, train_acc=90.63%, test_acc=94.47%, elapsed=0.16min
Epoch 01: loss=0.1640, train_acc=95.26%, test_acc=95.58%, elapsed=0.31min
Epoch 02: loss=0.1254, train_acc=96.35%, test_acc=96.13%, elapsed=0.48min
Epoch 03: loss=0.1010, train_acc=97.05%, test_acc=96.50%, elapsed=0.62min
Epoch 04: loss=0.0866, train_acc=97.43%, test_acc=96.72%, elapsed=0.77min
Epoch 05: loss=0.0769, train_acc=97.74%, test_acc=97.01%, elapsed=0.93min
Epoch 06: loss=0.0661, train_acc=98.07%, test_acc=96.65%, elapsed=1.08min
Epoch 07: loss=0.0632, train_acc=98.14%, test_acc=97.08%, elapsed=1.23min
Epoch 08: loss=0.0561, train_acc=98.33%, test_acc=97.13%, elapsed=1.37min
Epoch 09: loss=0.0516, train_acc=98.44%, test_acc=97.22%, elapsed=1.53min
Epoch 10: loss=0.0486, train_acc=98.48%, test_acc=97.20%, elapsed=1.68min
Epoch 11: loss=0.0457, train_acc=98.62%, test_acc=97.03%, elapsed=1.82min
Epoch 12: loss=0.0420, train_acc=98.72%, test_acc=97.25%, e

# 11. Эксперименты с фиксированным T

In [None]:
if __name__ == "__main__":
    dt=0.05
    encoder.reset(dt=dt)
    results = run_grid_experiments(SEEDS, dt, T, EPOCHS=10)
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=0.05 ===
Epoch 00: loss=2.2431, train_acc=16.29%, test_acc=29.22%, elapsed=0.99min
Epoch 01: loss=1.8953, train_acc=33.75%, test_acc=37.43%, elapsed=1.79min
Epoch 02: loss=1.6694, train_acc=40.35%, test_acc=43.37%, elapsed=2.55min
Epoch 03: loss=1.5743, train_acc=43.94%, test_acc=45.74%, elapsed=3.33min
Epoch 04: loss=1.5166, train_acc=46.72%, test_acc=48.00%, elapsed=4.08min
Epoch 05: loss=1.4874, train_acc=48.06%, test_acc=48.33%, elapsed=4.84min
Epoch 06: loss=1.4686, train_acc=49.04%, test_acc=50.39%, elapsed=5.60min
Epoch 07: loss=1.4464, train_acc=49.77%, test_acc=50.66%, elapsed=6.36min
Epoch 08: loss=1.4350, train_acc=50.72%, test_acc=51.29%, elapsed=7.12min
Epoch 09: loss=1.4178, train_acc=51.69%, test_acc=52.39%, elapsed=7.88min
--> IzhSNN integrator=euler: mean=52.39% std=0.00% over seeds=[0]

=== Model IzhSNN integrator=symplectic seed=0 dt=0.05 ===
Epoch 00: loss=2.2437, train_acc=16.21%, test_acc=29.21%, elapsed=0.77min
Epoch 0

In [None]:
if __name__ == "__main__":
    dt=0.1
    encoder.reset(dt=dt)
    results = run_grid_experiments(SEEDS, dt, T, EPOCHS=10)
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=0.1 ===
Epoch 00: loss=1.9815, train_acc=35.13%, test_acc=49.66%, elapsed=0.74min
Epoch 01: loss=1.4290, train_acc=53.52%, test_acc=57.05%, elapsed=1.50min
Epoch 02: loss=1.2824, train_acc=57.93%, test_acc=60.48%, elapsed=2.25min
Epoch 03: loss=1.2124, train_acc=60.11%, test_acc=61.31%, elapsed=2.99min
Epoch 04: loss=1.1756, train_acc=61.69%, test_acc=63.07%, elapsed=3.74min
Epoch 05: loss=1.1467, train_acc=63.19%, test_acc=63.92%, elapsed=4.49min
Epoch 06: loss=1.1258, train_acc=63.94%, test_acc=64.95%, elapsed=5.24min
Epoch 07: loss=1.1052, train_acc=64.66%, test_acc=65.60%, elapsed=5.99min
Epoch 08: loss=1.0883, train_acc=65.24%, test_acc=66.42%, elapsed=6.74min
Epoch 09: loss=1.0659, train_acc=66.05%, test_acc=66.65%, elapsed=7.50min
--> IzhSNN integrator=euler: mean=66.65% std=0.00% over seeds=[0]

=== Model IzhSNN integrator=symplectic seed=0 dt=0.1 ===
Epoch 00: loss=1.9804, train_acc=35.03%, test_acc=49.47%, elapsed=0.75min
Epoch 01:

In [None]:
if __name__ == "__main__":
    dt=0.5
    encoder.reset(dt=dt)
    results = run_grid_experiments(SEEDS, dt, T, EPOCHS=10)
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=0.5 ===
Epoch 00: loss=1.0604, train_acc=71.06%, test_acc=81.31%, elapsed=0.76min
Epoch 01: loss=0.5971, train_acc=82.21%, test_acc=84.12%, elapsed=1.51min
Epoch 02: loss=0.5275, train_acc=84.36%, test_acc=85.73%, elapsed=2.26min
Epoch 03: loss=0.4867, train_acc=85.44%, test_acc=86.15%, elapsed=3.00min
Epoch 04: loss=0.4609, train_acc=86.43%, test_acc=87.52%, elapsed=3.75min
Epoch 05: loss=0.4348, train_acc=87.24%, test_acc=88.07%, elapsed=4.49min
Epoch 06: loss=0.4141, train_acc=87.89%, test_acc=88.32%, elapsed=5.23min
Epoch 07: loss=0.3984, train_acc=88.17%, test_acc=88.84%, elapsed=5.98min
Epoch 08: loss=0.3849, train_acc=88.71%, test_acc=88.98%, elapsed=6.72min
Epoch 09: loss=0.3734, train_acc=89.05%, test_acc=89.43%, elapsed=7.47min
--> IzhSNN integrator=euler: mean=89.43% std=0.00% over seeds=[0]

=== Model IzhSNN integrator=symplectic seed=0 dt=0.5 ===
Epoch 00: loss=1.0473, train_acc=71.40%, test_acc=81.75%, elapsed=0.74min
Epoch 01:

In [None]:
if __name__ == "__main__":
    dt=1.0
    encoder.reset(dt=dt)
    results = run_grid_experiments(SEEDS, dt, T, EPOCHS=10)
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=1.0 ===
Epoch 00: loss=0.7486, train_acc=79.79%, test_acc=87.07%, elapsed=0.74min
Epoch 01: loss=0.4156, train_acc=87.68%, test_acc=89.08%, elapsed=1.48min
Epoch 02: loss=0.3551, train_acc=89.43%, test_acc=90.48%, elapsed=2.22min
Epoch 03: loss=0.3144, train_acc=90.78%, test_acc=91.46%, elapsed=2.95min
Epoch 04: loss=0.2890, train_acc=91.42%, test_acc=92.36%, elapsed=3.69min
Epoch 05: loss=0.2715, train_acc=91.82%, test_acc=92.66%, elapsed=4.43min
Epoch 06: loss=0.2528, train_acc=92.40%, test_acc=92.61%, elapsed=5.16min
Epoch 07: loss=0.2445, train_acc=92.59%, test_acc=92.85%, elapsed=5.90min
Epoch 08: loss=0.2355, train_acc=92.90%, test_acc=92.83%, elapsed=6.64min
Epoch 09: loss=0.2256, train_acc=93.09%, test_acc=93.11%, elapsed=7.37min
--> IzhSNN integrator=euler: mean=93.11% std=0.00% over seeds=[0]

=== Model IzhSNN integrator=symplectic seed=0 dt=1.0 ===
Epoch 00: loss=0.7239, train_acc=80.62%, test_acc=88.01%, elapsed=0.74min
Epoch 01:

In [None]:
if __name__ == "__main__":
    dt=2.0
    encoder.reset(dt=dt)
    results = run_grid_experiments(SEEDS, dt, T, EPOCHS=10)
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=2.0 ===
Epoch 00: loss=0.5554, train_acc=84.60%, test_acc=89.51%, elapsed=0.74min
Epoch 01: loss=0.3593, train_acc=88.99%, test_acc=89.46%, elapsed=1.47min
Epoch 02: loss=0.3366, train_acc=89.52%, test_acc=90.75%, elapsed=2.21min
Epoch 03: loss=0.3133, train_acc=90.14%, test_acc=90.31%, elapsed=2.94min
Epoch 04: loss=0.3137, train_acc=90.18%, test_acc=90.26%, elapsed=3.68min
Epoch 05: loss=0.3101, train_acc=90.35%, test_acc=90.82%, elapsed=4.42min
Epoch 06: loss=0.2958, train_acc=90.83%, test_acc=90.98%, elapsed=5.15min
Epoch 07: loss=0.2985, train_acc=90.72%, test_acc=90.75%, elapsed=5.88min
Epoch 08: loss=0.2866, train_acc=91.11%, test_acc=90.98%, elapsed=6.62min
Epoch 09: loss=0.2809, train_acc=91.16%, test_acc=91.35%, elapsed=7.36min
--> IzhSNN integrator=euler: mean=91.35% std=0.00% over seeds=[0]

=== Model IzhSNN integrator=symplectic seed=0 dt=2.0 ===
Epoch 00: loss=0.5244, train_acc=85.75%, test_acc=90.19%, elapsed=0.74min
Epoch 01:

In [None]:
if __name__ == "__main__":
    dt=5.0
    encoder.reset(dt=dt)
    results = run_grid_experiments(SEEDS, dt, T, EPOCHS=10)
    # печать финальной сводки
    for r in results:
        print(r)


=== Model IzhSNN integrator=euler seed=0 dt=5.0 ===
Epoch 00: loss=2.3026, train_acc=9.87%, test_acc=9.80%, elapsed=0.74min
Epoch 01: loss=2.3026, train_acc=9.88%, test_acc=9.80%, elapsed=1.48min
Epoch 02: loss=2.3026, train_acc=9.87%, test_acc=9.80%, elapsed=2.22min
Epoch 03: loss=2.3026, train_acc=9.88%, test_acc=9.80%, elapsed=2.96min
Epoch 04: loss=2.3026, train_acc=9.87%, test_acc=9.80%, elapsed=3.70min
Epoch 05: loss=2.3026, train_acc=9.86%, test_acc=9.80%, elapsed=4.44min
Epoch 06: loss=2.3026, train_acc=9.86%, test_acc=9.80%, elapsed=5.17min
Epoch 07: loss=2.3026, train_acc=9.86%, test_acc=9.80%, elapsed=5.92min
Epoch 08: loss=2.3026, train_acc=9.87%, test_acc=9.80%, elapsed=6.65min
Epoch 09: loss=2.3026, train_acc=9.88%, test_acc=9.80%, elapsed=7.39min
--> IzhSNN integrator=euler: mean=9.80% std=0.00% over seeds=[0]

=== Model IzhSNN integrator=symplectic seed=0 dt=5.0 ===
Epoch 00: loss=1.4010, train_acc=61.78%, test_acc=74.36%, elapsed=0.74min
Epoch 01: loss=0.8083, train_a

# 12. Исследование режима T=50, dt=5.0

In [None]:
@dataclass
class SpikeStats:
    total_spikes: float = 0.0
    total_sites: int = 0          # суммарно B*N по всем шагам и батчам
    silent_sites: int = 0
    active_sites: int = 0

    def update(self, s: torch.Tensor):
        # s: ожидаем [B, N] (0/1), но на всякий приведём
        with torch.no_grad():
            if s.dim() == 1:
                s = s.unsqueeze(0)
            else:
                s = s.view(s.size(0), -1)

            B, N = s.shape
            sites = B * N
            self.total_sites += int(sites)

            active = (s > 0).sum().item()
            silent = sites - active

            self.active_sites += int(active)
            self.silent_sites += int(silent)
            self.total_spikes += float(active)  # если строго 0/1

    def summary(self):
        denom = max(1, self.total_sites)
        active_frac = self.active_sites / denom
        silent_frac = self.silent_sites / denom
        mean_spike_rate = self.total_spikes / denom  # == active_frac для 0/1

        return {
            "mean_spike_rate": float(mean_spike_rate),
            "silent_frac": float(silent_frac),
            "active_frac": float(active_frac)
        }


In [None]:
class IzhSNN(nn.Module):
    def __init__(self, integrator='euler', dt=1.0, T=50):
        super().__init__()
        self.T = int(T)

        self.fc1 = nn.Linear(28 * 28, 256, bias=False)

        self.neuron1 = neuron.IzhikevichNode(
            tau=2.0,
            v_c=0.8,
            a0=1.0,
            v_threshold=1.0,
            v_reset=0.0,
            v_rest=-0.1,
            w_rest=0.0,
            tau_w=2.0,
            a=0.02,
            b=0.2,
            surrogate_function=surrogate.Sigmoid(),
            integrator=integrator,
            dt=float(dt)
        )

        self.fc_out = nn.Linear(256, 10, bias=False)

    def forward(self, x, collect_stats=False):
        x = x.view(x.size(0), -1)

        out_accum = None
        stats = SpikeStats() if collect_stats else None

        for _ in range(self.T):
            x_t = encoder(x)
            h = self.fc1(x_t)
            s = self.neuron1(h)

            if collect_stats:
                stats.update(s)

            out = self.fc_out(s)

            if out_accum is None:
                out_accum = out
            else:
                out_accum = torch.maximum(out_accum, out)

        return (out_accum, stats) if collect_stats else out_accum

In [None]:
class LIFSNN(nn.Module):
    def __init__(self, integrator='euler', dt=1.0, T=50):
        super().__init__()
        self.T = int(T)

        self.fc1 = nn.Linear(28 * 28, 256, bias=False)

        self.neuron1 = neuron.LIFNode(
            tau=2.0,
            decay_input=True,
            v_threshold=1.0,
            v_reset=0.0,
            surrogate_function=surrogate.Sigmoid(),
            integrator=integrator,   # 'euler' или 'exp'
            dt=float(dt)
        )

        self.fc_out = nn.Linear(256, 10, bias=False)

    def forward(self, x, collect_stats=False):
        x = x.view(x.size(0), -1)

        out_accum = None
        stats = SpikeStats() if collect_stats else None

        for _ in range(self.T):
            x_t = encoder(x)
            h = self.fc1(x_t)
            s = self.neuron1(h)

            if collect_stats:
                stats.update(s)

            out = self.fc_out(s)

            if out_accum is None:
                out_accum = out
            else:
                out_accum = torch.maximum(out_accum, out)

        return (out_accum, stats) if collect_stats else out_accum

In [None]:
@torch.no_grad()
def evaluate_with_stats(model, loader, max_batches=None):
    model.eval()
    agg = SpikeStats()

    correct = 0
    total = 0

    for bi, (x, y) in enumerate(loader):
        if max_batches is not None and bi >= max_batches:
            break

        x = x.to(DEVICE)
        y = y.to(DEVICE)

        functional.reset_net(model)

        out, stats = model(x, collect_stats=True)  # <-- вот так, без костылей

        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

        # аккуратно агрегируем
        agg.total_spikes += stats.total_spikes
        agg.total_sites  += stats.total_sites
        agg.silent_sites += stats.silent_sites
        agg.active_sites += stats.active_sites

    acc = 100.0 * correct / max(1, total)
    return acc, agg.summary()

In [None]:
def run_dt5_suite(model_cls, integrators, seeds, T, dt, EPOCHS):
    rows = []

    for integ in integrators:
        for seed in seeds:
            print(f"\n=== {model_cls.__name__} integ={integ} seed={seed} dt={dt} T={T} ===")

            set_seed(seed)
            encoder.reset(dt=dt)

            model = model_cls(integrator=integ, T=T, dt=dt).to(DEVICE)
            optimizer = optim.Adam(model.parameters(), lr=LR)
            criterion = nn.CrossEntropyLoss()

            for epoch in range(EPOCHS):
                train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
                test_acc, test_stats = evaluate_with_stats(model, test_loader, max_batches=10)

                print(
                    f"Epoch {epoch:02d}: loss={train_loss:.4f}, "
                    f"train_acc={train_acc:.2f}%, test_acc={test_acc:.2f}%, "
                    f"spike_rate={test_stats['mean_spike_rate']:.4e}, "
                    f"silent={test_stats['silent_frac']:.3f}, active={test_stats['active_frac']:.3f}"
                )

            final_acc, final_stats = evaluate_with_stats(model, test_loader, max_batches=50)

            rows.append({
                "model": model_cls.__name__,
                "integrator": integ,
                "seed": seed,
                "dt": float(dt),
                "T": int(T),
                "final_acc": float(final_acc),
                "final_mean_spike_rate": final_stats["mean_spike_rate"],
                "final_silent_frac": final_stats["silent_frac"],
                "final_active_frac": final_stats["active_frac"],
            })

    return pd.DataFrame(rows)


In [None]:
dt = 5.0
T = 50
EPOCHS = 10
SEEDS = [0, 1, 2, 3, 4]

df_izh = run_dt5_suite(
    IzhSNN,
    integrators=["euler", "symplectic"],
    seeds=SEEDS,
    T=T,
    dt=dt,
    EPOCHS=EPOCHS
)

display(df_izh)
df_izh.to_csv("dt5_stats_izh.csv", index=False)
print("Saved to dt5_stats_izh.csv")



=== IzhSNN integ=euler seed=0 dt=5.0 T=50 ===
Epoch 00: loss=2.3026, train_acc=9.88%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 01: loss=2.3026, train_acc=9.88%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 02: loss=2.3026, train_acc=9.87%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 03: loss=2.3026, train_acc=9.88%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 04: loss=2.3026, train_acc=9.87%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 05: loss=2.3026, train_acc=9.86%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 06: loss=2.3026, train_acc=9.86%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 07: loss=2.3026, train_acc=9.86%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 08: loss=2.3026, train_acc=9.87%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000

Unnamed: 0,model,integrator,seed,dt,T,final_acc,final_mean_spike_rate,final_silent_frac,final_active_frac
0,IzhSNN,euler,0,5.0,50,9.46875,0.0,1.0,0.0
1,IzhSNN,euler,1,5.0,50,9.46875,0.0,1.0,0.0
2,IzhSNN,euler,2,5.0,50,9.46875,0.0,1.0,0.0
3,IzhSNN,euler,3,5.0,50,9.46875,0.0,1.0,0.0
4,IzhSNN,euler,4,5.0,50,9.46875,0.0,1.0,0.0
5,IzhSNN,symplectic,0,5.0,50,77.75,0.522346,0.477654,0.522346
6,IzhSNN,symplectic,1,5.0,50,80.453125,0.523586,0.476414,0.523586
7,IzhSNN,symplectic,2,5.0,50,78.015625,0.523141,0.476859,0.523141
8,IzhSNN,symplectic,3,5.0,50,77.25,0.528478,0.471522,0.528478
9,IzhSNN,symplectic,4,5.0,50,74.890625,0.525523,0.474477,0.525523


Saved to dt5_stats_izh.csv


In [None]:
dt = 5.0
T = 50
EPOCHS = 10
SEEDS = [0, 1, 2, 3, 4]

df_lif = run_dt5_suite(
    LIFSNN,
    integrators=["euler", "exp"],
    seeds=SEEDS,
    T=T,
    dt=dt,
    EPOCHS=EPOCHS
)

display(df_lif)
df_lif.to_csv("dt5_stats_lif.csv", index=False)
print("Saved to dt5_stats_lif.csv")



=== LIFSNN integ=euler seed=0 dt=5.0 T=50 ===
Epoch 00: loss=1.5321, train_acc=60.08%, test_acc=63.28%, spike_rate=5.7740e-01, silent=0.423, active=0.577
Epoch 01: loss=1.2133, train_acc=67.02%, test_acc=63.91%, spike_rate=6.4782e-01, silent=0.352, active=0.648
Epoch 02: loss=1.1908, train_acc=66.04%, test_acc=65.70%, spike_rate=6.6820e-01, silent=0.332, active=0.668
Epoch 03: loss=1.0970, train_acc=68.36%, test_acc=62.81%, spike_rate=6.7853e-01, silent=0.321, active=0.679
Epoch 04: loss=1.0977, train_acc=66.54%, test_acc=65.16%, spike_rate=6.8982e-01, silent=0.310, active=0.690
Epoch 05: loss=1.0866, train_acc=65.93%, test_acc=62.73%, spike_rate=7.0480e-01, silent=0.295, active=0.705
Epoch 06: loss=1.0670, train_acc=66.43%, test_acc=64.77%, spike_rate=7.1216e-01, silent=0.288, active=0.712
Epoch 07: loss=1.0842, train_acc=65.57%, test_acc=61.95%, spike_rate=7.1762e-01, silent=0.282, active=0.718
Epoch 08: loss=1.0487, train_acc=66.72%, test_acc=63.12%, spike_rate=7.2656e-01, silent=0

Unnamed: 0,model,integrator,seed,dt,T,final_acc,final_mean_spike_rate,final_silent_frac,final_active_frac
0,LIFSNN,euler,0,5.0,50,67.359375,0.739574,0.260426,0.739574
1,LIFSNN,euler,1,5.0,50,65.796875,0.708221,0.291779,0.708221
2,LIFSNN,euler,2,5.0,50,69.765625,0.68868,0.31132,0.68868
3,LIFSNN,euler,3,5.0,50,73.984375,0.691897,0.308103,0.691897
4,LIFSNN,euler,4,5.0,50,65.15625,0.673109,0.326891,0.673109
5,LIFSNN,exp,0,5.0,50,97.0,0.430748,0.569252,0.430748
6,LIFSNN,exp,1,5.0,50,97.09375,0.441342,0.558658,0.441342
7,LIFSNN,exp,2,5.0,50,97.234375,0.423535,0.576465,0.423535
8,LIFSNN,exp,3,5.0,50,97.046875,0.446466,0.553534,0.446466
9,LIFSNN,exp,4,5.0,50,97.015625,0.454024,0.545976,0.454024


Saved to dt5_stats_lif.csv


# 13. Исследование режима T=50, dt=5.0 с усреднением по времени

In [None]:
class IzhSNN(nn.Module):
    def __init__(self, integrator: str = 'euler', dt: float = 1.0, T: int = 50, readout: str = "mean"):
        super().__init__()
        self.T = int(T)
        self.readout = str(readout)

        self.fc1 = nn.Linear(28 * 28, 256, bias=False)
        self.neuron1 = neuron.IzhikevichNode(
            tau=2.0,
            v_c=0.8,
            a0=1.0,
            v_threshold=1.0,
            v_reset=0.0,
            v_rest=-0.1,
            w_rest=0.0,
            tau_w=2.0,
            a=0.02,
            b=0.2,
            surrogate_function=surrogate.Sigmoid(),
            integrator=integrator,
            dt=float(dt),
        )
        self.fc_out = nn.Linear(256, 10, bias=False)

    def forward(self, x: torch.Tensor, collect_stats: bool = False):
        x = x.view(x.size(0), -1)

        stats = SpikeStats() if collect_stats else None
        out_sum = None          # для mean/sum
        out_max = None          # для max

        for _ in range(self.T):
            x_t = encoder(x)        # encoder должен быть определён снаружи
            h = self.fc1(x_t)
            s = self.neuron1(h)

            if collect_stats:
                stats.update(s)

            out = self.fc_out(s)

            if self.readout == "max":
                out_max = out if out_max is None else torch.maximum(out_max, out)
            else:
                out_sum = out if out_sum is None else (out_sum + out)

        if self.readout == "max":
            out_accum = out_max
        elif self.readout == "sum":
            out_accum = out_sum
        elif self.readout == "mean":
            out_accum = out_sum / float(self.T)
        else:
            raise ValueError(f"Unknown readout={self.readout}. Use 'mean', 'sum', or 'max'.")

        return (out_accum, stats) if collect_stats else out_accum

In [None]:
class LIFSNN(nn.Module):
    def __init__(self, integrator: str = 'euler', dt: float = 1.0, T: int = 50, readout: str = "mean"):
        super().__init__()
        self.T = int(T)
        self.readout = str(readout)

        self.fc1 = nn.Linear(28 * 28, 256, bias=False)
        self.neuron1 = neuron.LIFNode(
            tau=2.0,
            decay_input=True,
            v_threshold=1.0,
            v_reset=0.0,
            surrogate_function=surrogate.Sigmoid(),
            integrator=integrator,   # 'euler' или 'exp'
            dt=float(dt),
        )
        self.fc_out = nn.Linear(256, 10, bias=False)

    def forward(self, x: torch.Tensor, collect_stats: bool = False):
        x = x.view(x.size(0), -1)

        stats = SpikeStats() if collect_stats else None
        out_sum = None
        out_max = None

        for _ in range(self.T):
            x_t = encoder(x)
            h = self.fc1(x_t)
            s = self.neuron1(h)

            if collect_stats:
                stats.update(s)

            out = self.fc_out(s)

            if self.readout == "max":
                out_max = out if out_max is None else torch.maximum(out_max, out)
            else:
                out_sum = out if out_sum is None else (out_sum + out)

        if self.readout == "max":
            out_accum = out_max
        elif self.readout == "sum":
            out_accum = out_sum
        elif self.readout == "mean":
            out_accum = out_sum / float(self.T)
        else:
            raise ValueError(f"Unknown readout={self.readout}. Use 'mean', 'sum', or 'max'.")

        return (out_accum, stats) if collect_stats else out_accum

In [None]:
def run_suite(model_cls, integrators, seeds, T, dt, EPOCHS, readout="mean"):
    rows = []
    for integ in integrators:
        for seed in seeds:
            print(f"\n=== {model_cls.__name__} integ={integ} seed={seed} dt={dt} T={T} readout={readout} ===")
            set_seed(seed)
            encoder.reset(dt=dt)

            model = model_cls(integrator=integ, T=T, dt=dt, readout=readout).to(DEVICE)
            optimizer = optim.Adam(model.parameters(), lr=LR)
            criterion = nn.CrossEntropyLoss()

            for epoch in range(EPOCHS):
                train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
                test_acc, test_stats = evaluate_with_stats(model, test_loader, max_batches=10)

                print(
                    f"Epoch {epoch:02d}: loss={train_loss:.4f}, train_acc={train_acc:.2f}%, "
                    f"test_acc={test_acc:.2f}%, "
                    f"spike_rate={test_stats['mean_spike_rate']:.4e}, "
                    f"silent={test_stats['silent_frac']:.3f}, active={test_stats['active_frac']:.3f}"
                )

            final_acc, final_stats = evaluate_with_stats(model, test_loader, max_batches=50)
            rows.append({
                "model": model_cls.__name__,
                "integrator": integ,
                "seed": seed,
                "dt": float(dt),
                "T": int(T),
                "readout": readout,
                "final_acc": float(final_acc),
                "final_mean_spike_rate": final_stats["mean_spike_rate"],
                "final_silent_frac": final_stats["silent_frac"],
                "final_active_frac": final_stats["active_frac"],
            })

    return pd.DataFrame(rows)


# -----------------------------
# Separate runners (as you asked)
# -----------------------------
def run_izh_dt5(dt=5.0, T=50, EPOCHS=10, SEEDS=(0,1,2,3,4), readout="mean"):
    return run_suite(IzhSNN, integrators=["euler", "symplectic"], seeds=list(SEEDS),
                     T=T, dt=dt, EPOCHS=EPOCHS, readout=readout)

def run_lif_dt5(dt=5.0, T=50, EPOCHS=10, SEEDS=(0,1,2,3,4), readout="mean"):
    return run_suite(LIFSNN, integrators=["euler", "exp"], seeds=list(SEEDS),
                     T=T, dt=dt, EPOCHS=EPOCHS, readout=readout)

In [None]:
dt = 5.0
T = 50
EPOCHS = 10
SEEDS = [0,1,2,3,4]

df_izh = run_izh_dt5(dt=dt, T=T, EPOCHS=EPOCHS, SEEDS=SEEDS, readout="mean")
display(df_izh)
df_izh.to_csv("dt5_izh_mean.csv", index=False)


=== IzhSNN integ=euler seed=0 dt=5.0 T=50 readout=mean ===
Epoch 00: loss=2.3026, train_acc=9.87%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 01: loss=2.3026, train_acc=9.88%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 02: loss=2.3026, train_acc=9.87%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 03: loss=2.3026, train_acc=9.88%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 04: loss=2.3026, train_acc=9.87%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 05: loss=2.3026, train_acc=9.86%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 06: loss=2.3026, train_acc=9.86%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 07: loss=2.3026, train_acc=9.86%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000, active=0.000
Epoch 08: loss=2.3026, train_acc=9.87%, test_acc=8.52%, spike_rate=0.0000e+00, silent=1.000,

Unnamed: 0,model,integrator,seed,dt,T,readout,final_acc,final_mean_spike_rate,final_silent_frac,final_active_frac
0,IzhSNN,euler,0,5.0,50,mean,9.46875,0.0,1.0,0.0
1,IzhSNN,euler,1,5.0,50,mean,9.46875,0.0,1.0,0.0
2,IzhSNN,euler,2,5.0,50,mean,9.46875,0.0,1.0,0.0
3,IzhSNN,euler,3,5.0,50,mean,9.46875,0.0,1.0,0.0
4,IzhSNN,euler,4,5.0,50,mean,9.46875,0.0,1.0,0.0
5,IzhSNN,symplectic,0,5.0,50,mean,33.125,0.526737,0.473263,0.526737
6,IzhSNN,symplectic,1,5.0,50,mean,39.921875,0.529574,0.470426,0.529574
7,IzhSNN,symplectic,2,5.0,50,mean,43.78125,0.530993,0.469007,0.530993
8,IzhSNN,symplectic,3,5.0,50,mean,41.203125,0.528481,0.471519,0.528481
9,IzhSNN,symplectic,4,5.0,50,mean,45.59375,0.530385,0.469615,0.530385


In [None]:
dt = 5.0
T = 50
EPOCHS = 10
SEEDS = [0,1,2,3,4]

df_lif = run_lif_dt5(dt=dt, T=T, EPOCHS=EPOCHS, SEEDS=SEEDS, readout="mean")
display(df_lif)
df_lif.to_csv("dt5_lif_mean.csv", index=False)


=== LIFSNN integ=euler seed=0 dt=5.0 T=50 readout=mean ===
Epoch 00: loss=1.1087, train_acc=73.24%, test_acc=77.34%, spike_rate=5.0704e-01, silent=0.493, active=0.507
Epoch 01: loss=0.7256, train_acc=80.81%, test_acc=80.47%, spike_rate=5.2418e-01, silent=0.476, active=0.524
Epoch 02: loss=0.6804, train_acc=81.21%, test_acc=79.06%, spike_rate=5.5563e-01, silent=0.444, active=0.556
Epoch 03: loss=0.6620, train_acc=81.08%, test_acc=78.67%, spike_rate=5.6913e-01, silent=0.431, active=0.569
Epoch 04: loss=0.6427, train_acc=81.44%, test_acc=77.03%, spike_rate=5.8843e-01, silent=0.412, active=0.588
Epoch 05: loss=0.7195, train_acc=79.30%, test_acc=72.19%, spike_rate=5.9255e-01, silent=0.407, active=0.593
Epoch 06: loss=0.7693, train_acc=76.03%, test_acc=74.22%, spike_rate=5.9983e-01, silent=0.400, active=0.600
Epoch 07: loss=0.7466, train_acc=77.68%, test_acc=76.80%, spike_rate=6.1008e-01, silent=0.390, active=0.610
Epoch 08: loss=0.7468, train_acc=76.33%, test_acc=73.12%, spike_rate=6.0356e

Unnamed: 0,model,integrator,seed,dt,T,readout,final_acc,final_mean_spike_rate,final_silent_frac,final_active_frac
0,LIFSNN,euler,0,5.0,50,mean,75.15625,0.620407,0.379593,0.620407
1,LIFSNN,euler,1,5.0,50,mean,77.546875,0.613187,0.386813,0.613187
2,LIFSNN,euler,2,5.0,50,mean,73.25,0.621442,0.378558,0.621442
3,LIFSNN,euler,3,5.0,50,mean,72.546875,0.623688,0.376312,0.623688
4,LIFSNN,euler,4,5.0,50,mean,72.71875,0.635236,0.364764,0.635236
5,LIFSNN,exp,0,5.0,50,mean,96.953125,0.431061,0.568939,0.431061
6,LIFSNN,exp,1,5.0,50,mean,97.203125,0.440122,0.559878,0.440122
7,LIFSNN,exp,2,5.0,50,mean,97.046875,0.422635,0.577365,0.422635
8,LIFSNN,exp,3,5.0,50,mean,97.328125,0.44002,0.55998,0.44002
9,LIFSNN,exp,4,5.0,50,mean,97.28125,0.452969,0.547031,0.452969


# Сеть с STDP

In [None]:
# rstdp_mnist_spikingjelly.py
# LIF-only SNN + spike-count argmax + R-STDP (on W_out) + optional teacher forcing on output current
# Integrator sweep: "euler" vs "exp"
#
# Requirements: spikingjelly, torch, torchvision

import os
import math
import random
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any, List

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from spikingjelly.activation_based import neuron, functional, surrogate


# -----------------------------
# Repro / device
# -----------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


# -----------------------------
# Simple Poisson encoder wrapper
# -----------------------------
class PoissonEncoderWrapper:
    """
    Bernoulli(p=x) per pixel per step, expects x in [0,1].
    dt stored only for API compatibility.
    """
    def __init__(self):
        self._dt = 1.0

    def reset(self, dt: float = 1.0):
        self._dt = float(dt)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return (torch.rand_like(x) < x).to(x.dtype)


encoder = PoissonEncoderWrapper()


# -----------------------------
# Spike stats
# -----------------------------
@dataclass
class SpikeStats:
    total_spikes: float = 0.0
    total_sites: int = 0
    silent_sites: int = 0
    active_sites: int = 0

    def update(self, s_2d: torch.Tensor):
        with torch.no_grad():
            B, N = s_2d.shape
            sites = B * N
            active = (s_2d > 0).sum().item()
            silent = sites - active
            self.total_sites += int(sites)
            self.active_sites += int(active)
            self.silent_sites += int(silent)
            self.total_spikes += float(active)

    def merge_(self, other: "SpikeStats"):
        self.total_spikes += other.total_spikes
        self.total_sites += other.total_sites
        self.silent_sites += other.silent_sites
        self.active_sites += other.active_sites

    def summary(self) -> Dict[str, float]:
        denom = max(1, self.total_sites)
        active_frac = self.active_sites / denom
        silent_frac = self.silent_sites / denom
        mean_spike_rate = self.total_spikes / denom
        return {
            "mean_spike_rate": float(mean_spike_rate),
            "silent_frac": float(silent_frac),
            "active_frac": float(active_frac),
        }


def _spikes_to_2d(s: torch.Tensor) -> torch.Tensor:
    if s.dim() == 1:
        return s.unsqueeze(0)
    return s.view(s.size(0), -1)


# -----------------------------
# Model: LIF-only SNN
# -----------------------------
class LIF_RSTDP_SNN(nn.Module):
    def __init__(
        self,
        integrator: str = "euler",
        dt: float = 0.1,
        T: int = 200,
        n_hidden: int = 256,
        n_classes: int = 10,
        v_th: float = 0.3,
        tau: float = 2.0,
        in_gain: float = 5.0,
        out_gain: float = 5.0,
        decay_input: bool = False,
    ):
        super().__init__()
        self.T = int(T)
        self.n_hidden = int(n_hidden)
        self.n_classes = int(n_classes)
        self.in_gain = float(in_gain)
        self.out_gain = float(out_gain)

        self.fc_in = nn.Linear(28 * 28, self.n_hidden, bias=False)
        self.neuron_h = neuron.LIFNode(
            tau=float(tau),
            decay_input=bool(decay_input),
            v_threshold=float(v_th),
            v_reset=0.0,
            surrogate_function=surrogate.Sigmoid(),  # INSTANCE (callable)
            integrator=integrator,
            dt=float(dt),
        )

        self.fc_out = nn.Linear(self.n_hidden, self.n_classes, bias=False)
        self.neuron_o = neuron.LIFNode(
            tau=float(tau),
            decay_input=bool(decay_input),
            v_threshold=float(v_th),
            v_reset=0.0,
            surrogate_function=surrogate.Sigmoid(),
            integrator=integrator,
            dt=float(dt),
        )

        nn.init.kaiming_uniform_(self.fc_in.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.fc_out.weight, a=math.sqrt(5))

    @torch.no_grad()
    def forward_collect(
        self,
        x: torch.Tensor,
        y: Optional[torch.Tensor] = None,
        teach_current: float = 0.0,
        collect_stats: bool = True,
    ):
        """
        Returns:
          h_count: [B,H] spike counts
          o_count: [B,C] spike counts
          stats_h, stats_o: SpikeStats for hidden and output separately
          no_spike_out_frac: fraction of samples where output never spiked
        """
        x = x.view(x.size(0), -1)

        stats_h = SpikeStats() if collect_stats else SpikeStats()
        stats_o = SpikeStats() if collect_stats else SpikeStats()

        B = x.size(0)
        h_count = torch.zeros((B, self.n_hidden), device=x.device, dtype=torch.float32)
        o_count = torch.zeros((B, self.n_classes), device=x.device, dtype=torch.float32)

        for _ in range(self.T):
            x_t = encoder(x)  # [B,784] 0/1

            h_cur = self.fc_in(x_t) * self.in_gain
            s_h = self.neuron_h(h_cur)  # [B,H]

            o_cur = self.fc_out(s_h) * self.out_gain  # [B,C]

            # teacher forcing only when y is given
            if (y is not None) and (teach_current > 0.0):
                o_cur = o_cur.clone()
                o_cur[torch.arange(B, device=x.device), y] += float(teach_current)

            s_o = self.neuron_o(o_cur)  # [B,C]

            # --- hard WTA on output spikes (per-step) ---
            # if any output spikes in a sample, keep only the spike of the max-current neuron
            spike_any = (s_o > 0).any(dim=1)  # [B]
            if spike_any.any():
                winner = o_cur.argmax(dim=1)  # [B]
                s_o_wta = torch.zeros_like(s_o)
                s_o_wta[spike_any, winner[spike_any]] = 1.0
                s_o = s_o_wta
            # ------------------------------------------


            h_count += s_h.float()
            o_count += s_o.float()

            if collect_stats:
                stats_h.update(_spikes_to_2d(s_h))
                stats_o.update(_spikes_to_2d(s_o))

        no_spike_out_frac = float((o_count.sum(dim=1) == 0).float().mean().item())
        return h_count, o_count, stats_h, stats_o, no_spike_out_frac


# -----------------------------
# R-STDP update on W_out only
# -----------------------------
@torch.no_grad()
def rstdp_update_w_out(
    model: LIF_RSTDP_SNN,
    h_count: torch.Tensor,   # [B,H]
    o_count: torch.Tensor,   # [B,C]
    y: torch.Tensor,         # [B]
    lr_wout: float,
    reward_mode: str = "pm1",    # "pm1" or "onehot"
    w_clip: Optional[float] = 5.0,
    l2_decay: float = 0.0,
):
    W = model.fc_out.weight  # [C,H]
    B = h_count.size(0)

    pre = h_count / max(1.0, float(model.T))   # [B,H]
    post = o_count / max(1.0, float(model.T))  # [B,C]

    if reward_mode == "pm1":
        # winner by spike-count
        winner = o_count.argmax(dim=1)  # [B]
        r = torch.where(winner == y,
                        torch.ones_like(y, dtype=torch.float32),
                        -torch.ones_like(y, dtype=torch.float32)).to(W.device)  # [B]

        # dW only for winner row: dW[w,:] += mean_i r_i * pre_i
        dW = torch.zeros_like(W)
        for c in range(model.n_classes):
            mask = (winner == c)
            if mask.any():
                dW[c] = (r[mask][:, None] * pre[mask]).mean(dim=0)

    elif reward_mode == "onehot":
        y_oh = torch.zeros((B, model.n_classes), device=W.device, dtype=torch.float32)
        y_oh.scatter_(1, y.view(-1, 1), 1.0)
        prob = torch.softmax(o_count, dim=1).detach()
        err = (y_oh - prob)  # [B,C]
        dW = (err[:, :, None] * pre[:, None, :]).mean(dim=0)

    else:
        raise ValueError(f"Unknown reward_mode: {reward_mode}")

    if l2_decay > 0:
        W.mul_(1.0 - float(l2_decay))

    W.add_(float(lr_wout) * dW)

    if w_clip is not None:
        W.clamp_(-float(w_clip), float(w_clip))


# -----------------------------
# Metrics
# -----------------------------
@torch.no_grad()
def spikecount_pseudo_ce(o_count: torch.Tensor, y: torch.Tensor) -> float:
    return float(torch.nn.functional.cross_entropy(o_count, y).item())


@torch.no_grad()
def evaluate_spikecount(
    model: LIF_RSTDP_SNN,
    loader: DataLoader,
    max_batches: Optional[int] = 50,
) -> Tuple[float, Dict[str, float], Dict[str, float], float]:
    model.eval()
    agg_h = SpikeStats()
    agg_o = SpikeStats()
    no_spike_out_sum = 0.0
    batches = 0

    correct = 0
    total = 0

    for bi, (x, y) in enumerate(loader):
        if max_batches is not None and bi >= max_batches:
            break

        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)

        functional.reset_net(model)
        h_count, o_count, stats_h, stats_o, no_spike_out_frac = model.forward_collect(
            x, y=None, teach_current=0.0, collect_stats=True
        )

        pred = o_count.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()

        agg_h.merge_(stats_h)
        agg_o.merge_(stats_o)
        no_spike_out_sum += no_spike_out_frac
        batches += 1

    acc = 100.0 * correct / max(1, total)
    return acc, agg_h.summary(), agg_o.summary(), (no_spike_out_sum / max(1, batches))


@torch.no_grad()
def train_one_epoch_rstdp(
    model: LIF_RSTDP_SNN,
    loader: DataLoader,
    lr_wout: float,
    reward_mode: str = "pm1",
    teach_current: float = 0.0,
    max_batches: Optional[int] = 1000,
    w_clip: Optional[float] = 5.0,
    l2_decay: float = 0.0,
) -> Tuple[float, float, Dict[str, float], Dict[str, float], float]:
    model.train()
    agg_h = SpikeStats()
    agg_o = SpikeStats()
    no_spike_out_sum = 0.0
    batches = 0

    correct = 0
    total = 0
    loss_sum = 0.0
    seen = 0

    for bi, (x, y) in enumerate(loader):
        if max_batches is not None and bi >= max_batches:
            break

        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)

        functional.reset_net(model)
        h_count, o_count, stats_h, stats_o, no_spike_out_frac = model.forward_collect(
            x, y=y, teach_current=teach_current, collect_stats=True
        )

        loss_sum += spikecount_pseudo_ce(o_count, y)
        seen += 1

        pred = o_count.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()

        agg_h.merge_(stats_h)
        agg_o.merge_(stats_o)
        no_spike_out_sum += no_spike_out_frac
        batches += 1

        rstdp_update_w_out(
            model=model,
            h_count=h_count,
            o_count=o_count,
            y=y,
            lr_wout=lr_wout,
            reward_mode=reward_mode,
            w_clip=w_clip,
            l2_decay=l2_decay,
        )

    train_acc = 100.0 * correct / max(1, total)
    pseudo_ce = loss_sum / max(1, seen)
    return pseudo_ce, train_acc, agg_h.summary(), agg_o.summary(), (no_spike_out_sum / max(1, batches))


# -----------------------------
# Experiment runner
# -----------------------------
def run_rstdp_experiment(
    integrator: str,
    seed: int,
    dt: float,
    T: int,
    epochs: int,
    lr_wout: float,
    reward_mode: str,
    teach_i0: float,
    teach_decay: float,
    train_loader: DataLoader,
    test_loader: DataLoader,
    max_train_batches: Optional[int] = 1000,
    max_test_batches: Optional[int] = 50,
    model_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    print(f"\n=== LIF R-STDP integ={integrator} seed={seed} dt={dt} T={T} ===")

    set_seed(seed)
    encoder.reset(dt=dt)

    model_kwargs = model_kwargs or {}
    model = LIF_RSTDP_SNN(integrator=integrator, dt=dt, T=T, **model_kwargs).to(DEVICE)

    # quick debug forward
    x, y = next(iter(train_loader))
    x = x.to(DEVICE); y = y.to(DEVICE)
    functional.reset_net(model)
    h_count, o_count, sh, so, noz = model.forward_collect(x, y=y, teach_current=teach_i0, collect_stats=True)
    print(
        "debug:",
        "h_count mean", float(h_count.mean().item()),
        "o_count mean", float(o_count.mean().item()),
        "H", sh.summary(),
        "O", so.summary(),
        "no_spike_out", f"{noz:.3f}"
    )

    teach_i = float(teach_i0)

    for ep in range(epochs):
        pseudo_ce, train_acc, train_h, train_o, train_noz = train_one_epoch_rstdp(
            model=model,
            loader=train_loader,
            lr_wout=lr_wout,
            reward_mode=reward_mode,
            teach_current=teach_i,
            max_batches=max_train_batches,
            w_clip=5.0,
            l2_decay=0.0,
        )

        test_acc, test_h, test_o, test_noz = evaluate_spikecount(
            model=model,
            loader=test_loader,
            max_batches=max_test_batches,
        )

        print(
            f"Epoch {ep:02d}: teachI={teach_i:.3f} pseudoCE={pseudo_ce:.4f} "
            f"train_acc={train_acc:.2f}% test_acc={test_acc:.2f}% | "
            f"H_rate={test_h['mean_spike_rate']:.4f} O_rate={test_o['mean_spike_rate']:.4f} "
            f"no_spike_out={test_noz:.3f}"
        )

        teach_i *= float(teach_decay)

    final_acc, final_h, final_o, final_noz = evaluate_spikecount(model, test_loader, max_batches=None)
    return {
        "integrator": integrator,
        "seed": seed,
        "dt": dt,
        "T": T,
        "final_acc": final_acc,
        "final_H_mean_spike_rate": final_h["mean_spike_rate"],
        "final_O_mean_spike_rate": final_o["mean_spike_rate"],
        "final_no_spike_out_frac": final_noz,
    }


In [None]:
def main():
    print(f"DEVICE = {DEVICE}")

    # ---- experiment settings ----
    DT = 0.5
    T = 200
    EPOCHS = 5
    SEEDS = [0]
    INTEGRATORS = ["euler", "exp"]
    LR_WOUT = 5e-2
    REWARD_MODE = "pm1"

    # teacher forcing (output current injection)
    TEACH_I0 = 1.0      # start injection
    TEACH_DECAY = 0.95  # per-epoch decay

    # speed
    MAX_TRAIN_BATCHES = 1000
    MAX_TEST_BATCHES = 50
    BATCH_TRAIN = 64
    BATCH_TEST = 256

    # data
    tfm = transforms.Compose([transforms.ToTensor()])
    root = os.environ.get("DATA", "./data")
    train_ds = datasets.MNIST(root=root, train=True, download=True, transform=tfm)
    test_ds = datasets.MNIST(root=root, train=False, download=True, transform=tfm)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_TRAIN, shuffle=True,
        num_workers=0, pin_memory=(DEVICE == "cuda"), drop_last=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_TEST, shuffle=False,
        num_workers=0, pin_memory=(DEVICE == "cuda")
    )

    with torch.no_grad():
        x0, _ = next(iter(train_loader))
        print(f"Encoder mean spike prob (over pixels, per step) ~ {float(x0.mean().item()):.6f}")

    # model hyperparams (keep fixed across integrators for fairness)
    model_kwargs = dict(
        v_th=0.3,
        tau=2.0,
        in_gain=5.0,
        out_gain=5.0,
        decay_input=False,
    )

    results = []
    for integ in INTEGRATORS:
        for seed in SEEDS:
            res = run_rstdp_experiment(
                integrator=integ,
                seed=seed,
                dt=DT,
                T=T,
                epochs=EPOCHS,
                lr_wout=LR_WOUT,
                reward_mode=REWARD_MODE,
                teach_i0=TEACH_I0,
                teach_decay=TEACH_DECAY,
                train_loader=train_loader,
                test_loader=test_loader,
                max_train_batches=MAX_TRAIN_BATCHES,
                max_test_batches=MAX_TEST_BATCHES,
                model_kwargs=model_kwargs,
            )
            results.append(res)

    print("\nResults:")
    for r in results:
        print(r)


if __name__ == "__main__":
    main()


DEVICE = cuda
Encoder mean spike prob (over pixels, per step) ~ 0.126272

=== LIF R-STDP integ=euler seed=0 dt=0.5 T=200 ===
debug: h_count mean 57.05804443359375 o_count mean 19.9609375 H {'mean_spike_rate': 0.28529022216796873, 'silent_frac': 0.7147097778320313, 'active_frac': 0.28529022216796873} O {'mean_spike_rate': 0.0998046875, 'silent_frac': 0.9001953125, 'active_frac': 0.0998046875} no_spike_out 0.000
Epoch 00: teachI=1.000 pseudoCE=17.2660 train_acc=28.87% test_acc=29.54% | H_rate=0.2874 O_rate=0.0245 no_spike_out=0.623
Epoch 01: teachI=0.950 pseudoCE=14.6078 train_acc=29.13% test_acc=29.49% | H_rate=0.2874 O_rate=0.0244 no_spike_out=0.621
Epoch 02: teachI=0.902 pseudoCE=14.6735 train_acc=29.19% test_acc=29.48% | H_rate=0.2874 O_rate=0.0245 no_spike_out=0.616
Epoch 03: teachI=0.857 pseudoCE=14.4002 train_acc=29.16% test_acc=29.49% | H_rate=0.2874 O_rate=0.0235 no_spike_out=0.630
Epoch 04: teachI=0.815 pseudoCE=14.2216 train_acc=29.15% test_acc=29.50% | H_rate=0.2874 O_rate=0.

In [None]:
def main():
    print(f"DEVICE = {DEVICE}")

    # ---- experiment settings ----
    DT = 0.1
    T = 200
    EPOCHS = 5
    SEEDS = [0]
    INTEGRATORS = ["euler", "exp"]
    LR_WOUT = 5e-2
    REWARD_MODE = "pm1"

    # teacher forcing (output current injection)
    TEACH_I0 = 1.0      # start injection
    TEACH_DECAY = 0.95  # per-epoch decay

    # speed
    MAX_TRAIN_BATCHES = 1000
    MAX_TEST_BATCHES = 50
    BATCH_TRAIN = 64
    BATCH_TEST = 256

    # data
    tfm = transforms.Compose([transforms.ToTensor()])
    root = os.environ.get("DATA", "./data")
    train_ds = datasets.MNIST(root=root, train=True, download=True, transform=tfm)
    test_ds = datasets.MNIST(root=root, train=False, download=True, transform=tfm)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_TRAIN, shuffle=True,
        num_workers=0, pin_memory=(DEVICE == "cuda"), drop_last=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_TEST, shuffle=False,
        num_workers=0, pin_memory=(DEVICE == "cuda")
    )

    with torch.no_grad():
        x0, _ = next(iter(train_loader))
        print(f"Encoder mean spike prob (over pixels, per step) ~ {float(x0.mean().item()):.6f}")

    # model hyperparams (keep fixed across integrators for fairness)
    model_kwargs = dict(
        v_th=0.3,
        tau=2.0,
        in_gain=5.0,
        out_gain=5.0,
        decay_input=False,
    )

    results = []
    for integ in INTEGRATORS:
        for seed in SEEDS:
            res = run_rstdp_experiment(
                integrator=integ,
                seed=seed,
                dt=DT,
                T=T,
                epochs=EPOCHS,
                lr_wout=LR_WOUT,
                reward_mode=REWARD_MODE,
                teach_i0=TEACH_I0,
                teach_decay=TEACH_DECAY,
                train_loader=train_loader,
                test_loader=test_loader,
                max_train_batches=MAX_TRAIN_BATCHES,
                max_test_batches=MAX_TEST_BATCHES,
                model_kwargs=model_kwargs,
            )
            results.append(res)

    print("\nResults:")
    for r in results:
        print(r)


if __name__ == "__main__":
    main()


DEVICE = cuda


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.24MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.72MB/s]


Encoder mean spike prob (over pixels, per step) ~ 0.127302

=== LIF R-STDP integ=euler seed=0 dt=0.1 T=200 ===
debug: h_count mean 18.20306396484375 o_count mean 10.83750057220459 H {'mean_spike_rate': 0.09101531982421875, 'silent_frac': 0.9089846801757813, 'active_frac': 0.09101531982421875} O {'mean_spike_rate': 0.0541875, 'silent_frac': 0.9458125, 'active_frac': 0.0541875} no_spike_out 0.000
Epoch 00: teachI=1.000 pseudoCE=10.6410 train_acc=79.52% test_acc=78.86% | H_rate=0.0916 O_rate=0.0997 no_spike_out=0.000
Epoch 01: teachI=0.950 pseudoCE=11.0424 train_acc=78.65% test_acc=78.68% | H_rate=0.0916 O_rate=0.0997 no_spike_out=0.000
Epoch 02: teachI=0.902 pseudoCE=9.4640 train_acc=79.48% test_acc=79.24% | H_rate=0.0916 O_rate=0.0997 no_spike_out=0.000
Epoch 03: teachI=0.857 pseudoCE=8.2619 train_acc=80.62% test_acc=79.31% | H_rate=0.0916 O_rate=0.0997 no_spike_out=0.000
Epoch 04: teachI=0.815 pseudoCE=7.2264 train_acc=81.83% test_acc=78.80% | H_rate=0.0916 O_rate=0.0997 no_spike_out=0

In [None]:
def main():
    print(f"DEVICE = {DEVICE}")

    # ---- experiment settings ----
    DT = 0.05
    T = 200
    EPOCHS = 5
    SEEDS = [0]
    INTEGRATORS = ["euler", "exp"]
    LR_WOUT = 5e-2
    REWARD_MODE = "pm1"

    # teacher forcing (output current injection)
    TEACH_I0 = 1.0      # start injection
    TEACH_DECAY = 0.95  # per-epoch decay

    # speed
    MAX_TRAIN_BATCHES = 1000
    MAX_TEST_BATCHES = 50
    BATCH_TRAIN = 64
    BATCH_TEST = 256

    # data
    tfm = transforms.Compose([transforms.ToTensor()])
    root = os.environ.get("DATA", "./data")
    train_ds = datasets.MNIST(root=root, train=True, download=True, transform=tfm)
    test_ds = datasets.MNIST(root=root, train=False, download=True, transform=tfm)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_TRAIN, shuffle=True,
        num_workers=0, pin_memory=(DEVICE == "cuda"), drop_last=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_TEST, shuffle=False,
        num_workers=0, pin_memory=(DEVICE == "cuda")
    )

    with torch.no_grad():
        x0, _ = next(iter(train_loader))
        print(f"Encoder mean spike prob (over pixels, per step) ~ {float(x0.mean().item()):.6f}")

    # model hyperparams (keep fixed across integrators for fairness)
    model_kwargs = dict(
        v_th=0.3,
        tau=2.0,
        in_gain=5.0,
        out_gain=5.0,
        decay_input=False,
    )

    results = []
    for integ in INTEGRATORS:
        for seed in SEEDS:
            res = run_rstdp_experiment(
                integrator=integ,
                seed=seed,
                dt=DT,
                T=T,
                epochs=EPOCHS,
                lr_wout=LR_WOUT,
                reward_mode=REWARD_MODE,
                teach_i0=TEACH_I0,
                teach_decay=TEACH_DECAY,
                train_loader=train_loader,
                test_loader=test_loader,
                max_train_batches=MAX_TRAIN_BATCHES,
                max_test_batches=MAX_TEST_BATCHES,
                model_kwargs=model_kwargs,
            )
            results.append(res)

    print("\nResults:")
    for r in results:
        print(r)


if __name__ == "__main__":
    main()


DEVICE = cuda
Encoder mean spike prob (over pixels, per step) ~ 0.126272

=== LIF R-STDP integ=euler seed=0 dt=0.05 T=200 ===
debug: h_count mean 9.7490234375 o_count mean 4.589062690734863 H {'mean_spike_rate': 0.0487451171875, 'silent_frac': 0.9512548828125, 'active_frac': 0.0487451171875} O {'mean_spike_rate': 0.0229453125, 'silent_frac': 0.9770546875, 'active_frac': 0.0229453125} no_spike_out 0.000
Epoch 00: teachI=1.000 pseudoCE=5.8131 train_acc=82.43% test_acc=78.43% | H_rate=0.0490 O_rate=0.0991 no_spike_out=0.000
Epoch 01: teachI=0.950 pseudoCE=7.2474 train_acc=79.40% test_acc=78.33% | H_rate=0.0490 O_rate=0.0991 no_spike_out=0.000
Epoch 02: teachI=0.902 pseudoCE=7.0326 train_acc=79.47% test_acc=78.34% | H_rate=0.0490 O_rate=0.0991 no_spike_out=0.000
Epoch 03: teachI=0.857 pseudoCE=6.1239 train_acc=80.61% test_acc=77.93% | H_rate=0.0490 O_rate=0.0991 no_spike_out=0.000
Epoch 04: teachI=0.815 pseudoCE=4.8941 train_acc=83.14% test_acc=78.03% | H_rate=0.0490 O_rate=0.0991 no_spike

In [None]:
def main():
    print(f"DEVICE = {DEVICE}")

    # ---- experiment settings ----
    DT = 0.01
    T = 200
    EPOCHS = 5
    SEEDS = [0]
    INTEGRATORS = ["euler", "exp"]
    LR_WOUT = 5e-2
    REWARD_MODE = "pm1"

    # teacher forcing (output current injection)
    TEACH_I0 = 1.0      # start injection
    TEACH_DECAY = 0.95  # per-epoch decay

    # speed
    MAX_TRAIN_BATCHES = 1000
    MAX_TEST_BATCHES = 50
    BATCH_TRAIN = 64
    BATCH_TEST = 256

    # data
    tfm = transforms.Compose([transforms.ToTensor()])
    root = os.environ.get("DATA", "./data")
    train_ds = datasets.MNIST(root=root, train=True, download=True, transform=tfm)
    test_ds = datasets.MNIST(root=root, train=False, download=True, transform=tfm)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_TRAIN, shuffle=True,
        num_workers=0, pin_memory=(DEVICE == "cuda"), drop_last=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_TEST, shuffle=False,
        num_workers=0, pin_memory=(DEVICE == "cuda")
    )

    with torch.no_grad():
        x0, _ = next(iter(train_loader))
        print(f"Encoder mean spike prob (over pixels, per step) ~ {float(x0.mean().item()):.6f}")

    # model hyperparams (keep fixed across integrators for fairness)
    model_kwargs = dict(
        v_th=0.3,
        tau=2.0,
        in_gain=5.0,
        out_gain=5.0,
        decay_input=False,
    )

    results = []
    for integ in INTEGRATORS:
        for seed in SEEDS:
            res = run_rstdp_experiment(
                integrator=integ,
                seed=seed,
                dt=DT,
                T=T,
                epochs=EPOCHS,
                lr_wout=LR_WOUT,
                reward_mode=REWARD_MODE,
                teach_i0=TEACH_I0,
                teach_decay=TEACH_DECAY,
                train_loader=train_loader,
                test_loader=test_loader,
                max_train_batches=MAX_TRAIN_BATCHES,
                max_test_batches=MAX_TEST_BATCHES,
                model_kwargs=model_kwargs,
            )
            results.append(res)

    print("\nResults:")
    for r in results:
        print(r)


if __name__ == "__main__":
    main()


DEVICE = cuda
Encoder mean spike prob (over pixels, per step) ~ 0.126272

=== LIF R-STDP integ=euler seed=0 dt=0.01 T=200 ===
debug: h_count mean 1.918701171875 o_count mean 0.5625 H {'mean_spike_rate': 0.009593505859375, 'silent_frac': 0.990406494140625, 'active_frac': 0.009593505859375} O {'mean_spike_rate': 0.0028125, 'silent_frac': 0.9971875, 'active_frac': 0.0028125} no_spike_out 0.000
Epoch 00: teachI=1.000 pseudoCE=0.3799 train_acc=95.68% test_acc=70.26% | H_rate=0.0097 O_rate=0.0684 no_spike_out=0.000
Epoch 01: teachI=0.950 pseudoCE=1.0719 train_acc=88.76% test_acc=73.35% | H_rate=0.0097 O_rate=0.0726 no_spike_out=0.000
Epoch 02: teachI=0.902 pseudoCE=1.4360 train_acc=85.62% test_acc=74.35% | H_rate=0.0097 O_rate=0.0745 no_spike_out=0.000
Epoch 03: teachI=0.857 pseudoCE=1.3956 train_acc=86.55% test_acc=73.79% | H_rate=0.0097 O_rate=0.0757 no_spike_out=0.000
Epoch 04: teachI=0.815 pseudoCE=1.1304 train_acc=88.75% test_acc=73.44% | H_rate=0.0097 O_rate=0.0761 no_spike_out=0.000



In [None]:
def main():
    print(f"DEVICE = {DEVICE}")

    # ---- experiment settings ----
    DT = 0.005
    T = 200
    EPOCHS = 5
    SEEDS = [0]
    INTEGRATORS = ["euler", "exp"]
    LR_WOUT = 5e-2
    REWARD_MODE = "pm1"

    # teacher forcing (output current injection)
    TEACH_I0 = 1.0      # start injection
    TEACH_DECAY = 0.95  # per-epoch decay

    # speed
    MAX_TRAIN_BATCHES = 1000
    MAX_TEST_BATCHES = 50
    BATCH_TRAIN = 64
    BATCH_TEST = 256

    # data
    tfm = transforms.Compose([transforms.ToTensor()])
    root = os.environ.get("DATA", "./data")
    train_ds = datasets.MNIST(root=root, train=True, download=True, transform=tfm)
    test_ds = datasets.MNIST(root=root, train=False, download=True, transform=tfm)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_TRAIN, shuffle=True,
        num_workers=0, pin_memory=(DEVICE == "cuda"), drop_last=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_TEST, shuffle=False,
        num_workers=0, pin_memory=(DEVICE == "cuda")
    )

    with torch.no_grad():
        x0, _ = next(iter(train_loader))
        print(f"Encoder mean spike prob (over pixels, per step) ~ {float(x0.mean().item()):.6f}")

    # model hyperparams (keep fixed across integrators for fairness)
    model_kwargs = dict(
        v_th=0.3,
        tau=2.0,
        in_gain=5.0,
        out_gain=5.0,
        decay_input=False,
    )

    results = []
    for integ in INTEGRATORS:
        for seed in SEEDS:
            res = run_rstdp_experiment(
                integrator=integ,
                seed=seed,
                dt=DT,
                T=T,
                epochs=EPOCHS,
                lr_wout=LR_WOUT,
                reward_mode=REWARD_MODE,
                teach_i0=TEACH_I0,
                teach_decay=TEACH_DECAY,
                train_loader=train_loader,
                test_loader=test_loader,
                max_train_batches=MAX_TRAIN_BATCHES,
                max_test_batches=MAX_TEST_BATCHES,
                model_kwargs=model_kwargs,
            )
            results.append(res)

    print("\nResults:")
    for r in results:
        print(r)


if __name__ == "__main__":
    main()


DEVICE = cuda
Encoder mean spike prob (over pixels, per step) ~ 0.126272

=== LIF R-STDP integ=euler seed=0 dt=0.005 T=200 ===
debug: h_count mean 0.8660888671875 o_count mean 0.27656251192092896 H {'mean_spike_rate': 0.0043304443359375, 'silent_frac': 0.9956695556640625, 'active_frac': 0.0043304443359375} O {'mean_spike_rate': 0.0013828125, 'silent_frac': 0.9986171875, 'active_frac': 0.0013828125} no_spike_out 0.000
Epoch 00: teachI=1.000 pseudoCE=0.1216 train_acc=98.12% test_acc=44.67% | H_rate=0.0044 O_rate=0.0211 no_spike_out=0.000
Epoch 01: teachI=0.950 pseudoCE=0.3111 train_acc=92.22% test_acc=58.22% | H_rate=0.0044 O_rate=0.0325 no_spike_out=0.000
Epoch 02: teachI=0.902 pseudoCE=0.4613 train_acc=89.62% test_acc=63.75% | H_rate=0.0044 O_rate=0.0388 no_spike_out=0.000
Epoch 03: teachI=0.857 pseudoCE=0.6441 train_acc=86.91% test_acc=66.31% | H_rate=0.0044 O_rate=0.0426 no_spike_out=0.000
Epoch 04: teachI=0.815 pseudoCE=0.8026 train_acc=84.57% test_acc=67.95% | H_rate=0.0044 O_rate=

In [None]:
def main():
    print(f"DEVICE = {DEVICE}")

    # ---- experiment settings ----
    DT = 0.005
    T = 2000
    EPOCHS = 5
    SEEDS = [0]
    INTEGRATORS = ["euler", "exp"]
    LR_WOUT = 5e-2
    REWARD_MODE = "pm1"

    # teacher forcing (output current injection)
    TEACH_I0 = 1.0      # start injection
    TEACH_DECAY = 0.95  # per-epoch decay

    # speed
    MAX_TRAIN_BATCHES = 1000
    MAX_TEST_BATCHES = 50
    BATCH_TRAIN = 64
    BATCH_TEST = 256

    # data
    tfm = transforms.Compose([transforms.ToTensor()])
    root = os.environ.get("DATA", "./data")
    train_ds = datasets.MNIST(root=root, train=True, download=True, transform=tfm)
    test_ds = datasets.MNIST(root=root, train=False, download=True, transform=tfm)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_TRAIN, shuffle=True,
        num_workers=0, pin_memory=(DEVICE == "cuda"), drop_last=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_TEST, shuffle=False,
        num_workers=0, pin_memory=(DEVICE == "cuda")
    )

    with torch.no_grad():
        x0, _ = next(iter(train_loader))
        print(f"Encoder mean spike prob (over pixels, per step) ~ {float(x0.mean().item()):.6f}")

    # model hyperparams (keep fixed across integrators for fairness)
    model_kwargs = dict(
        v_th=0.3,
        tau=2.0,
        in_gain=5.0,
        out_gain=5.0,
        decay_input=False,
    )

    results = []
    for integ in INTEGRATORS:
        for seed in SEEDS:
            res = run_rstdp_experiment(
                integrator=integ,
                seed=seed,
                dt=DT,
                T=T,
                epochs=EPOCHS,
                lr_wout=LR_WOUT,
                reward_mode=REWARD_MODE,
                teach_i0=TEACH_I0,
                teach_decay=TEACH_DECAY,
                train_loader=train_loader,
                test_loader=test_loader,
                max_train_batches=MAX_TRAIN_BATCHES,
                max_test_batches=MAX_TEST_BATCHES,
                model_kwargs=model_kwargs,
            )
            results.append(res)

    print("\nResults:")
    for r in results:
        print(r)


if __name__ == "__main__":
    main()


DEVICE = cuda
Encoder mean spike prob (over pixels, per step) ~ 0.126272

=== LIF R-STDP integ=euler seed=0 dt=0.005 T=2000 ===
debug: h_count mean 10.512451171875 o_count mean 3.0140626430511475 H {'mean_spike_rate': 0.0052562255859375, 'silent_frac': 0.9947437744140625, 'active_frac': 0.0052562255859375} O {'mean_spike_rate': 0.00150703125, 'silent_frac': 0.99849296875, 'active_frac': 0.00150703125} no_spike_out 0.000
