In [18]:
from brian2 import *
import numpy as np
import warnings

#warnings.filterwarnings("ignore", category=UserWarning, module='brian2.codegen.generators.base')

start_scope()

defaultclock.dt = 0.0001*ms  

# Custom timing function
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, result=1, sum=1, spikes_received=1)
def spike_timing(w, global_clock, layer, sum, spikes_received): 
    x = global_clock % 1
    if w >= 0:
        return (x ** (1 - w)) 
    else:
        return (1 - (1 - x) ** (1 + w)) 
    
@implementation('numpy', discard_units=True)
@check_units(layer=1, result=1, sum=1, spikes_received=1)
def math1(layer, sum, spikes_received): 
    return (sum/spikes_received )+ layer

# def run_Urd(inputs, weights_1, weights_2, weights_3):
#     '''4-10-3 SNN'''
#     # will add check of weights # so it all works
#     n_input = 4 
#     n_hidden = 10
#     n_output = 3
#     n_total = n_input + n_hidden + n_output

#     neurons = NeuronGroup(n_total, '''
#         v : 1
#         sum : 1
#         spikes_received : 1
#         scheduled_time : second
#         global_clock : 1
#     ''', threshold='v > 1', reset='v = 0', method='exact')
#     neurons.v = 0
#     neurons.scheduled_time = 1e9 * second
#     neurons.global_clock = 0.0
#     neurons.sum = 0.0
#     neurons.spikes_received = 0.0


#     indicess = [i for i in range(n_input)]
#     stim = SpikeGeneratorGroup(n_input, indices=indicess, times=(inputs*ms))

#     syn_input = Synapses(stim, neurons[0:n_input], '''
#         w : 1
#         layer : 1
#     ''', on_pre='''
#         spikes_received += 1
#         sum += spike_timing(w, global_clock, layer, spikes_received, sum)
#         scheduled_time = ((sum/spikes_received) + layer) * ms 
#     ''')
#     syn_input.connect(j='i')
#     syn_input.w = weights_1
#     syn_input.layer = 0

#     syn_hidden = Synapses(neurons[0:n_input], neurons[n_input:n_input+n_hidden], '''
#         w : 1
#         layer : 1
#     ''', on_pre='''
#         spikes_received += 1
#         sum += spike_timing(w, global_clock, layer, spikes_received, sum)
#         scheduled_time = ((sum/spikes_received) + layer) * ms 
#     ''')
#     for inp in range(n_input):
#         for hid in range(n_hidden):
#             syn_hidden.connect(i=inp, j=hid)

#     syn_hidden.w = weights_2
#     syn_hidden.layer = 1


#     syn_output = Synapses(
#         neurons[n_input:n_input+n_hidden], 
#         neurons[n_input+n_hidden:n_total], 
#         '''
#         w : 1
#         layer : 1
#         ''',
#         on_pre='''
#         spikes_received += 1
#         sum += spike_timing(w, global_clock, layer, spikes_received, sum)
#         scheduled_time = ((sum/spikes_received) + layer) * ms 
#         '''
#     )

#     for hid in range(n_hidden):
#         for out in range(n_output):
#             syn_output.connect(i=hid, j=out)

#     # Set weights in correct order
#     syn_output.w[:] = weights_3
#     syn_output.layer = 2

#     #print(syn_output.i[:], syn_output.j[:])
#     #weights_into_output_1 = weights_3[1::3]



#     neurons.run_regularly('''
#         v = int(abs(t - scheduled_time) < 0.0005*ms) * 1.2
#         global_clock += 0.001
#     ''', dt=0.001*ms)


#     spikemon = SpikeMonitor(neurons)
    
#     # neurons.v = 0
#     # neurons.scheduled_time = 1e9 * second
#     # neurons.global_clock = 0.0
#     # neurons.sum = 0.0
#     # neurons.spikes_received = 0.0

#     run(5*ms)

#     result = []

#     for i in range(n_total):
#         times = spikemon.spike_trains()[i]
#         if len(times) > 0:
#             result.append(round(times[0]/ms, 3))
#         else:
#             result.append(None)  # or some other placeholder like float('nan')
            
