In [None]:
import torch, torch.nn as nn, pytorch_lightning as pl
import snntorch.functional as SF
from sklearn.metrics import accuracy_score
from models import ConvConv
# ---------------- constants -------------
T, Q = 16, 1.3
TH_SHIFT = (Q + 1) / (2 * Q)

# -------- 1. Conv2d/Linear â†’ SSPC wrapper --------
class SSPCBlock(nn.Module):
    def __init__(self, layer, base_in, base_out, max_in, max_out):
        super().__init__()
        # rescale weights
        w = layer.weight.data * (max_in.view(1, -1, 1, 1) / max_out.view(-1, 1, 1, 1))
        b = layer.bias.data / max_out
        self.W = nn.Parameter(w, requires_grad=False)
        self.b = nn.Parameter(b, requires_grad=False)
        self.base_in, self.base_out = base_in, base_out
        self.register_buffer("Vth", torch.full((w.size(0),), TH_SHIFT))
        self.register_buffer("fired", torch.zeros(w.size(0), dtype=torch.bool))
        self.register_buffer("v", torch.zeros(w.size(0)))
    def forward(self, x_spk):
        if self.fired.all():           # no more computation after first spike
            return torch.zeros_like(x_spk[:, :self.W.size(0)])
        self.v = self.v * self.base_in + torch.einsum('oiwh,b i h w -> b o', self.W, x_spk) + self.b
        s = (self.v >= self.Vth) & ~self.fired
        self.fired |= s
        self.Vth *= self.base_in / self.base_out
        return s.float()

# -------- 2. Spiking ConvConv -------------
class SpikingConvConv(pl.LightningModule):
    def __init__(self, input_shape, num_labels, num_conv_filters, size,
                 num_hops=10, learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters()

        # ---- original layers for calibration only -------
        self.orig = ConvConv(input_shape, num_labels, num_conv_filters, size)

        # ---- gather max_act statistics (offline) -----
        max_act = torch.load('max_act.pt')                    # assume pre-computed

        # ---- build spiking replicas ----
        self.block1 = SSPCBlock(self.orig.conv, 2.0, Q,
                                max_act['in0'], max_act['out0'])
        self.block2 = SSPCBlock(self.orig.temp_conv[0], Q, Q,
                                max_act['out0'], max_act['out1'])
        self.block3 = SSPCBlock(self.orig.temp_conv[1], Q, Q,
                                max_act['out1'], max_act['out2'])
        self.block4 = SSPCBlock(self.orig.temp_conv[3], Q, Q,
                                max_act['out2'], max_act['out3'])
        self.block5 = SSPCBlock(self.orig.temp_conv[4], Q, Q,
                                max_act['out3'], max_act['out4'])
        self.fc1    = SSPCBlock(self.orig.fc[0], Q, Q,
                                max_act['out4'], max_act['fc0'])
        self.fc2    = nn.Linear(128, num_labels, bias=True)   # last layer stays analog
        self.criterion = nn.CrossEntropyLoss()

    # ---- forward over all phases -------
    def forward(self, x):
        spikes = SF.rate_to_binary(x, num_steps=T)             # (B,T,seq,feat,1)
        B = x.size(0); logits = 0
        for t in range(T):
            s0 = spikes[:, t].permute(0, 3, 1, 2)             # (B,1,seq,feat)
            s1 = self.block1(s0)
            s1 = s1.unsqueeze(-1)                             # restore (B,C,H,W)
            s2 = self.block2(s1)
            s2 = self.block3(s2)
            s3 = self.block4(s2)
            s3 = self.block5(s3)
            flat = s3.flatten(1)
            s4 = self.fc1(flat)
            logits += self.fc2(s4)                            # accumulate
        return logits / T                                     # mean firing score

    def training_step(self, batch, _):
        x, y = batch
        logit = self(x)
        loss = self.criterion(logit, y)
        self.log("loss", loss);   return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),
                                 lr=self.hparams.learning_rate)