In [2]:
from brian2 import *
import torchvision
import torchvision.transforms as transforms

# ----------------------------
# Dataset (MNIST -> small subset for speed)
# ----------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1)) # flatten
])
mnist = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
data, targets = mnist.data[:1000].float()/255.0, mnist.targets[:1000]  # subset

# ----------------------------
# Neuron model (LIF)
# ----------------------------
N_in, N_hidden, N_out = 784, 200, 10

tau = 10*ms
eqs = '''
dv/dt = (-v + I)/tau : 1 (unless refractory)
I : 1
'''

input_group = SpikeGeneratorGroup(N_in, [], []*ms)
hidden = NeuronGroup(N_hidden, eqs, threshold='v>1', reset='v=0', refractory=5*ms, method='euler')
output = NeuronGroup(N_out, eqs, threshold='v>1', reset='v=0', refractory=5*ms, method='euler')

# Synapses with eligibility traces
S_in = Synapses(input_group, hidden, model='''
                w : 1
                e_trace : 1
                dI/dt = -I/tau : 1 (clock-driven)
                ''',
                on_pre='''
                I_post += w
                e_trace += 1.0
                ''')
S_in.connect(p=0.1)
S_in.w = '0.2*rand()'

S_hid = Synapses(hidden, output, model='''
                w : 1
                e_trace : 1
                dI/dt = -I/tau : 1 (clock-driven)
                ''',
                on_pre='''
                I_post += w
                e_trace += 1.0
                ''')
S_hid.connect(p=0.1)
S_hid.w = '0.2*rand()'

# ----------------------------
# e-prop weight update rule
# ----------------------------
@network_operation(dt=10*ms)
def eprop_update():
    lr = 1e-3
    # simple local error: output spikes - target
    for i in range(N_out):
        error = (output.v[i] - 0.5)  # placeholder error term
        S_hid.w[i] -= lr * error * S_hid.e_trace[i]

# ----------------------------
# Run simulation for one sample
# ----------------------------
net = Network(collect())
net.add(eprop_update)
net.run(100*ms)


100.0%
100.0%
100.0%
100.0%


ValueError: The post-synaptic variable 'I' has the same name as a synaptic variable, rename the synaptic variable.(for example to 'I_syn') to avoid confusion