For this network configure the constants as follows:
- population_code = False
- min_weight = - 0.386
- max_weight = 0.386
- image_size = 28
- num_classes = 10

In [None]:
import torch, torch.nn as nn
import torch.nn.functional as F
from snntorch import functional as SF
from snntorch import spikegen
from torchvision import datasets
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import random
import time

In [None]:
from TVLSI26.datsets.mnist_dataset import MNISTdata
from TVLSI26.configs.config import modelConstants
from TVLSI26.neuron_models.digLIF import digLIF, Square
from TVLSI26.ctt_weights.weight_variations import maskW, WeightDropoutLinear, apply_quant_noise, add_retention_noise

In [None]:
seed = int(time.time() * 1000) ^ random.getrandbits(32)
counter_global = 0
data_path='/home/zmoham13/pydata'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float

# Network and Learning

In [None]:
batch_size = 64

mnist_dataset = MNISTdata(data_path=modelConstants.data_path, FULL_MNIST=True, image_size=modelConstants.image_size)
train_loader = DataLoader(mnist_dataset.get_train_data(), batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_dataset.get_test_data(), batch_size=batch_size, shuffle=False, drop_last=True)

data = iter(train_loader)
data_it, targets_it = next(data)

for dataset_name, dataset in zip(["mnist_train", "mnist_test"], (mnist_dataset.get_train_data(), mnist_dataset.get_test_data())):
    print(f"Number of images in {dataset_name}: {len(dataset)}")

In [None]:
class NNfullyConnected(nn.Module):
    def __init__(self,
                 num_inputs=28*28,
                 h1=256,
                 h2=128,
                 num_outputs=10):
        super().__init__()

        self.num_inputs  = num_inputs
        self.h1          = h1
        self.h2          = h2
        self.num_outputs = num_outputs

        # 784 -> 512 -> 256 -> 10
        self.fc1 = WeightDropoutLinear(self.num_inputs,  self.h1,          bias=False)
        self.fc2 = WeightDropoutLinear(self.h1,          self.h2,          bias=False)
        self.fc3 = WeightDropoutLinear(self.h2,          self.num_outputs, bias=False)

        self.lif1 = digLIF(beta=1.0, reset_mechanism="zero")
        self.lif2 = digLIF(beta=1.0, reset_mechanism="zero")
        self.lif3 = digLIF(beta=1.0, reset_mechanism="zero")

        self.initialize_weights()

    def initialize_weights(self):
        # keep centered in your tiny range
        mean = (modelConstants.max_weight + modelConstants.min_weight) / 2.0
        std  = (modelConstants.max_weight - modelConstants.min_weight) / 8.0

        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.normal_(layer.linear.weight, mean=mean, std=std)
            with torch.no_grad():
                layer.linear.weight.clamp_(modelConstants.min_weight, modelConstants.max_weight)

    def _noisy_weight(self, base_weight):
        w = base_weight
        if self.training and modelConstants.training_noise_ctt:
            if modelConstants.retention_noise:
                w = add_retention_noise(w, modelConstants.std_ret_high, modelConstants.std_ret_low)
            if modelConstants.quantization_noise:
                w = apply_quant_noise(w)
        return w

    def forward(self, x):
        """
        x: [T, B, 1, 28, 28] or [T, B, 28, 28] or [T, B, 784]

        returns:
            spk_rec_3: [T, B, 10]
            mem_rec_3: [T, B, 10]
        """

        T = x.shape[0]

        # Flatten spatial dims
        if x.dim() == 5:          # [T, B, C, H, W]
            x_flat = x.view(T, x.shape[1], -1)
        elif x.dim() == 4:        # [T, B, H, W]
            x_flat = x.view(T, x.shape[1], -1)
        elif x.dim() == 3:        # [T, B, F]
            x_flat = x
        else:
            raise ValueError(f"Unexpected input shape {x.shape}")

        # init membrane
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # sample weight noise for this forward
        self.nW1 = torch.normal(torch.zeros_like(self.fc1.linear.weight), modelConstants.std_ctt)
        self.nW2 = torch.normal(torch.zeros_like(self.fc2.linear.weight), modelConstants.std_ctt)
        self.nW3 = torch.normal(torch.zeros_like(self.fc3.linear.weight), modelConstants.std_ctt)

        # effective weights
        noisy_w1 = self._noisy_weight(self.fc1.linear.weight)
        noisy_w2 = self._noisy_weight(self.fc2.linear.weight)
        noisy_w3 = self._noisy_weight(self.fc3.linear.weight)

        spk_rec_3, mem_rec_3 = [], []

        for step in range(T):
            spk0 = x_flat[step]                      # [B, 784]

            # layer 1
            cur1 = Square.apply(spk0, noisy_w1)
            with torch.no_grad():
                noise1 = nn.functional.linear(spk0, self.nW1)
            if self.training and modelConstants.general_noise:
                cur1 = cur1 + noise1
            spk1, mem1 = self.lif1(cur1, mem1)

            # layer 2
            cur2 = Square.apply(spk1, noisy_w2)
            with torch.no_grad():
                noise2 = nn.functional.linear(spk1, self.nW2)
            if self.training and modelConstants.general_noise:
                cur2 = cur2 + noise2
            spk2, mem2 = self.lif2(cur2, mem2)

            # layer 3 (output)
            cur3 = Square.apply(spk2, noisy_w3)
            with torch.no_grad():
                noise3 = nn.functional.linear(spk2, self.nW3)
            if self.training and modelConstants.general_noise:
                cur3 = cur3 + noise3
            spk3, mem3 = self.lif3(cur3, mem3)

            spk_rec_3.append(spk3)
            mem_rec_3.append(mem3)

        spk_rec_3 = torch.stack(spk_rec_3, dim=0)   # [T, B, 10]
        mem_rec_3 = torch.stack(mem_rec_3, dim=0)   # [T, B, 10]

        return spk_rec_3, mem_rec_3


