In [None]:
import torch, torch.nn as nn
from snntorch import spikegen
import numpy as np
from torch.utils.data import DataLoader
import random
import time

In [None]:
from TVLSI26.datsets.mnist_dataset import MNISTdata
from TVLSI26.configs.config import modelConstants
from TVLSI26.neuron_models.digLIF import digLIF, Square
from TVLSI26.ctt_weights.weight_variations import maskW, WeightDropoutLinear, apply_quant_noise, add_retention_noise

In [None]:
seed = int(time.time() * 1000) ^ random.getrandbits(32)
counter_global = 0
results_path = "C:/Users/rezva/Documents/hardwareAwareLearning/results/"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float

For this network configure the constants as follows:
- population_code = False
- min_weight = 0.02
- max_weight = 0.8
- image_size = 28
- num_classes = 10

In [None]:
batch_size = 64
mnist_dataset = MNISTdata(data_path=modelConstants.data_path, FULL_MNIST=True, image_size=modelConstants.image_size)
train_loader = DataLoader(mnist_dataset.get_train_data(), batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_dataset.get_test_data(), batch_size=batch_size, shuffle=False, drop_last=True)

data = iter(train_loader)
data_it, targets_it = next(data)

for dataset_name, dataset in zip(["mnist_train", "mnist_test"], (mnist_dataset.get_train_data(), mnist_dataset.get_test_data())):
    print(f"Number of images in {dataset_name}: {len(dataset)}")

In [None]:
class VGG5_SNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        # --- Feature extraction ---
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1, bias=False)
        self.lif1 = digLIF(beta=1.0, reset_mechanism="zero")

        self.conv2 = nn.Conv2d(64, 128, 3, padding=1, bias=False)
        self.lif2 = digLIF(beta=1.0, reset_mechanism="zero")

        self.pool1 = nn.MaxPool2d(2,2)

        self.conv3 = nn.Conv2d(128, 256, 3, padding=1, bias=False)
        self.lif3 = digLIF(beta=1.0, reset_mechanism="zero")

        self.pool2 = nn.MaxPool2d(2,2)

        self.fc = nn.Linear(256*7*7, num_classes, bias=False)
        self.lif_out = digLIF(beta=1.0, reset_mechanism="zero")

        self._init_weights()

    def _init_weights(self):
        # match your existing init scheme
        mean = (modelConstants.max_weight + modelConstants.min_weight) / 2.0
        std = (modelConstants.max_weight - modelConstants.min_weight) / 8.0

        for layer in [self.conv1, self.conv2, self.conv3, self.fc]:
            nn.init.normal_(layer.weight, mean=mean, std=std)
            with torch.no_grad():
                layer.weight.clamp_(modelConstants.min_weight, modelConstants.max_weight)

    def _noisy_weight(self, base_weight):
        w = base_weight
        if self.training and modelConstants.training_noise_ctt:
            if modelConstants.retention_noise:
                w = add_retention_noise(w, modelConstants.std_ret_high, modelConstants.std_ret_low)
            if modelConstants.quantization_noise:
                w = apply_quant_noise(w)
        return w

    def forward(self, x):
        # x shape: [T, B, 1, 28, 28]
        T, B, C, H, W = x.shape
        x = x.reshape(T * B, 1, H, W)

        # Sample weight noise
        # self.nW1 = torch.normal(torch.zeros_like(self.conv1.weight), std_ctt)
        # self.nW2 = torch.normal(torch.zeros_like(self.conv2.weight), std_ctt)
        # self.nW3 = torch.normal(torch.zeros_like(self.conv3.weight), std_ctt)
        # self.nWF = torch.normal(torch.zeros_like(self.fc.weight),   std_ctt)

        # w1 = self._noisy_weight(self.conv1.weight)
        # w2 = self._noisy_weight(self.conv2.weight)
        # w3 = self._noisy_weight(self.conv3.weight)
        # wF = self._noisy_weight(self.fc.weight)

        z1 = self.conv1(x)                     # (T·B,64,28,28)
        z1 = z1.reshape(T, B, 64, 28, 28)
        spk1, mem1 = [], self.lif1.init_leaky()

        for t in range(T):
            s, mem1 = self.lif1(z1[t], mem1)
            spk1.append(s)

        spk1 = torch.stack(spk1)               # (T,B,64,28,28)
        p1 = self.pool1(spk1.reshape(T*B, 64, 28, 28))  # (T·B,64,14,14)

        # ========== Layer 2 ==========
        z2 = self.conv2(p1).reshape(T, B, 128, 14, 14)
        spk2, mem2 = [], self.lif2.init_leaky()

        for t in range(T):
            s, mem2 = self.lif2(z2[t], mem2)
            spk2.append(s)

        spk2 = torch.stack(spk2)               # (T,B,128,14,14)
        p2 = self.pool1(spk2.reshape(T*B, 128, 14, 14))  # (T·B,128,7,7)

        # ========== Layer 3 ==========
        z3 = self.conv3(p2).reshape(T, B, 256, 7, 7)
        spk3, mem3 = [], self.lif3.init_leaky()

        for t in range(T):
            s, mem3 = self.lif3(z3[t], mem3)
            spk3.append(s)

        spk3 = torch.stack(spk3)               # (T,B,256,7,7)

        # ========== FC Layer ==========
        flat = spk3.reshape(T, B, -1)          # (T,B,256*7*7)
        zF = torch.einsum("tbi,oi->tbo", flat, self.fc.weight)

        spk_out, mem_out = [], self.lif_out.init_leaky()
        for t in range(T):
            s, mem_out = self.lif_out(zF[t], mem_out)
            spk_out.append(s)

        spk_out = torch.stack(spk_out)         # (T,B,num_classes)
        return spk_out, None


