In [8]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.0,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # Initialize weights
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2
    N = len(X)

    for ep in range(epochs):
        # Accumulators for this epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2)
        epoch_loss = 0.0

        for xi, yi in zip(X, Y):
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            o_times = layer_forward(h_times, W2, layer2_idx)

            # — Separation loss —
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L

            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = np.clip(W1 - lr1 * vW1, w_min, w_max)
        W2 = np.clip(W2 - lr  * vW2, w_min, w_max)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2






            
            # # print changes before applying them
            # print("\nWeight changes for this sample:")
            # for i in range(W1.shape[0]):
            #     for k in range(W1.shape[1]):
            #         change = -lr * dW1[i,k]
            #         print(f"  W1[{i},{k}] change: {change:+.5f}")
            # for k in range(W2.shape[0]):
            #     for j in range(W2.shape[1]):
            #         change = -lr * dW2[k,j]
            #         print(f"  W2[{k},{j}] change: {change:+.5f}")
            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.05, 2.05])
    y1 = np.array([2.05, 2.05, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
    # X= []
    # Y = []
    # for _ in range(10):
    #     X.append(x0 + np.random.randn(4)*0.02);  Y.append(y0)
    #     X.append(x1 + np.random.randn(4)*0.02);  Y.append(y1)
    

    
     #W1_0 = np.array([[0.21958991, 0.16223261, 0.02545166, 0.18849804, 0.09521701, 0.22744421, 0.05556097, 0.33130229, 0.03974721, 0.1968464],
    #             [0.41958955, 0.47541312, 0.22287581, 0.69627866, 0.83639384, 0.79597959, 0.15029805, 0.126486, 0.18285382, 0.07470098],
    #             [0.69559509, 0.41228614, 0.06028855, 0.51098037, 0.33730611, 1.17605488, 0.15405119, 0.28079173, 0.17365651, 0.23041775],
    #             [0.79721356, 0.82210554, 0.15028745, 1.09421856, 0.68280376, 1.07577422, 0.16962136, 0.23838796, 0.0735181, 0.1719861],
    #             [0.0631449, 0.10618091, 0.05791614, 0.0260418, -0.01797577, -0.1209534, 0.18702474, -0.01662061, -0.0683026, 0.05468931]])
        #W1_0 = [[ 0.21958991,  0.16223261,  0.02545166  0.18849804  0.09521701  0.22744421, 0.05556097  0.33130229  0.03974721  0.1968464 ], [ 0.41958955  0.47541312  0.22287581  0.69627866  0.83639384  0.79597959, 0.15029805  0.126486    0.18285382  0.07470098], [ 0.69559509  0.41228614  0.06028855  0.51098037  0.33730611  1.17605488, 0.15405119  0.28079173  0.17365651  0.23041775], [ 0.79721356  0.82210554  0.15028745  1.09421856  0.68280376  1.07577422, 0.16962136  0.23838796  0.0735181   0.1719861 ], [ 0.0631449   0.10618091  0.05791614  0.0260418  -0.01797577 -0.1209534, 0.18702474 -0.01662061 -0.0683026   0.05468931]]
        
    
        
        #np.random.randn(4+1, 10)*0.1   # +1 for bias
    w1 = np.load("w1.npy")
    b1 = np.load("b1.npy")
    w2 = np.load("w2.npy")
    b2 = np.load("b2.npy")

    b1 = b1.reshape(1, -1)
    b2 = b2.reshape(1, -1)

    W1_0 = np.vstack([w1, b1])  # Combine weight and bias
    W2_0 = np.vstack([w2, b2])
    
    # W2_0 = np.array([[ 0.38843549, -1.10085101,  0.38776897],
    #             [ 0.35841177, -1.06985881,  0.40542767],
    #             [ 0.40732835, -0.63317637,  0.59202003],
    #             [ 0.32589618, -0.9691458,   0.48481597],
    #             [ 0.26060079, -0.85267046,  0.75057083],
    #             [ 0.31287437, -1.34806606,  0.44407244],
    #             [ 0.1996638,  -0.65507816,  0.27473825],
    #             [ 0.40169137, -0.9979045,   0.09884702],
    #             [ 0.37579471, -0.62826323,  0.58035222],
    #             [ 0.2191151,  -0.84980247,  0.14767976],
    #             [ 0.11199583, -0.16498428,  0.00871235]])
    
    # will print out last times so DO NOT run the same the same expermeent to have a differnt outcoem


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=10, lr=0.1)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # ── Now test on the same two patterns ──
    print("\n=== Test predictions ===")
    for xi, yi in zip(X, Y):
        # call layer_forward(positionally) rather than with layer1_idx=
        h_times = layer_forward(xi, W1_tr, 1)
        o_times = layer_forward(h_times, W2_tr, 2)

        pred_class = np.argmax(o_times)  # changed to argmin WHY???
        true_class = np.argmax(yi)

        print(f"Input: {xi}")
        print(f" Spike times: {o_times}")
        print(f" Predicted class: {pred_class}, True class: {true_class}\n")




Epoch 1/10 — avg loss=2.2882
             ‖W1‖=9.166, ‖W2‖=9.157

Epoch 2/10 — avg loss=2.2877
             ‖W1‖=9.342, ‖W2‖=9.557

Epoch 3/10 — avg loss=1.2062
             ‖W1‖=9.647, ‖W2‖=9.949

Epoch 4/10 — avg loss=1.1838
             ‖W1‖=10.164, ‖W2‖=10.250

Epoch 5/10 — avg loss=0.1218
             ‖W1‖=10.706, ‖W2‖=10.539

Epoch 6/10 — avg loss=1.1477
             ‖W1‖=11.334, ‖W2‖=10.897

Epoch 7/10 — avg loss=0.1275
             ‖W1‖=11.967, ‖W2‖=11.237

Epoch 8/10 — avg loss=1.1472
             ‖W1‖=12.690, ‖W2‖=11.640

Epoch 9/10 — avg loss=0.1362
             ‖W1‖=13.393, ‖W2‖=12.016

Epoch 10/10 — avg loss=0.1400
             ‖W1‖=14.060, ‖W2‖=12.366

Trained W1: [[-5.47237496 -0.30621458 -1.39157777 -1.3210122   2.02990935  0.13222637
  -5.66102315  1.87108098  0.97082094 -0.32550164]
 [-3.03721634 -0.25017633 -0.288491   -0.34870021 -0.86124509 -0.28085861
  -4.13226868  0.19064628  0.40280951  0.20816894]
 [-0.29862368  0.11241383  0.78464379  1.01685423 -5.64681193 -