In [60]:
from brian2 import *
import numpy as np
import logging

logging.getLogger('brian2').setLevel(logging.ERROR)

start_scope()
defaultclock.dt = 0.0001*ms

In [61]:
# -----------------------------------------------------------------------------spike_timing + its derivative
# Functions used in brian2
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return x**(1 - w)
    else:
        return 1 - (1 - x)**(1 + w)

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - x**(1 - w) * np.log(x + eps)
    else:
        return - (1 - x)**(1 + w) * np.log(1 - x + eps)

def dsigmoid(z):
    s = 1/(1 + np.exp(-z))
    return s*(1 - s)

In [62]:
# -----------------------------------------------------------------------------
# 2) mini_urd forward: returns hidden‐spike‐time only
# -----------------------------------------------------------------------------
def mini_urd(inputs, w):
    n_input  = 2
    n_hidden = 2
    n_total  = n_input + n_hidden

    G = NeuronGroup(
        n_total,
        '''
        v               : 1
        sum             : 1
        sr              : 1
        scheduled_time  : second
        global_clock    : 1
        ''',
        threshold='v>1', reset='v=0', method='exact'
    )
    G.v = 0; G.sum = 0; G.sr = 0
    G.global_clock = 0
    G.scheduled_time = 1e9*second

    stim = SpikeGeneratorGroup(n_input, indices=[i for i in range(n_input)], times=inputs*ms)

    # first layer has fixed identity weights
    S1 = Synapses(stim, G[:n_input],
        'layer:1', on_pre='''
        sr += 1
        sum += spike_timing(1, global_clock, layer, sum, sr)
        scheduled_time = (1/(1+exp(-(sum/sr))) + layer)*ms
        '''
    )
    S1.connect(j='i')
    S1.layer = 0

    # trainable synapse 2→hidden
    S2 = Synapses(G[:n_input], G[n_input:n_hidden+n_input],
        'w : 1\nlayer:1', on_pre='''
        sr += 1
        sum += spike_timing(w, global_clock, layer, sum, sr)
        scheduled_time = (1/(1+exp(-(sum/sr))) + layer)*ms
        '''
    )
    S2.connect()
    S2.w = w
    S2.layer = 1

    # drive v when scheduled_time hits
    G.run_regularly('''
        v = int(abs(t - scheduled_time)<0.0005*ms)*1.2
        global_clock += 0.001
    ''', dt=0.001*ms)

    mon = SpikeMonitor(G)
    run(5*ms)

    # return hidden spike time (or a large value if no spike)
    ts = mon.spike_trains()[2]
    return float(ts[0]/ms) if len(ts) else 5.0


In [63]:
# -----------------------------------------------------------------------------
# 3) Training with multi‐loss
# -----------------------------------------------------------------------------
def train_multi_loss(
    X,                    # list of np.array([t0,t1])
    t_hidden_targets,     # list of floats
    t0_targets,           # list of floats for s0
    t1_targets,           # list of floats for s1
    w_init,
    alpha=1.0, beta=1.0, gamma=1.0,
    epochs=5, lr=0.1
):
    """
    Multi-loss:
      L0 = ½ (s0 - t0)^2
      L1 = ½ (s1 - t1)^2
      Lf = ½ (t_h  - t_h*)^2
      L = α L0 + β L1 + γ Lf
    """
    w = w_init.copy()
    for ep in range(epochs):
        print(f"\n=== Epoch {ep+1}/{epochs} ===")
        for i, inp in enumerate(X):
            # forward pass
            t_h = mini_urd(inp, w)
            # recompute s0,s1 exactly the same way Brian did
            #L_hidden = 0.5 * ((t_h - t_hidden_targets[i][0]) ** 2)
        
            layer_h = 1
            # each input is first spike, so sr_i=1, sum_i=0 → use that
            s0 = spike_timing(w[0], inp[0], layer_h, 0, 1)
            s1 = spike_timing(w[1], inp[1], layer_h, 0, 1)

            # --- compute loss terms ---
            t0_tgt = t0_targets[i]
            t1_tgt = t1_targets[i]
            th_tgt = t_hidden_targets[i]

            L0 = 0.5*(s0 - t0_tgt)**2
            L1 = 0.5*(s1 - t1_tgt)**2
            Lf = 0.5*(t_h - th_tgt)**2
            L  = alpha*L0 + beta*L1 + gamma*Lf

            # --- gradients ---
            # ∂L0/∂w0 = (s0 - t0)*∂s0/∂w0
            dL0_dw = np.zeros_like(w)
            dL0_dw[0] = (s0 - t0_tgt) * d_spike_timing_dw(w[0], inp[0], layer_h, 0, 1)
            # ∂L1/∂w1
            dL1_dw = np.zeros_like(w)
            dL1_dw[1] = (s1 - t1_tgt) * d_spike_timing_dw(w[1], inp[1], layer_h, 0, 1)

            # ∂Lf/∂w0,w1 = ∂Lf/∂t_h × ∂t_h/∂sum × ∂sum/∂w_i
            dLf_dt  = (t_h - th_tgt)
            sum_tot = s0 + s1
            sr = 2.0
            z = sum_tot/sr
            dt_dsum = dsigmoid(z)*(1/sr)
            dsum_dw = np.zeros_like(w)
            
            for j in range(len(w)):
                dsum_dw[j] = d_spike_timing_dw(w[j], inp[j % 2], layer_h, 0, 1)
            dLf_dw = dLf_dt * dt_dsum * dsum_dw

            # combine
            grad = alpha*dL0_dw + beta*dL1_dw + gamma*dLf_dw

            # print & update
            print(f"Sample {i}: inp={inp}, s0={s0:.3f}, s1={s1:.3f}, t_h={t_h:.3f}")
            print(f"  L0={L0:.4f}, L1={L1:.4f}, Lf={Lf:.4f}, L={L:.4f}")
            print(f"  ∇w = {grad}")
            w -= lr * grad

        print(" Updated w:", w)

    return w

