In [60]:
from brian2 import *
import numpy as np
import logging

logging.getLogger('brian2').setLevel(logging.ERROR)

start_scope()
defaultclock.dt = 0.0001*ms

In [61]:
# -----------------------------------------------------------------------------spike_timing + its derivative
# Functions used in brian2
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return x**(1 - w)
    else:
        return 1 - (1 - x)**(1 + w)

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - x**(1 - w) * np.log(x + eps)
    else:
        return - (1 - x)**(1 + w) * np.log(1 - x + eps)

def dsigmoid(z):
    s = 1/(1 + np.exp(-z))
    return s*(1 - s)

In [62]:
# -----------------------------------------------------------------------------
# 2) mini_urd forward: returns hidden‐spike‐time only
# -----------------------------------------------------------------------------
def mini_urd(inputs, w):
    n_input  = 2
    n_hidden = 2
    n_total  = n_input + n_hidden

    G = NeuronGroup(
        n_total,
        '''
        v               : 1
        sum             : 1
        sr              : 1
        scheduled_time  : second
        global_clock    : 1
        ''',
        threshold='v>1', reset='v=0', method='exact'
    )
    G.v = 0; G.sum = 0; G.sr = 0
    G.global_clock = 0
    G.scheduled_time = 1e9*second

    stim = SpikeGeneratorGroup(n_input, indices=[i for i in range(n_input)], times=inputs*ms)

    # first layer has fixed identity weights
    S1 = Synapses(stim, G[:n_input],
        'layer:1', on_pre='''
        sr += 1
        sum += spike_timing(1, global_clock, layer, sum, sr)
        scheduled_time = (1/(1+exp(-(sum/sr))) + layer)*ms
        '''
    )
    S1.connect(j='i')
    S1.layer = 0

    # trainable synapse 2→hidden
    S2 = Synapses(G[:n_input], G[n_input:n_hidden+n_input],
        'w : 1\nlayer:1', on_pre='''
        sr += 1
        sum += spike_timing(w, global_clock, layer, sum, sr)
        scheduled_time = (1/(1+exp(-(sum/sr))) + layer)*ms
        '''
    )
    S2.connect()
    S2.w = w
    S2.layer = 1

    # drive v when scheduled_time hits
    G.run_regularly('''
        v = int(abs(t - scheduled_time)<0.0005*ms)*1.2
        global_clock += 0.001
    ''', dt=0.001*ms)

    mon = SpikeMonitor(G)
    run(5*ms)

    # return hidden spike time (or a large value if no spike)
    ts = mon.spike_trains()[2]
    return float(ts[0]/ms) if len(ts) else 5.0