#     return result




In [12]:
import numpy as np
import logging

logging.getLogger('brian2').setLevel(logging.ERROR)


import numpy as np
from copy import deepcopy

def compute_loss(predicted, desired):
    """
    MSE loss on the 3 output spike-times.
    predicted: full list of 17 times (None → we treat as large penalty)
    desired: array-like of length 3
    """
    # pull out last 3 spikes
    out_t = np.array([t if t is not None else 10.0 for t in predicted[-3:]])
    d = np.array(desired)
    return np.mean((out_t - d)**2)

def finite_difference_grads(inputs, w1, w2, w3, desired, eps=3):
    """
    Returns numerical gradients of loss wrt each weight array.
    """
    # baseline loss
    base_out = run_Urd(inputs, w1, w2, w3)
    L0 = compute_loss(base_out, desired)

    # grad for layer1
    gw1 = np.zeros_like(w1)
    for idx in np.ndindex(w1.shape):
        w1p = w1.copy()
        w1p[idx] += eps
        Lp = compute_loss(run_Urd(inputs, w1p, w2, w3), desired)
        gw1[idx] = (Lp - L0)/eps

    # grad for layer2
    gw2 = np.zeros_like(w2)
    for idx in np.ndindex(w2.shape):
        w2p = w2.copy()
        w2p[idx] += eps
        Lp = compute_loss(run_Urd(inputs, w1, w2p, w3), desired)
        gw2[idx] = (Lp - L0)/eps

    # grad for layer3
    gw3 = np.zeros_like(w3)
    for idx in np.ndindex(w3.shape):
        w3p = w3.copy()
        w3p[idx] += eps
        Lp = compute_loss(run_Urd(inputs, w1, w2, w3p), desired)
        gw3[idx] = (Lp - L0)/eps

    return gw1, gw2, gw3, L0

def backprop_snn_fd(inputs, w1, w2, w3,
                    desired=[2.1, 2.6, 2.9],
                    lr=0.1, eps=3):
    """
    One gradient‐step on the SNN via finite‐difference.
    Returns updated (w1, w2, w3) and the loss before update.
    """
    gw1, gw2, gw3, loss = finite_difference_grads(inputs, w1, w2, w3, desired, eps)
    # gradient descent
    w1 -= lr * gw1
    w2 -= lr * gw2
    w3 -= lr * gw3
    return w1, w2, w3, loss

# — example usage —
# random init
inputs    = np.random.uniform(0, 1, 4)
weights_1 = np.random.uniform(0, 1, 4)
weights_2 = np.random.uniform(0, 1, 40)
weights_3 = np.random.uniform(0, 1, 30)

print(weights_1, weights_2, weights_3)

desired = [2.1, 2.6, 2.9]
for epoch in range(1):
    w1, w2, w3, loss = backprop_snn_fd(
        inputs, weights_1, weights_2, weights_3,
        desired=desired, lr=0.5, eps=3
    )
    weights_1, weights_2, weights_3 = w1, w2, w3
    out = run_Urd(inputs, w1, w2, w3)[-3:]
    print(f"Epoch {epoch:2d}  Loss={loss:.4f}  Outputs={out}")

print(weights_1, weights_2, weights_3)

# After a few epochs you should see the 3 output times marching
# closer to [2.1, 2.6, 2.9].

[0.14345227 0.86066371 0.88725555 0.85814195] [0.44777422 0.76353581 0.48259652 0.3623195  0.99852955 0.99073865
 0.90408036 0.39032022 0.29293302 0.22087558 0.15663    0.32492888
 0.56082809 0.25731051 0.04950147 0.28284291 0.23587225 0.44035401
 0.39615996 0.35721367 0.89914801 0.44307031 0.63992119 0.52551554
 0.98827489 0.05136685 0.22214933 0.74473358 0.12725664 0.83123581
 0.98797027 0.54554983 0.49111424 0.37928825 0.2765512  0.63761703
 0.15451435 0.66322082 0.94605562 0.18985174] [0.22203751 0.43385482 0.91329917 0.06779644 0.40236334 0.81143013
 0.4022748  0.1185883  0.60056945 0.44718627 0.4049495  0.8794631
 0.85964283 0.19174956 0.60335014 0.80864427 0.46067967 0.14843805
 0.04586898 0.5629214  0.93889325 0.55204879 0.64219171 0.40427682
 0.87504725 0.92320919 0.09731941 0.36464712 0.61217678 0.59806614]
