In [None]:
min_weight, max_weight = -0.7, 0.7
num_weight_levels = 24

# Setup and dataset

In [None]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch, warnings
from torch.amp import autocast, GradScaler
from torch.cuda.amp import autocast, GradScaler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler

In [None]:
mean = (0.5071, 0.4867, 0.4408)
std  = (0.2675, 0.2565, 0.2761)

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=0),
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_set = datasets.CIFAR100(root="./data", train=True, download=True, transform=train_tf)
test_set  = datasets.CIFAR100(root="./data", train=False, download=True, transform=test_tf)

batch_size = 16
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False,
                          num_workers=2, pin_memory=True)

# Network

In [None]:
class CurrentConv2d(nn.Module):
    def __init__(self, in_ch, out_ch, k, stride=1, padding=0, bias=False, width_scale=1.0):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, k, stride=stride, padding=padding, bias=bias)
        self.width_scale = float(width_scale)

    def forward(self, x):
        # x interpreted as pulse-width signal scaling current
        return self.conv(x * self.width_scale)


class CurrentLinear(nn.Module):
    def __init__(self, in_f, out_f, bias=True, width_scale=1.0):
        super().__init__()
        self.fc = nn.Linear(in_f, out_f, bias=bias)
        self.width_scale = float(width_scale)

    def forward(self, x):
        return self.fc(x * self.width_scale)

In [None]:
class WRNBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride, dropout=0.0):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.conv1 = CurrentConv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)
        self.conv2 = CurrentConv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.dropout = float(dropout)

        self.shortcut = nn.Identity()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = F.relu(self.bn1(x), inplace=True)
        out = self.conv1(out)
        out = F.relu(self.bn2(out), inplace=True)
        if self.dropout > 0:
            out = F.dropout(out, p=self.dropout, training=self.training)
        out = self.conv2(out)
        return out + self.shortcut(x)


class WideResNet(nn.Module):
    def __init__(self, depth=28, widen_factor=10, dropout=0.0, num_classes=100):
        super().__init__()
        assert (depth - 4) % 6 == 0
        n = (depth - 4) // 6
        k = widen_factor
        ch = [16, 16*k, 32*k, 64*k]

        self.conv1 = CurrentConv2d(3, ch[0], 3, stride=1, padding=1, bias=False)
        self.block1 = self._make_group(ch[0], ch[1], n, stride=1, dropout=dropout)
        self.block2 = self._make_group(ch[1], ch[2], n, stride=2, dropout=dropout)
        self.block3 = self._make_group(ch[2], ch[3], n, stride=2, dropout=dropout)
        self.bn = nn.BatchNorm2d(ch[3])
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = CurrentLinear(ch[3], num_classes, bias=True)

        # He init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

    def _make_group(self, in_ch, out_ch, n, stride, dropout):
        layers = [WRNBlock(in_ch, out_ch, stride=stride, dropout=dropout)]
        for _ in range(1, n):
            layers.append(WRNBlock(out_ch, out_ch, stride=1, dropout=dropout))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = F.relu(self.bn(x), inplace=True)
        x = self.pool(x).flatten(1)
        x = self.fc(x)
        return x

In [None]:
def maskW(w: torch.Tensor) -> torch.Tensor:
    w = torch.clamp(w, min=min_weight, max=max_weight)
    step = (max_weight - min_weight) / (num_weight_levels - 1)
    return ((w - min_weight) / step).round() * step + min_weight

def maskW_ste(w: torch.Tensor) -> torch.Tensor:
    # Forward: quantized weights
    w_q = maskW(w)
    # Backward: pretend quantization was identity (straight-through estimator)
    return w + (w_q - w).detach()

def one_hot(y, num_classes, device):
    return F.one_hot(y, num_classes=num_classes).float().to(device)

def mixup_cutmix(x, y, num_classes, mixup_alpha=0.2, cutmix_alpha=1.0, p=1.0):
    if random.random() > p:
        return x, one_hot(y, num_classes, x.device)

    use_cutmix = random.random() < 0.5
    if use_cutmix:
        lam = torch.distributions.Beta(cutmix_alpha, cutmix_alpha).sample().item()
        rand_index = torch.randperm(x.size(0), device=x.device)
        y_a = one_hot(y, num_classes, x.device)
        y_b = one_hot(y[rand_index], num_classes, x.device)

        bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
        x[:, :, bby1:bby2, bbx1:bbx2] = x[rand_index, :, bby1:bby2, bbx1:bbx2]
        lam = 1.0 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size(-1) * x.size(-2)))
        y_mix = lam * y_a + (1 - lam) * y_b
        return x, y_mix
    else:
        lam = torch.distributions.Beta(mixup_alpha, mixup_alpha).sample().item()
        rand_index = torch.randperm(x.size(0), device=x.device)
        x_mix = lam * x + (1 - lam) * x[rand_index]
        y_a = one_hot(y, num_classes, x.device)
        y_b = one_hot(y[rand_index], num_classes, x.device)
        y_mix = lam * y_a + (1 - lam) * y_b
        return x_mix, y_mix