In [63]:
# -----------------------------------------------------------------------------
# 3) Training with multi‐loss
# -----------------------------------------------------------------------------
def train_multi_loss(
    X,                    # list of np.array([t0,t1])
    t_hidden_targets,     # list of floats
    t0_targets,           # list of floats for s0
    t1_targets,           # list of floats for s1
    w_init,
    alpha=1.0, beta=1.0, gamma=1.0,
    epochs=5, lr=0.1
):
    """
    Multi-loss:
      L0 = ½ (s0 - t0)^2
      L1 = ½ (s1 - t1)^2
      Lf = ½ (t_h  - t_h*)^2
      L = α L0 + β L1 + γ Lf
    """
    w = w_init.copy()
    for ep in range(epochs):
        print(f"\n=== Epoch {ep+1}/{epochs} ===")
        for i, inp in enumerate(X):
            # forward pass
            t_h = mini_urd(inp, w)
            # recompute s0,s1 exactly the same way Brian did
            #L_hidden = 0.5 * ((t_h - t_hidden_targets[i][0]) ** 2)
        
            layer_h = 1
            # each input is first spike, so sr_i=1, sum_i=0 → use that
            s0 = spike_timing(w[0], inp[0], layer_h, 0, 1)
            s1 = spike_timing(w[1], inp[1], layer_h, 0, 1)

            # --- compute loss terms ---
            t0_tgt = t0_targets[i]
            t1_tgt = t1_targets[i]
            th_tgt = t_hidden_targets[i]

            L0 = 0.5*(s0 - t0_tgt)**2
            L1 = 0.5*(s1 - t1_tgt)**2
            Lf = 0.5*(t_h - th_tgt)**2
            L  = alpha*L0 + beta*L1 + gamma*Lf

            # --- gradients ---
            # ∂L0/∂w0 = (s0 - t0)*∂s0/∂w0
            dL0_dw = np.zeros_like(w)
            dL0_dw[0] = (s0 - t0_tgt) * d_spike_timing_dw(w[0], inp[0], layer_h, 0, 1)
            # ∂L1/∂w1
            dL1_dw = np.zeros_like(w)
            dL1_dw[1] = (s1 - t1_tgt) * d_spike_timing_dw(w[1], inp[1], layer_h, 0, 1)

            # ∂Lf/∂w0,w1 = ∂Lf/∂t_h × ∂t_h/∂sum × ∂sum/∂w_i
            dLf_dt  = (t_h - th_tgt)
            sum_tot = s0 + s1
            sr = 2.0
            z = sum_tot/sr
            dt_dsum = dsigmoid(z)*(1/sr)
            dsum_dw = np.zeros_like(w)
            
            for j in range(len(w)):
                dsum_dw[j] = d_spike_timing_dw(w[j], inp[j % 2], layer_h, 0, 1)
            dLf_dw = dLf_dt * dt_dsum * dsum_dw

            # combine
            grad = alpha*dL0_dw + beta*dL1_dw + gamma*dLf_dw

            # print & update
            print(f"Sample {i}: inp={inp}, s0={s0:.3f}, s1={s1:.3f}, t_h={t_h:.3f}")
            print(f"  L0={L0:.4f}, L1={L1:.4f}, Lf={Lf:.4f}, L={L:.4f}")
            print(f"  ∇w = {grad}")
            w -= lr * grad

        print(" Updated w:", w)

    return w

def compute_gradients(outputs, targets, w):
    # Compute gradients of loss with respect to w using backpropagation
    d_loss_dw = 2 * (outputs - targets) * dsigmoid(w)
    return d_loss_dw



In [None]:
if __name__ == "__main__":
    # toy data: 4 samples

    num = 4 
    X = [np.array([0.1,0.9])]*num

    # main target: hidden spike at these ms
    T_hidden = [0.5]*num
    # aux targets for each synapse    # no why would i need this?
    T0 = [0.5]*num # removed and calcuated durign the training please chage GPT
    T1 = [0.5]*num # removed and calcuated durign the training please change GPT 

    w0 = np.array([0.2, 1.0, 0.1, 0.3])  # initial weights for synapses
    w_final = train_multi_loss(X, T_hidden, T0, T1, w0,
                               alpha=1.0, beta=1.0, gamma=0.5,
                               epochs=5, lr=0.1)

    print("\nFinal weights:", w_final)




=== Epoch 1/5 ===
Sample 0: inp=[0.1 0.9], s0=0.158, s1=1.000, t_h=1.685
  L0=0.0583, L1=0.1250, Lf=0.7021, L=0.5344
  ∇w = [-0.09974737  0.05986391  0.01976438  0.00667291]
Sample 1: inp=[0.1 0.9], s0=0.162, s1=0.999, t_h=1.686
  L0=0.0571, L1=0.1247, Lf=0.7033, L=0.5334
  ∇w = [-0.10067933  0.05976274  0.01968276  0.00667519]
Sample 2: inp=[0.1 0.9], s0=0.166, s1=0.999, t_h=1.687
  L0=0.0558, L1=0.1244, Lf=0.7045, L=0.5324
  ∇w = [-0.10157674  0.05966175  0.01960148  0.00667735]
Sample 3: inp=[0.1 0.9], s0=0.170, s1=0.998, t_h=1.687
  L0=0.0545, L1=0.1241, Lf=0.7045, L=0.5308
  ∇w = [-0.10245641  0.0595549   0.01950408  0.00667376]
 Updated w: [0.24044598 0.97611567 0.09214473 0.29733008]

