# setup

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.amp import autocast, GradScaler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

KeyboardInterrupt: 

In [None]:
min_weight, max_weight = -0.7, 0.7
num_weight_levels = 24

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# dataset

In [None]:
train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # gives [0,1]
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

train_set = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
test_set  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

batch_size = 128
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False,
                          num_workers=2, pin_memory=True)

In [None]:
@torch.no_grad()
def estimate_input_scale(loader, device, batches=50, q=0.99, sample_per_batch=200_000):
    vals = []
    for i, (x, y) in enumerate(loader):
        if i >= batches:
            break
        x = x.to(device, non_blocking=True)

        v = x.abs().flatten()

        # NEW: cap how many elements you keep from each batch
        if v.numel() > sample_per_batch:
            idx = torch.randint(v.numel(), (sample_per_batch,), device=device)
            v = v[idx]

        vals.append(v)

    v = torch.cat(vals)
    return torch.quantile(v, q).item()

In [None]:
INPUT_SCALE_TRAIN = estimate_input_scale(train_loader, device=device, batches=50, q=0.99)

In [None]:
INPUT_SCALE_TEST = estimate_input_scale(test_loader, device=device, batches=50, q=0.99)

In [None]:
def diff_encode_func(x, input_scale=1.0, clamp01=False):
    """
    x: (B,C,H,W), can be negative (after Normalize)
    returns: (B,2*C,H,W) = [x_pos, x_neg] where x = x_pos - x_neg

    If clamp01=True -> enforce duty-cycle limit [0,1] (hardware-like) but introduces saturation.
    """
    x = x / float(input_scale)
    x_pos = torch.clamp(x, min=0.0)
    x_neg = torch.clamp(-x, min=0.0)
    if clamp01:
        x_pos = x_pos.clamp(0.0, 1.0)
        x_neg = x_neg.clamp(0.0, 1.0)
    return torch.cat([x_pos, x_neg], dim=1)

# Training

## Helper functions

In [None]:
@torch.no_grad()
def clamp_weights_(model: nn.Module, lo: float, hi: float, clamp_bias: bool = True):
    """
    Clamp weights (and optionally bias) in-place for any module that has
    Tensor attributes `weight` / `bias`. Leaves BatchNorm params alone.
    """
    for name, m in model.named_modules():
        # Skip BN entirely
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            continue

        w = getattr(m, "weight", None)
        if torch.is_tensor(w):
            w.clamp_(lo, hi)

def maskW(w: torch.Tensor) -> torch.Tensor:
    w = torch.clamp(w, min=min_weight, max=max_weight)
    step = (max_weight - min_weight) / (num_weight_levels - 1)
    return ((w - min_weight) / step).round() * step + min_weight

def maskW_ste(w: torch.Tensor) -> torch.Tensor:
    # Forward: quantized weights
    w_q = maskW(w)
    # Backward: pretend quantization was identity (straight-through estimator)
    return w + (w_q - w).detach()

In [None]:
k: float = 1.38e-23   # Boltzmann constant (J/K)
q: float = 1.602176634e-19
T_r: float = 300
n: float = 1.5
k_mu: float = 1.5
k_vt: float = 1e-3 # Typical value for threshold voltage temperature dependence (in V/K)
Tr = 300

def thermal_voltage(T):
    return k * T / q

def Id_subthreshold(W, L, mu, Cox, Vth, VGS, VDS, T, n):
    Vt = thermal_voltage(T)
    Is0 = (W / L) * mu * Cox * (Vt**2) * np.exp(1.8)
    return Is0 * np.exp((VGS - Vth) / (n * Vt)) * (1 - np.exp(-VDS / Vt))