Epoch  0  Loss=0.2828  Outputs=[2.948, 2.955, 2.962]
[0.13953711 0.86182532 0.88831144 0.85949145] [0.44743883 0.76336731 0.48216986 0.36189828 0.99909788 0.9911931
 0.903

In [25]:
from brian2 import *
import numpy as np

#------------------------------------------------------------------------------
# 2) Network creation (only once)
#------------------------------------------------------------------------------
def create_network(weights1, weights2, weights3, input_times):
    start_scope()
    defaultclock.dt = 0.1*ms

    n_input, n_hidden, n_output = 4, 10, 3
    n_total = n_input + n_hidden + n_output

    # 2.1 NeuronGroup
    neurons = NeuronGroup(
        n_total,
        '''
        v : 1
        sum : 1
        spikes_received : 1
        scheduled_time : second
        global_clock : 1
        ''',
        threshold='v>1',
        reset='v=0',
        method='exact'
    )
    neurons.v = 0
    neurons.sum = 0
    neurons.spikes_received = 0
    neurons.scheduled_time = 1e9*second
    neurons.global_clock = 0

    # 2.2 Input spike generator
    stim = SpikeGeneratorGroup(
        n_input,
        np.arange(n_input),
        input_times * ms
    )

    # shared model string (no semicolons!)
    syn_model = '''
    w : 1
    layer : 1
    '''

    on_pre_eq = '''
    spikes_received += 1
    sum += spike_timing(w, global_clock, layer, sum, spikes_received)
    scheduled_time = math1(layer, sum, spikes_received)*ms
    '''

    # 2.3 Input → Input
    syn_input = Synapses(
        stim, neurons[:n_input],
        model=syn_model,
        on_pre=on_pre_eq
    )
    syn_input.connect(condition='i==j')
    syn_input.w = weights1
    syn_input.layer = 0

    # 2.4 Input → Hidden
    syn_hidden = Synapses(
        neurons[:n_input], neurons[n_input:n_input+n_hidden],
        model=syn_model,
        on_pre=on_pre_eq
    )
    syn_hidden.connect(condition='True')  # full connectivity
    syn_hidden.w = weights2
    syn_hidden.layer = 1

    # 2.5 Hidden → Output
    syn_output = Synapses(
        neurons[n_input:n_input+n_hidden], neurons[n_input+n_hidden:],
        model=syn_model,
        on_pre=on_pre_eq
    )
    syn_output.connect(condition='True')
    syn_output.w = weights3
    syn_output.layer = 2

    # 2.6 Regularly check scheduled_time → fire v
    neurons.run_regularly('''
        v = int(abs(t - scheduled_time) < 0.0005*ms)*1.2
        global_clock += 0.001
    ''', dt=0.001*ms)

    # 2.7 Monitor all spikes
    spikemon = SpikeMonitor(neurons)

    return {
        'neurons': neurons,
        'stim': stim,
        'syn_input': syn_input,
        'syn_hidden': syn_hidden,
        'syn_output': syn_output,
        'spikemon': spikemon
    }

#------------------------------------------------------------------------------
# 3) Reset network state before each forward pass
#------------------------------------------------------------------------------
def reset_state(net, input_times):
    ng = net['neurons']
    ng.v = 0
    ng.sum = 0
    ng.spikes_received = 0
    ng.scheduled_time = 1e9*second
    ng.global_clock = 0

    net['stim'].set_spikes(np.arange(4), input_times*ms)
    # Remove and remake the monitor each epoch (inside reset_state or your loop)
    # net['spikemon'].detach()
    # net['spikemon'] = SpikeMonitor(net['neurons'])
    # net['spikemon'].active = True

#------------------------------------------------------------------------------
# 4) Offline STDP-style learning (pair-based)
#------------------------------------------------------------------------------
def offline_update(net, lr=0.005, tau_plus=20.0, tau_minus=20.0, wmin=-1, wmax=1):
    """Return updated weight arrays [w1, w2, w3]."""
    trains = net['spikemon'].spike_trains()
    new_ws = []

    for syn_name in ('syn_input', 'syn_hidden', 'syn_output'):
        S = net[syn_name]
        w = np.array(S.w[:], copy=True)
        pre = np.array(S.i[:], int)
        post = np.array(S.j[:], int)

        # for each synapse
        for idx in range(len(w)):
            t_pre  = trains[pre[idx]]   * 1e3  # ms
            t_post = trains[post[idx]]  * 1e3  # ms
            if not len(t_pre) or not len(t_post):
                continue
            # pair-based STDP
            for tp in t_pre:
                for tj in t_post:
                    dt = tj - tp
                    if 0 < dt < 100:
                        w[idx] += lr * np.exp(-dt/tau_plus)
                    elif -100 < dt < 0:
                        w[idx] -= lr * np.exp(dt/tau_minus)

        new_ws.append(np.clip(w, wmin, wmax))

    return new_ws

#------------------------------------------------------------------------------
# 5) Training loop
#------------------------------------------------------------------------------
if __name__ == '__main__':
    # initial random weights
    np.random.seed(42)
    w1 = np.random.uniform(-0.5, 0.5, size=4)
    w2 = np.random.uniform(-0.5, 0.5, size=4*10)
    w3 = np.random.uniform(-0.5, 0.5, size=10*3)

    # one sample: 4 spike-times in ms
    input_times = np.array([0.0, 0.4, 0.6, 0.9])

    net = create_network(w1, w2, w3, input_times)

    n_epochs = 20
    # assume net, w1, w2, w3, reset_state, offline_update are defined

    for epoch in range(n_epochs):
        reset_state(net, input_times)

        # 1) new monitor for this epoch
        spikemon = SpikeMonitor(net['neurons'])

        # 2) forward pass
        run(5*ms)

        # 3) grab spikes
        trains = spikemon.spike_trains()
        del spikemon  # toss the old monitor

        # 4) offline weight update
        w1, w2, w3 = offline_update(net, trains)

        # 5) reload into Brian
        net['syn_input'].w[:]  = w1
        net['syn_hidden'].w[:] = w2
        net['syn_output'].w[:] = w3

        print(f'Epoch {epoch+1}/{n_epochs} – Avg spikes: '
            f'{np.mean([len(t) for t in trains.values()]):.2f}')

    # final test
  # attach a new monitor just for this pass
    spikemon = SpikeMonitor(net['neurons'])
    run(5*ms)
    # pull out the trains and then delete the monitor
    trains = spikemon.spike_trains()
    del spikemon
    final = net['spikemon'].spike_trains()
    print("Final first-spike times (ms):",
          [round(tr[0]*1e3,2) if len(tr)>0 else None for tr in final.values()])

Epoch 1/20 – Avg spikes: 0.00
Epoch 2/20 – Avg spikes: 0.00
Epoch 3/20 – Avg spikes: 0.00
Epoch 4/20 – Avg spikes: 0.00
Epoch 5/20 – Avg spikes: 0.00
Epoch 6/20 – Avg spikes: 0.00
Epoch 7/20 – Avg spikes: 0.00
Epoch 8/20 – Avg spikes: 0.00
Epoch 9/20 – Avg spikes: 0.00
Epoch 10/20 – Avg spikes: 0.00
Epoch 11/20 – Avg spikes: 0.00
Epoch 12/20 – Avg spikes: 0.00
Epoch 13/20 – Avg spikes: 0.00
Epoch 14/20 – Avg spikes: 0.00
Epoch 15/20 – Avg spikes: 0.00
Epoch 16/20 – Avg spikes: 0.00
Epoch 17/20 – Avg spikes: 0.00
Epoch 18/20 – Avg spikes: 0.00
Epoch 19/20 – Avg spikes: 0.00
Epoch 20/20 – Avg spikes: 0.00
Final first-spike times (ms): [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