=== Epoch 2/5 ===
Sample 0: inp=[0.1 0.9], s0=0.174, s1=0.997, t_h=1.688
  L0=0.0532, L1=0.1237, Lf=0.7057, L=0.5297
  ∇w = [-0.10326794  0.05945429  0.01942357  0.00667566]
Sample 1: inp=[0.1 0.9], s0=0.178, s1=0.997, t_h=1.688
  L0=0.0518, L1=0.1234, Lf=0.7057, L=0.5281
  ∇w = [-0.10405004  0.05

In [19]:
from brian2 import *
import numpy as np
import logging
import warnings

# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
np.seterr(over='ignore', under='ignore')

logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.0001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# mini_urd: 2 inputs -> 2 hidden neurons (full connect), return both spike times

def mini_urd(inputs, W):
    """
    Two separate hidden neurons, each simulated independently in its own Brian scope.
    inputs: [t_in0, t_in1], W: shape (2,2)
    returns [t_h0, t_h1]
    """
    n_input, n_hidden = W.shape
    hidden_times = []
    # simulate each hidden neuron separately to reset network state
    for j in range(n_hidden):
        start_scope()  # clear previous Brian state
        defaultclock.dt = 0.0001*ms
        # recreate spike timing functions (if needed)
        # spike_timing and d_spike_timing_dw are already in namespace

        # build one-neuron group
        G = NeuronGroup(1,
            '''
            v               : 1
            sum             : 1
            sr              : 1
            scheduled_time  : second
            global_clock    : 1
            ''',
            threshold='v>1', reset='v=0', method='exact')
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # input spikes
        stim = SpikeGeneratorGroup(n_input,
            indices=list(range(n_input)),
            times=inputs*ms)
        S = Synapses(stim, G,
            '''w:1
              layer:1''', on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[S.i, j]
        S.layer = 1

        # drive membrane
        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.0005*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)
        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else 5.0
        hidden_times.append(t0)

    return np.array(hidden_times)

# ----------------------------------------------------------------------------
# Training full matrix W

def train_snn(
    X,           # list of input arrays
    Y,           # list of target arrays
    W_init,      # initial weight matrix (2x2)
    epochs=10,
    lr=0.1,
    max_grad=20.0,
    w_min=-20.0,
    w_max=20.0
):
    W = W_init.copy()
    layer_h = 1
    n_input, n_hidden = W.shape

    for ep in range(epochs):
        print(f"Epoch {ep+1}/{epochs}")
        for i, inp in enumerate(X):
            t_pred = mini_urd(inp, W)       # shape (n_hidden,)
            t_tgt  = Y[i]
            L = 0.5 * np.sum((t_pred - t_tgt)**2)

            # gradient matrix dL/dW
            dW = np.zeros_like(W)
            for j in range(n_hidden):
                for k in range(n_input):
                    dW[k, j] = (t_pred[j] - t_tgt[j]) * d_spike_timing_dw(
                        W[k, j], inp[k], layer_h, 0, 1)

            # clip & update
            dW = np.clip(dW, -max_grad, max_grad)
            W = np.clip(W - lr * dW, w_min, w_max)

            print(f" Sample {i}: inp={inp}, pred={t_pred}, tgt={t_tgt}, L={L:.4f}")
            print(f"  dW=\n{dW}\n  W=\n{W}")

    return W

# ----------------------------------------------------------------------------
if __name__ == "__main__":
    import numpy as np
    X = [np.array([0.1, 0.9])]*4
    Y = [np.array([1.2, 1.8]) for _ in X]
    W0 = np.array([[0.2, -0.5], [0.3, -.5]])

    W_trained = train_snn(X, Y, W0, epochs=30, lr=0.2)
    print("Trained weight matrix:\n", W_trained)



Epoch 1/30
 Sample 0: inp=[0.1 0.9], pred=[1.545 1.369], tgt=[1.2 1.8], L=0.1524
  dW=
[[ 0.12590262 -0.04308007]
 [ 0.033765   -0.31382892]]
  W=
[[ 0.17481948 -0.49138399]
 [ 0.293247   -0.43723422]]
 Sample 1: inp=[0.1 0.9], pred=[1.54 1.39], tgt=[1.2 1.8], L=0.1419
  dW=
[[ 0.11708846 -0.04094386]
 [ 0.03325198 -0.25836505]]
  W=
[[ 0.15140178 -0.48319521]
 [ 0.2865966  -0.38556121]]
 Sample 2: inp=[0.1 0.9], pred=[1.536 1.406], tgt=[1.2 1.8], L=0.1341
  dW=
[[ 0.1096369  -0.03931212]
 [ 0.03283777 -0.22043123]]
  W=
[[ 0.1294744  -0.47533279]
 [ 0.28002905 -0.34147496]]
 Sample 3: inp=[0.1 0.9], pred=[1.532 1.418], tgt=[1.2 1.8], L=0.1281
  dW=
[[ 0.10299785 -0.03808323]
 [ 0.0324244  -0.19308743]]
  W=
[[ 0.10887483 -0.46771615]
 [ 0.27354417 -0.30285748]]
Epoch 2/30
 Sample 0: inp=[0.1 0.9], pred=[1.528 1.428], tgt=[1.2 1.8], L=0.1230
  dW=
[[ 0.09704303 -0.03705654]
 [ 0.03201186 -0.17203472]]
  W=
[[ 0.08946623 -0.46030484]
 [ 0.2671418  -0.26845053]]
 Sample 1: inp=[0.1 0.9],

In [6]:
from brian2 import *
import numpy as np
import logging
import warnings

# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.0001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in,n_out)
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    n_in, n_out = W.shape
    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.0001*ms
        # Neuron group
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second
        # Spike inputs
        stim = SpikeGeneratorGroup(n_in, indices=list(range(n_in)), times=inputs*ms)
        S = Synapses(stim, G, '''w:1
            layer:1''', on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx
        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.0005*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)
        mon = SpikeMonitor(G)
        run(5*ms)
        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)
    return np.array(out_times)

# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3

def train_snn_backprop(
    X, Y,        # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1, max_grad=20.0, w_min=-20.0, w_max=20.0
):
    W1 = W1_init.copy()
    W2 = W2_init.copy()
    layer1_idx = 1
    layer2_idx = 2

    for ep in range(epochs):
        print(f"Epoch {ep+1}/{epochs}")
        for xi, yi in zip(X, Y):
            # forward
            h_times = layer_forward(xi, W1, layer1_idx)  # shape (10,)
            o_times = layer_forward(h_times, W2, layer2_idx)  # shape (3,)
            # loss
            L = 0.5 * np.sum((o_times - yi)**2)
            # gradients
            dW2 = np.zeros_like(W2)
            delta_o = (o_times - yi)  # shape (3,)
            # dW2[k,j] = delta_o[j] * d_spike_timing_dw(W2[k,j], h_times[k], layer2, 0, 1)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k,j] = delta_o[j] * d_spike_timing_dw(
                        W2[k,j], h_times[k], layer2_idx, 0, 1)
            # hidden deltas
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                # sum over output neurons
                for j in range(W2.shape[1]):
                    dt_dw = d_spike_timing_dw(W2[k,j], h_times[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * W2[k,j] * dt_dw
            # gradients for W1
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i,k] = delta_h[k] * d_spike_timing_dw(
                        W1[i,k], xi[i], layer1_idx, 0, 1)
            # clip and update
            dW1 = np.clip(dW1, -max_grad, max_grad)
            dW2 = np.clip(dW2, -max_grad, max_grad)
            W1 = np.clip(W1 - lr * dW1, w_min, w_max)
            W2 = np.clip(W2 - lr * dW2, w_min, w_max)
            # logging
            print(f" Input: {xi}, Pred: {o_times}, Target: {yi}, Loss: {L:.4f}")
        print(f" End Epoch {ep+1}: W1 norm={np.linalg.norm(W1):.3f}, W2 norm={np.linalg.norm(W2):.3f}\n")
    return W1, W2

if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    X = [np.array([0.6, 0.2, 0.4, 0.8]) for _ in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    Y = [np.array([2.1, 2.9, 2.1]) for _ in range(8)]
    # initialize weights
    W1_0 = np.random.randn(4, 10) * 0.1
    W2_0 = np.random.randn(10, 3) * 0.1
    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                     epochs=10, lr=0.4)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 



Epoch 1/10
 Input: [0.6 0.2 0.4 0.8], Pred: [2.529 2.501 2.504], Target: [2.1 2.9 2.1], Loss: 0.2532
 Input: [0.6 0.2 0.4 0.8], Pred: [2.505 2.522 2.481], Target: [2.1 2.9 2.1], Loss: 0.2260
 Input: [0.6 0.2 0.4 0.8], Pred: [2.484 2.545 2.461], Target: [2.1 2.9 2.1], Loss: 0.2019
 Input: [0.6 0.2 0.4 0.8], Pred: [2.466 2.567 2.443], Target: [2.1 2.9 2.1], Loss: 0.1812
 Input: [0.6 0.2 0.4 0.8], Pred: [2.447 2.591 2.424], Target: [2.1 2.9 2.1], Loss: 0.1604
 Input: [0.6 0.2 0.4 0.8], Pred: [2.429 2.615 2.405], Target: [2.1 2.9 2.1], Loss: 0.1412
 Input: [0.6 0.2 0.4 0.8], Pred: [2.411 2.639 2.387], Target: [2.1 2.9 2.1], Loss: 0.1236
 Input: [0.6 0.2 0.4 0.8], Pred: [2.393 2.662 2.368], Target: [2.1 2.9 2.1], Loss: 0.1072
 End Epoch 1: W1 norm=0.879, W2 norm=2.271

Epoch 2/10
 Input: [0.6 0.2 0.4 0.8], Pred: [2.374 2.685 2.348], Target: [2.1 2.9 2.1], Loss: 0.0914
 Input: [0.6 0.2 0.4 0.8], Pred: [2.355 2.707 2.329], Target: [2.1 2.9 2.1], Loss: 0.0774
 Input: [0.6 0.2 0.4 0.8], Pred: [

KeyboardInterrupt: 

In [9]:
from brian2 import *
import numpy as np
import logging
import warnings

# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.0001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass for a single layer

def layer_forward(inputs, W, layer_idx):
    n_in, n_out = W.shape
    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.0001*ms
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second
        stim = SpikeGeneratorGroup(n_in, indices=list(range(n_in)), times=inputs*ms)
        S = Synapses(stim, G, '''w:1
                      layer:1''', on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx
        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.0005*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)
        mon = SpikeMonitor(G)
        run(5*ms)
        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else 5.0
        out_times.append(t0)
    return np.array(out_times)

# ----------------------------------------------------------------------------
# Training with mini-batch updates for 4-10-3 network

def train_snn_backprop(
    X, Y,
    W1_init, W2_init,
    batch_size=8,
    epochs=10,
    lr=0.1,
    max_grad=20.0,
    w_min=-20.0,
    w_max=20.0
):
    W1 = W1_init.copy()
    W2 = W2_init.copy()
    n_samples = len(X)
    layer1_idx, layer2_idx = 1, 2

    for ep in range(epochs):
        print(f"Epoch {ep+1}/{epochs}")
        # shuffle indices
        idxs = np.random.permutation(n_samples)
        for start in range(0, n_samples, batch_size):
            batch_idxs = idxs[start:start+batch_size]
            # accumulate gradients
            acc_dW1 = np.zeros_like(W1)
            acc_dW2 = np.zeros_like(W2)
            for i in batch_idxs:
                xi, yi = X[i], Y[i]
                h_times = layer_forward(xi, W1, layer1_idx)
                o_times = layer_forward(h_times, W2, layer2_idx)
                delta_o = (o_times - yi)
                # grad W2
                for k in range(W2.shape[0]):
                    for j in range(W2.shape[1]):
                        acc_dW2[k,j] += delta_o[j] * d_spike_timing_dw(
                            W2[k,j], h_times[k], layer2_idx, 0, 1)
                # hidden deltas & grad W1
                delta_h = np.zeros_like(h_times)
                for k in range(len(h_times)):
                    for j in range(W2.shape[1]):
                        dt_dw = d_spike_timing_dw(W2[k,j], h_times[k], layer2_idx, 0, 1)
                        delta_h[k] += delta_o[j] * W2[k,j] * dt_dw
                for a in range(W1.shape[0]):
                    for b in range(W1.shape[1]):
                        acc_dW1[a,b] += delta_h[b] * d_spike_timing_dw(
                            W1[a,b], xi[a], layer1_idx, 0, 1)
            # average and clip
            acc_dW1 /= len(batch_idxs)
            acc_dW2 /= len(batch_idxs)
            acc_dW1 = np.clip(acc_dW1, -max_grad, max_grad)
            acc_dW2 = np.clip(acc_dW2, -max_grad, max_grad)
            # update
            W1 = np.clip(W1 - lr * acc_dW1, w_min, w_max)
            W2 = np.clip(W2 - lr * acc_dW2, w_min, w_max)
        print(f" After Epoch {ep+1}: ||W1||={np.linalg.norm(W1):.3f}, ||W2||={np.linalg.norm(W2):.3f}")
    return W1, W2

# ----------------------------------------------------------------------------
if __name__ == "__main__":
    # example with batch updates
    X = [np.array([0.6,0.2,0.4,0.8]) for _ in range(16)]
    Y = [np.array([2.1,2.1,2.9]) for _ in range(16)]
    W1_0 = np.random.randn(4,10)*0.1
    W2_0 = np.random.randn(10,3)*0.1
    W1_tr, W2_tr = train_snn_backprop(
        X, Y, W1_0, W2_0,
        batch_size=8, epochs=20, lr=0.05
    )
    print("Trained W1 norm:", np.linalg.norm(W1_tr))
    print("Trained W2 norm:", np.linalg.norm(W2_tr))


Epoch 1/20
 After Epoch 1: ||W1||=0.665, ||W2||=0.436
Epoch 2/20
 After Epoch 2: ||W1||=0.664, ||W2||=0.431
Epoch 3/20
 After Epoch 3: ||W1||=0.665, ||W2||=0.441
Epoch 4/20
 After Epoch 4: ||W1||=0.665, ||W2||=0.464
Epoch 5/20
 After Epoch 5: ||W1||=0.666, ||W2||=0.500
Epoch 6/20
 After Epoch 6: ||W1||=0.668, ||W2||=0.544
Epoch 7/20
 After Epoch 7: ||W1||=0.669, ||W2||=0.595
Epoch 8/20
 After Epoch 8: ||W1||=0.671, ||W2||=0.650
Epoch 9/20
 After Epoch 9: ||W1||=0.674, ||W2||=0.710
Epoch 10/20
 After Epoch 10: ||W1||=0.677, ||W2||=0.772
Epoch 11/20
 After Epoch 11: ||W1||=0.680, ||W2||=0.835
Epoch 12/20
 After Epoch 12: ||W1||=0.684, ||W2||=0.901
Epoch 13/20
 After Epoch 13: ||W1||=0.689, ||W2||=0.967
Epoch 14/20
 After Epoch 14: ||W1||=0.694, ||W2||=1.034
Epoch 15/20
 After Epoch 15: ||W1||=0.700, ||W2||=1.101
Epoch 16/20
 After Epoch 16: ||W1||=0.706, ||W2||=1.169
Epoch 17/20
 After Epoch 17: ||W1||=0.714, ||W2||=1.237
Epoch 18/20
 After Epoch 18: ||W1||=0.722, ||W2||=1.305
Epoch 19/2