class SubthresholdPVTNoise(nn.Module):
    def __init__(self):
        super().__init__()
        """
        params: dict with
            W, L, mu, Cox, Vth, VGS, VDS, T, n
        sigmas: dict with
            W, L, mu, Cox, Vth, VDS (normal std)
            dT (uniform half-range)
        """
        #self.p = params
        self.W   = 44e-9
        self.L   = 22e-9
        self.mu  = 0.03
        self.Cox = 0.03
        self.Vth = 0.35
        self.VGS = 0.25
        self.VDS = 0.8
        self.T0  = 310.0
        self.n   = 1.5
        self.dId_dW_value = 0.0
        self.dId_dL_value = 0.0
        self.dId_dmu_value = 0.0
        self.dId_dCox_value = 0.0
        self.dId_dVth_value = 0.0
        self.dId_dVDS_value = 0.0
        self.dId_dT_value = 0.0
        self.s = {
            "W":   0.10 * self.W,
            "L":   0.10 * self.L,
            "mu":  0.10 * self.mu,
            "Cox": 0.10 * self.Cox,
            "Vth": 0.10 * self.Vth,
            "VDS": 0.10 * self.VDS,
            "T_min": 248,
            "T_max": 398,
        }

    def dId_dW(self):
        Vt = thermal_voltage(self.T0)
        return (self.mu * self.Cox * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(-self.VDS / Vt))

    def dId_dmu(self):
        Vt = thermal_voltage(self.T0)
        return (self.W * self.Cox * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(-self.VDS / Vt))

    def dId_dCox(self):
        Vt = thermal_voltage(self.T0)
        return (self.W * self.mu * Vt**2 / self.L) * np.exp(1.8) * \
            np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(-self.VDS / Vt))

    def dId_dL(self):
        Vt = thermal_voltage(self.T0)
        return -(self.W * self.mu * self.Cox * Vt**2 / self.L**2) * np.exp(1.8) * \
                np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (1 - np.exp(-self.VDS / Vt))
    
    def dId_dVDS(self):
        Vt = thermal_voltage(self.T0)
        Is0 = (self.W / self.L) * self.mu * self.Cox * (Vt**2) * np.exp(1.8)
        return Is0 * np.exp((self.VGS - self.Vth) / (self.n * Vt)) * (np.exp(-self.VDS / Vt) / Vt)
    
    def dId_dVth(self):
        Id = Id_subthreshold(self.W, self.L, self.mu, self.Cox, self.Vth, self.VGS, self.VDS, self.T0, self.n)
        Vt = thermal_voltage(self.T0)
        return (-1 / (self.n * Vt)) * Id * np.exp((self.VGS - self.Vth) / (self.n * Vt)) * \
            (1 - np.exp(-self.VDS / Vt))

    def dId_dT(self):
        Vt = thermal_voltage(self.T0)
        Id = Id_subthreshold(self.W, self.L, self.mu, self.Cox, self.Vth, self.VGS, self.VDS, self.T0, self.n)

        term1 = k_mu - (( k_vt * self.T0) / (self.n * Vt)) * Id
        term2 = (1 - np.exp(-self.VDS / Vt)) * Tr**(-k_mu) * self.T0**(k_mu - 1)
        term3 = np.exp((self.VGS - self.Vth * Tr - k_vt * (self.T0-Tr)) / (self.n * Vt))
        return Id * (term1 + term2 + term3)

    def update_params(self):
        self.dId_dW_value = self.dId_dW()
        self.dId_dL_value = self.dId_dL()
        self.dId_dmu_value = self.dId_dmu()
        self.dId_dCox_value = self.dId_dCox()
        self.dId_dVth_value = self.dId_dVth()
        self.dId_dVDS_value = self.dId_dVDS()
        self.dId_dT_value = self.dId_dT()

    def forward(self, Id0):
        """
        Id0: torch.Tensor (any shape)
        returns: noisy Id tensor (same shape)
        """
        device, dtype = Id0.device, Id0.dtype
        self.update_params()

        # ---- sample parameter variations ----
        dW   = torch.normal(0.0, self.s["W"],   size=Id0.shape, device=device, dtype=dtype)
        dL   = torch.normal(0.0, self.s["L"],   size=Id0.shape, device=device, dtype=dtype)
        dmu  = torch.normal(0.0, self.s["mu"],  size=Id0.shape, device=device, dtype=dtype)
        dCox = torch.normal(0.0, self.s["Cox"], size=Id0.shape, device=device, dtype=dtype)
        dVth = torch.normal(0.0, self.s["Vth"], size=Id0.shape, device=device, dtype=dtype)
        dVDS = torch.normal(0.0, self.s["VDS"], size=Id0.shape, device=device, dtype=dtype)
        dT   = torch.empty_like(Id0).uniform_(self.s["T_min"], self.s["T_max"])
        dT  = dT - self.T0

        # dW = 0
        # dL = 0
        # dmu = 0
        # dCox = 0    
        # dVth = 0
        # dVDS = 0
        # dT = 0

        # ---- analytical partial derivatives ----
        dId = (
            self.dId_dW_value           * dW   +
            self.dId_dL_value           * dL   +
            self.dId_dmu_value          * dmu  +
            self.dId_dCox_value         * dCox +
            self.dId_dVth_value         * dVth +
            self.dId_dVDS_value * dVDS +
            self.dId_dT_value * dT
        )

        return Id0 + dId