def compute_gradients(outputs, targets, w):
    # Compute gradients of loss with respect to w using backpropagation
    d_loss_dw = 2 * (outputs - targets) * dsigmoid(w)
    return d_loss_dw



In [None]:
if __name__ == "__main__":
    # toy data: 4 samples

    num = 4 
    X = [np.array([0.1,0.9])]*num

    # main target: hidden spike at these ms
    T_hidden = [0.5]*num
    # aux targets for each synapse    # no why would i need this?
    T0 = [0.5]*num # removed and calcuated durign the training please chage GPT
    T1 = [0.5]*num # removed and calcuated durign the training please change GPT 

    w0 = np.array([0.2, 1.0, 0.1, 0.3])  # initial weights for synapses
    w_final = train_multi_loss(X, T_hidden, T0, T1, w0,
                               alpha=1.0, beta=1.0, gamma=0.5,
                               epochs=5, lr=0.1)

    print("\nFinal weights:", w_final)




=== Epoch 1/5 ===
Sample 0: inp=[0.1 0.9], s0=0.158, s1=1.000, t_h=1.685
  L0=0.0583, L1=0.1250, Lf=0.7021, L=0.5344
  ∇w = [-0.09974737  0.05986391  0.01976438  0.00667291]
Sample 1: inp=[0.1 0.9], s0=0.162, s1=0.999, t_h=1.686
  L0=0.0571, L1=0.1247, Lf=0.7033, L=0.5334
  ∇w = [-0.10067933  0.05976274  0.01968276  0.00667519]
Sample 2: inp=[0.1 0.9], s0=0.166, s1=0.999, t_h=1.687
  L0=0.0558, L1=0.1244, Lf=0.7045, L=0.5324
  ∇w = [-0.10157674  0.05966175  0.01960148  0.00667735]
Sample 3: inp=[0.1 0.9], s0=0.170, s1=0.998, t_h=1.687
  L0=0.0545, L1=0.1241, Lf=0.7045, L=0.5308
  ∇w = [-0.10245641  0.0595549   0.01950408  0.00667376]
 Updated w: [0.24044598 0.97611567 0.09214473 0.29733008]

=== Epoch 2/5 ===
Sample 0: inp=[0.1 0.9], s0=0.174, s1=0.997, t_h=1.688
  L0=0.0532, L1=0.1237, Lf=0.7057, L=0.5297
  ∇w = [-0.10326794  0.05945429  0.01942357  0.00667566]