def rand_bbox(size, lam):
    # size: (B, C, H, W)
    W = size[3]
    H = size[2]
    cut_rat = math.sqrt(1.0 - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = random.randint(0, W - 1)
    cy = random.randint(0, H - 1)
    bbx1 = max(cx - cut_w // 2, 0)
    bby1 = max(cy - cut_h // 2, 0)
    bbx2 = min(cx + cut_w // 2, W)
    bby2 = min(cy + cut_h // 2, H)
    return bbx1, bby1, bbx2, bby2

def soft_ce(logits, y_soft):
    logp = F.log_softmax(logits, dim=1)
    return -(y_soft * logp).sum(dim=1).mean()

In [None]:
k: float = 1.38e-23   # Boltzmann constant (J/K)
q: float = 1.602176634e-19
T_r: float = 300
n: float = 1.5
k_mu: float = 1.5
k_vt: float = 1e-3 # Typical value for threshold voltage temperature dependence (in V/K)
Tr = 300

def thermal_voltage(T):
    return k * T / q

def Id_subthreshold(W, L, mu, Cox, Vth, VGS, VDS, T, n):
    Vt = thermal_voltage(T)
    Is0 = (W / L) * mu * Cox * (Vt**2) * np.exp(1.8)
    return Is0 * np.exp((VGS - Vth) / (n * Vt)) * (1 - np.exp(-VDS / Vt))

class SubthresholdPVTNoise(nn.Module):
    def __init__(self):
        super().__init__()
        """
        params: dict with
            W, L, mu, Cox, Vth, VGS, VDS, T, n
        sigmas: dict with
            W, L, mu, Cox, Vth, VDS (normal std)
            dT (uniform half-range)
        """
        #self.p = params
        self.W   = 44e-9
        self.L   = 22e-9
        self.mu  = 0.03
        self.Cox = 0.03
        self.Vth = 0.35
        self.VGS = 0.25
        self.VDS = 0.8
        self.T0  = 310.0
        self.n   = 1.5
        self.dId_dW_value = 0.0
        self.dId_dL_value = 0.0
        self.dId_dmu_value = 0.0
        self.dId_dCox_value = 0.0
        self.dId_dVth_value = 0.0
        self.dId_dVDS_value = 0.0
        self.dId_dT_value = 0.0
        self.s = {
            "W":   0.10 * self.W,
            "L":   0.10 * self.L,
            "mu":  0.10 * self.mu,
            "Cox": 0.10 * self.Cox,
            "Vth": 0.10 * self.Vth,
            "VDS": 0.10 * self.VDS,
            "T_min": 248,
            "T_max": 398,
        }

    def dId_dW(self):
        Vt = thermal_voltage(self.T0)
        return (self.mu * self.Cox * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(-self.VDS / Vt))

    def dId_dmu(self):
        Vt = thermal_voltage(self.T0)
        return (self.W * self.Cox * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(-self.VDS / Vt))

    def dId_dCox(self):
        Vt = thermal_voltage(self.T0)
        return (self.W * self.mu * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(-self.VDS / Vt))

    def dId_dL(self):
        Vt = thermal_voltage(self.T0)
        return -(self.W * self.mu * self.Cox * Vt**2 / self.L**2) * np.exp(1.8) * \
                np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(-self.VDS / Vt))
    
    def dId_dVDS(self):
        Vt = thermal_voltage(self.T0)
        Is0 = (self.W / self.L) * self.mu * self.Cox * (Vt**2) * np.exp(1.8)
        return Is0 * np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (np.exp(-self.VDS / Vt) / Vt)
    
    def dId_dVth(self):
        Id = Id_subthreshold(self.W, self.L, self.mu, self.Cox, self.Vth, self.VGS, self.VDS, self.T0, self.n)
        Vt = thermal_voltage(self.T0)
        return (-1 / (self.n * Vt)) * Id * np.exp((self.VGS - self.Vth) / (self.n * Vt)) * \
            (1 - np.exp(-self.VDS / Vt))

    def dId_dT(self):
        Vt = thermal_voltage(self.T0)
        Id = Id_subthreshold(self.W, self.L, self.mu, self.Cox, self.Vth, self.VGS, self.VDS, self.T0, self.n)

        term1 = k_mu - (( k_vt * self.T0) / (self.n * Vt)) * Id
        term2 = (1 - np.exp(-self.VDS / Vt)) * Tr**(-k_mu) * self.T0**(k_mu - 1)
        term3 = np.exp((self.VGS - self.Vth * Tr - k_vt * (self.T0-Tr)) / (self.n * Vt))
        return Id * (term1 + term2 + term3)

    def update_params(self):
        self.dId_dW_value = self.dId_dW()
        self.dId_dL_value = self.dId_dL()
        self.dId_dmu_value = self.dId_dmu()
        self.dId_dCox_value = self.dId_dCox()
        self.dId_dVth_value = self.dId_dVth()
        self.dId_dVDS_value = self.dId_dVDS()
        self.dId_dT_value = self.dId_dT()

    def forward(self, Id0):
        """
        Id0: torch.Tensor (any shape)
        returns: noisy Id tensor (same shape)
        """
        device, dtype = Id0.device, Id0.dtype
        self.update_params()

        # ---- sample parameter variations ----
        dW   = torch.normal(0.0, self.s["W"],   size=Id0.shape, device=device, dtype=dtype)
        dL   = torch.normal(0.0, self.s["L"],   size=Id0.shape, device=device, dtype=dtype)
        dmu  = torch.normal(0.0, self.s["mu"],  size=Id0.shape, device=device, dtype=dtype)
        dCox = torch.normal(0.0, self.s["Cox"], size=Id0.shape, device=device, dtype=dtype)
        dVth = torch.normal(0.0, self.s["Vth"], size=Id0.shape, device=device, dtype=dtype)
        dVDS = torch.normal(0.0, self.s["VDS"], size=Id0.shape, device=device, dtype=dtype)
        dT   = torch.empty_like(Id0).uniform_(self.s["T_min"], self.s["T_max"])
        dT  = dT - self.T0

        # dW = 0
        # dL = 0
        # dmu = 0
        # dCox = 0    
        # dVth = 0
        # dVDS = 0
        # dT = 0

        # ---- analytical partial derivatives ----
        dId = (
            self.dId_dW_value           * dW   +
            self.dId_dL_value           * dL   +
            self.dId_dmu_value          * dmu  +
            self.dId_dCox_value         * dCox +
            self.dId_dVth_value         * dVth +
            self.dId_dVDS_value * dVDS +
            self.dId_dT_value * dT
        )

        return Id0 + dId


In [None]:
class QATConv2d_no_other_noise(nn.Conv2d):
    def forward(self, x):
        w = maskW_ste(self.weight) if self.training else maskW(self.weight)
        return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)

class QATLinear_no_other_noise(nn.Linear):
    def forward(self, x):
        w = maskW_ste(self.weight) if self.training else maskW(self.weight)
        return F.linear(x, w, self.bias)

In [None]:
class QATConv2d(nn.Conv2d):
    def __init__(self, *args, noise=None, **kwargs):
        super().__init__(*args, **kwargs)
        # Register as a submodule so it moves to GPU + appears in state_dict
        self.noise = noise if noise is not None else SubthresholdPVTNoise()
        self.noise.update_params()

    def forward(self, x):
        w = maskW_ste(self.weight) if self.training else maskW(self.weight)

        # Add noise in forward pass (typically only during training)
        if self.training:
            w = self.noise(w)   # calls SubthresholdPVTNoise.forward(w)

        return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)


class QATLinear(nn.Linear):
    def __init__(self, *args, noise=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.noise = noise if noise is not None else SubthresholdPVTNoise()
        self.noise.update_params()

    def forward(self, x):
        w = maskW_ste(self.weight) if self.training else maskW(self.weight)
        if self.training:
            w = self.noise(w)
        return F.linear(x, w, self.bias)


In [None]:
class SoftLIFRate(nn.Module):
    def __init__(self, theta=0.0, gain=1.0, tau_rc=0.02, tau_ref=0.002, eps=1e-4):
        super().__init__()
        self.theta = nn.Parameter(torch.tensor(float(theta)))
        self.gain  = nn.Parameter(torch.tensor(float(gain)))
        self.tau_rc  = float(tau_rc)
        self.tau_ref = float(tau_ref)
        self.eps = float(eps)

    def forward(self, z):
        # Force the critical math to FP32 to avoid AMP/FP16 inf->nan gradients
        z32 = z.float()
        u = (self.gain.float()) * (z32 - self.theta.float())

        # J = 1 + softplus(u) >= 1
        Jm1 = F.softplus(u)                       # this is (J-1)
        Jm1 = torch.clamp(Jm1, min=self.eps)      # prevents divide-by-0 / inf grads

        # log(1 + 1/(J-1))
        denom = self.tau_ref + self.tau_rc * torch.log1p(1.0 / Jm1)
        rate  = 1.0 / denom

        r = rate * self.tau_ref                   # normalize to [0,1]
        r = torch.clamp(r, 0.0, 1.0)

        return r.to(dtype=z.dtype)                # restore original dtype

In [None]:
class WRNBlock_Rate_noQuant(nn.Module):
    def __init__(self, in_ch, out_ch, stride, dropout=0.0, act=None):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.act1 = act if act is not None else SoftLIFRate()
        self.conv1 = CurrentConv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)

        self.bn2 = nn.BatchNorm2d(out_ch)
        self.act2 = act if act is not None else SoftLIFRate()
        self.conv2 = CurrentConv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)

        self.dropout = float(dropout)

        self.shortcut = nn.Identity()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = self.act1(self.bn1(x))
        out = self.conv1(out)

        out = self.act2(self.bn2(out))
        if self.dropout > 0:
            out = F.dropout(out, p=self.dropout, training=self.training)

        out = self.conv2(out)
        return out + self.shortcut(x)
class WideResNet_SoftLIFRate_noQuant(nn.Module):
    def __init__(self, depth=28, widen_factor=10, dropout=0.0, num_classes=100,
                 in_channels=3, act_theta=0.0, act_gain=1.0, tau_rc=0.02, tau_ref=0.002):
        super().__init__()
        assert (depth - 4) % 6 == 0
        n = (depth - 4) // 6
        k = widen_factor
        ch = [16, 16*k, 32*k, 64*k]

        # one shared config, but separate modules get created per block below
        def make_act():
            return SoftLIFRate(theta=act_theta, gain=act_gain, tau_rc=tau_rc, tau_ref=tau_ref)

        self.conv1 = CurrentConv2d(in_channels, ch[0], 3, stride=1, padding=1, bias=False)
        self.block1 = self._make_group(ch[0], ch[1], n, stride=1, dropout=dropout, make_act=make_act)
        self.block2 = self._make_group(ch[1], ch[2], n, stride=2, dropout=dropout, make_act=make_act)
        self.block3 = self._make_group(ch[2], ch[3], n, stride=2, dropout=dropout, make_act=make_act)

        self.bn = nn.BatchNorm2d(ch[3])
        self.act_out = make_act()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = CurrentLinear(ch[3], num_classes, bias=True)

    def _make_group(self, in_ch, out_ch, n, stride, dropout, make_act):
        layers = [WRNBlock_Rate_noQuant(in_ch, out_ch, stride=stride, dropout=dropout, act=make_act())]
        for _ in range(1, n):
            layers.append(WRNBlock_Rate_noQuant(out_ch, out_ch, stride=1, dropout=dropout, act=make_act()))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.act_out(self.bn(x))
        x = self.pool(x).flatten(1)
        x = self.fc(x)
        return x

In [None]:
class WRNBlock_Rate(nn.Module):
    def __init__(self, in_ch, out_ch, stride, dropout=0.0, act=None):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.act1 = act if act is not None else SoftLIFRate()
        self.conv1 = QATConv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)

        self.bn2 = nn.BatchNorm2d(out_ch)
        self.act2 = act if act is not None else SoftLIFRate()
        self.conv2 = QATConv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.dropout = float(dropout)

        self.shortcut = nn.Identity()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = QATConv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = self.act1(self.bn1(x))
        out = self.conv1(out)

        out = self.act2(self.bn2(out))
        if self.dropout > 0:
            out = F.dropout(out, p=self.dropout, training=self.training)

        out = self.conv2(out)
        return out + self.shortcut(x)