## Network

### Conv Layer

In [None]:
class QATConv2d_no_other_noise(nn.Conv2d):
    def forward(self, x):
        w = maskW_ste(self.weight) if self.training else maskW(self.weight)
        return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)

class QATLinear_no_other_noise(nn.Linear):
    def forward(self, x):
        w = maskW_ste(self.weight) if self.training else maskW(self.weight)
        return F.linear(x, w, self.bias)

In [None]:
# class QATConv2d(nn.Conv2d):
#     def __init__(self, *args, noise=None, **kwargs):
#         super().__init__(*args, **kwargs)
#         # Register as a submodule so it moves to GPU + appears in state_dict
#         self.noise = noise if noise is not None else SubthresholdPVTNoise()
#         self.noise.update_params()

#     def forward(self, x):
#         w = maskW_ste(self.weight) if self.training else maskW(self.weight)

#         # Add noise in forward pass (typically only during training)
#         if self.training:
#             w = self.noise(w)   # calls SubthresholdPVTNoise.forward(w)

#         return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)


# class QATLinear(nn.Linear):
#     def __init__(self, *args, noise=None, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.noise = noise if noise is not None else SubthresholdPVTNoise()
#         self.noise.update_params()

#     def forward(self, x):
#         w = maskW_ste(self.weight) if self.training else maskW(self.weight)
#         if self.training:
#             w = self.noise(w)
#         return F.linear(x, w, self.bias)

### neuron

In [None]:
class SoftLIFRate(nn.Module):
    def __init__(self, theta=0.0, gain=1.0, tau_rc=0.02, tau_ref=0.002, eps=1e-4):
        super().__init__()
        self.theta = nn.Parameter(torch.tensor(float(theta)))
        self.gain  = nn.Parameter(torch.tensor(float(gain)))
        self.tau_rc  = float(tau_rc)
        self.tau_ref = float(tau_ref)
        self.eps = float(eps)

    def forward(self, z):
        # Force the critical math to FP32 to avoid AMP/FP16 inf->nan gradients
        z32 = z.float()
        u = (self.gain.float()) * (z32 - self.theta.float())

        # J = 1 + softplus(u) >= 1
        Jm1 = F.softplus(u)                       # this is (J-1)
        Jm1 = torch.clamp(Jm1, min=self.eps)      # prevents divide-by-0 / inf grads

        # log(1 + 1/(J-1))
        denom = self.tau_ref + self.tau_rc * torch.log1p(1.0 / Jm1)
        rate  = 1.0 / denom

        r = rate * self.tau_ref                   # normalize to [0,1]
        r = torch.clamp(r, 0.0, 1.0)

        return r.to(dtype=z.dtype)                # restore original dtype

### ResNet blocks

In [None]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_ch, out_ch, stride=1, act = None):
        super().__init__()
        self.conv1 = QATConv2d_no_other_noise(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.act1 = act if act is not None else SoftLIFRate()
        self.conv2 = QATConv2d_no_other_noise(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.shortcut = nn.Identity()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch),
            )

        self.act2 = act if act is not None else SoftLIFRate()

    def forward(self, x):
        # "Neuron type = ReLU": apply after summed currents + BN
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = out + self.shortcut(x)
        out = self.act2(out)
        return out


class ResNetCIFAR(nn.Module):
    def __init__(self, num_classes=10, in_channels=64, diff_encode=False, layers=(2, 2, 2, 2), act_theta=0.0, act_gain=1.0, tau_rc=0.02, tau_ref=0.002):
        super().__init__()

        stem_in = 6 if diff_encode else 3
        self.in_ch = in_channels

        def make_act():
            return SoftLIFRate(theta=act_theta, gain=act_gain, tau_rc=tau_rc, tau_ref=tau_ref)

        self.act_out = make_act()
        
        self.stem = nn.Sequential(
            QATConv2d_no_other_noise(stem_in, in_channels, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(in_channels),
            make_act(),
        )

        b1, b2, b3, b4 = layers
        # ResNet18 layers: [2,2,2,2] blocks with channel sizes [in_channels,128,256,512]
        self.layer1 = self._make_layer(in_channels, b1, stride=1, make_act=make_act)
        self.layer2 = self._make_layer(128, b2, stride=2, make_act=make_act)
        self.layer3 = self._make_layer(256, b3, stride=2, make_act=make_act)
        self.layer4 = self._make_layer(512, b4, stride=2, make_act=make_act)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = QATLinear_no_other_noise(512, num_classes, bias=True)

    def _make_layer(self, out_ch, num_blocks, stride, make_act):
        strides = [stride] + [1]*(num_blocks-1)
        blocks = []
        for s in strides:
            blocks.append(BasicBlock(self.in_ch, out_ch, stride=s, act=make_act()))
            self.in_ch = out_ch
        return nn.Sequential(*blocks)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.pool(x).flatten(1)
        x = self.fc(x)
        return x

## learning

In [None]:
model = ResNetCIFAR(num_classes=10, in_channels=64, diff_encode=True, layers=(1, 1, 1, 2)).to(device)
criterion = nn.CrossEntropyLoss()

# SGD + momentum is the default strong baseline for CIFAR ResNets
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4, nesterov=True)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