Sample 1: inp=[0.1 0.9], s0=0.178, s1=0.997, t_h=1.688
  L0=0.0518, L1=0.1234, Lf=0.7057, L=0.5281
  ∇w = [-0.10405004  0.05

In [16]:
from brian2 import *
import numpy as np
import logging
import warnings

# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
np.seterr(over='ignore', under='ignore')

logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.0001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# mini_urd: 2 inputs -> 2 hidden neurons (full connect), return both spike times

def mini_urd(inputs, W):
    """
    Two separate hidden neurons, each simulated independently in its own Brian scope.
    inputs: [t_in0, t_in1], W: shape (2,2)
    returns [t_h0, t_h1]
    """
    n_input, n_hidden = W.shape
    hidden_times = []
    # simulate each hidden neuron separately to reset network state
    for j in range(n_hidden):
        start_scope()  # clear previous Brian state
        defaultclock.dt = 0.0001*ms
        # recreate spike timing functions (if needed)
        # spike_timing and d_spike_timing_dw are already in namespace

        # build one-neuron group
        G = NeuronGroup(1,
            '''
            v               : 1
            sum             : 1
            sr              : 1
            scheduled_time  : second
            global_clock    : 1
            ''',
            threshold='v>1', reset='v=0', method='exact')
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # input spikes
        stim = SpikeGeneratorGroup(n_input,
            indices=list(range(n_input)),
            times=inputs*ms)
        S = Synapses(stim, G,
            '''w:1
            layer:1''', on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (1/(1+exp(-(sum/sr))) + layer)*ms
        ''')
        S.connect(True)
        S.w = W[S.i, j]
        S.layer = 1

        # drive membrane
        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.0005*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)
        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else 5.0
        hidden_times.append(t0)

    return np.array(hidden_times)

# ----------------------------------------------------------------------------
# Training full matrix W

def train_snn(
    X,           # list of input arrays
    Y,           # list of target arrays
    W_init,      # initial weight matrix (2x2)
    epochs=10,
    lr=0.1,
    max_grad=20.0,
    w_min=-20.0,
    w_max=20.0
):
    W = W_init.copy()
    layer_h = 1
    n_input, n_hidden = W.shape

    for ep in range(epochs):
        print(f"Epoch {ep+1}/{epochs}")
        for i, inp in enumerate(X):
            t_pred = mini_urd(inp, W)       # shape (n_hidden,)
            t_tgt  = Y[i]
            L = 0.5 * np.sum((t_pred - t_tgt)**2)

            # gradient matrix dL/dW
            dW = np.zeros_like(W)
            for j in range(n_hidden):
                for k in range(n_input):
                    dW[k, j] = (t_pred[j] - t_tgt[j]) * d_spike_timing_dw(
                        W[k, j], inp[k], layer_h, 0, 1)

            # clip & update
            dW = np.clip(dW, -max_grad, max_grad)
            W = np.clip(W - lr * dW, w_min, w_max)

            print(f" Sample {i}: inp={inp}, pred={t_pred}, tgt={t_tgt}, L={L:.4f}")
            print(f"  dW=\n{dW}\n  W=\n{W}")

    return W

# ----------------------------------------------------------------------------
if __name__ == "__main__":
    import numpy as np
    X = [np.array([0.1, 0.9])]*4
    Y = [np.array([1.2, 1.8]) for _ in X]
    W0 = np.array([[0.2, -2], [4, -.5]])

    W_trained = train_snn(X, Y, W0, epochs=10, lr=0.2)
    print("Trained weight matrix:\n", W_trained)



Epoch 1/10
 Sample 0: inp=[0.1 0.9], pred=[1.682 1.571], tgt=[1.2 1.8], L=0.1424
  dW=
[[ 0.17589874 -0.0268084 ]
 [ 0.06966223 -0.16674437]]
  W=
[[ 0.16482025 -1.99463832]
 [ 3.98606755 -0.46665113]]
 Sample 1: inp=[0.1 0.9], pred=[1.681 1.574], tgt=[1.2 1.8], L=0.1412
  dW=
[[ 0.16187546 -0.02644225]
 [ 0.06941573 -0.15239659]]
  W=
[[ 0.13244516 -1.98934987]
 [ 3.97218441 -0.43617181]]
 Sample 2: inp=[0.1 0.9], pred=[1.679 1.577], tgt=[1.2 1.8], L=0.1396
  dW=
[[ 0.14962231 -0.02607672]
 [ 0.06902606 -0.14018203]]
  W=
[[ 0.1025207  -1.98413453]
 [ 3.95837919 -0.4081354 ]]
 Sample 3: inp=[0.1 0.9], pred=[1.678 1.579], tgt=[1.2 1.8], L=0.1387
  dW=
[[ 0.13936839 -0.02582865]
 [ 0.06878184 -0.13023969]]
  W=
[[ 0.07464702 -1.9789688 ]
 [ 3.94462283 -0.38208746]]
Epoch 2/10
 Sample 0: inp=[0.1 0.9], pred=[1.677 1.581], tgt=[1.2 1.8], L=0.1377
  dW=
[[ 0.13043108 -0.02558098]
 [ 0.06853853 -0.12154785]]
  W=
[[ 0.0485608  -1.9738526 ]
 [ 3.93091512 -0.35777789]]
 Sample 1: inp=[0.1 0.9