In [1]:
import torch
class LatencyEncoder:
    def __init__(self, time: int = 100):
        self.time = time  # –û–±—â–µ–µ —á–∏—Å–ª–æ –≤—Ä–µ–º–µ–Ω–Ω—ã—Ö —à–∞–≥–æ–≤

    def __call__(self, image: torch.Tensor) -> torch.Tensor:
        """
        image: Tensor [1, 28, 28] –∏–ª–∏ [28, 28], –∑–Ω–∞—á–µ–Ω–∏—è –æ—Ç 0 –¥–æ 1 –∏–ª–∏ –¥–æ 255
        return: spike_tensor [time, 1, 784]
        """
        if image.ndim == 3:
            image = image.squeeze()

        if image.max() > 1:
            image = image / 255.0

        spike_tensor = torch.zeros((self.time, 1, 784))

        for i in range(28):
            for j in range(28):
                pixel = image[i, j].item()
                if pixel > 0:
                    spike_time = int((1.0 - pixel) * (self.time - 1))
                    spike_tensor[spike_time, 0, i * 28 + j] = 1.0

        return spike_tensor.view(self.time, 1, 28, 28)


In [16]:
def set_stdp_nu(conn, nu_plus, nu_minus):
    dev = conn.w.device
    conn.update_rule.nu = (torch.tensor(nu_plus, device=dev),
                           torch.tensor(nu_minus, device=dev))

In [31]:
_rate_ema = None
def adapt_thresholds_ema(layer, spike_counts, T, target=1.5, alpha=0.9, k=0.02):
    global _rate_ema
    with torch.no_grad():
        rate = spike_counts / max(1, T)
        if _rate_ema is None: _rate_ema = rate.clone()
        _rate_ema = alpha * _rate_ema + (1 - alpha) * rate
        vt = layer.v_thresh if hasattr(layer,"v_thresh") else layer.thresh
        vt += k * (_rate_ema - target)
        vt.clamp_(0.15, 1.2)
        if hasattr(layer, "v_thresh"): layer.v_thresh = vt
        else: layer.thresh = vt

In [39]:
# --- –±–µ–∑–æ–ø–∞—Å–Ω–æ–µ –ø—Ä–∏—Å–≤–∞–∏–≤–∞–Ω–∏–µ –≤ –±—É—Ñ–µ—Ä—ã/–∞—Ç—Ä–∏–±—É—Ç—ã –º–æ–¥—É–ª—è ---
def _set_param(module, name, value, prefer_scalar=False):
    """
    –°—Ç–∞–≤–∏—Ç value –≤ module.<name>, —É—á–∏—Ç—ã–≤–∞—è, —á—Ç–æ —ç—Ç–æ –º–æ–∂–µ—Ç –±—ã—Ç—å –∑–∞—Ä–µ–≥–∏—Å—Ç—Ä–∏—Ä–æ–≤–∞–Ω–Ω—ã–π –±—É—Ñ–µ—Ä (Tensor).
    prefer_scalar=True -> –¥–µ–ª–∞–µ—Ç 0-D —Ç–µ–Ω–∑–æ—Ä (–¥–ª—è refrac).
    """
    if not hasattr(module, name):
        return False
    cur = getattr(module, name)

    # –ï—Å–ª–∏ —ç—Ç–æ Tensor-–±—É—Ñ–µ—Ä
    if isinstance(cur, torch.Tensor):
        if prefer_scalar:
            # –ù–∞–º –Ω—É–∂–µ–Ω 0-D, –∏–Ω–∞—á–µ masked_fill_ —É–ø–∞–¥—ë—Ç (–¥–ª—è refrac)
            val = torch.tensor(float(value), device=cur.device)
            setattr(module, name, val)  # –∑–∞–º–µ–Ω–∏—Ç—å –±—É—Ñ–µ—Ä –Ω–∞ 0-D —Ç–µ–Ω–∑–æ—Ä
            return True

        # –ò–Ω–∞—á–µ –∑–∞–ø–æ–ª–Ω—è–µ–º –ø–æ —Ñ–æ—Ä–º–µ —Ç–µ–∫—É—â–µ–≥–æ –±—É—Ñ–µ—Ä–∞
        if torch.is_tensor(value):
            if value.numel() == 1 and cur.numel() > 1:
                cur.data.fill_(float(value))
            else:
                if value.shape != cur.shape:
                    value = value.view_as(cur)
                cur.data.copy_(value.to(cur.device, dtype=cur.dtype))
        else:
            cur.data.fill_(float(value))
        return True

    # –ù–µ –±—É—Ñ–µ—Ä ‚Äî –æ–±—ã—á–Ω—ã–π –∞—Ç—Ä–∏–±—É—Ç: —Å—Ç–∞–≤–∏–º –∫–∞–∫ –µ—Å—Ç—å
    setattr(module, name, value if not prefer_scalar else float(value))
    return True