use_amp = (device.type == "cuda")
scaler = GradScaler(device=device, enabled=use_amp)

@torch.no_grad()
def evaluate(model, loader, diff_encode=True):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        if diff_encode:
            x = diff_encode_func(x, input_scale=INPUT_SCALE_TEST, clamp01=False)
        logits = model(x)
        loss = criterion(logits, y)
        loss_sum += float(loss) * x.size(0)
        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum())
        total += y.numel()
    return loss_sum / total, correct / total

def train_one_epoch(model, loader, diff_encode=True):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        if diff_encode:
            x = diff_encode_func(x, input_scale=INPUT_SCALE_TRAIN, clamp01=False)
        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type=device.type, enabled=use_amp):
            logits = model(x)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        clamp_weights_(model, min_weight, max_weight)
        scaler.update()

        running_loss += float(loss) * x.size(0)
        pred = logits.argmax(dim=1)
        correct += int((pred == y).sum())
        total += y.numel()

    return running_loss / total, correct / total

In [None]:
epochs = 200  # this is the usual CIFAR recipe; should exceed 80% well before the end
best_acc = 0.0

In [None]:
ckpt =  torch.load("checkpoints/cifar10/epoch_71_acc_0.9431.pt")
model.load_state_dict(ckpt["state_dict"], strict=True)
fp_loss, fp_acc = evaluate(model, test_loader, True)
print(f"FP32   test acc: {(fp_acc * 100):.2f}% | loss: {fp_loss:.4f}")

In [None]:
for ep in range(1, epochs + 1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader)
    te_loss, te_acc = evaluate(model, test_loader)
    scheduler.step()

    best_acc = max(best_acc, te_acc)
    print(f"Epoch {ep:3d} | train loss {tr_loss:.4f} acc {tr_acc*100:5.2f}%"
            f" | test loss {te_loss:.4f} acc {te_acc*100:5.2f}% | best {best_acc*100:5.2f}%")
    torch.save({"state_dict": model.state_dict(), "test_acc": te_acc}, f"checkpoints/cifar10/epoch_{ep}_acc_{te_acc:.4f}.pt")

print("Done. Best test accuracy:", best_acc * 100, "%")

Epoch   1 | train loss 1.6795 acc 36.36% | test loss 2.0001 acc 37.47% | best 37.47%
Epoch   2 | train loss 1.1369 acc 59.05% | test loss 1.5113 acc 43.88% | best 43.88%
Epoch   3 | train loss 0.8740 acc 69.19% | test loss 1.3209 acc 56.11% | best 56.11%
Epoch   4 | train loss 0.7356 acc 74.43% | test loss 2.5614 acc 37.54% | best 56.11%
Epoch   5 | train loss 0.6690 acc 76.88% | test loss 0.9293 acc 68.08% | best 68.08%
Epoch   6 | train loss 0.6141 acc 78.57% | test loss 3.1037 acc 37.79% | best 68.08%
Epoch   7 | train loss 0.5840 acc 80.01% | test loss 1.3563 acc 57.27% | best 68.08%
Epoch   8 | train loss 0.5560 acc 80.89% | test loss 0.7251 acc 75.49% | best 75.49%
Epoch   9 | train loss 0.5337 acc 81.71% | test loss 0.9112 acc 71.68% | best 75.49%
Epoch  10 | train loss 0.5168 acc 82.24% | test loss 1.8989 acc 52.99% | best 75.49%
Epoch  11 | train loss 0.5016 acc 82.92% | test loss 0.5686 acc 80.72% | best 80.72%
Epoch  12 | train loss 0.4919 acc 83.09% | test loss 1.2428 acc 6

# eval

In [None]:
fp_loss, fp_acc = evaluate(model, test_loader, True)
print(f"FP32   test acc: {(fp_acc * 100):.2f}% | loss: {fp_loss:.4f}")