In [None]:
torch.manual_seed(seed)
net = NNfullyConnected(28*28, 256, 256, 10).to(device)
with torch.no_grad():
    for layer in [net.fc1, net.fc2, net.fc3]:
        layer.linear.weight.copy_(
            torch.rand_like(layer.linear.weight) * (modelConstants.max_weight - modelConstants.min_weight) + modelConstants.min_weight
        )

for lif in [net.lif1, net.lif2, net.lif3]:
    lif.threshold.data = (lif.threshold * modelConstants.Threshold_voltage).to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

ce = nn.CrossEntropyLoss()

In [None]:
n_params = sum(p.numel() for p in net.parameters())
print(f"Total parameters: {n_params}")

In [None]:
import time
seed = int(time.time() * 1000) ^ random.getrandbits(32)
torch.manual_seed(seed)
counter = 0

def train_and_evaluate(loss_fn = ce, num_epochs = 5, optimizer = optimizer, scheduler=scheduler, train_loader=train_loader, test_loader=test_loader):
    epoch_times = []

    for epoch in range(num_epochs):
        train_batch = iter(train_loader)

        start_time = time.time()
        for data, targets in train_batch:
            data = data.to(device)
            data = spikegen.latency(data, num_steps=modelConstants.num_steps, threshold=0.01, clip=True, first_spike_time=0, linear=True, normalize=True).cumsum(0)
            targets = targets.to(device)

            net.train()
            spk_rec, _ = net(data)

            loss_val = torch.zeros((1,), dtype=dtype, device=device)

            logits = spk_rec.sum(dim=0)
            loss_val = loss_fn(logits, targets)
                
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            
            with torch.no_grad():
                for layer in [net.fc1, net.fc2, net.fc3]:
                    layer.linear.weight.clamp_(modelConstants.min_weight, modelConstants.max_weight)

        end_time = time.time()
        epoch_times.append(end_time - start_time)
        print("epoch ", epoch, " loss: ", loss_val)
        scheduler.step()

        with torch.no_grad():
            net.eval()
            total = 0
            correct = 0

            for data, targets in test_loader:
                data = data.to(device)
                data = spikegen.latency(data, num_steps=modelConstants.num_steps, threshold=0.01, clip=True, first_spike_time=0, linear=True, normalize=True).cumsum(0)
                targets = targets.to(device)
                spk_rec, _ = net(data)

                logits = spk_rec.sum(dim=0)
                predicted = logits.argmax(dim=-1)
                correct += (predicted == targets).sum().item()

                total += targets.size(0)

            print(f"Total correctly classified test set images: {correct}/{total}")
            print(f"Test Set Accuracy: {100 * correct / total:.2f}%\n")
            for param_group in optimizer.param_groups:
                current_lr = param_group['lr']
                print(f"Current learning rate: {current_lr}")
            '''with open("model_accuracy.txt", "a") as f:
                f.write(f"{epoch}: {100 * correct / total:.2f}%\n")'''
    print("average epoch time: ", np.mean(epoch_times))
    return 100 *(correct/total)

