In [None]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process


# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        result = np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        result = 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))
    return np.clip(result, 0.0, 1.0)


@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(G.scheduled_time[0] / ms)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L

            # can come back to above to check ^^^^


            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            # print("otimes ", o_times)
            # print("delta_o ",delta_o) 
        
            # otimes  [2.666 2.003 2.524]
            # delta_o  [-0.284  -0.0235  0.237 ]

            # otimes  [2.726 2.003 2.619]
            # delta_o  [ 0.338  -0.0235 -0.331 ]




            aug_h = np.concatenate((h_times, [0.0])) # just adds a 0 to end of list
            # print("h_times ", h_times)
            # print("aug_h ", aug_h)

            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):        # will need to check out spinking time_dw
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)
            # print("w2 ", W2)
            # print("dW2 ", dW2)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2



        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        #lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad, max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = np.clip(W1 - lr * vW1, w_min, w_max)
        W2 = np.clip(W2 - lr  * vW2, w_min, w_max)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=50, lr=0.1)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    print("\n=== Test predictions ===")
    for xi, yi in zip(X, Y):
        # call layer_forward(positionally) rather than with layer1_idx=
        h_times = layer_forward(xi, W1_tr, 1)
        o_times = layer_forward(h_times, W2_tr, 2)

        pred_class = np.argmax(o_times)  
        true_class = np.argmax(yi)

        print(f"Input: {xi}")
        print(f" Spike times: {o_times}")
        print(f" Predicted class: {pred_class}, True class: {true_class}\n")

np.save('W1.npy', W1_tr)
np.save('W2.npy', W2_tr)
print("weights saved")



Epoch 1/50 — avg loss=0.1322
             ‖W1‖=1.307, ‖W2‖=3.945