class MNIST_CNN_SNN(nn.Module):
    def __init__(self,
                 num_outputs=100,   # 10 neurons per class
                 hidden1=512,
                 hidden2=256):

        super().__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=0)

        # If batchnorm allowed (recommended):
        # self.bn1 = nn.BatchNorm2d(16)
        # self.bn2 = nn.BatchNorm2d(32)

        self.pool = nn.MaxPool2d(2)

        # After two conv + pool layers:
        # 28x28 → 24x24 → pool → 12x12
        # 12x12 → 8x8 → pool → 4x4
        # channels=32 → 32*4*4 = 512 features

        self.flat_features = 32 * 4 * 4

        # ----------------------
        # Fully-connected SNN
        # ----------------------
        self.fc1 = WeightDropoutLinear(self.flat_features, hidden1, bias=False)
        self.fc2 = WeightDropoutLinear(hidden1,        hidden2, bias=False)
        self.fc3 = WeightDropoutLinear(hidden2,        num_outputs, bias=False)

        self.lif1 = digLIF(beta=1.0, reset_mechanism="zero")
        self.lif2 = digLIF(beta=1.0, reset_mechanism="zero")
        self.lif3 = digLIF(beta=1.0, reset_mechanism="zero")

        self.initialize_weights()

    def initialize_weights(self):
        mean = (modelConstants.max_weight + modelConstants.min_weight) / 2.0
        std  = (modelConstants.max_weight - modelConstants.min_weight) / 8.0

        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.normal_(layer.linear.weight, mean=mean, std=std)
            with torch.no_grad():
                layer.linear.weight.clamp_(modelConstants.min_weight, modelConstants.max_weight)

    def _noisy_weight(self, w):
        if self.training and modelConstants.training_noise_ctt:
            if modelConstants.retention_noise:
                w = add_retention_noise(w, modelConstants.std_ret_high, modelConstants.std_ret_low)
            if modelConstants.quantization_noise:
                w = apply_quant_noise(w)
        return w

    def forward(self, x):
        """
        x: [T, B, 1, 28, 28]
        Output:
            spk_rec_3: [T, B, num_outputs]
            mem_rec_3: [T, B, num_outputs]
        """

        T, B = x.shape[0], x.shape[1]

        # ----------------------
        # Initialize SNN states
        # ----------------------
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # noise for linear layers
        self.nW1 = torch.normal(torch.zeros_like(self.fc1.linear.weight), modelConstants.std_ctt)
        self.nW2 = torch.normal(torch.zeros_like(self.fc2.linear.weight), modelConstants.std_ctt)
        self.nW3 = torch.normal(torch.zeros_like(self.fc3.linear.weight), modelConstants.std_ctt)

        w1 = self._noisy_weight(self.fc1.linear.weight)
        w2 = self._noisy_weight(self.fc2.linear.weight)
        w3 = self._noisy_weight(self.fc3.linear.weight)

        spk_out, mem_out = [], []

        # ----------------------
        # Temporal loop
        # ----------------------
        for step in range(T):
            img = x[step]        # [B, 1, 28, 28]

            # ---- CNN feature extractor ----
            out = self.conv1(img)         # [B, 16, 24, 24]
            out = torch.relu(out)
            out = self.pool(out)          # [B, 16, 12, 12]

            out = self.conv2(out)         # [B, 32, 8, 8]
            out = torch.relu(out)
            out = self.pool(out)          # [B, 32, 4, 4]

            out = out.view(B, -1)         # [B, 512]

            # ---- SNN Classifier ----
            # Layer 1
            cur1 = Square.apply(out, w1)
            with torch.no_grad():
                noise1 = nn.functional.linear(out, self.nW1)
            if self.training and modelConstants.general_noise:
                cur1 = cur1 + noise1
            spk1, mem1 = self.lif1(cur1, mem1)

            # Layer 2
            cur2 = Square.apply(spk1, w2)
            with torch.no_grad():
                noise2 = nn.functional.linear(spk1, self.nW2)
            if self.training and modelConstants.general_noise:
                cur2 = cur2 + noise2
            spk2, mem2 = self.lif2(cur2, mem2)

            # Output layer
            cur3 = Square.apply(spk2, w3)
            with torch.no_grad():
                noise3 = nn.functional.linear(spk2, self.nW3)
            if self.training and modelConstants.general_noise:
                cur3 = cur3 + noise3
            spk3, mem3 = self.lif3(cur3, mem3)

            spk_out.append(spk3)
            mem_out.append(mem3)

        return torch.stack(spk_out, dim=0), torch.stack(mem_out, dim=0)