def tune_lif_params(lif_layer, n_hidden, vt_mean=0.35, vt_jitter=0.02, tau_val=50.0, refrac_val=2.0):
    with torch.no_grad():
        # –ü–æ—Ä–æ–≥–∏: –≤–µ–∫—Ç–æ—Ä —Å –ª—ë–≥–∫–∏–º —Ä–∞–∑–±—Ä–æ—Å–æ–º
        vt = (vt_mean + vt_jitter * torch.randn(n_hidden)).clamp(0.05, 2.0)
        if not _set_param(lif_layer, "v_thresh", vt):
            _set_param(lif_layer, "thresh", vt)  # thresh –∑–¥–µ—Å—å –∏–º–µ–Ω–Ω–æ Tensor, –Ω–µ float!

        # –ú–µ–º–±—Ä–∞–Ω–Ω–∞—è –∫–æ–Ω—Å—Ç–∞–Ω—Ç–∞ (–∏–º—è –≤–∞—Ä—å–∏—Ä—É–µ—Ç—Å—è –ø–æ –≤–µ—Ä—Å–∏—è–º)
        if not _set_param(lif_layer, "tau_m", torch.full((n_hidden,), tau_val)):
            _set_param(lif_layer, "tau",   torch.full((n_hidden,), tau_val))

        # –í–ê–ñ–ù–û: —Ä–µ—Ñ—Ä–∞–∫—Ç–µ—Ä–∫–∞ –¥–æ–ª–∂–Ω–∞ –±—ã—Ç—å –°–ö–ê–õ–Ø–†–û–ú (0-D Tensor –∏–ª–∏ float)
        _set_param(lif_layer, "refrac", refrac_val, prefer_scalar=True)

        # reset ‚Äî —Å–∫–∞–ª—è—Ä
        if not _set_param(lif_layer, "v_reset", 0.0, prefer_scalar=True):
            _set_param(lif_layer, "reset",  0.0, prefer_scalar=True)


# üß† Spiking Neural Network (SNN) –Ω–∞ –±–∞–∑–µ BindsNET —Å –æ–±—É—á–µ–Ω–∏–µ–º —á–µ—Ä–µ–∑ STDP

–í —ç—Ç–æ–º —ç–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç–µ —Ä–µ–∞–ª–∏–∑–æ–≤–∞–Ω–∞ –ø—Ä–æ—Å—Ç–∞—è –±–∏–æ–ª–æ–≥–∏—á–µ—Å–∫–∏ –ø—Ä–∞–≤–¥–æ–ø–æ–¥–æ–±–Ω–∞—è —Å–ø–∞–π–∫–æ–≤–∞—è –Ω–µ–π—Ä–æ—Å–µ—Ç—å (SNN) –¥–ª—è –æ–±—Ä–∞–±–æ—Ç–∫–∏ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π MNIST.