In [None]:
best_accuracy = 0
best_params = None
worst_accuracy = 100 

thr_values = [0.4]

hyperparams = [{"thr": thr} for thr in thr_values]
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=2,          # first cycle length
    T_mult=1,       # cycle length stays constant
    eta_min=1e-5    # small floor, not zero
)

for params in hyperparams:
    V_th = params["thr"]
    net.lif1.threshold.data = torch.tensor(V_th)
    accuracy = train_and_evaluate(loss_fn=ce, num_epochs = 30, optimizer = optimizer, scheduler=scheduler)
    print(f"Accuracy with params {params}: {accuracy}")

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_params = params
        best_weights = net.fc1.linear.weight.clone()
    if accuracy < worst_accuracy:
        worst_accuracy = accuracy
        worst_params = params
        
    print(f"Best parameters: {best_params} with accuracy: {best_accuracy}")

In [None]:
@torch.no_grad()
def clone_model_fp(model_fp: nn.Module) -> nn.Module:
    """
    Rebuilds NNfullyConnected using the SAME hyperparams stored in the instance,
    then loads state_dict. This avoids deepcopy() issues.
    """
    device = next(model_fp.parameters()).device

    # Your NN stores these attributes in __init__ :contentReference[oaicite:2]{index=2}
    model_copy = NNfullyConnected(
        num_inputs=model_fp.num_inputs,
        h1=model_fp.h1,
        h2=model_fp.h2,
        num_outputs=model_fp.num_outputs,
    ).to(device)

    model_copy.load_state_dict(model_fp.state_dict(), strict=True)
    model_copy.eval()
    return model_copy

# Quantization copy (no deepcopy)
def quantize_tensor_uniform(w: torch.Tensor, lo: float, hi: float, levels: int) -> torch.Tensor:
    w = w.clamp(lo, hi)
    step = (hi - lo) / (levels - 1)
    return torch.round((w - lo) / step) * step + lo


@torch.no_grad()
def apply_quantization_copy(model_fp: nn.Module, lo: float, hi: float, levels: int, include_bias: bool = True):
    """
    Returns a quantized COPY of model_fp (model_fp is untouched).
    Implemented WITHOUT deepcopy() to avoid:
    RuntimeError: Only Tensors created explicitly by the user support deepcopy...
    """
    model_q = clone_model_fp(model_fp)

    # Your weights live in WeightDropoutLinear.linear.weight :contentReference[oaicite:3]{index=3}
    for layer in [model_q.fc1, model_q.fc2, model_q.fc3]:
        layer.linear.weight.copy_(quantize_tensor_uniform(layer.linear.weight, lo=lo, hi=hi, levels=levels))
        if include_bias and layer.linear.bias is not None:
            layer.linear.bias.copy_(quantize_tensor_uniform(layer.linear.bias, lo=lo, hi=hi, levels=levels))

    return model_q