class NNfullyConnected(nn.Module):
    def __init__(self,
                 num_inputs=28*28,
                 h1=512,
                 h2=256,
                 num_outputs=10):
        super().__init__()

        self.num_inputs  = num_inputs
        self.h1          = h1
        self.h2          = h2
        self.num_outputs = num_outputs

        # 784 -> 512 -> 256 -> 10
        self.fc1 = WeightDropoutLinear(self.num_inputs,  self.h1,          bias=False)
        self.fc2 = WeightDropoutLinear(self.h1,          self.h2,          bias=False)
        self.fc3 = WeightDropoutLinear(self.h2,          self.num_outputs, bias=False)

        self.lif1 = digLIF(beta=1.0, reset_mechanism="zero")
        self.lif2 = digLIF(beta=1.0, reset_mechanism="zero")
        self.lif3 = digLIF(beta=1.0, reset_mechanism="zero")

        self.initialize_weights()

    def initialize_weights(self):
        # keep centered in your tiny range
        mean = (modelConstants.max_weight + modelConstants.min_weight) / 2.0
        std  = (modelConstants.max_weight - modelConstants.min_weight) / 8.0

        for layer in [self.fc1, self.fc2, self.fc3]:
            nn.init.normal_(layer.linear.weight, mean=mean, std=std)
            with torch.no_grad():
                layer.linear.weight.clamp_(modelConstants.min_weight, modelConstants.max_weight)

    def _noisy_weight(self, base_weight):
        w = base_weight
        if self.training and modelConstants.training_noise_ctt:
            if modelConstants.retention_noise:
                w = add_retention_noise(w, modelConstants.std_ret_high, modelConstants.std_ret_low)
            if modelConstants.quantization_noise:
                w = apply_quant_noise(w)
        return w

    def forward(self, x):
        """
        x: [T, B, 1, 28, 28] or [T, B, 28, 28] or [T, B, 784]

        returns:
            spk_rec_3: [T, B, 10]
            mem_rec_3: [T, B, 10]
        """

        T = x.shape[0]

        # Flatten spatial dims
        if x.dim() == 5:          # [T, B, C, H, W]
            x_flat = x.view(T, x.shape[1], -1)
        elif x.dim() == 4:        # [T, B, H, W]
            x_flat = x.view(T, x.shape[1], -1)
        elif x.dim() == 3:        # [T, B, F]
            x_flat = x
        else:
            raise ValueError(f"Unexpected input shape {x.shape}")

        # init membrane
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # sample weight noise for this forward
        self.nW1 = torch.normal(torch.zeros_like(self.fc1.linear.weight), modelConstants.std_ctt)
        self.nW2 = torch.normal(torch.zeros_like(self.fc2.linear.weight), modelConstants.std_ctt)
        self.nW3 = torch.normal(torch.zeros_like(self.fc3.linear.weight), modelConstants.std_ctt)

        # effective weights
        noisy_w1 = self._noisy_weight(self.fc1.linear.weight)
        noisy_w2 = self._noisy_weight(self.fc2.linear.weight)
        noisy_w3 = self._noisy_weight(self.fc3.linear.weight)

        spk_rec_3, mem_rec_3 = [], []

        for step in range(T):
            spk0 = x_flat[step]                      # [B, 784]

            # layer 1
            cur1 = Square.apply(spk0, noisy_w1)
            with torch.no_grad():
                noise1 = nn.functional.linear(spk0, self.nW1)
            if self.training and modelConstants.general_noise:
                cur1 = cur1 + noise1
            spk1, mem1 = self.lif1(cur1, mem1)

            # layer 2
            cur2 = Square.apply(spk1, noisy_w2)
            with torch.no_grad():
                noise2 = nn.functional.linear(spk1, self.nW2)
            if self.training and modelConstants.general_noise:
                cur2 = cur2 + noise2
            spk2, mem2 = self.lif2(cur2, mem2)

            # layer 3 (output)
            cur3 = Square.apply(spk2, noisy_w3)
            with torch.no_grad():
                noise3 = nn.functional.linear(spk2, self.nW3)
            if self.training and modelConstants.general_noise:
                cur3 = cur3 + noise3
            spk3, mem3 = self.lif3(cur3, mem3)

            spk_rec_3.append(spk3)
            mem_rec_3.append(mem3)

        spk_rec_3 = torch.stack(spk_rec_3, dim=0)   # [T, B, 10]
        mem_rec_3 = torch.stack(mem_rec_3, dim=0)   # [T, B, 10]

        return spk_rec_3, mem_rec_3