## üìå –ê—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä–∞ —Å–µ—Ç–∏:
- **–í—Ö–æ–¥–Ω–æ–π —Å–ª–æ–π (`Input`, 784 –Ω–µ–π—Ä–æ–Ω–∞)** ‚Äî –ø–æ –æ–¥–Ω–æ–º—É –Ω–µ–π—Ä–æ–Ω—É –Ω–∞ –∫–∞–∂–¥—ã–π –ø–∏–∫—Å–µ–ª—å –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è 28√ó28.
- **Poisson-–∫–æ–¥–∏—Ä–æ–≤—â–∏–∫** ‚Äî –ø—Ä–µ–æ–±—Ä–∞–∑—É–µ—Ç —è—Ä–∫–æ—Å—Ç—å –ø–∏–∫—Å–µ–ª–µ–π –≤ –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç–Ω—ã–µ –≤—Ä–µ–º–µ–Ω–Ω—ã–µ —Å–ø–∞–π–∫–∏.
- **–ü–æ–ª–Ω–æ—Å–≤—è–∑–Ω—ã–π —Å–ª–æ–π (`Connection`)** ‚Äî —Å–æ–µ–¥–∏–Ω—è–µ—Ç –≤—Ö–æ–¥ —Å –≤—ã—Ö–æ–¥–æ–º (–º–∞—Ç—Ä–∏—Ü–∞ –≤–µ—Å–æ–≤ 784 √ó 100).
- **–í—ã—Ö–æ–¥–Ω–æ–π —Å–ª–æ–π (`LIF`, 100 –Ω–µ–π—Ä–æ–Ω–æ–≤)** ‚Äî Leaky Integrate-and-Fire –Ω–µ–π—Ä–æ–Ω—ã —Å —É—Ç–µ—á–∫–æ–π –∏ –ø–æ—Ä–æ–≥–æ–º.
- **STDP (Spike-Timing Dependent Plasticity)** ‚Äî –æ–±—É—á–µ–Ω–∏–µ –±–µ–∑ –≥—Ä–∞–¥–∏–µ–Ω—Ç–æ–≤; –≤–µ—Å–∞ —É—Å–∏–ª–∏–≤–∞—é—Ç—Å—è, –µ—Å–ª–∏ –≤—Ö–æ–¥ –∞–∫—Ç–∏–≤–µ–Ω –¥–æ –≤—ã—Ö–æ–¥–Ω–æ–≥–æ —Å–ø–∞–π–∫–∞.

## üî¨ –ß—Ç–æ –¥–µ–ª–∞–µ—Ç—Å—è:
1. –ó–∞–≥—Ä—É–∂–∞–µ—Ç—Å—è –æ–¥–Ω–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ MNIST.
2. –ö–æ–¥–∏—Ä—É–µ—Ç—Å—è –≤ Poisson-—Å–ø–∞–π–∫–æ–≤—ã–π –ø–æ—Ç–æ–∫.
3. –ü—Ä–æ–ø—É—Å–∫–∞–µ—Ç—Å—è —á–µ—Ä–µ–∑ —Å–µ—Ç—å:
   - `Input` –ø–æ–ª—É—á–∞–µ—Ç –≤—Ö–æ–¥–Ω—ã–µ —Å–ø–∞–π–∫–∏,
   - `LIF` –Ω–µ–π—Ä–æ–Ω—ã –∞–∫—Ç–∏–≤–∏—Ä—É—é—Ç—Å—è –≤ –∑–∞–≤–∏—Å–∏–º–æ—Å—Ç–∏ –æ—Ç –≤–µ—Å–æ–≤.
4. –°–æ—Ö—Ä–∞–Ω—è—é—Ç—Å—è:
   - –°–ø–∞–π–∫–æ–≤–∞—è –∞–∫—Ç–∏–≤–Ω–æ—Å—Ç—å `LIF`-–Ω–µ–π—Ä–æ–Ω–æ–≤ –¥–æ –∏ –ø–æ—Å–ª–µ –ø–æ–¥–∞—á–∏ –≤—Ö–æ–¥–∞.
   - –°—É–º–º–∞ —Å–ø–∞–π–∫–æ–≤ –Ω–∞ –≤—Ö–æ–¥–µ (`Input`) ‚Äî –ø–æ–∫–∞–∑—ã–≤–∞–µ—Ç, –∫–∞–∫–∏–µ –ø–∏–∫—Å–µ–ª–∏ –∞–∫—Ç–∏–≤–Ω—ã.
   - –í–µ—Å–∞ –æ–¥–Ω–æ–≥–æ –≤—ã–±—Ä–∞–Ω–Ω–æ–≥–æ `LIF`-–Ω–µ–π—Ä–æ–Ω–∞ ‚Äî –¥–æ –∏ –ø–æ—Å–ª–µ STDP.