# Noise-only: apply SubthresholdPVTNoise to weights
@torch.no_grad()
def apply_subthreshold_noise_copy(model_fp: nn.Module, noise_module: nn.Module):
    """
    Copies model_fp, applies your SubthresholdPVTNoise() to the synaptic weights only,
    clamps to [min_weight, max_weight], returns noisy model.
    """
    model_n = clone_model_fp(model_fp)

    # Apply noise to *weights only* (fc1/fc2/fc3)
    for layer in [model_n.fc1, model_n.fc2, model_n.fc3]:
        w = layer.linear.weight
        w_noisy = noise_module(w)          # <-- THIS is your noise function call
        w.copy_(w_noisy)
        w.clamp_(modelConstants.min_weight, modelConstants.max_weight)   # keep same hardware range

    model_n.eval()
    return model_n


# One-call: noisy weights inference (original net untouched)
@torch.no_grad()
def evaluate_with_noisy_weights(net_fp: nn.Module, test_loader, noise_module: nn.Module):
    """
    Makes a noisy-weight copy and runs your existing infer_accuracy_snn().
    """
    net_noisy = apply_subthreshold_noise_copy(net_fp, noise_module)
    noisy_loss, noisy_acc = infer_accuracy_snn(net_noisy, test_loader)
    return noisy_loss, noisy_acc


noise = SubthresholdPVTNoise().to(device)   # if your class is nn.Module
n_loss, n_acc = evaluate_with_noisy_weights(net, test_loader, noise)
print(f"NOISY  test acc: {n_acc*100:.2f}% | loss: {n_loss:.4f}")


NOISY  test acc: 97.51% | loss: 0.0921


In [None]:
import copy
import torch.nn as nn

@torch.no_grad()
def quantize_model_copy(model_fp: nn.Module,
                        lo: float,
                        hi: float,
                        levels: int,
                        include_bias: bool = True) -> nn.Module:
    """
    Returns a quantized deep-copy of model_fp. model_fp is NOT modified.
    Quantizes nn.Linear weights (and optional bias). Skips BatchNorm params.
    """
    model_q = clone_model_fp(model_fp).to(device)
    model_q.eval()

    step = (hi - lo) / (levels - 1)

    for mod in model_q.modules():
        # only quantize synapses; your synapses are inside WeightDropoutLinear.linear (nn.Linear)
        if isinstance(mod, nn.Linear):
            # weight
            w = mod.weight
            w.clamp_(lo, hi)
            w.copy_(maskW(w))

            # optional bias
            if include_bias and (mod.bias is not None):
                b = mod.bias
                b.clamp_(lo, hi)
                b.copy_(torch.round((b - lo) / step) * step + lo)

    return model_q


@torch.no_grad()
def infer_accuracy_snn(model: nn.Module, loader, population_code: bool = False):
    model.eval()
    total, correct = 0, 0

    for data, targets in loader:
        data = data.to(device)
        targets = targets.to(device)

        # match your preprocessing
        data = spikegen.latency(
            data, num_steps=modelConstants.num_steps, threshold=0.01, clip=True,
            first_spike_time=0, linear=True, normalize=True
        ).cumsum(0)

        spk_rec, _ = model(data)

        logits = spk_rec.sum(dim=0)          # [B, 10]
        pred = logits.argmax(dim=-1)         # [B]
        correct += (pred == targets).sum().item()
        total += targets.size(0)

    return 100.0 * float(correct) / float(total)


# --- Build quantized copy + run inference ---
levels = modelConstants.num_w_levels
net_q = quantize_model_copy(net, lo=modelConstants.min_weight, hi=modelConstants.max_weight, levels=levels, include_bias=True)