class WideResNet_SoftLIFRate(nn.Module):
    def __init__(self, depth=16, widen_factor=10, dropout=0.0, num_classes=100,
                 in_channels=3, act_theta=0.0, act_gain=1.0, tau_rc=0.02, tau_ref=0.002):
        super().__init__()
        assert (depth - 4) % 6 == 0
        n = (depth - 4) // 6
        k = widen_factor
        ch = [16, 16*k, 32*k, 64*k]

        # one shared config, but separate modules get created per block below
        def make_act():
            return SoftLIFRate(theta=act_theta, gain=act_gain, tau_rc=tau_rc, tau_ref=tau_ref)

        self.conv1 = QATConv2d(in_channels, ch[0], 3, stride=1, padding=1, bias=False)
        self.block1 = self._make_group(ch[0], ch[1], n, stride=1, dropout=dropout, make_act=make_act)
        self.block2 = self._make_group(ch[1], ch[2], n, stride=2, dropout=dropout, make_act=make_act)
        self.block3 = self._make_group(ch[2], ch[3], n, stride=2, dropout=dropout, make_act=make_act)

        self.bn = nn.BatchNorm2d(ch[3])
        self.act_out = make_act()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = QATLinear(ch[3], num_classes, bias=True)

    def _make_group(self, in_ch, out_ch, n, stride, dropout, make_act):
        layers = [WRNBlock_Rate(in_ch, out_ch, stride=stride, dropout=dropout, act=make_act())]
        for _ in range(1, n):
            layers.append(WRNBlock_Rate(out_ch, out_ch, stride=1, dropout=dropout, act=make_act()))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.act_out(self.bn(x))
        x = self.pool(x).flatten(1)
        x = self.fc(x)
        return x

In [None]:
@torch.no_grad()
def clamp_weights_(model, lo=min_weight, hi=max_weight):
    # Clamp only "synaptic" weights (Conv/Linear). Leave BN alone.
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            if m.weight is not None:
                m.weight.clamp_(lo, hi)
            if getattr(m, "bias", None) is not None:
                m.bias.clamp_(lo, hi)

# Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = WideResNet(depth=16, widen_factor=10, dropout=0.0, num_classes=100).to(device)

# Label smoothing helps CIFAR-100; if using soft labels (mixup), weâ€™ll use soft CE anyway.
label_smoothing = 0.1
criterion_hard = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)
epochs = 200
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

use_amp = (device == "cuda")
scaler = GradScaler(device, enabled=use_amp)
warnings.filterwarnings("ignore", category=FutureWarning, message="`torch.cuda.amp.*` is deprecated.*")
print("autocast module:", autocast.__module__)
print("GradScaler module:", GradScaler.__module__)

In [None]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        loss = criterion_hard(logits, y)
        loss_sum += float(loss) * x.size(0)
        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum())
        total += y.numel()
    return loss_sum / total, correct / total

def train_one_epoch(model, loader, use_mix=True):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        if use_mix:
            x, y_soft = mixup_cutmix(x, y, num_classes=100, mixup_alpha=0.2, cutmix_alpha=1.0, p=1.0)
        else:
            y_soft = None

        with autocast(enabled=use_amp, device_type=device.type):
            logits = model(x)
            if y_soft is None:
                loss = criterion_hard(logits, y)
            else:
                loss = soft_ce(logits, y_soft)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        clamp_weights_(model, -0.7, 0.7)
        scaler.update()

        running_loss += float(loss) * x.size(0)
        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum())
        total += y.numel()
        print(f"\r  Batch train loss: {running_loss / total:.4f}", end="")

    return running_loss / total, correct / total


In [None]:
import os
os.makedirs("checkpoints", exist_ok=True)

In [None]:
weights_load = torch.load("checkpoints/bp4snn/digLif/epoch_049_acc_24.78.pt")
model.load_state_dict(weights_load["state_dict"], strict=False) 

In [None]:
best_acc = 0.0
for ep in range(1, epochs + 1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, use_mix=True)
    te_loss, te_acc = evaluate(model, test_loader)
    scheduler.step()
    best_acc = max(best_acc, te_acc)
    print(f"Epoch {ep:3d} | train loss {tr_loss:.4f} acc {tr_acc*100:5.2f}%"
            f" | test loss {te_loss:.4f} acc {te_acc*100:5.2f}%")

    if ep % 10 == 0 or ep == 1:
        print(f"Epoch {ep:3d} | train loss {tr_loss:.4f} acc {tr_acc*100:5.2f}%"
              f" | test loss {te_loss:.4f} acc {te_acc*100:5.2f}% | best {best_acc*100:5.2f}%")
    ckpt_path = f"checkpoints/snntorch/softLIF/{ep}_acc_{te_acc}.pt"
    torch.save({"state_dict": model.state_dict(), "test_acc": te_acc}, ckpt_path)
print("Done. Best test accuracy:", best_acc * 100, "%")

# Generating spiking

In [339]:
@torch.no_grad()
def collect_bn_z_stats(relu_model, loader, device, num_batches=50, sample_per_layer=200_000):
    relu_model.eval()
    samples = {}

    def make_hook(name):
        def hook(mod, inp, out):
            # out is BN output = z (pre-ReLU)
            z = out.detach().float().flatten()
            # random subsample to keep memory sane
            if z.numel() > sample_per_layer:
                idx = torch.randint(0, z.numel(), (sample_per_layer,), device=z.device)
                z = z[idx]
            z = z.cpu()
            samples.setdefault(name, []).append(z)
        return hook

    hooks = []
    for name, m in relu_model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            hooks.append(m.register_forward_hook(make_hook(name)))

    # drive a few batches through
    for bi, (x, _) in enumerate(loader):
        if bi >= num_batches:
            break
        x = x.to(device)
        _ = relu_model(x)

    for h in hooks:
        h.remove()

    # compute percentiles + init params
    stats = {}
    for name, chunks in samples.items():
        z = torch.cat(chunks, dim=0)
        q10, q50, q90 = torch.quantile(z, torch.tensor([0.1, 0.5, 0.9]))
        dz = (q90 - q10).item()
        theta = q50.item()
        gain = 4.4 / (dz + 1e-6)         # your rule
        gain = float(max(min(gain, 50.0), 0.1))  # clamp to avoid insanity
        stats[name] = dict(theta=theta, gain=gain, q10=q10.item(), q50=q50.item(), q90=q90.item())
    return stats


In [340]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
relu_model = WideResNet(depth=16, widen_factor=10, dropout=0.0, num_classes=100).to(device)
weights_relu = torch.load("checkpoints/snntorch/softLIF/17_acc_0.1884.pt")
relu_model.load_state_dict(weights_relu["state_dict"], strict=True)  # adapt to your checkpoint
relu_loss, relu_acc = evaluate(relu_model, test_loader)
print(f"FP32   test acc: {relu_acc*100:.2f}% | loss: {relu_loss:.4f}")
bn_stats = collect_bn_z_stats(relu_model, train_loader, device, num_batches=50)

  weights_relu = torch.load("checkpoints/snntorch/softLIF/17_acc_0.1884.pt")


FP32   test acc: 18.84% | loss: 4.2049


In [341]:
soft_model = WideResNet_SoftLIFRate(depth=16, widen_factor=10, dropout=0.0, num_classes=100).to(device)
soft_model.load_state_dict(relu_model.state_dict(), strict=False)  # activations are new