## üìà –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è:
- –ì—Ä–∞—Ñ–∏–∫: —Å—Ä–∞–≤–Ω–µ–Ω–∏–µ —Å–ø–∞–π–∫–æ–≤–æ–π –∞–∫—Ç–∏–≤–Ω–æ—Å—Ç–∏ –Ω–µ–π—Ä–æ–Ω–æ–≤ –¥–æ/–ø–æ—Å–ª–µ + –≤—Ö–æ–¥–Ω—ã–µ —Å–ø–∞–π–∫–∏.
- –ì—Ä–∞—Ñ–∏–∫: –∏–∑–º–µ–Ω–µ–Ω–∏–µ –≤–µ—Å–æ–≤, –≤–µ–¥—É—â–∏—Ö –∫ –Ω–µ–π—Ä–æ–Ω—É `LIF[42]` ‚Äî –≤–∏–¥–Ω–æ, –∫–∞–∫ STDP —É—Å–∏–ª–∏–≤–∞–µ—Ç –∑–Ω–∞—á–∏–º—ã–µ —Å–≤—è–∑–∏.

## üéØ –¶–µ–ª—å:
–ü–æ–∫–∞–∑–∞—Ç—å, –∫–∞–∫ SNN:
- –ø—Ä–µ–æ–±—Ä–∞–∑—É–µ—Ç –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ –≤ –ø–æ—Ç–æ–∫ —Å–ø–∞–π–∫–æ–≤,
- –∞–∫—Ç–∏–≤–∏—Ä—É–µ—Ç —Ç–æ–ª—å–∫–æ —Å–ø–µ—Ü–∏—Ñ–∏—á–Ω—ã–µ –Ω–µ–π—Ä–æ–Ω—ã,
- –∞–¥–∞–ø—Ç–∏—Ä—É–µ—Ç –≤–µ—Å–∞ –Ω–∞ –æ—Å–Ω–æ–≤–µ –≤—Ä–µ–º–µ–Ω–Ω—ã—Ö —à–∞–±–ª–æ–Ω–æ–≤ (STDP), –±–µ–∑ –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—è –æ–±—Ä–∞—Ç–Ω–æ–≥–æ —Ä–∞—Å–ø—Ä–æ—Å—Ç—Ä–∞–Ω–µ–Ω–∏—è –æ—à–∏–±–∫–∏.


In [40]:
# ====== SETUP (Colab/Local) ======
# !pip -q install bindsnet==0.2.8 torchvision==0.18.1 torch==2.3.1 --extra-index-url https://download.pytorch.org/whl/cu121

import os, itertools, random, csv, time as _ptime
from dataclasses import dataclass, asdict

import torch
import numpy as np
import matplotlib.pyplot as plt

# ====== Utils ======
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

def to_2d(s):  # [T,N] or [T,1,N] -> [T,N]
    return s[:,0,:] if (s.dim()==3 and s.size(1)==1) else s

# ====== Config ======
@dataclass
class Cfg:
    # core
    time:   int = 200
    n_hidden: int = 100
    nu_plus:  float = 0.02
    nu_minus: float = -0.02

    # inhibition / WTA
    inhib_strength: float = 0.3
    inh_decay: float = 0.9
    top_k: int = 0                    # 0 = WTA off for diagnostics
    enable_inhibition_at_start: bool = False

    # encoder
    encoder: str = "latency"         # start with Poisson to "ignite" spikes

    # homeostasis
    target_spikes: float = 2.0
    eta_up: float = 1.0
    eta_down: float = 0.5
    thresh_min: float = 0.2
    thresh_max: float = 2.0
    thresh_init: float = 0.5          # v_thresh initial (BindsNET positive scale)

    # weights
    w_clip_min: float = 0.0
    w_clip_max: float = 1.5
    w_col_target_norm: float = 20.0
    w_init_lo: float = 0.8
    w_init_hi: float = 1.2
    wmin:float = 0.0
    wmax:float = 2.0
    # loop
    N: int = 200
    log_every: int = 50
    seed: int = 42

# ====== Helpers: WTA, norm, thresholds, plots, metrics ======
def apply_wta(s, top_k=1):
    s2 = to_2d(s)
    sb = s2.sum(0).float().squeeze()
    if sb.sum() == 0:
        return False, None
    vals, idxs = torch.topk(sb, k=min(top_k, sb.numel()))
    s.zero_()
    for j in idxs.tolist():
        if s.dim()==3:
            s[:,0,j] = True
        else:
            s[:,j] = True
    return True, idxs.tolist()

def weight_soft_bound_and_colnorm(conn_w, w_clip_min, w_clip_max, target_norm):
    with torch.no_grad():
        w = conn_w.data
        w.clamp_(w_clip_min, w_clip_max)
        col_norm = w.norm(p=1, dim=0, keepdim=True) + 1e-6
        w.mul_(target_norm / col_norm)