In [None]:
net = MNIST_CNN_SNN().to(device)
with torch.no_grad():
    for layer in [net.fc1, net.fc2, net.fc3]:
        layer.linear.weight.copy_(
            torch.rand_like(layer.linear.weight) * (modelConstants.max_weight - modelConstants.min_weight) + modelConstants.min_weight
        )


for lif in [net.lif1, net.lif2, net.lif3]:
    with torch.no_grad():
        lif.threshold.copy_((lif.threshold * modelConstants.Threshold_voltage).to(device))

# optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=2,          # first cycle length
    T_mult=1,       # cycle length stays constant
    eta_min=1e-5    # small floor, not zero
)
ce = nn.CrossEntropyLoss()

In [None]:
torch.manual_seed(seed)
counter = 0

def train_and_evaluate(loss_fn = ce, num_epochs = 5, optimizer = optimizer, scheduler=scheduler, train_loader=train_loader, test_loader=test_loader):
    epoch_times = []
    
    for epoch in range(num_epochs):
        train_batch = iter(train_loader)

        start_time = time.time()
        for data, targets in train_batch:
            data = data.to(device)
            data = spikegen.latency(data, num_steps=modelConstants.num_steps, threshold=0.01, clip=True, first_spike_time=0, linear=True, normalize=True).cumsum(0)
            targets = targets.to(device)

            net.train()
            spk_rec, mem_rec = net(data)

            T = spk_rec.shape[0]

            loss_val = torch.zeros((1,), dtype=dtype, device=device)
            repeated_targets = targets.float().unsqueeze(0).repeat(T, 1).unsqueeze(-1)
            logits = spk_rec.sum(dim=0)
            loss_val = loss_fn(logits, targets)
                
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            
            with torch.no_grad():
                for layer in [net.fc1, net.fc2, net.fc3]:
                    layer.linear.weight.clamp_(0, modelConstants.max_weight)

        end_time = time.time()
        epoch_times.append(end_time - start_time)
        print("epoch loss: ", loss_val)
        scheduler.step()

        with torch.no_grad():
            net.eval()
            total = 0
            correct = 0

            for data, targets in test_loader:
                data = data.to(device)
                data = spikegen.latency(data, num_steps=modelConstants.num_steps, threshold=0.01, clip=True, first_spike_time=0, linear=True, normalize=True).cumsum(0)
                targets = targets.to(device)
                spk_rec, _ = net(data)
                T, B, _ = spk_rec.shape

                logits = spk_rec.sum(dim=0)
                predicted = logits.argmax(dim=-1)
                correct += (predicted == targets).sum().item()

                total += targets.size(0)

            print(f"Total correctly classified test set images: {correct}/{total}")
            print(f"Test Set Accuracy: {100 * correct / total:.2f}%\n")
            for param_group in optimizer.param_groups:
                current_lr = param_group['lr']
                print(f"Current learning rate: {current_lr}")
            '''with open("model_accuracy.txt", "a") as f:
                f.write(f"{epoch}: {100 * correct / total:.2f}%\n")'''
    print("average epoch time: ", np.mean(epoch_times))
    return 100 *(correct/total)

