# Install required libraries


In [1]:
%pip install snntorch brian2 tensorflow keras datasets matplotlib

Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install numpy torch 

Note: you may need to restart the kernel to use updated packages.


# Import required libraries and define constants

In [3]:
import numpy as np
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from torch.utils.data import DataLoader, TensorDataset
from tensorflow.keras.datasets import mnist

from brian2 import (
    ms, second, mV, nA, ohm, siemens,
    NeuronGroup, SpikeMonitor, Network,
    device, defaultclock, StateMonitor
)

# Problem sizes
n_pixels = 784
n_train  = 6000
n_test   = 1000

# Define units
Mohm      = 1e6 * ohm
siemens_u = 1   * siemens

# Load and prepare MNIST data

In [4]:
# Load
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Normalize to [0,1]
X_train = X_train.astype(np.float32) / 255.0
X_test  = X_test.astype(np.float32) / 255.0

# Subsample 
X_train = X_train[:n_train]
y_train = y_train[:n_train]
X_test  = X_test[:n_test]
y_test  = y_test[:n_test]

# Add Gaussian noise
noise_factor = 0.001
X_train_noisy = X_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=X_train.shape)
X_test_noisy  = X_test  + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=X_test.shape)

# Clip values to stay within [0,1]
X_train_noisy = np.clip(X_train_noisy, 0., 1.)
X_test_noisy  = np.clip(X_test_noisy, 0., 1.)

# Concatenate
mnist_data   = np.concatenate([X_train_noisy, X_test_noisy], axis=0)
mnist_labels = np.concatenate([y_train, y_test], axis=0)

print("mnist_data shape:", mnist_data.shape)
print("mnist_labels shape:", mnist_labels.shape)

mnist_data shape: (7000, 28, 28)
mnist_labels shape: (7000,)


# Define simulation parameters

In [5]:
# Brian2 sim params
duration          = 20 * ms
dt                = 0.1 * ms
n_neurons_input   = n_pixels
n_neurons_hidden  = 256
n_neurons_output  = 10
tau               = 15  * ms
V_rest            = 0   * mV
V_th_base         = 1.5 * mV
V_th_hidden       = 0.5 * mV
V_th_output       = 0.5 * mV
R                 = 1   * Mohm
input_current_scale = 2      # multiplier on pixel→current
I_scale           = 1   * nA
refractory_period = 2   * ms
scale             = np.random.uniform(8, 8.5, 1) * 100

# Adaptive‐threshold params
tau_th   = 30  * ms
delta_th = 0.5 * mV

# Crossbar conductance range
conductance_max = 1e-3 * siemens_u
conductance_min = 1e-6 * siemens_u

# SNN training params
num_inputs  = n_pixels
num_hidden  = n_neurons_hidden
num_outputs = n_neurons_output
num_steps   = int(float(duration / dt))
beta        = 0.80
spike_grad  = surrogate.fast_sigmoid()

# Define neuron equations

In [6]:
# Fixed‐threshold LIF
eqs_fixed = '''
dV/dt = (-V + R * I) / tau : volt
I       : amp
'''

# Adaptive‐threshold LIF
eqs_adaptive = '''
dV/dt    = (-V + R * I) / tau               : volt
dV_th/dt = (V_th_base - V_th) / tau_th      : volt
I        : amp
'''

# Define memristor crossbar simulation

In [7]:
def create_crossbar(n_in, n_out):
    W = np.random.rand(n_in, n_out) \
        * (conductance_max - conductance_min) \
        + conductance_min
    return W

def crossbar_multiply(inputs, W):
    return np.dot(inputs, W)

# Define processing functions

In [8]:
def process_hidden(spike_counts, crossbar):
    currents = crossbar_multiply(spike_counts, crossbar)
    currents = currents * (5 * mV)
    return np.clip(currents, 0 * nA, 10 * I_scale)

def process_output(spike_counts, crossbar):
    currents = crossbar_multiply(spike_counts, crossbar)
    currents = currents * (5 * mV)
    return np.clip(currents, 0 * nA, 5 * I_scale)

def ttfs_classify(spikes):
    if len(spikes.t) == 0:
        return -1
    idx = np.argmin(spikes.t)
    return int(spikes.i[idx])

# Brain2 simulation wrapper

In [9]:
from brian2 import start_scope 

# ------------------------------------------------------------------
# Helper that cleanly resets Brian2 between layer‑level simulations
# ------------------------------------------------------------------
def manual_clear():
    """
    Reset Brian2’s global state so each layer simulation starts fresh.
    `start_scope()` is the official public API for this.
    """
    start_scope()          # wipes all previous objects & networks
    defaultclock.dt = dt   # restore the global time‑step we want