def adapt_thresholds(layer, spike_counts, cfg: Cfg):
    with torch.no_grad():
        vt = layer.v_thresh if hasattr(layer, "v_thresh") else layer.thresh
        vt -= 0.05 * (spike_counts < 1.0).float()        # if silent -> lower threshold
        vt += 0.02 * (spike_counts > 3.0).float()        # if too active -> raise
        vt.clamp_(cfg.thresh_min, cfg.thresh_max)
        if hasattr(layer, "v_thresh"): layer.v_thresh = vt
        else: layer.thresh = vt

def spiking_metrics_window(lif_s, winners=None):
    s = to_2d(lif_s).to(torch.bool)
    T, N = s.shape
    per_n = s.sum(0)
    tot = int(per_n.sum())
    active = int((per_n > 0).sum())
    if tot > 0:
        p = (per_n / tot).float().cpu().numpy()
        HHI = float((p**2).sum())
        ps = np.sort(p)
        Gini = float((np.cumsum(ps).sum()/ps.sum() - (len(ps)+1)/2)/len(ps))
    else:
        HHI, Gini = 1.0, 1.0
    uniq_winners = len(set(winners)) if winners else 0
    return dict(T=T, N=N, total_spikes=tot, active=active, HHI=HHI, Gini=Gini, uniq_winners=uniq_winners)

class SNNMeter:
    def __init__(self): self.reset()
    def reset(self):
        self.samples=0; self.S_out=0; self.S_in=0; self.SynOps=0; self.V_updates=0
        self.usage_counts = {}
    def log_sample(self, lif_s, in_s, n_hidden, T, winners=None):
        lif2 = to_2d(lif_s);  in2 = to_2d(in_s)
        s_out = int(lif2.sum().item())
        s_in  = int(in2.sum().item())
        self.S_out += s_out; self.S_in += s_in
        self.SynOps += s_in * n_hidden
        self.V_updates += n_hidden * T
        self.samples += 1
        if winners:
            for j in winners:
                self.usage_counts[j] = self.usage_counts.get(j,0)+1
    def report(self, a=1.0, b=0.05, c=0.005):
        s = max(1, self.samples)
        HHI_win = 0.0
        if self.usage_counts:
            tot = sum(self.usage_counts.values())
            ps = np.array([v/tot for v in self.usage_counts.values()], dtype=float)
            HHI_win = float((ps**2).sum())
        return {
            "spikes_per_sample": self.S_out/s,
            "synops_per_sample": self.SynOps/s,
            "v_updates_per_sample": self.V_updates/s,
            "energy_proxy_per_sample": (a*self.S_out + b*self.SynOps + c*self.V_updates)/s,
            "winners_unique": len(self.usage_counts),
            "winner_HHI": HHI_win,
        }

# ====== Build Net & Encoder ======
from bindsnet.network import Network
from bindsnet.network.nodes import Input, LIFNodes
from bindsnet.network.topology import Connection
from bindsnet.learning import PostPre
from torchvision import transforms
from bindsnet.datasets import MNIST
from bindsnet.network.monitors import Monitor

def build_net(cfg: Cfg):
    net = Network()

    input_layer = Input(n=784, traces=True)
    lif_layer   = LIFNodes(n=cfg.n_hidden, traces=True)
    tune_lif_params(lif_layer, cfg.n_hidden, vt_mean=0.35, vt_jitter=0.02, tau_val=50.0, refrac_val=2.0)
            
    net.add_layer(input_layer, name='Input')
    net.add_layer(lif_layer,   name='LIF')
    

    connection = Connection(source=input_layer, target=lif_layer)
    connection.update_rule = PostPre(connection=connection,
                                 nu=(torch.tensor(cfg.nu_plus),
                                     torch.tensor(cfg.nu_minus)))
    net.add_connection(connection, source='Input', target='LIF')

    # Lateral inhibition (created, but optionally disabled at start)
    W_inh = torch.full((cfg.n_hidden, cfg.n_hidden), -cfg.inhib_strength)
    W_inh.fill_diagonal_(0.0)
    recurrent_inh = Connection(source=lif_layer, target=lif_layer, w=W_inh.clone())
    net.add_connection(recurrent_inh, source='LIF', target='LIF')

    # Weights init (stronger to ignite)
    with torch.no_grad():
        connection.w.data.uniform_(cfg.w_init_lo, cfg.w_init_hi)

    # Thresholds: use v_thresh if available
    th0 = torch.full((cfg.n_hidden,), cfg.thresh_init)
    if hasattr(lif_layer, "v_thresh"): lif_layer.v_thresh = th0.clone()
    else: lif_layer.thresh = th0.clone()

    # Optionally disable inhibition at start (for diagnostics)
    if not cfg.enable_inhibition_at_start:
        with torch.no_grad():
            recurrent_inh.w.zero_()

    return net, input_layer, lif_layer, connection, recurrent_inh, W_inh