acc_fp = infer_accuracy_snn(net, test_loader)
acc_q  = infer_accuracy_snn(net_q, test_loader)

print(f"FP32 test acc:  {acc_fp:.2f}%")
print(f"QUANT test acc: {acc_q:.2f}% (levels={levels}, range=[{modelConstants.min_weight},{modelConstants.max_weight}])")

# Optional sanity check: prove the original net weights didn't change
# (prints max absolute diff for fc1 weights; should be 0)
with torch.no_grad():
    diff = (net.fc1.linear.weight - net_q.fc1.linear.weight).abs().max().item()
print(f"Sanity: max|fc1_fp - fc1_quant| = {diff:.6f}")

FP32 test acc:  97.49%
QUANT test acc: 97.64% (levels=24, range=[-0.386,0.386])
Sanity: max|fc1_fp - fc1_quant| = 0.016782


In [None]:
import optuna
def optuna_objective(trial):
    global net, optimizer, scheduler, num_steps

    h1 = trial.suggest_int("h1", 256, 512, step=256)
    h2 = trial.suggest_int("h2", 128, 256, step=128)

    thr_scale = trial.suggest_float("thr_scale", 0.2, 0.6, step=0.1)
    num_steps = trial.suggest_int("num_steps", 10, 40, step=10)

    #lr = trial.suggest_int("lr", 10e-5, 10e-2)
    # number of epochs per trial (keep small for speed)
    num_epochs = 4
    #thr_scale = 0.3

    net = NNfullyConnected(
        num_inputs=28*28,
        h1=h1,
        h2=h2,
        num_outputs=modelConstants.num_classes,
    ).to(device)

    for lif in [net.lif1, net.lif2, net.lif3]:
        lif.threshold = (lif.threshold * thr_scale * modelConstants.Threshold_voltage).to(device)

    # optimizer + scheduler
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=2,          # first cycle length
        T_mult=1,       # cycle length stays constant
        eta_min=1e-5    # small floor, not zero
    )
#     scheduler = torch.optim.lr_scheduler.CyclicLR(
#     optimizer,
#     base_lr=1e-4,
#     max_lr=1e-3,
#     step_size_up=1000,   # batches
#     mode='triangular2',
#     cycle_momentum=False
# )


    # ensure all noise flags are OFF for this phase
    # (you said noise is disabled; keep it that way for BO)
    # training_noise_ctt = False
    # general_noise = False
    # retention_noise = False
    # quantization_noise = False

    # ---------- Train & evaluate ----------
    # population_code=False so we use CE on spike counts
    acc = train_and_evaluate(
        loss_fn=ce,
        num_epochs=num_epochs,
        optimizer=optimizer,
        scheduler=scheduler,
    )

    return acc


In [None]:
study = optuna.create_study(direction="maximize")
study.optimize(optuna_objective, n_trials=20)  # start with 20; increase if useful

print("Best value (accuracy):", study.best_value)
print("Best params:", study.best_params)

In [None]:
for t in study.trials:
    print(t.number, t.value, t.params)

In [None]:
import os
os.makedirs("/home/zmoham13/model_weights", exist_ok=True)
torch.save(net.state_dict(), "/home/zmoham13/model_weights/fc_net_weights.pth")

In [None]:
best = study.best_params

h1 = best["h1"]
h2 = best["h2"]
lr = best["lr"]
thr_scale = best["thr_scale"]
num_steps = best["num_steps"]

net = NNfullyConnected(
    num_inputs=28*28,
    h1=h1,
    h2=h2,
    num_outputs=modelConstants.num_classes,
).to(device)

for lif in [net.lif1, net.lif2, net.lif3]:
    lif.threshold = (lif.threshold * thr_scale * modelConstants.Threshold_voltage).to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

final_acc = train_and_evaluate(
    loss_fn=ce,
    num_epochs=15,
    optimizer=optimizer,
    scheduler=scheduler,
)

print("Final accuracy with best params:", final_acc)