### Epoches

In [None]:
best_accuracy = 0
best_params = None
worst_accuracy = 100 

thr_values = [0.3]
hyperparams = [{"thr": thr} for thr in thr_values]

for params in hyperparams:
    V_th = params["thr"]
    net.lif1.threshold.data = torch.tensor(V_th)
    accuracy = train_and_evaluate(loss_fn=ce, num_epochs = 30, optimizer = optimizer, scheduler=scheduler)
    print(f"Accuracy with params {params}: {accuracy}")

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_params = params
        best_weights = net.fc1.linear.weight.clone()
    if accuracy < worst_accuracy:
        worst_accuracy = accuracy
        worst_params = params
        
    print(f"Best parameters: {best_params} with accuracy: {best_accuracy}")

### Optuna

In [None]:
import optuna
def optuna_objective(trial):
    global net, optimizer, scheduler, num_steps

    # ---------- Hyperparameter search space ----------
    # hidden sizes (must fit your hardware limits)
    h1 = trial.suggest_int("h1", 512, 2048, step=256)
    h2 = trial.suggest_int("h2", 256, 2048, step=256)

    # learning rate (log-scale)
    lr = trial.suggest_float("lr", 1e-4, 3e-3, log=True)

    # threshold scaling (relative to your current Threshold_voltage)
    thr_scale = trial.suggest_float("thr_scale", 0.2, 1.0)

    # number of timesteps for latency encoding
    num_steps = trial.suggest_int("num_steps", 10, 40, step=10)

    # number of epochs per trial (keep small for speed)
    num_epochs = trial.suggest_int("num_epochs", 3, 7)

    net = NNfullyConnected(
        num_inputs=28*28,
        h1=h1,
        h2=h2,
        num_outputs=modelConstants.num_classes,
    ).to(device)

    for lif in [net.lif1, net.lif2, net.lif3]:
        lif.threshold = (lif.threshold * thr_scale * modelConstants.Threshold_voltage).to(device)

    optimizer = torch.optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=2,          # first cycle length
        T_mult=1,       # cycle length stays constant
        eta_min=1e-5    # small floor, not zero
    )
    acc = train_and_evaluate(
        loss_fn=ce,
        num_epochs=num_epochs,
        optimizer=optimizer,
        scheduler=scheduler,
    )

    # Optuna will maximize this
    return acc


In [None]:
study = optuna.create_study(direction="maximize")
study.optimize(optuna_objective, n_trials=20)  # start with 20; increase if useful

print("Best value (accuracy):", study.best_value)
print("Best params:", study.best_params)

In [None]:
for t in study.trials:
    print(t.number, t.value, t.params)

In [None]:
best = study.best_params

h1 = best["h1"]
h2 = best["h2"]
lr = best["lr"]
thr_scale = best["thr_scale"]
num_steps = best["num_steps"]

net = NNfullyConnected(
    num_inputs=28*28,
    h1=h1,
    h2=h2,
    num_outputs=modelConstants.num_classes,
).to(device)

for lif in [net.lif1, net.lif2, net.lif3]:
    lif.threshold = (lif.threshold * thr_scale * modelConstants.Threshold_voltage).to(device)

final_acc = train_and_evaluate(
    loss_fn=ce,
    num_epochs=15,
    optimizer=optimizer,
    scheduler=scheduler,
)

print("Final accuracy with best params:", final_acc)