# –ü—Ä–µ-–ø—Ä–æ—Ü–µ—Å—Å –¥–ª—è –ø—Ä–æ–∏–∑–≤–æ–ª—å–Ω—ã—Ö –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π (PIL / np / tensor)
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((28, 28), antialias=True),
    transforms.ToTensor(),                 # -> [1,28,28] float in [0,1]
    # –ù–∏–∫–∞–∫–∏—Ö Normalize(mean,std) –∑–¥–µ—Å—å ‚Äî –Ω–∞–º –Ω—É–∂–Ω—ã ¬´—Å—ã—Ä—ã–µ¬ª 0..1!
])

def make_encoder(encoder_type: str, T: int, rate_floor: float = 0.0):  # floor –ø–æ —É–º–æ–ª—á–∞–Ω–∏—é 0
    encoder_type = encoder_type.lower()
    assert encoder_type in ("poisson", "latency")

    def encode_poisson(img_tensor):
        rates = img_tensor.view(-1).clamp(0, 1)
        # –î–ï–õ–ê–ï–ú –†–ï–î–ö–ò–ô –ü–û–¢–û–ö, –ß–¢–û–ë–´ –ù–ï –ó–ê–õ–ò–í–ê–¢–¨ –°–ï–¢–¨
        RATE_SCALE = 0.2
        rates = rates * RATE_SCALE
        # –Ω–µ–±–æ–ª—å—à–æ–π floor (–µ—Å–ª–∏ —Ö–æ—á–µ—à—å —Å–æ–≤—Å–µ–º –Ω–µ –Ω–æ–ª—å)
        # rates = torch.maximum(rates, torch.full_like(rates, 5e-3))
        rand = torch.rand((T, rates.numel()))
        spikes = (rand < rates).float().view(T, 1, 784)
        return spikes

    def encode_latency(img_tensor):
        x = img_tensor.squeeze(0).clamp(0, 1)
        spikes = torch.zeros((T, 1, 784), dtype=torch.float32)
        nz = (x > 0).nonzero(as_tuple=False)
        if nz.numel() == 0:
            return spikes
        for idx in nz:
            i, j = int(idx[0]), int(idx[1])
            p = float(x[i, j])
            t = int(round((1.0 - p) * (T - 1)))
            # –º–∞–ª–µ–Ω—å–∫–∏–π –¥–∂–∏—Ç—Ç–µ—Ä ¬±1 —Ç–∏–∫ (–≤ –ø—Ä–µ–¥–µ–ª–∞—Ö –æ–∫–Ω–∞)
            if T >= 3:
                t += int(torch.randint(-1, 2, (1,)).item())
                t = max(0, min(T-1, t))
            spikes[t, 0, i*28 + j] = 1.0
        return spikes

    return (encode_poisson if encoder_type == "poisson" else encode_latency), preprocess
        
def _to_2d(s):  # [T, B, N] -> [T, N]
    return s[:, 0, :] if s.dim()==3 else s
    