_IncompatibleKeys(missing_keys=['conv1.weight', 'block1.0.act1.theta', 'block1.0.act1.gain', 'block1.0.conv1.weight', 'block1.0.act2.theta', 'block1.0.act2.gain', 'block1.0.conv2.weight', 'block1.1.act1.theta', 'block1.1.act1.gain', 'block1.1.conv1.weight', 'block1.1.act2.theta', 'block1.1.act2.gain', 'block1.1.conv2.weight', 'block2.0.act1.theta', 'block2.0.act1.gain', 'block2.0.conv1.weight', 'block2.0.act2.theta', 'block2.0.act2.gain', 'block2.0.conv2.weight', 'block2.1.act1.theta', 'block2.1.act1.gain', 'block2.1.conv1.weight', 'block2.1.act2.theta', 'block2.1.act2.gain', 'block2.1.conv2.weight', 'block3.0.act1.theta', 'block3.0.act1.gain', 'block3.0.conv1.weight', 'block3.0.act2.theta', 'block3.0.act2.gain', 'block3.0.conv2.weight', 'block3.1.act1.theta', 'block3.1.act1.gain', 'block3.1.conv1.weight', 'block3.1.act2.theta', 'block3.1.act2.gain', 'block3.1.conv2.weight', 'act_out.theta', 'act_out.gain', 'fc.weight', 'fc.bias'], unexpected_keys=['conv1.conv.weight', 'block1.0.conv1.

In [342]:
def softlif_rate01(z, theta, gain, tau_rc=0.02, tau_ref=0.002, eps=1e-4):
    z = z.float()
    theta = torch.tensor(theta, device=z.device, dtype=torch.float32)
    gain  = torch.tensor(gain,  device=z.device, dtype=torch.float32)

    u = gain * (z - theta)
    Jm1 = F.softplus(u)                 # = J-1
    Jm1 = torch.clamp(Jm1, min=eps)
    denom = tau_ref + tau_rc * torch.log1p(1.0 / Jm1)
    rate = 1.0 / denom
    r = rate * tau_ref
    return torch.clamp(r, 0.0, 1.0)

@torch.no_grad()
def solve_gain_for_layer(z10, z90, theta,
                         tau_rc=0.02, tau_ref=0.002,
                         r_hi=0.90, g_min=0.05, g_max=20.0, iters=30):
    """
    Choose gain so that r(z90) ~= r_hi (keeps layer from saturating too early).
    Smaller gain = smoother (more linear), bigger gain = more gate-like.
    """
    z90_t = torch.tensor([z90], device="cpu")

    lo, hi = g_min, g_max
    for _ in range(iters):
        mid = 0.5 * (lo + hi)
        r90 = softlif_rate01(z90_t, theta, mid, tau_rc, tau_ref).item()
        # if too saturated at z90, gain is too large
        if r90 > r_hi:
            hi = mid
        else:
            lo = mid
    return 0.5 * (lo + hi)

@torch.no_grad()
def clamp_softlif_params(model, theta_min=-5.0, theta_max=5.0, gain_min=0.05, gain_max=20.0):
    for m in model.modules():
        if isinstance(m, SoftLIFRate):
            m.theta.data.clamp_(theta_min, theta_max)
            m.gain.data.clamp_(gain_min, gain_max)

In [343]:
def init_softlif_from_bn_stats(
    soft_model,
    bn_stats,
    *,
    theta_key_prefer=("q70", "q50", "theta"),  # try q70 -> q50 -> precomputed theta
    r_hi=0.90,                                 # keep z90 from saturating too hard
    tau_rc=0.02,
    tau_ref=0.002,
    gain_min=0.05,
    gain_max=20.0,
):
    """
    Initializes SoftLIF params from BN stats but in a way that avoids early saturation.

    Expected bn_stats[name] to contain at least q10/q90 and (ideally) q70 or q50.
    If your bn_stats already stores 'theta'/'gain', those are used only as fallback.
    Requires solve_gain_for_layer(...) to be defined in your notebook.
    """

    def pick_theta(d):
        for k in theta_key_prefer:
            if k in d:
                return float(d[k])
        raise KeyError(f"bn_stats entry missing theta candidates: {theta_key_prefer}")

    def pick_q10_q90(d):
        if "q10" in d and "q90" in d:
            return float(d["q10"]), float(d["q90"])
        # fallback if older dict uses different keys
        if "z10" in d and "z90" in d:
            return float(d["z10"]), float(d["z90"])
        raise KeyError("bn_stats entry must include q10/q90 (or z10/z90).")

    # map module object -> its qualified name
    mod2name = {m: n for n, m in soft_model.named_modules()}

    # initialize per-block activations from the BN that feeds them
    for m in soft_model.modules():
        if isinstance(m, WRNBlock):
            # bn1 -> act1
            n_bn1 = mod2name[m.bn1]
            if n_bn1 in bn_stats:
                d = bn_stats[n_bn1]
                q10, q90 = pick_q10_q90(d)
                theta = pick_theta(d)
                gain = solve_gain_for_layer(
                    z10=q10, z90=q90, theta=theta,
                    tau_rc=tau_rc, tau_ref=tau_ref,
                    r_hi=r_hi, g_min=gain_min, g_max=gain_max
                )
                m.act1.theta.data.fill_(theta)
                m.act1.gain.data.fill_(gain)

            # bn2 -> act2
            n_bn2 = mod2name[m.bn2]
            if n_bn2 in bn_stats:
                d = bn_stats[n_bn2]
                q10, q90 = pick_q10_q90(d)
                theta = pick_theta(d)
                gain = solve_gain_for_layer(
                    z10=q10, z90=q90, theta=theta,
                    tau_rc=tau_rc, tau_ref=tau_ref,
                    r_hi=r_hi, g_min=gain_min, g_max=gain_max
                )
                m.act2.theta.data.fill_(theta)
                m.act2.gain.data.fill_(gain)

    # final BN -> final activation
    n_final_bn = mod2name[soft_model.bn]
    if n_final_bn in bn_stats:
        d = bn_stats[n_final_bn]
        q10, q90 = pick_q10_q90(d)
        theta = pick_theta(d)
        gain = solve_gain_for_layer(
            z10=q10, z90=q90, theta=theta,
            tau_rc=tau_rc, tau_ref=tau_ref,
            r_hi=r_hi, g_min=gain_min, g_max=gain_max
        )
        soft_model.act_out.theta.data.fill_(theta)
        soft_model.act_out.gain.data.fill_(gain)

In [344]:
init_softlif_from_bn_stats(soft_model, bn_stats)

In [345]:
def set_bn_eval(m):
    # freezes BN running mean/var updates
    for mod in m.modules():
        if isinstance(mod, (nn.BatchNorm2d, nn.BatchNorm1d)):
            mod.eval()

def get_softlif_params(m):
    # theta/gain only
    params = []
    for mod in m.modules():
        if isinstance(mod, SoftLIFRate):
            params.append(mod.theta)
            params.append(mod.gain)
    return params

@torch.no_grad()
def eval_epoch(model, loader, device, criterion=None):
    model.eval()
    tot, correct, loss_sum = 0, 0, 0.0
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        if criterion is not None:
            loss_sum += criterion(logits, y).item() * y.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        tot += y.numel()
    acc = 100.0 * correct / tot
    loss = (loss_sum / tot) if (criterion is not None) else None
    return acc, loss

def train_epoch(model, loader, device, optimizer, criterion, freeze_bn=False, grad_clip=None):
    model.train()
    if freeze_bn:
        set_bn_eval(model)

    tot, correct, loss_sum = 0, 0, 0.0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()

        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        optimizer.step()

        loss_sum += loss.item() * y.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        tot += y.numel()

    return 100.0 * correct / tot, loss_sum / tot

train_loader = DataLoader(train_set, batch_size=8, shuffle=True,
                          num_workers=2, pin_memory=True)

theta_gain_only_epochs = 0
total_finetune_epochs  = 1   # includes the first 2 epochs
lr_theta_gain = 3e-4
lr_full = 2e-4
weight_decay = 5e-4
grad_clip = None  # e.g., 1.0 if unstable

criterion = nn.CrossEntropyLoss()

# ---- Phase 1: freeze all except theta/gain ----
def is_softlif_param(name):
    return (".theta" in name) or (".gain" in name)
for p in soft_model.parameters():
    p.requires_grad = False
for p in get_softlif_params(soft_model):
    p.requires_grad = True

opt1 = torch.optim.AdamW(
    get_softlif_params(soft_model),
    lr=lr_theta_gain,
    weight_decay=0.0
)

for epoch in range(1, theta_gain_only_epochs + 1):
    tr_acc, tr_loss = train_epoch(
        soft_model, train_loader, device, opt1, criterion,
        freeze_bn=True,            # IMPORTANT in phase 1
        grad_clip=grad_clip
    )
    te_acc, te_loss = eval_epoch(soft_model, test_loader, device, criterion)
    print(f"[Phase1 {epoch}/{theta_gain_only_epochs}] "
          f"train acc={tr_acc:.2f} loss={tr_loss:.4f} | "
          f"test acc={te_acc:.2f} loss={te_loss:.4f}")

# ---- Phase 2: unfreeze all, small LR ----
for p in soft_model.parameters():
    p.requires_grad = True

decay, no_decay = [], []
for n, p in soft_model.named_parameters():
    if not p.requires_grad:
        continue
    if is_softlif_param(n) or n.endswith(".bias") or "bn" in n.lower():
        no_decay.append(p)
    else:
        decay.append(p)

opt2 = torch.optim.SGD(
    [
        {"params": decay, "weight_decay": 5e-4},
        {"params": no_decay, "weight_decay": 0.0},
    ],
    lr=lr_full, momentum=0.9, nesterov=True
)

# optional scheduler (safe default)
sched2 = torch.optim.lr_scheduler.CosineAnnealingLR(
    opt2, T_max=max(1, total_finetune_epochs - theta_gain_only_epochs)
)

# for epoch in range(theta_gain_only_epochs + 1, total_finetune_epochs + 1):
#     tr_acc, tr_loss = train_epoch(
#         soft_model, train_loader, device, opt2, criterion,
#         freeze_bn=False,           # BN can adapt again in phase 2
#         grad_clip=grad_clip
#     )
#     te_acc, te_loss = eval_epoch(soft_model, test_loader, device, criterion)
#     sched2.step()
#     torch.save({"state_dict": soft_model.state_dict(), "test_acc": te_acc}, f"checkpoints/snntorch/softLIF/finetune_epoch_{epoch}.pt")
#     print(f"[Phase2 {epoch}/{total_finetune_epochs}] "
#           f"train acc={tr_acc:.2f} loss={tr_loss:.4f} | "
#           f"test acc={te_acc:.2f} loss={te_loss:.4f} | "
#           f"lr={opt2.param_groups[0]['lr']:.2e}")

In [346]:
# weights_load = torch.load("checkpoints/snntorch/softLIF/finetune_epoch_63.pt")
# soft_model.load_state_dict(weights_load["state_dict"], strict=False) 

In [347]:
# torch.save({"state_dict": soft_model.state_dict(), "test_acc": te_acc}, f"checkpoints/snntorch/softLIF/finetune_epoch_{epoch}.pt")

# Hardware-realistic inputs

## Setup

In [348]:
min_weight, max_weight = -0.7, 0.7
num_weight_levels = 24

In [349]:
def diff_encode(x, input_scale=1.0, clamp01=False):
    """
    x: (B,3,H,W), can be negative (after Normalize)
    returns: (B,6,H,W) = [x_pos, x_neg] where x = x_pos - x_neg

    If clamp01=True -> enforce duty-cycle limit [0,1] (hardware-like) but introduces saturation.
    """
    x = x / float(input_scale)
    x_pos = torch.clamp(x, min=0.0)
    x_neg = torch.clamp(-x, min=0.0)
    if clamp01:
        x_pos = x_pos.clamp(0.0, 1.0)
        x_neg = x_neg.clamp(0.0, 1.0)
    return torch.cat([x_pos, x_neg], dim=1)

In [350]:
@torch.no_grad()
def estimate_input_scale(loader, device="cuda", batches=50, q=0.99):
    vals = []
    for i, (x, _) in enumerate(loader):
        if i >= batches: break
        x = x.to(device)
        vals.append(x.abs().flatten())
    v = torch.cat(vals)
    return float(torch.quantile(v, torch.tensor(q, device=v.device)).item())

In [351]:
@torch.no_grad()
def convert_softlif_to_diff_in(soft_model_3ch, input_scale=1.0, device="cuda"):
    # 1) build 6-channel model
    diff_model = WideResNet_SoftLIFRate(
        depth=16, widen_factor=10, dropout=0.0, num_classes=100,
        in_channels=6
    ).to(device)

    # 2) load all weights EXCEPT conv1
    sd = soft_model_3ch.state_dict()
    sd = {k: v for k, v in sd.items() if k != "conv1.conv.weight"}  # drop mismatched tensor
    incompatible = diff_model.load_state_dict(sd, strict=False)

    # 3) set conv1 as [W, -W] and compensate scaling
    W3 = soft_model_3ch.conv1.conv.weight.data  # (out,3,k,k)
    W6 = torch.cat([W3, -W3], dim=1)            # (out,6,k,k)
    diff_model.conv1.conv.weight.data.copy_(W6 * float(input_scale))

    return diff_model, incompatible.missing_keys, incompatible.unexpected_keys

In [352]:
diff_model = WideResNet_SoftLIFRate(
    depth=16, widen_factor=10, dropout=0.0, num_classes=100,
    in_channels=6
).to(device)

In [353]:
INPUT_SCALE = estimate_input_scale(train_loader, device=device, batches=50, q=0.99)
# diff_model, missing, unexpected = convert_softlif_to_diff_in(soft_model, input_scale=INPUT_SCALE, device=device)
# print("missing:", missing)
# print("unexpected:", unexpected)

In [369]:
import torch.nn as nn
from torch.amp import autocast, GradScaler

@torch.no_grad()
def evaluate_diff_model(diff_model, loader, input_scale, clamp01, device="cuda"):
    diff_model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    ce = nn.CrossEntropyLoss()

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        x6 = diff_encode(x, input_scale=input_scale, clamp01=clamp01)
        logits = diff_model(x6)

        loss = ce(logits, y)
        loss_sum += float(loss.item()) * y.numel()

        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum().item())
        total += int(y.numel())

    return 100.0 * correct / total, loss_sum / total


def train_diff_model(
    diff_model,
    train_loader,
    test_loader,
    input_scale,
    epochs=10,
    lr=1e-4,
    weight_decay=5e-4,
    clamp01=False,
    device="cuda",
    use_amp=True,
    grad_clip=0.0,
):
    diff_model.to(device)
    ce = nn.CrossEntropyLoss()

    # optimizer: AdamW is stable for fine-tuning
    opt = torch.optim.AdamW(diff_model.parameters(), lr=lr, weight_decay=weight_decay)

    # simple cosine schedule (optional but good)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

    scaler = GradScaler(device=device, enabled=use_amp)

    for epoch in range(1, epochs + 1):
        diff_model.train()
        correct = 0
        total = 0
        loss_sum = 0.0

        for x, y in train_loader:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            x6 = diff_encode(x, input_scale=input_scale, clamp01=clamp01)

            opt.zero_grad(set_to_none=True)

            with autocast(device_type=device.type, enabled=use_amp):
                logits = diff_model(x6)
                loss = ce(logits, y)

            scaler.scale(loss).backward()
            
            if grad_clip and grad_clip > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(diff_model.parameters(), grad_clip)

            scaler.step(opt)
            clamp_weights_(diff_model, min_weight, max_weight)
            scaler.update()

            loss_sum += float(loss.item()) * y.numel()
            pred = logits.argmax(dim=1)
            correct += int((pred == y).sum().item())
            total += int(y.numel())

        sched.step()
        opt.step()

        tr_acc = 100.0 * correct / total
        tr_loss = loss_sum / total
        te_acc, te_loss = evaluate_diff_model(diff_model, test_loader, input_scale, clamp01, device=device)

        torch.save({"state_dict": diff_model.state_dict(), "test_acc": te_acc}, f"checkpoints/QAT/epoch_{epoch}_acc_{te_acc:.4f}.pt")
        print(f"[{epoch:02d}/{epochs}] "
              f"train acc={tr_acc:.2f} loss={tr_loss:.4f} | "
              f"test acc={te_acc:.2f} loss={te_loss:.4f} | "
              f"lr={opt.param_groups[0]['lr']:.2e}")

    return diff_model


In [355]:
# # sanity: diff input shape
# x, y = next(iter(train_loader))
# x6 = diff_encode(x.to(device), INPUT_SCALE, clamp01=False)
# print("x:", x.shape, "x6:", x6.shape)  # should be (B,3,H,W) -> (B,6,H,W)

# # sanity: forward works
# diff_model.eval()
# with torch.no_grad():
#     logits = diff_model(x6)
# print("logits:", logits.shape)  # (B, num_classes)


In [356]:
# err = (x.to(device) / INPUT_SCALE - (x6[:, :3] - x6[:, 3:])).abs().max().item()
# print("max reconstruction error:", err)

In [357]:
# soft_model.eval()
# diff_model.eval()

# with torch.no_grad():
#     y_old = soft_model.conv1(x.to(device))
#     y_new = diff_model.conv1(x6)

# print("conv1 max abs diff:", (y_old - y_new).abs().max().item())
# print("conv1 mean abs diff:", (y_old - y_new).abs().mean().item())


In [358]:
# x, _ = next(iter(train_loader))
# x = x.to(device)
# x6 = diff_encode(x, input_scale=INPUT_SCALE, clamp01=False)
# print("x6 min/max (no clamp):", x6.min().item(), x6.max().item())
# x6c = diff_encode(x, input_scale=INPUT_SCALE, clamp01=True)
# print("x6 min/max (clamp01=True):", x6c.min().item(), x6c.max().item())


In [370]:
diff_model = train_diff_model(
    diff_model,
    train_loader=train_loader,
    test_loader=test_loader,
    input_scale=INPUT_SCALE,
    epochs=400,
    lr=1e-4,            # fine-tune small
    weight_decay=5e-4,
    clamp01=False,      # start here
    device=device,
    use_amp=True,
    grad_clip=1.0,
)

[01/400] train acc=7.51 loss=4.0393 | test acc=11.06 loss=3.7404 | lr=1.00e-04
[02/400] train acc=10.22 loss=3.8539 | test acc=13.91 loss=3.5578 | lr=1.00e-04
[03/400] train acc=13.05 loss=3.6620 | test acc=15.64 loss=3.4967 | lr=1.00e-04
[04/400] train acc=15.76 loss=3.4885 | test acc=21.50 loss=3.0911 | lr=1.00e-04
[05/400] train acc=19.10 loss=3.3047 | test acc=27.00 loss=2.7888 | lr=1.00e-04
[06/400] train acc=22.22 loss=3.1329 | test acc=28.17 loss=2.7471 | lr=9.99e-05
[07/400] train acc=25.62 loss=2.9584 | test acc=35.79 loss=2.4060 | lr=9.99e-05
[08/400] train acc=28.50 loss=2.8026 | test acc=38.63 loss=2.2624 | lr=9.99e-05
[09/400] train acc=31.45 loss=2.6787 | test acc=40.54 loss=2.1712 | lr=9.99e-05
[10/400] train acc=34.14 loss=2.5489 | test acc=43.93 loss=2.0580 | lr=9.98e-05
[11/400] train acc=36.61 loss=2.4338 | test acc=46.95 loss=1.9279 | lr=9.98e-05
[12/400] train acc=38.98 loss=2.3307 | test acc=48.46 loss=1.8619 | lr=9.98e-05
[13/400] train acc=41.24 loss=2.2237 | te

KeyboardInterrupt: 

In [None]:
# torch.save({"state_dict": diff_model.state_dict(), "test_acc": te_acc}, f"checkpoints/snntorch/diffEncoding/epoch_{epoch}_acc_{te_acc:.4f}.pt")

In [None]:
diff_model = train_diff_model(
    diff_model,
    train_loader=train_loader,
    test_loader=test_loader,
    input_scale=INPUT_SCALE,
    epochs=10,
    lr=5e-5,
    weight_decay=5e-4,
    clamp01=True,
    device=device,
    use_amp=True,
    grad_clip=1.0,
)


In [None]:
weights_load = torch.load("checkpoints/QAT/epoch_269_acc_77.2300.pt")
diff_model.load_state_dict(weights_load["state_dict"], strict=False)
fp_acc , fp_loss= evaluate_diff_model(diff_model, test_loader, INPUT_SCALE, False, device=device)
print(f"FP32   test acc: {fp_acc:.2f}% | loss: {fp_loss:.4f}")

## Noisy Inference

In [374]:
ckpt =  torch.load("checkpoints/QAT/epoch_269_acc_77.2300.pt")
model_fp = WideResNet_SoftLIFRate(
    depth=16, widen_factor=10, dropout=0.0, num_classes=100, in_channels=6
).to(device)
model_fp.load_state_dict(ckpt["state_dict"], strict=True)
INPUT_SCALE = estimate_input_scale(train_loader, device=device, batches=50, q=0.99)
fp_acc, fp_loss = evaluate_diff_model(model_fp, test_loader, INPUT_SCALE, False, device=device)
print(f"FP32   test acc: {fp_acc:.2f}% | loss: {fp_loss:.4f}")

  ckpt =  torch.load("checkpoints/QAT/epoch_269_acc_77.2300.pt")


FP32   test acc: 76.65% | loss: 1.4955


In [375]:
def iter_named_weight_tensors(m: nn.Module, include_bias: bool = True):
    for mod_name, mod in m.named_modules():
        if isinstance(mod, (nn.Conv2d, nn.Linear)):
            if getattr(mod, "weight", None) is not None:
                yield f"{mod_name}.weight", mod.weight
            if include_bias and getattr(mod, "bias", None) is not None:
                yield f"{mod_name}.bias", mod.bias
w_mins, w_maxs = [], []
with torch.no_grad():
    for _, w in iter_named_weight_tensors(model_fp, include_bias=True):
        w_mins.append(float(w.min().item()))
        w_maxs.append(float(w.max().item()))

min_weight = min(w_mins)
max_weight = max(w_maxs)
num_weight_levels = 24
if max_weight == min_weight:
    raise ValueError("All weights are identical; quantization step would be zero.")
print(f"Quant bounds: min_weight={min_weight:.6g}, max_weight={max_weight:.6g}, num_weight_levels={num_weight_levels}")

Quant bounds: min_weight=-0.7, max_weight=0.699776, num_weight_levels=24


In [None]:
def apply_qnoise(w: torch.Tensor, noise) -> torch.Tensor:
    # noise.forward(w) should return same shape tensor
    w_noisy = noise(w)
    return torch.clamp(w_noisy, min=min_weight, max=max_weight)

In [None]:
def _functional_call(module: nn.Module, params: dict, buffers: dict, x: torch.Tensor):
    # torch>=2.0: torch.func.functional_call
    try:
        from torch.func import functional_call as func_call
        return func_call(module, (params, buffers), (x,))
    except Exception:
        from torch.nn.utils.stateless import functional_call as stateless_call
        return stateless_call(module, params, (x,), buffers=buffers)

class DiffModel(nn.Module):
    """
    Wraps a base model, but runs forward with "virtual" modified weights:
      mode = "fp" | "quant" | "qnoise"
    """
    def __init__(self, base: nn.Module, mode: str, noise=None, include_bias: bool = True):
        super().__init__()
        assert mode in {"fp", "quant", "qnoise"}
        self.base = base
        self.mode = mode
        self.noise = noise
        self.include_bias = include_bias

        # Cache target parameter names for Conv/Linear
        self.target_param_names = set(name for name, _ in iter_named_weight_tensors(base, include_bias=include_bias))

    def forward(self, x):
        # Grab current params/buffers each call (safe even if base changes)
        params = dict(self.base.named_parameters())
        buffers = dict(self.base.named_buffers())

        # Build overridden params dict
        new_params = {}
        for name, p in params.items():
            if name in self.target_param_names:
                w = p
                if self.mode in {"quant", "qnoise"}:
                    w = maskW(w)
                if self.mode == "qnoise":
                    if self.noise is None:
                        raise ValueError("mode='qnoise' requires a noise object")
                    w = apply_qnoise(w, self.noise)
                new_params[name] = w
            else:
                new_params[name] = p

        return _functional_call(self.base, new_params, buffers, x)

In [None]:
diff_fp = DiffModel(model_fp, mode="fp").to(device).eval()
fp_acc, fp_loss = evaluate_diff_model(diff_fp, test_loader, INPUT_SCALE, False, device=device)
print(f"FP32    test acc: {fp_acc:.2f}% | loss: {fp_loss:.4f}")

In [None]:
# #num_weight_levels = 64
# diff_q = DiffModel(model_fp, mode="quant").to(device).eval()
# q_acc, q_loss = evaluate_diff_model(diff_q, test_loader, INPUT_SCALE, False, device=device)
# print(f"QUANT   test acc: {q_acc:.2f}% | loss: {q_loss:.4f}")

In [376]:
noise = SubthresholdPVTNoise()
noise.update_params()
diff_qn = DiffModel(model_fp, mode="qnoise", noise=noise).to(device).eval()
qn_acc, qn_loss = evaluate_diff_model(diff_qn, test_loader, INPUT_SCALE, False, device=device)
print(f"Q+NOISE test acc: {qn_acc:.2f}% | loss: {qn_loss:.4f}")


Q+NOISE test acc: 76.74% | loss: 1.4895


# Spiking Inference

In [None]:
import torch
import torch.nn.functional as F

class DeltaSigmaEncoder:
    """Deterministic rate-to-spikes: spike count over T closely matches rate*T even for small T."""
    def __init__(self):
        self.acc = None

    def reset(self, shape, device, dtype):
        self.acc = torch.zeros(shape, device=device, dtype=dtype)

    def step(self, rate01):
        # rate01 in [0,1]
        if self.acc is None or self.acc.shape != rate01.shape:
            self.reset(rate01.shape, rate01.device, rate01.dtype)
        self.acc += rate01
        spk = (self.acc >= 1.0).to(rate01.dtype)
        self.acc -= spk
        return spk

@torch.no_grad()
def eval_softlif_with_spike_outputs(
    soft_model,
    test_loader,
    T=12,                 # pick 6, 8, 12, 16
    temp=1.0,             # softmax temperature for spike-rate outputs
    device="cuda",
):
    """
    GUARANTEED same decision as soft_model if you use prob_rate (not noisy spikes) for argmax.
    Also returns a spike-rate representation of the output using deterministic spikes.
    """
    soft_model.eval()
    correct = 0
    total = 0

    enc = DeltaSigmaEncoder()

    for x, y in test_loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        logits = soft_model(x)  # EXACT trained computation
        probs = F.softmax(logits / temp, dim=1)  # in [0,1], sums to 1

        # --- output "spike rate" representation ---
        # We generate T-step spikes whose mean approximates probs (deterministic, low-variance)
        enc.reset(probs.shape, probs.device, probs.dtype)
        out_counts = torch.zeros_like(probs)

        for _ in range(T):
            out_spk = enc.step(probs)   # (B,C) spikes
            out_counts += out_spk

        out_rate = out_counts / float(T)  # "spike rate" per class

        # IMPORTANT:
        # If you want "guaranteed same as soft_model", use probs.argmax (or logits.argmax).
        # If you want "pure spiking decision", use out_rate.argmax (can differ for tiny T).
        pred = probs.argmax(dim=1)

        correct += (pred == y).sum().item()
        total += y.numel()

    return 100.0 * correct / total


In [None]:
acc = eval_softlif_with_spike_outputs(soft_model, test_loader, T=12, temp=1.0, device=device)
print("SoftLIF accuracy (guaranteed):", acc)


# Inference

In [None]:
# import os, glob, copy, re

# # ---- find and load the best checkpoint (highest stored test_acc) ----
# ckpt_dir = "checkpoints/snntorch/diffEncoding/epoch_59_acc_74.1100.pt"
# ckpt_paths = sorted(glob.glob(os.path.join(ckpt_dir, "*.pt")))
# if len(ckpt_paths) == 0:
#     raise FileNotFoundError(f"No checkpoints found in {ckpt_dir!r} (expected .pt files).")

# def _ckpt_score(path: str) -> float:
#     try:
#         obj = torch.load(path, map_location="cpu")
#         acc = obj.get("test_acc", None)
#         if isinstance(acc, (float, int)):
#             return float(acc)
#         if torch.is_tensor(acc):
#             return float(acc.item())
#     except Exception:
#         pass
#     # fallback: parse "..._acc_<float>.pt"
#     m = re.search(r"_acc_([0-9]*\.?[0-9]+)\.pt$", os.path.basename(path))
#     return float(m.group(1)) if m else -1.0

# best_ckpt_path = max(ckpt_paths, key=_ckpt_score)
# ckpt = torch.load(best_ckpt_path, map_location="cpu")

# print("Loaded checkpoint:", best_ckpt_path)
# print("Stored test_acc:", ckpt.get("test_acc", "N/A"))

In [None]:
best_ckpt_path = "checkpoints/snntorch/diffEncoding/epoch_182_acc_73.1100.pt"
ckpt = torch.load(best_ckpt_path, map_location="cpu")

In [None]:
import copy
model_fp =  WideResNet_SoftLIFRate(
        depth=16, widen_factor=10, dropout=0.0, num_classes=100,
        in_channels=6
    ).to(device)
model_fp.load_state_dict(ckpt["state_dict"], strict=True)
INPUT_SCALE = estimate_input_scale(train_loader, device=device, batches=50, q=0.99)
fp_loss, fp_acc = evaluate_diff_model(model_fp, test_loader, INPUT_SCALE, False, device=device)
print(f"FP32   test acc: {fp_acc:.2f}% | loss: {fp_loss:.4f}")

In [None]:
model_fp = WideResNet(depth=28, widen_factor=10, dropout=0.0, num_classes=100).to(device)
model_fp.load_state_dict(ckpt["state_dict"], strict=True)

fp_loss, fp_acc = evaluate(model_fp, test_loader)
print(f"FP32   test acc: {fp_acc*100:.2f}% | loss: {fp_loss:.4f}")

In [None]:
def _iter_weight_tensors(m: nn.Module):
    for mod in m.modules():
        if isinstance(mod, (nn.Conv2d, nn.Linear)):
            if getattr(mod, "weight", None) is not None:
                yield mod.weight
            if getattr(mod, "bias", None) is not None:
                yield mod.bias

In [None]:
w_mins, w_maxs = [], []
with torch.no_grad():
    for w in _iter_weight_tensors(model_fp):
        w_mins.append(float(w.min().item()))
        w_maxs.append(float(w.max().item()))

min_weight = min(w_mins)
max_weight = max(w_maxs)
num_w_levels = 64  # change if needed (e.g., 8, 32, 256)

if max_weight == min_weight:
    raise ValueError("All weights are identical; quantization step would be zero.")

print(f"Quant bounds: min_weight={min_weight:.6g}, max_weight={max_weight:.6g}, num_w_levels={num_w_levels}")

In [None]:
def maskW(w):
    w = torch.clamp(w, min=min_weight, max=max_weight)
    step = (max_weight - min_weight) / (num_w_levels - 1)
    return ((w - min_weight) / step).round() * step + min_weight

@torch.no_grad()
def apply_quantization_(m: nn.Module):
    for w in _iter_weight_tensors(m):
        w.copy_(maskW(w))

model_q = WideResNet(depth=28, widen_factor=10, dropout=0.0, num_classes=100).to(device)
model_q.load_state_dict(model_fp.state_dict(), strict=True)
apply_quantization_(model_q)

q_loss, q_acc = evaluate(model_q, test_loader)
print(f"QUANT  test acc: {q_acc*100:.2f}% | loss: {q_loss:.4f}")

In [None]:
import torch
import numpy as np

k: float = 1.38e-23   # Boltzmann constant (J/K)
q: float = 1.602176634e-19
T_r: float = 300
n: float = 1.5
k_mu: float = 1.5
k_vt: float = 1e-3 # Typical value for threshold voltage temperature dependence (in V/K)
Tr = 300

def thermal_voltage(T):
    return k * T / q

def Id_subthreshold(W, L, mu, Cox, Vth, VGS, VDS, T, n):
    Vt = thermal_voltage(T)
    Is0 = (W / L) * mu * Cox * (Vt**2) * np.exp(1.8)
    return Is0 * np.exp((VGS - Vth) / (n * Vt)) * (1 - np.exp(-VDS / Vt))

class SubthresholdPVTNoise:
    def __init__(self):
        """
        params: dict with
            W, L, mu, Cox, Vth, VGS, VDS, T, n
        sigmas: dict with
            W, L, mu, Cox, Vth, VDS (normal std)
            dT (uniform half-range)
        """
        #self.p = params
        self.W   = 44e-9
        self.L   = 22e-9
        self.mu  = 0.03
        self.Cox = 0.03
        self.Vth = 0.35
        self.VGS = 0.25
        self.VDS = 0.8
        self.T0  = 310
        self.n   = n
        self.dId_dW_value = 0
        self.dId_dL_value = 0
        self.dId_dmu_value = 0
        self.dId_dCox_value = 0
        self.dId_dVth_value = 0
        self.dId_dVDS_value = 0
        self.dId_dT_value = 0
        self.s = {
            "W":   0.10 * self.W,
            "L":   0.10 * self.L,
            "mu":  0.10 * self.mu,
            "Cox": 0.10 * self.Cox,
            "Vth": 0.10 * self.Vth,
            "VDS": 0.10 * self.VDS,
            "T_min": 248,
            "T_max": 398,
        }

    def dId_dW(self):
        Vt = thermal_voltage(self.T0)
        return (self.mu * self.Cox * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(self.VDS / Vt))

    def dId_dmu(self):
        Vt = thermal_voltage(self.T0)
        return (self.W * self.Cox * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(self.VDS / Vt))

    def dId_dCox(self):
        Vt = thermal_voltage(self.T0)
        return (self.W * self.mu * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(self.VDS / Vt))

    def dId_dL(self):
        Vt = thermal_voltage(self.T0)
        return -(self.W * self.mu * self.Cox * Vt**2 / self.L**2) * np.exp(1.8) * \
                np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(self.VDS / Vt))
    
    def dId_dVDS(self):
        Vt = thermal_voltage(self.T0)
        Is0 = (-1 / Vt) * (self.W / self.L) * self.mu * self.Cox * Vt**2 * np.exp(1.8)
        return Is0 * np.exp((self.VGS - self.Vth) / (self.n * Vt)) * \
            np.exp(self.VDS / Vt)
    
    def dId_dVth(self):
        Id = Id_subthreshold(self.W, self.L, self.mu, self.Cox, self.Vth, self.VGS, self.VDS, self.T0, self.n)
        Vt = thermal_voltage(self.T0)
        return (-1 / (self.n * Vt)) * Id * np.exp((self.VGS - self.Vth) / (self.n * Vt)) * \
            (1 - np.exp(self.VDS / Vt))

    def dId_dT(self):
        Vt = thermal_voltage(self.T0)
        Id = Id_subthreshold(self.W, self.L, self.mu, self.Cox, self.Vth, self.VGS, self.VDS, self.T0, self.n)

        term1 = k_mu - (( k_vt * self.T0) / (self.n * Vt)) * Id
        term2 = (1 - np.exp(self.VDS / Vt)) * Tr**(-k_mu) * self.T0**(k_mu - 1)
        term3 = np.exp((self.VGS - self.Vth * Tr - k_vt * (self.T0-Tr)) / (self.n * Vt))
        return Id * (term1 + term2 + term3)

    def update_params(self, **kwargs):
        self.p.update(kwargs)
        self.dId_dW_value = self.dId_dW()
        self.dId_dL_value = self.dId_dL()
        self.dId_dmu_value = self.dId_dmu()
        self.dId_dCox_value = self.dId_dCox()
        self.dId_dVth_value = self.dId_dVth()
        self.dId_dVDS_value = self.dId_dVDS()
        self.dId_dT_value = self.dId_dT()

    def forward(self, Id0):
        """
        Id0: torch.Tensor (any shape)
        returns: noisy Id tensor (same shape)
        """
        device, dtype = Id0.device, Id0.dtype

        # ---- sample parameter variations ----
        dW   = torch.normal(0.0, self.s["W"],   Id0.shape, device=device, dtype=dtype)
        dL   = torch.normal(0.0, self.s["L"],   Id0.shape, device=device, dtype=dtype)
        dmu  = torch.normal(0.0, self.s["mu"],  Id0.shape, device=device, dtype=dtype)
        dCox = torch.normal(0.0, self.s["Cox"], Id0.shape, device=device, dtype=dtype)
        dVth = torch.normal(0.0, self.s["Vth"], Id0.shape, device=device, dtype=dtype)
        dVDS = torch.normal(0.0, self.s["VDS"], Id0.shape, device=device, dtype=dtype)
        dT   = torch.empty_like(Id0).uniform_(self.s["T_min"], self.s["T_max"])

        # dW = 0
        # dL = 0
        # dmu = 0
        # dCox = 0    
        # dVth = 0
        # dVDS = 0
        # dT = 0

        # ---- analytical partial derivatives ----
        dId = (
            self.dId_dW_value           * dW   +
            self.dId_dL_value           * dL   +
            self.dId_dmu_value          * dmu  +
            self.dId_dCox_value         * dCox +
            self.dId_dVth_value         * dVth +
            self.dId_dVDS_value * dVDS +
            self.dId_dT_value * dT
        )

        return Id0 + dId