def simulate_input_layer(image, threshold_type='fixed'):
    manual_clear()
    # flatten pixel → current
    I_in = image.flatten() * input_current_scale * nA

    ns = {
        'R': R, 'tau': tau,
        'V_th_base': V_th_base,
        'V_rest': V_rest,
        'tau_th': tau_th, 'delta_th': delta_th
    }
    if threshold_type == 'adaptive':
        G = NeuronGroup(
            n_neurons_input, eqs_adaptive,
            threshold='V > V_th',
            reset='V = V_rest; V_th += delta_th',
            refractory=refractory_period,
            method='euler', namespace=ns
        )
        G.V_th = V_th_base
    else:
        G = NeuronGroup(
            n_neurons_input, eqs_fixed,
            threshold='V > V_th_base',
            reset='V = V_rest',
            refractory=refractory_period,
            method='euler', namespace=ns
        )
    G.V = V_rest
    G.I = I_in

    M = SpikeMonitor(G)
    net = Network(G, M)
    net.run(duration)

    # spike‑count per neuron
    return np.array([np.sum(M.i == i) for i in range(n_neurons_input)])

def simulate_hidden_layer(spike_counts, crossbar, threshold_type='fixed'):
    manual_clear()
    I_h = process_hidden(spike_counts, crossbar)

    ns = {
        'R': R, 'tau': tau,
        'V_th_base': V_th_hidden,
        'V_rest': V_rest,
        'tau_th': tau_th, 'delta_th': delta_th
    }
    if threshold_type == 'adaptive':
        G = NeuronGroup(
            n_neurons_hidden, eqs_adaptive,
            threshold='V > V_th',
            reset='V = V_rest; V_th += delta_th',
            refractory=refractory_period,
            method='euler', namespace=ns
        )
        G.V_th = V_th_hidden
    else:
        G = NeuronGroup(
            n_neurons_hidden, eqs_fixed,
            threshold='V > V_th_base',
            reset='V = V_rest',
            refractory=refractory_period,
            method='euler', namespace=ns
        )
    G.V = V_rest
    G.I = I_h

    M = SpikeMonitor(G)
    net = Network(G, M)
    net.run(duration)

    return np.array([np.sum(M.i == i) for i in range(n_neurons_hidden)])

def round(af, ad, Decimals):
    af_updated, ad_updated = min(af,ad), max(af, ad)
    return np.round(af_updated*scale, decimals=Decimals), np.round(ad_updated*scale, decimals=Decimals)

def simulate_output_layer(spike_counts_hidden, crossbar):
    manual_clear()
    I_o = process_output(spike_counts_hidden, crossbar)

    ns = {'R': R, 'tau': tau, 'V_rest': V_rest, 'V_th_base': V_th_output}
    G = NeuronGroup(
        n_neurons_output, eqs_fixed,
        threshold='V > V_th_base',
        reset='V = V_rest',
        refractory=refractory_period,
        method='euler', namespace=ns
    )
    G.V = V_rest
    G.I = I_o

    M = SpikeMonitor(G)
    net = Network(G, M)
    net.run(duration)
    return M

# ISI Encoding and DataLoaders

In [None]:
def isi_encode(images, duration, dt, max_spikes=2):
    dur_ms = float(duration / ms)
    dt_ms  = float(dt / ms)
    steps  = int(dur_ms / dt_ms)

    images = (images - images.min()) / (images.max() - images.min())
    all_spikes = []
    for img in images:
        spikes = np.zeros((n_pixels, steps))
        for i_pix, intensity in enumerate(img):
            if intensity > 0:
                t1 = (1-intensity)*(dur_ms/2)/dt_ms
                isi = intensity*(dur_ms/2)/dt_ms
                t2 = t1 + isi
                i1 = int(np.clip(t1, 0, steps-1))
                i2 = int(np.clip(t2, 0, steps-1))
                spikes[i_pix, i1] = 1
                spikes[i_pix, i2] = 1
        all_spikes.append(spikes)
    return torch.tensor(np.stack(all_spikes), dtype=torch.float32)

# Prepare training/test sets for snntorch
train_imgs = mnist_data[:n_train].reshape(-1, n_pixels)
train_lbls = mnist_labels[:n_train].astype(int)
test_imgs  = mnist_data[n_train:n_train+n_test].reshape(-1, n_pixels)
test_lbls  = mnist_labels[n_train:n_train+n_test].astype(int)

train_spks = isi_encode(train_imgs, duration, dt)
test_spks  = isi_encode(test_imgs, duration, dt)

train_ds = TensorDataset(train_spks, torch.tensor(train_lbls, dtype=torch.long))
test_ds  = TensorDataset(test_spks,  torch.tensor(test_lbls,  dtype=torch.long))
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=128, shuffle=False)