# ====== One Experiment ======
def run_experiment(cfg: Cfg, verbose=True):
    set_seed(cfg.seed)

    # –î–∞—Ç–∞—Å–µ—Ç (MNIST —É–∂–µ –≤ [0,1] –∏ [1,28,28])
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = MNIST(root='./data', train=True, download=True, transform=transform)
    # –°–µ—Ç—å + –º–æ–Ω–∏—Ç–æ—Ä—ã
    net, input_layer, lif_layer, connection, recurrent_inh, W_inh = build_net(cfg)
    with torch.no_grad():
        recurrent_inh.w.fill_(0.0)
        recurrent_inh.w -= 0.2 * (1 - torch.eye(cfg.n_hidden))  # –º—è–≥–∫–∞—è –∏–Ω–≥–∏–±–∏—Ü–∏—è
    lif_mon = Monitor(lif_layer, state_vars=("s",), time=cfg.time)
    inp_mon = Monitor(input_layer, state_vars=("s",), time=cfg.time)
    net.add_monitor(lif_mon, name="lif_mon")
    net.add_monitor(inp_mon, name="inp_mon")

    # –≠–Ω–∫–æ–¥–µ—Ä
    ENCODER_TYPE = cfg.encoder
    T = cfg.time
    encoder, _ = make_encoder(ENCODER_TYPE, T)

    # --------- WARMUP (–±–µ–∑ STDP) ---------
    WARMUP = getattr(cfg, "warmup_N", 50)
    if WARMUP > 0:
        # –Ω–∞ –ø—Ä–æ–≥—Ä–µ–≤ STDP –≤—ã–∫–ª.
        set_stdp_nu(connection, 0.0, 0.0)
        for wi in range(min(WARMUP, len(dataset))):
            image = dataset[wi]["image"]
            spike_input = encoder(image)
            net.run(inputs={"Input": spike_input}, time=cfg.time)

            # –∞–¥–∞–ø—Ç–∞—Ü–∏—è –ø–æ—Ä–æ–≥–æ–≤ –Ω–∞ –ø—Ä–æ–≥—Ä–µ–≤–µ (–ø–æ –∂–µ–ª–∞–Ω–∏—é ‚Äî –ø–æ–ª–µ–∑–Ω–æ)
            lif_s_full = lif_mon.get("s")
            spike_counts = to_2d(lif_s_full).sum(0).float().squeeze()
            adapt_thresholds_ema(lif_layer, spike_counts, cfg.time, target=2.0)
            

            # –æ—á–∏—Å—Ç–∫–∞ —Å–æ—Å—Ç–æ—è–Ω–∏–π –º–µ–∂–¥—É –ø—Ä–∏–º–µ—Ä–∞–º–∏
            net.reset_state_variables()
            lif_mon.reset_state_variables()
            inp_mon.reset_state_variables()

    # --------- –û–°–ù–û–í–ù–û–ô –¶–ò–ö–õ ---------
    meter = SNNMeter()
    # –º—è–≥–∫–∏–π STDP –ø–æ—Å–ª–µ –ø—Ä–æ–≥—Ä–µ–≤–∞
    set_stdp_nu(connection, 2e-4, -1e-4)  # –±—ã–ª–æ 1e-3 / -5e-4
    

    for i in range(cfg.N):
        sample = dataset[i]
        image  = sample["image"]
        spike_input = encoder(image)
        inputs = {"Input": spike_input}

        net.run(inputs=inputs, time=cfg.time)

        # –ü–æ–ª–Ω—ã–π —Ä–∞—Å—Ç—Ä –∑–∞ –æ–∫–Ω–æ
        lif_s_full = lif_mon.get("s")   # [T,B,N]
        in_s_full  = inp_mon.get("s")   # [T,B,784]
        lif2 = _to_2d(lif_s_full); in2 = _to_2d(in_s_full)

        # –î–∏–∞–≥–Ω–æ—Å—Ç–∏–∫–∞ –ø–æ —á–µ–∫–ø–æ–∏–Ω—Ç–∞–º
        if i in (0, 50, 150):
            print("INPUT spikes sum (window):", int(in2.sum().item()))
            print("LIF   spikes sum (window):",   int(lif2.sum().item()))
            print("INPUT window sum:", int(in2.sum()))
            print("LIF   window sum:", int(lif2.sum()))
            vt = (lif_layer.v_thresh if hasattr(lif_layer,'v_thresh') else lif_layer.thresh)
            print("v_thresh mean¬±std:", float(vt.mean()), float(vt.std()))
            print("w[min,max]:", float(connection.w.min()), float(connection.w.max()))

        # WTA (–µ—Å–ª–∏ –≤–∫–ª—é—á—ë–Ω)
        winners = None
        if cfg.top_k and cfg.top_k > 0:
            ok, winners = apply_wta(lif_layer.s, top_k=cfg.top_k)
            if not ok:
                net.reset_state_variables(); lif_mon.reset_state_variables(); inp_mon.reset_state_variables()
                continue

        # –ú–µ—Ç—Ä–∏–∫–∏ –æ–∫–Ω–∞
        m = spiking_metrics_window(lif_s_full, winners)
        if verbose and ((i+1) % cfg.log_every == 0 or i == 0):
            print(f"[{i+1}] total={m['total_spikes']} active={m['active']}/{m['N']} HHI={m['HHI']:.3f}")

        # Homeostasis –ø–æ ¬´—Å—ã—Ä—ã–º¬ª —Å–ø–∞–π–∫–∞–º (–¥–æ WTA)
        spike_counts = to_2d(lif_s_full).sum(0).float()
        adapt_thresholds_ema(lif_layer, spike_counts, cfg.time, target=1.5)

        # –ö–ª–∞–º–ø –≤–µ—Å–æ–≤ (–±–µ–∑ –∞–≥—Ä–µ—Å—Å–∏–≤–Ω–æ–π –∫–æ–ª–æ–Ω–æ—á–Ω–æ–π –Ω–æ—Ä–º–∏—Ä–æ–≤–∫–∏ –∫–∞–∂–¥—ã–π —à–∞–≥)
        with torch.no_grad():
            connection.w.clamp_(0.0, 1.0)

        # –≠–Ω–µ—Ä–≥–µ—Ç–∏–∫–∞ / —É—á—ë—Ç
        meter.log_sample(lif_s_full, in_s_full, cfg.n_hidden, cfg.time, winners=winners)

        # –°–±—Ä–æ—Å —Å–æ—Å—Ç–æ—è–Ω–∏–π –∏ –º–æ–Ω–∏—Ç–æ—Ä–æ–≤
        net.reset_state_variables()
        lif_mon.reset_state_variables()
        inp_mon.reset_state_variables()

   
    rpt = meter.report()
    if getattr(meter, "samples", 0) == 0:
        print("!! meter: no samples logged ‚Äî –ø—Ä–æ–≤–µ—Ä—å –ø–æ—Ä—è–¥–æ–∫ log_sample()/reset() –∏ continue –≤ —Ü–∏–∫–ª–µ")
    out = {**asdict(cfg), **rpt}
    return out