In [None]:
@torch.no_grad()
def apply_noise_(m: nn.Module, noise: SubthresholdPVTNoise):
    for w in _iter_weight_tensors(m):
        w_noisy = noise.forward(w)
        w.copy_(torch.clamp(w_noisy, min=min_weight, max=max_weight))

noise = SubthresholdPVTNoise()

model_qn = WideResNet(depth=28, widen_factor=10, dropout=0.0, num_classes=100).to(device)
model_qn.load_state_dict(model_fp.state_dict(), strict=True)
apply_noise_(model_qn, noise)

qn_loss, qn_acc = evaluate(model_qn, test_loader)
print(f"Q+NOISE test acc: {qn_acc*100:.2f}% | loss: {qn_loss:.4f}")

In [None]:
def add_retention_noise(model_weights, std_ret_high, std_ret_low, threshold = 0.06):
    """
    Adds noise to trained weights based on a threshold and clamps the weights within a given range.
    
    Args:
        model_weights (torch.Tensor): The tensor of trained weights.
        std_ret_high (float): The standard deviation for noise applied to weights >= threshold.
        std_ret_low (float): The standard deviation for noise applied to weights < threshold.
        threshold (float): The threshold to separate high and low noise application.
        min_weight (float): The minimum value to clamp weights to.
        max_weight (float): The maximum value to clamp weights to.
    
    Returns:
        torch.Tensor: The modified weights after noise addition and clamping.
    """
    # Create a mask for weights below the threshold
    mask_low = model_weights < threshold

    # Generate noise for weights below and above the threshold
    noise_low = torch.abs(torch.normal(0, std_ret_low, size=model_weights.size()))
    noise_high = torch.abs(torch.normal(0, std_ret_high, size=model_weights.size()))
    
    # Select the appropriate noise based on the mask
    ret_noise = torch.where(mask_low, noise_low, noise_high)
    
    return model_weights + ret_noise

def apply_quant_noise(model_weights, low0 = 0.24, high0 = 0.3):
    """
    Applies noise to the trained weights based on specified ranges, keeping them part of the computational graph.
    
    Args:
        model_weights (torch.Tensor): The trained weights to which noise will be applied.
        low0 (float): Lower bound for the first range of noise application.
        high0 (float): Upper bound for the first range of noise application.
        max_weight (float): Maximum weight for the second range of noise application.
    
    Returns:
        torch.Tensor: The updated weights with noise applied, part of the computational graph.
    """
    global max_weight

    # First range: [low0, high0]
    mask0 = (model_weights > low0) & (model_weights < high0)
    random_values0 = low0 + (high0 - low0) * torch.abs(torch.randn_like(model_weights))
    random_values0 = torch.where(mask0, random_values0, torch.zeros_like(model_weights))
    
    # Second range: [high0, max_weight + 0.01]
    low1, high1 = high0, max_weight + 0.01
    mask1 = (model_weights > low1) & (model_weights < high1)
    random_values1 = low1 + (high1 - low1) * torch.abs(torch.randn_like(model_weights))
    random_values1 = torch.where(mask1, random_values1, torch.zeros_like(model_weights))
    
    return model_weights + random_values0 + random_values1