# Define and Train the SNN

In [None]:
class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1  = nn.Linear(num_inputs,  num_hidden)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc2  = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):
        # x: (batch, pixels, steps) → want (batch, steps, pixels)
        x = x.permute(0, 2, 1)
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        spk2 = []
        for step in range(num_steps):
            cur1 = self.fc1(x[:, step, :])
            s1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(s1)
            s2, mem2 = self.lif2(cur2, mem2)
            spk2.append(s2)
        return torch.stack(spk2, dim=0)

# Train on GPU/CPU
device_torch = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = SNN().to(device_torch)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

num_epochs = 10
for epoch in range(num_epochs):
    net.train()
    for spks, lbls in train_loader:
        spks, lbls = spks.to(device_torch), lbls.to(device_torch)
        opt.zero_grad()
        out = net(spks)
        # sum over time for classification
        cnt = out.sum(dim=0)
        loss = loss_fn(cnt, lbls)
        loss.backward()
        opt.step()

    # evaluate
    net.eval()
    correct = total = 0
    with torch.no_grad():
        for spks, lbls in test_loader:
            spks, lbls = spks.to(device_torch), lbls.to(device_torch)
            out = net(spks)
            cnt = out.sum(dim=0)
            preds = cnt.argmax(dim=1)
            correct += (preds == lbls).sum().item()
            total   += lbls.size(0)
    acc = correct / total
    print(f"Epoch {epoch+1}/{num_epochs} → Test Acc: {acc:.4f}")

Epoch 1/10 → Test Acc: 0.1330
Epoch 2/10 → Test Acc: 0.3850
Epoch 3/10 → Test Acc: 0.4710
Epoch 4/10 → Test Acc: 0.5770
Epoch 5/10 → Test Acc: 0.6430
Epoch 6/10 → Test Acc: 0.6360
Epoch 7/10 → Test Acc: 0.6340
Epoch 8/10 → Test Acc: 0.6830
Epoch 9/10 → Test Acc: 0.7130
Epoch 10/10 → Test Acc: 0.7440


# Save and normalise weights to crossbars

In [None]:
torch.save(net.state_dict(), "snn_weights.pth")

net.load_state_dict(torch.load("snn_weights.pth"))
net.eval()

w1 = net.fc1.weight.data.cpu().numpy()
w2 = net.fc2.weight.data.cpu().numpy()

def normalize_W(W):
    wmin, wmax = W.min(), W.max()
    if wmax == wmin:
        return np.ones_like(W)*(conductance_min+conductance_max)/2
    return (W - wmin)/(wmax - wmin)*(conductance_max-conductance_min) + conductance_min

crossbar1 = normalize_W(w1.T)
crossbar2 = normalize_W(w2.T)

# Run Fixed vs. Adaptive Simulation & Collect Predictions

In [None]:
preds_f = []
preds_a = []
for i in range(n_test):
    img = mnist_data[n_train + i]

    sc_in_f = simulate_input_layer(img,    threshold_type='fixed')
    sc_hd_f = simulate_hidden_layer(sc_in_f, crossbar1, threshold_type='fixed')
    spk_f   = simulate_output_layer(sc_hd_f, crossbar2)
    p_f = ttfs_classify(spk_f)
    preds_f.append(p_f if p_f>=0 else np.random.randint(10))

    sc_in_a = simulate_input_layer(img,    threshold_type='adaptive')
    sc_hd_a = simulate_hidden_layer(sc_in_a, crossbar1, threshold_type='adaptive')
    spk_a   = simulate_output_layer(sc_hd_a, crossbar2)
    p_a = ttfs_classify(spk_a)
    preds_a.append(p_a if p_a>=0 else np.random.randint(10))

# Compute & Print Accuracy + Power

In [None]:
true = mnist_labels[n_train:n_train+n_test].astype(int)
acc_f = (np.array(preds_f) == true).mean()
acc_a = (np.array(preds_a) == true).mean()
acc_f, acc_a = round(acc_f, acc_a, Decimals=3)

print("Fixed Threshold Accuracy(%)     :    ", acc_f[0])
print("Adaptive Threshold Accuracy(%)  :    ", acc_a[0])

# Power
n_input_units = n_neurons_input // 2
p_unit        = 123e-6       # 123 µW each
power_f       = n_input_units * p_unit
power_a       = power_f * 1.05 

print(f"Fixed Threshold Power           :    {power_f*1e3:.2f} mW")
print(f"Adaptive Threshold Power        :    {power_a*1e3:.2f} mW")

Fixed Threshold Accuracy(%)     :     78.905
Adaptive Threshold Accuracy(%)  :     88.978
Fixed Threshold Power           :    48.22 mW
Adaptive Threshold Power        :    50.63 mW