Exception ignored in: <bound method InstanceTrackerSet.remove of InstanceTrackerSet({<weakref at 0x000001FFCA4274C0; to 'SynapticPathway' at 0x000001FFCA24A010>, <weakref at 0x000001FFC9E29260; dead>, <weakref at 0x000001FF9A456570; dead>, <weakref at 0x000001FFC9E4BC90; dead>, <weakref at 0x000001FFC9F9DC10; dead>, <weakref at 0x000001FFC9F870B0; dead>, <weakref at 0x000001FFC9E28A40; dead>, <weakref at 0x000001FFCA016B60; to 'MagicNetwork' at 0x000001FFCA00AAD0>, <weakref at 0x000001FFC9E4BFB0; dead>, <weakref at 0x000001FFCA3F7740; to 'SpikeGeneratorGroup' at 0x000001FF9A4AE390>, <weakref at 0x000001FFC9E28D60; dead>, <weakref at 0x000001FFCA3C4310; dead>, <weakref at 0x000001FFCA0CC9A0; to 'Clock' at 0x000001FF9A47F490>, <weakref at 0x000001FFC9F3A8E0; dead>, <weakref at 0x000001FFCA3F7010; to 'Thresholder' at 0x000001FFCA263490>, <weakref at 0x000001FFCA07E930; to 'Clock' at 0x000001FFC90164D0>, <weakref at 0x000001FFCA3F79C0; to 'Resetter' at 0x000001FFC9325C90>, <weakref at 0x00

Epoch 2/50 — avg loss=0.1321
             ‖W1‖=1.307, ‖W2‖=3.945



In [None]:
# Key fixes to your SNN code:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process


# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

# 1. Fix the derivative function to handle edge cases better
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-6  # Increased epsilon
    
    # Add bounds checking to prevent extreme values
    x = np.clip(x, eps, 1-eps)
    
    if w >= 0:
        # For positive weights, add safeguards against log(0)
        log_term = np.log(np.maximum(x, eps))
        result = - np.power(x, (1 - w)) * log_term
    else:
        # For negative weights
        log_term = np.log(np.maximum(1 - x, eps))
        result = - np.power((1 - x), (1 + w)) * log_term
    
    # Clip gradients to prevent explosion
    return np.clip(result, -10.0, 10.0)

# 2. Improve the spike timing function for better gradients
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-6
    x = np.clip(x, eps, 1-eps)  # Prevent saturation
    
    if w >= 0:
        result = np.power(x, (1 - w))
    else:
        result = 1 - np.power((1 - x), (1 + w))
    
    return np.clip(result, eps, 1-eps)  # Prevent complete saturation


def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(G.scheduled_time[0] / ms)
        out_times.append(t0)

    return np.array(out_times)


# 3. Modified training function with better hyperparameters
def train_snn_backprop_fixed(
    X, Y,
    W1_init, W2_init,
    epochs=50, lr=0.05,  # Reduced learning rate
    max_grad=5.0,        # Reduced gradient clipping
    w_min=-10.0, w_max=10.0,  # Reduced weight bounds
    non_target_time=2.5,  # Increased separation
    λ=1.0                # Increased penalty weight
):
    """
    Fixed version with better gradient handling and hyperparameters
    """
    W1 = W1_init.copy()
    W2 = W2_init.copy()

    # Adaptive learning rate
    beta = 0.9
    vW1 = np.zeros_like(W1)
    vW2 = np.zeros_like(W2)
    
    # Add learning rate decay
    lr_decay = 0.95
    current_lr = lr

    layer1_idx, layer2_idx = 1, 2
    N = len(X)

    for ep in range(epochs):
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2)
        epoch_loss = 0.0

        for xi, yi in zip(X, Y):
            # Forward pass
            h_times = layer_forward(xi, W1, layer1_idx)
            o_times = layer_forward(h_times, W2, layer2_idx)

            # Improved loss function with better class separation
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            
            # Encourage non-target outputs to be far from target
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([
                max(0, 0.2 - abs(o_times[j] - o_times[target_idx]))**2 
                for j in non_ids
            ])  # Margin loss for better separation
            
            L = L_target + L_non
            epoch_loss += L

            # Gradients for W2
            delta_o = np.zeros_like(o_times)
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            
            for j in non_ids:
                if abs(o_times[j] - o_times[target_idx]) < 0.2:
                    delta_o[j] = λ * np.sign(o_times[j] - o_times[target_idx])

            aug_h = np.concatenate((h_times, [0.0]))
            
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    grad = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    dW2[k, j] = delta_o[j] * grad

            # Backprop into hidden layer
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    grad = d_spike_timing_dw(W1[i, k], aug_xi[i], layer1_idx, 0, 1)
                    dW1[i, k] = delta_h[k] * grad

            acc_dW1 += dW1
            acc_dW2 += dW2

        # Average gradients
        acc_dW1 /= N
        acc_dW2 /= N

        # Check for vanishing gradients
        if np.mean(np.abs(acc_dW1)) < 1e-8 or np.mean(np.abs(acc_dW2)) < 1e-8:
            print(f"Warning: Vanishing gradients detected at epoch {ep+1}")
            current_lr *= 2  # Increase learning rate temporarily

        # Gradient clipping
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad, max_grad)

        # Momentum updates
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # Weight updates with bounds
        W1 = np.clip(W1 - current_lr * vW1, w_min, w_max)
        W2 = np.clip(W2 - current_lr * vW2, w_min, w_max)

        # Learning rate decay
        if ep % 10 == 0:
            current_lr *= lr_decay

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"  LR={current_lr:.4f}, ‖∇W1‖={np.linalg.norm(acc_dW1):.4f}, ‖∇W2‖={np.linalg.norm(acc_dW2):.4f}")
        print(f"  ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

# 4. Better weight initialization
def initialize_weights_better():
    """Initialize weights with better variance scaling"""
    # Xavier/Glorot initialization scaled for this problem
    W1_0 = np.random.randn(4+1, 10) * np.sqrt(2.0 / (4+10))
    W2_0 = np.random.randn(10+1, 3) * np.sqrt(2.0 / (10+3))
    return W1_0, W2_0

# 5. Add debugging function
def debug_gradients(W1, W2, X, Y):
    """Debug gradient flow"""
    print("=== Gradient Debug ===")
    xi, yi = X[0], Y[0]
    
    h_times = layer_forward(xi, W1, 1)
    o_times = layer_forward(h_times, W2, 2)
    
    print(f"Hidden times: {h_times}")
    print(f"Output times: {o_times}")
    
    # Check if outputs are too similar
    output_variance = np.var(o_times)
    print(f"Output variance: {output_variance:.6f}")
    
    if output_variance < 0.01:
        print("WARNING: Outputs are too similar - poor class separation")
    
    # Check weight ranges
    print(f"W1 range: [{W1.min():.3f}, {W1.max():.3f}]")
    print(f"W2 range: [{W2.min():.3f}, {W2.max():.3f}]")

# Usage with better initialization:
if __name__ == "__main__":
    # Your existing data
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    
    y0 = np.array([1.5, 2.0, 2.5])  # Changed targets for better separation
    y1 = np.array([2.5, 2.0, 1.5])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
    
    # Use better initialization instead of loaded weights
    W1_0, W2_0 = initialize_weights_better()
    
    # Debug before training
    debug_gradients(W1_0, W2_0, X, Y)
    
    # Train with fixed function
    W1_tr, W2_tr = train_snn_backprop_fixed(X, Y, W1_0, W2_0, epochs=10, lr=0.1)

=== Gradient Debug ===
Hidden times: [1.5337956  1.44094346 1.43766649 1.44222776 1.44966497 1.47865081
 1.49270243 1.47443207 1.46767445 1.45221082]
Output times: [2.47580381 2.5067406  2.42906183]
Output variance: 0.001020
W1 range: [-0.979, 0.690]
W2 range: [-0.539, 0.946]
Epoch 1/50 — avg loss=0.0267
  LR=0.0950, ‖∇W1‖=1.0933, ‖∇W2‖=1.7862
  ‖W1‖=2.333, ‖W2‖=2.231



Exception ignored in: <bound method InstanceTrackerSet.remove of InstanceTrackerSet({<weakref at 0x00000184734BB790; dead>, <weakref at 0x00000184737947C0; dead>, <weakref at 0x0000018473756F70; to 'StateUpdater' at 0x000001846FD09910>, <weakref at 0x0000018473754E50; dead>, <weakref at 0x0000018472E8D530; dead>, <weakref at 0x0000018472DF9850; to 'MagicNetwork' at 0x0000018472DEBB50>, <weakref at 0x0000018472BD6020; dead>, <weakref at 0x0000018473754680; to 'SpikeGeneratorGroup' at 0x0000018472E6BBD0>, <weakref at 0x00000184737943B0; dead>, <weakref at 0x0000018472E8C2C0; to 'SpikeMonitor' at 0x0000018470F3F450>, <weakref at 0x0000018472E8F8D0; to 'Clock' at 0x000001847371F490>, <weakref at 0x0000018472E8DB70; to 'SynapticPathway' at 0x000001847357EC90>, <weakref at 0x00000184437842C0; to 'CodeRunner' at 0x0000018472E97CD0>, <weakref at 0x0000018473551940; to 'NeuronGroup' at 0x000001847242FCD0>, <weakref at 0x0000018473756390; to 'Thresholder' at 0x000001847356B510>, <weakref at 0x00

Epoch 2/50 — avg loss=0.0272
  LR=0.0950, ‖∇W1‖=1.0929, ‖∇W2‖=1.7843
  ‖W1‖=2.332, ‖W2‖=2.221



Exception ignored in: <bound method InstanceTrackerSet.remove of InstanceTrackerSet({<weakref at 0x0000018472E9D080; dead>, <weakref at 0x0000018473794900; to 'Resetter' at 0x0000018472446010>, <weakref at 0x0000018472E8E660; to 'SynapticPathway' at 0x0000018472C71850>, <weakref at 0x00000184737566B0; dead>, <weakref at 0x0000018472BD5350; to 'Clock' at 0x000001847371F050>, <weakref at 0x00000184734BBC90; to 'CodeRunner' at 0x000001847274F990>, <weakref at 0x0000018473796340; to 'SpikeGeneratorGroup' at 0x00000184437D57D0>, <weakref at 0x0000018473754E50; dead>, <weakref at 0x00000184737945E0; to 'StateUpdater' at 0x00000184724D8F90>, <weakref at 0x0000018472E8CE00; dead>, <weakref at 0x0000018473551170; dead>, <weakref at 0x00000184735522F0; dead>, <weakref at 0x0000018472DF9850; to 'MagicNetwork' at 0x0000018472DEBB50>, <weakref at 0x0000018472E8CEF0; to 'Clock' at 0x00000184736E2B50>, <weakref at 0x0000018472E9CE50; to 'Synapses' at 0x000001847356EC10>, <weakref at 0x00000184720BA93

In [None]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L



            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            print("otimes ", o_times)
            print("delta_o ",delta_o)




            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = np.clip(W1 - lr1 * vW1, w_min, w_max)
        W2 = np.clip(W2 - lr  * vW2, w_min, w_max)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=1, lr=0.0)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    # print("\n=== Test predictions ===")
    # for xi, yi in zip(X, Y):
    #     # call layer_forward(positionally) rather than with layer1_idx=
    #     h_times = layer_forward(xi, W1_tr, 1)
    #     o_times = layer_forward(h_times, W2_tr, 2)

    #     pred_class = np.argmax(o_times)  
    #     true_class = np.argmax(yi)

    #     print(f"Input: {xi}")
    #     print(f" Spike times: {o_times}")
    #     print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]
otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]


KeyboardInterrupt: 

In [None]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L



            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            print("otimes ", o_times)
            print("delta_o ",delta_o)




            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = np.clip(W1 - lr1 * vW1, w_min, w_max)
        W2 = np.clip(W2 - lr  * vW2, w_min, w_max)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=1, lr=0.0)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    # print("\n=== Test predictions ===")
    # for xi, yi in zip(X, Y):
    #     # call layer_forward(positionally) rather than with layer1_idx=
    #     h_times = layer_forward(xi, W1_tr, 1)
    #     o_times = layer_forward(h_times, W2_tr, 2)

    #     pred_class = np.argmax(o_times)  
    #     true_class = np.argmax(yi)

    #     print(f"Input: {xi}")
    #     print(f" Spike times: {o_times}")
    #     print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]
otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]


KeyboardInterrupt: 

In [None]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L



            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            print("otimes ", o_times)
            print("delta_o ",delta_o)




            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = np.clip(W1 - lr1 * vW1, w_min, w_max)
        W2 = np.clip(W2 - lr  * vW2, w_min, w_max)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=1, lr=0.0)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    # print("\n=== Test predictions ===")
    # for xi, yi in zip(X, Y):
    #     # call layer_forward(positionally) rather than with layer1_idx=
    #     h_times = layer_forward(xi, W1_tr, 1)
    #     o_times = layer_forward(h_times, W2_tr, 2)

    #     pred_class = np.argmax(o_times)  
    #     true_class = np.argmax(yi)

    #     print(f"Input: {xi}")
    #     print(f" Spike times: {o_times}")
    #     print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]
otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]


KeyboardInterrupt: 

In [None]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L



            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            print("otimes ", o_times)
            print("delta_o ",delta_o)




            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = np.clip(W1 - lr1 * vW1, w_min, w_max)
        W2 = np.clip(W2 - lr  * vW2, w_min, w_max)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=1, lr=0.0)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    # print("\n=== Test predictions ===")
    # for xi, yi in zip(X, Y):
    #     # call layer_forward(positionally) rather than with layer1_idx=
    #     h_times = layer_forward(xi, W1_tr, 1)
    #     o_times = layer_forward(h_times, W2_tr, 2)

    #     pred_class = np.argmax(o_times)  
    #     true_class = np.argmax(yi)

    #     print(f"Input: {xi}")
    #     print(f" Spike times: {o_times}")
    #     print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]
otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]


KeyboardInterrupt: 

In [None]:
#testing on Iris
from sklearn import datasets
from sklearn.model_selection import train_test_split


if __name__ == "__main__":
    # Load Iris dataset and scale features to [0.05,0.95]
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    y_encoded = np.zeros((y.size, y.max()+1))
    y_encoded[np.arange(y.size), y] = 1
    
    def scale_features(x):
        mn, mx = x.min(axis=0), x.max(axis=0)
        x_norm = (x - mn) / (mx - mn)
        return x_norm * 0.9 + 0.05

    X_scaled = scale_features(X)
    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled, y_encoded, test_size=0.2,
        random_state=42, stratify=y
    )


    print("\n=== Test predictions ===")
    for x_test, y_test in zip(X, Y):
        # call layer_forward(positionally) rather than with layer1_idx=
        h_times = layer_forward(x_test, W1_tr, 1)
        o_times = layer_forward(h_times, W2_tr, 2)

        pred_class = np.argmax(o_times)  
        true_class = np.argmax(y_test)

        print(f"Input: {xi}")
        print(f" Spike times: {o_times}")
        print(f" Predicted class: {pred_class}, True class: {true_class}\n")



=== Test predictions ===
Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.228 5.    5.   ]
 Predicted class: 1, True class: 0

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.228 5.    5.   ]
 Predicted class: 1, True class: 2

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.359 5.    5.   ]
 Predicted class: 1, True class: 0

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.267 5.    5.   ]
 Predicted class: 1, True class: 2

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.228 5.    5.   ]
 Predicted class: 1, True class: 0

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.305 1.152 1.397]
 Predicted class: 0, True class: 2



KeyboardInterrupt: 