# ====== Grid Runner (compact) ======
def grid_run(base: Cfg):
    grid = {
        "inhib_strength": [0.3, 0.5],
        "top_k": [0, 3],
        "time": [200, 300],
        "use_latency": [False, True],
    }
    keys, vals = zip(*grid.items())
    results = []
    t0 = _ptime.time()
    for combo in itertools.product(*vals):
        cfg = Cfg(**{**asdict(base), **dict(zip(keys, combo))})
        print(">>> run:", {k: getattr(cfg,k) for k in keys})
        res = run_experiment(cfg, verbose=False)
        print({k: res[k] for k in ["spikes_per_sample","winners_unique","winner_HHI","energy_proxy_per_sample"]})
        results.append(res)

    os.makedirs("out", exist_ok=True)
    csv_path = os.path.join("out","snn_energy_accuracy_grid.csv")
    with open(csv_path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=results[0].keys())
        writer.writeheader(); writer.writerows(results)
    print(f"Saved: {csv_path} | runs={len(results)} | elapsed={_ptime.time()-t0:.1f}s")

    # quick Pareto-ish
    best = (sorted(results, key=lambda r: (r["energy_proxy_per_sample"], r["winner_HHI"], -r["winners_unique"])))[:5]
    print("\nTop-5 Pareto-ish:")
    for r in best:
        print({k: r[k] for k in ["inhib_strength","top_k","time","use_latency",
                                 "spikes_per_sample","winners_unique","winner_HHI","energy_proxy_per_sample"]})


In [41]:
cfg = Cfg(
    time=300,
    n_hidden=100,
    encoder="latency",                 # —Å–Ω–∞—á–∞–ª–∞ Poisson
    top_k=0,                           # WTA –≤—ã–∫–ª. –¥–ª—è –¥–∏–∞–≥–Ω–æ—Å—Ç–∏–∫–∏
    enable_inhibition_at_start=False,  # –∏–Ω–≥–∏–±–∏—Ü–∏—é –≤–∫–ª—é—á–∏–º –ø–æ–∑–∂–µ
    nu_plus = 0.002,
    nu_minus = -0.001
)
res = run_experiment(cfg, verbose=True)
print("\nSUMMARY:", {k: res[k] for k in ["spikes_per_sample","winners_unique","winner_HHI","energy_proxy_per_sample"]})


RuntimeError: shape '[]' is invalid for input of size 100