FP32   test acc: 93.76% | loss: 0.2177


In [None]:
# ckpt =  torch.load("")
# model_fp = ResNetCIFAR(num_classes=10, in_channels=64, diff_encode=True, layers=(1, 1, 1, 2)).to(device)
# model_fp.load_state_dict(ckpt["state_dict"], strict=True)
# fp_loss, fp_acc = evaluate(model_fp, test_loader, True)
# print(f"FP32   test acc: {(fp_acc * 100):.2f}% | loss: {fp_loss:.4f}")

In [None]:
def iter_named_weight_tensors(model: nn.Module, include_bias: bool = False):
    # Only "synaptic" layers
    for mod_name, mod in model.named_modules():
        if isinstance(mod, (nn.Conv2d, nn.Linear)):
            w = getattr(mod, "weight", None)
            if w is not None:
                yield f"{mod_name}.weight", w
            if include_bias:
                b = getattr(mod, "bias", None)
                if b is not None:
                    yield f"{mod_name}.bias", b

In [None]:
w_mins, w_maxs = [], []
names = []

with torch.no_grad():
    for name, w in iter_named_weight_tensors(model, include_bias=False):
        names.append(name)
        w_mins.append(w.detach().min().item())
        w_maxs.append(w.detach().max().item())

print("tensors scanned:", len(names))
min_w = float(min(w_mins))
max_w = float(max(w_maxs))

print(f"Quant bounds: min={min_w:.6g}, max={max_w:.6g}")

tensors scanned: 15
Quant bounds: min=-0.441908, max=0.504627


In [None]:
def apply_qnoise(w: torch.Tensor, noise) -> torch.Tensor:
    # noise.forward(w) should return same shape tensor
    w_noisy = noise(w)
    return torch.clamp(w_noisy, min=min_weight, max=max_weight)

In [None]:
def _functional_call(module: nn.Module, params: dict, buffers: dict, x: torch.Tensor):
    # Preferred (PyTorch 2.x): torch.func.functional_call
    from torch.func import functional_call

    state = {**params, **buffers}
    return functional_call(module, state, (x,))

class DiffModel(nn.Module):
    """
    Wraps a base model, but runs forward with "virtual" modified weights:
      mode = "fp" | "quant" | "qnoise"
    """
    def __init__(self, base: nn.Module, mode: str, noise=None, include_bias: bool = True):
        super().__init__()
        assert mode in {"fp", "quant", "qnoise"}
        self.base = base
        self.mode = mode
        self.noise = noise
        self.include_bias = include_bias

        # Cache target parameter names for Conv/Linear
        self.target_param_names = set(name for name, _ in iter_named_weight_tensors(base, include_bias=include_bias))

    def forward(self, x):
        # Grab current params/buffers each call (safe even if base changes)
        params = dict(self.base.named_parameters())
        buffers = dict(self.base.named_buffers())

        # Build overridden params dict
        new_params = {}
        for name, p in params.items():
            if name in self.target_param_names:
                w = p
                if self.mode == "quant":
                    w = maskW(w)
                if self.mode == "qnoise":
                    if self.noise is None:
                        raise ValueError("mode='qnoise' requires a noise object")
                    w = apply_qnoise(w, self.noise)
                new_params[name] = w
            else:
                new_params[name] = p

        return _functional_call(self.base, new_params, buffers, x)

In [None]:
te_loss, te_acc = evaluate(model, test_loader, True)

In [None]:
te_acc * 100

93.76

In [None]:
diff_fp = DiffModel(model, mode="fp").to(device).eval()
fp_loss, fp_acc= evaluate(diff_fp, test_loader, True)
print(f"FP32    test acc: {(fp_acc * 100):.2f}% | loss: {fp_loss:.4f}")

In [None]:
diff_q = DiffModel(model, mode="quant").to(device).eval()
q_loss, q_acc = evaluate(diff_q, test_loader, True)
print(f"QUANT   test acc: {(q_acc * 100):.2f}% | loss: {q_loss:.4f}")

QUANT   test acc: 70.74% | loss: 1.0473


In [None]:
noise = SubthresholdPVTNoise()
noise.update_params()
diff_qn = DiffModel(model, mode="qnoise", noise=noise).to(device).eval()
qn_loss, qn_acc = evaluate(diff_qn, test_loader, True)
print(f"Q+NOISE test acc: {(qn_acc * 100):.2f}% | loss: {qn_loss:.4f}")

Q+NOISE test acc: 10.03% | loss: 2.7797
