In [10]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process


# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    return np.tanh(w * x)


@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(G.scheduled_time[0] / ms)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L

            # can come back to above to check ^^^^


            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            # print("otimes ", o_times)
            # print("delta_o ",delta_o) 
        
            # otimes  [2.666 2.003 2.524]
            # delta_o  [-0.284  -0.0235  0.237 ]

            # otimes  [2.726 2.003 2.619]
            # delta_o  [ 0.338  -0.0235 -0.331 ]




            aug_h = np.concatenate((h_times, [0.0])) # just adds a 0 to end of list
            # print("h_times ", h_times)
            # print("aug_h ", aug_h)

            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):        # will need to check out spinking time_dw
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)
            # print("w2 ", W2)
            # print("dW2 ", dW2)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2



        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        #lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad, max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = W1 - lr * vW1  
        W2 = W2 - lr * vW2


        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("NN_W_1.npy")
    print(W1_0.shape)

    test  = np.load("W1.npy")
    print(test.shape)

    W2_0 = np.load("NN_W_2.npy")
    print(W2_0.shape)

    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=10, lr=0.1)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    print("\n=== Test predictions ===")
    for xi, yi in zip(X, Y):
        # call layer_forward(positionally) rather than with layer1_idx=
        h_times = layer_forward(xi, W1_tr, 1)
        o_times = layer_forward(h_times, W2_tr, 2)

        pred_class = np.argmax(o_times)  
        true_class = np.argmax(yi)

        print(f"Input: {xi}")
        print(f" Spike times: {o_times}")
        print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



(3, 2)
(5, 10)
(3, 2)


AssertionError: 

In [13]:
# Key fixes to your SNN code:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process


# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

# 1. Fix the derivative function to handle edge cases better
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-6  # Increased epsilon
    
    # Add bounds checking to prevent extreme values
    x = np.clip(x, eps, 1-eps)
    
    if w >= 0:
        # For positive weights, add safeguards against log(0)
        log_term = np.log(np.maximum(x, eps))
        result = - np.power(x, (1 - w)) * log_term
    else:
        # For negative weights
        log_term = np.log(np.maximum(1 - x, eps))
        result = - np.power((1 - x), (1 + w)) * log_term
    
    # Clip gradients to prevent explosion
    return np.clip(result, -10.0, 10.0)

# 2. Improve the spike timing function for better gradients
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-6
    x = np.clip(x, eps, 1-eps)  # Prevent saturation
    
    if w >= 0:
        result = np.power(x, (1 - w))
    else:
        result = 1 - np.power((1 - x), (1 + w))
    
    return np.clip(result, eps, 1-eps)  # Prevent complete saturation


def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(G.scheduled_time[0] / ms)
        out_times.append(t0)

    return np.array(out_times)


# 3. Modified training function with better hyperparameters
def train_snn_backprop_fixed(
    X, Y,
    W1_init, W2_init,
    epochs=50, lr=0.05,  # Reduced learning rate
    max_grad=5.0,        # Reduced gradient clipping
    w_min=-10.0, w_max=10.0,  # Reduced weight bounds
    non_target_time=2.5,  # Increased separation
    λ=1.0                # Increased penalty weight
):
    """
    Fixed version with better gradient handling and hyperparameters
    """
    W1 = W1_init.copy()
    W2 = W2_init.copy()

    # Adaptive learning rate
    beta = 0.9
    vW1 = np.zeros_like(W1)
    vW2 = np.zeros_like(W2)
    
    # Add learning rate decay
    lr_decay = 0.95
    current_lr = lr

    layer1_idx, layer2_idx = 1, 2
    N = len(X)

    for ep in range(epochs):
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2)
        epoch_loss = 0.0

        for xi, yi in zip(X, Y):
            # Forward pass
            h_times = layer_forward(xi, W1, layer1_idx)
            o_times = layer_forward(h_times, W2, layer2_idx)

            # Improved loss function with better class separation
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            
            # Encourage non-target outputs to be far from target
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([
                max(0, 0.2 - abs(o_times[j] - o_times[target_idx]))**2 
                for j in non_ids
            ])  # Margin loss for better separation
            
            L = L_target + L_non
            epoch_loss += L

            # Gradients for W2
            delta_o = np.zeros_like(o_times)
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            
            for j in non_ids:
                if abs(o_times[j] - o_times[target_idx]) < 0.2:
                    delta_o[j] = λ * np.sign(o_times[j] - o_times[target_idx])

            aug_h = np.concatenate((h_times, [0.0]))
            
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    grad = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    dW2[k, j] = delta_o[j] * grad

            # Backprop into hidden layer
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    grad = d_spike_timing_dw(W1[i, k], aug_xi[i], layer1_idx, 0, 1)
                    dW1[i, k] = delta_h[k] * grad

            acc_dW1 += dW1
            acc_dW2 += dW2

        # Average gradients
        acc_dW1 /= N
        acc_dW2 /= N

        # Check for vanishing gradients
        if np.mean(np.abs(acc_dW1)) < 1e-8 or np.mean(np.abs(acc_dW2)) < 1e-8:
            print(f"Warning: Vanishing gradients detected at epoch {ep+1}")
            current_lr *= 2  # Increase learning rate temporarily

        # Gradient clipping
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad, max_grad)

        # Momentum updates
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # Weight updates with bounds
        W1 = np.clip(W1 - current_lr * vW1, w_min, w_max)
        W2 = np.clip(W2 - current_lr * vW2, w_min, w_max)

        # Learning rate decay
        if ep % 10 == 0:
            current_lr *= lr_decay

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"  LR={current_lr:.4f}, ‖∇W1‖={np.linalg.norm(acc_dW1):.4f}, ‖∇W2‖={np.linalg.norm(acc_dW2):.4f}")
        print(f"  ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

# 4. Better weight initialization
def initialize_weights_better():
    """Initialize weights with better variance scaling"""
    # Xavier/Glorot initialization scaled for this problem
    W1_0 = np.random.randn(4+1, 10) * np.sqrt(2.0 / (4+10))
    W2_0 = np.random.randn(10+1, 3) * np.sqrt(2.0 / (10+3))
    return W1_0, W2_0

# 5. Add debugging function
def debug_gradients(W1, W2, X, Y):
    """Debug gradient flow"""
    print("=== Gradient Debug ===")
    xi, yi = X[0], Y[0]
    
    h_times = layer_forward(xi, W1, 1)
    o_times = layer_forward(h_times, W2, 2)
    
    print(f"Hidden times: {h_times}")
    print(f"Output times: {o_times}")
    
    # Check if outputs are too similar
    output_variance = np.var(o_times)
    print(f"Output variance: {output_variance:.6f}")
    
    if output_variance < 0.01:
        print("WARNING: Outputs are too similar - poor class separation")
    
    # Check weight ranges
    print(f"W1 range: [{W1.min():.3f}, {W1.max():.3f}]")
    print(f"W2 range: [{W2.min():.3f}, {W2.max():.3f}]")

# Usage with better initialization:
if __name__ == "__main__":
    # Your existing data
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    
    y0 = np.array([1.5, 2.0, 2.5])  # Changed targets for better separation
    y1 = np.array([2.5, 2.0, 1.5])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
    
    # Use better initialization instead of loaded weights
    W1_0, W2_0 = initialize_weights_better()
    
    # Debug before training
    debug_gradients(W1_0, W2_0, X, Y)
    
    # Train with fixed function
    W1_tr, W2_tr = train_snn_backprop_fixed(X, Y, W1_0, W2_0, epochs=2, lr=0.2)

=== Gradient Debug ===
Hidden times: [1.45820835 1.58188232 1.53247909 1.42944979 1.4640665  1.49132949
 1.42066056 1.4655997  1.35691895 1.37716041]
Output times: [2.51084009 2.39558591 2.40642763]
Output variance: 0.002700
W1 range: [-0.935, 0.742]
W2 range: [-0.810, 0.702]
Epoch 1/2 — avg loss=0.0207
  LR=0.1900, ‖∇W1‖=1.0908, ‖∇W2‖=1.7358
  ‖W1‖=2.543, ‖W2‖=2.117

Epoch 2/2 — avg loss=0.0213
  LR=0.1900, ‖∇W1‖=1.0809, ‖∇W2‖=1.7324
  ‖W1‖=2.532, ‖W2‖=2.096



In [2]:
# Analysis of the real issues in your SNN code while preserving layer timing

# 1. The derivative function has numerical issues
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw_fixed(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-7  # Small epsilon to prevent log(0)
    
    # The issue: x can be very close to 0 or 1, making log(x) explode
    # Your logic is correct, but needs numerical stability
    
    if w >= 0:
        # Prevent log(0) by ensuring x > eps
        x_safe = np.maximum(x, eps)
        result = - np.power(x_safe, (1 - w)) * np.log(x_safe)
    else:
        # Prevent log(0) by ensuring (1-x) > eps  
        x_safe = np.minimum(x, 1 - eps)
        result = - np.power((1 - x_safe), (1 + w)) * np.log(1 - x_safe)
    
    # Clip extreme gradients but preserve sign and magnitude relationships
    return np.clip(result, -50.0, 50.0)

# 2. The timing function also needs numerical stability
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing_fixed(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-7
    
    if w >= 0:
        # Prevent x=0 case
        x_safe = np.maximum(x, eps)
        result = np.power(x_safe, (1 - w))
    else:
        # Prevent x=1 case
        x_safe = np.minimum(x, 1 - eps)
        result = 1 - np.power((1 - x_safe), (1 + w))
    
    return np.clip(result, 0.0, 1.0)

# 3. The main issue: Check what values are being passed to these functions
def debug_timing_inputs(W1, W2, X):
    """Debug what actual values are being passed to timing functions"""
    print("=== Debug Timing Function Inputs ===")
    
    xi = X[0]  # First input
    
    # Forward pass through layer 1
    h_times = layer_forward(xi, W1, 1)
    print(f"Hidden layer spike times: {h_times}")
    print(f"Hidden time range: [{h_times.min():.3f}, {h_times.max():.3f}]")
    
    # What gets passed to layer 2?
    aug_h = np.concatenate((h_times, [0.0]))  # bias at t=0
    print(f"Augmented hidden times: {aug_h}")
    
    # Check what global_clock values are used
    print(f"For layer 2, global_clock % 1 values will be: {aug_h % 1}")
    
    # These are the values that go into your timing functions!
    # If they're all very close to 0 or 1, gradients will vanish
    
    return h_times, aug_h

# 4. The real issue might be in your layer_forward function
def layer_forward_with_debug(inputs, W, layer_idx):
    """
    Debug version of layer_forward to see what's happening
    """
    print(f"\n--- Layer {layer_idx} Forward Pass ---")
    print(f"Inputs: {inputs}")
    
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))
    print(f"Augmented inputs: {aug_inputs}")
    
    n_in_plus_bias, n_out = W.shape
    out_times = []
    
    for j in range(n_out):
        # Your simulation setup
        start_scope()
        defaultclock.dt = 0.001*ms
        
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')
        
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second
        
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)
        
        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx
        
        # The issue might be here - what values is global_clock taking?
        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)
        
        mon = SpikeMonitor(G)
        run(5*ms)
        
        # Debug the actual computation
        final_sum = G.sum[0]
        final_sr = G.sr[0]
        scheduled_t = float(G.scheduled_time[0] / ms)
        
        print(f"  Neuron {j}: sum={final_sum:.4f}, sr={final_sr}, scheduled_time={scheduled_t:.4f}")
        
        # The real issue might be that sum/sr is always similar values
        # This would make all outputs similar regardless of weights
        
        out_times.append(scheduled_t)
    
    print(f"Layer {layer_idx} outputs: {out_times}")
    return np.array(out_times)

# 5. Check if the issue is in the weight update order
def analyze_weight_updates(X, Y, W1, W2):
    """
    Check if weight updates are consistent across samples
    """
    print("=== Weight Update Analysis ===")
    
    # Check gradients for first two samples
    for i, (xi, yi) in enumerate(zip(X[:2], Y[:2])):
        print(f"\nSample {i}: input={xi}, target={yi}")
        
        # Forward pass
        h_times = layer_forward(xi, W1, 1)
        o_times = layer_forward(h_times, W2, 2)
        
        print(f"  Hidden times: {h_times}")
        print(f"  Output times: {o_times}")
        
        # Check if hidden times are too similar
        h_variance = np.var(h_times)
        o_variance = np.var(o_times)
        
        print(f"  Hidden variance: {h_variance:.6f}")
        print(f"  Output variance: {o_variance:.6f}")
        
        if h_variance < 0.001:
            print("  WARNING: Hidden layer outputs are too similar!")
        if o_variance < 0.001:
            print("  WARNING: Output layer outputs are too similar!")

# 6. The corrected training function (keeping your timing structure)
def train_snn_backprop_corrected(
    X, Y,
    W1_init, W2_init,
    epochs=50, lr=0.05,  # Reduced learning rate
    max_grad=10.0,
    w_min=-10.0, w_max=10.0,
    non_target_time=2.05,  # Keep your original range
    λ=0.5
):
    """
    Corrected version that preserves your layer timing structure
    """
    W1 = W1_init.copy()
    W2 = W2_init.copy()
    
    # Add momentum
    beta = 0.9
    vW1 = np.zeros_like(W1)
    vW2 = np.zeros_like(W2)
    
    layer1_idx, layer2_idx = 1, 2
    N = len(X)
    
    for ep in range(epochs):
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2)
        epoch_loss = 0.0
        
        for xi, yi in zip(X, Y):
            # Forward pass (unchanged)
            h_times = layer_forward(xi, W1, layer1_idx)
            o_times = layer_forward(h_times, W2, layer2_idx)
            
            # Loss calculation (unchanged - your structure is correct)
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L
            
            # Gradient computation (your logic is correct, just need numerical stability)
            delta_o = np.zeros_like(o_times)
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)
            
            # Rest of gradient computation is correct...
            # The issue is in the timing functions, not the structure
            
            # [Include your original gradient computation here with fixed timing functions]
            
        # Check for vanishing gradients
        grad_norm_1 = np.linalg.norm(acc_dW1)
        grad_norm_2 = np.linalg.norm(acc_dW2)
        
        if grad_norm_1 < 1e-10 or grad_norm_2 < 1e-10:
            print(f"  WARNING: Very small gradients - norm1={grad_norm_1:.2e}, norm2={grad_norm_2:.2e}")
        
        # [Rest of your training loop...]
        
        print(f"Epoch {ep+1}: loss={epoch_loss/N:.4f}, grad_norms=({grad_norm_1:.4f}, {grad_norm_2:.4f})")
    
    return W1, W2

# 7. Key insight: Your problem might be here
def check_spike_timing_behavior():
    """
    Check if your spike_timing function is producing reasonable gradients
    """
    print("=== Spike Timing Function Analysis ===")
    
    # Test with typical values that would occur in your network
    test_weights = np.array([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
    test_times = np.array([0.1, 0.3, 0.5, 0.7, 0.9])  # global_clock % 1 values
    
    print("Weight\tTime\tSpike_timing\tDerivative")
    for w in test_weights:
        for t in test_times:
            st = spike_timing(w, t, 1, 0, 1)
            dt = d_spike_timing_dw(w, t, 1, 0, 1)
            print(f"{w:.1f}\t{t:.1f}\t{st:.4f}\t\t{dt:.4f}")
    
    # Look for patterns where derivatives are always very small

# Usage:
if __name__ == "__main__":
    # Your original data (this is correct)
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    
    # Your original targets (this is correct for your architecture)
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
    
    # Load your existing weights
    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")
    
    # Run diagnostics
    check_spike_timing_behavior()
    debug_timing_inputs(W1_0, W2_0, X)
    analyze_weight_updates(X, Y, W1_0, W2_0)

=== Spike Timing Function Analysis ===
Weight	Time	Spike_timing	Derivative
-2.0	0.1	0.0000		0.1171
-2.0	0.3	0.0000		0.5095
-2.0	0.5	0.0000		1.3863
-2.0	0.7	0.0000		4.0132
-2.0	0.9	0.0000		10.0000
-1.0	0.1	0.0000		0.1054
-1.0	0.3	0.0000		0.3567
-1.0	0.5	0.0000		0.6931
-1.0	0.7	0.0000		1.2040
-1.0	0.9	0.0000		2.3026
-0.5	0.1	0.0513		0.1000
-0.5	0.3	0.1633		0.2984
-0.5	0.5	0.2929		0.4901
-0.5	0.7	0.4523		0.6594
-0.5	0.9	0.6838		0.7281
0.0	0.1	0.1000		0.2303
0.0	0.3	0.3000		0.3612
0.0	0.5	0.5000		0.3466
0.0	0.7	0.7000		0.2497
0.0	0.9	0.9000		0.0948
0.5	0.1	0.3162		0.7281
0.5	0.3	0.5477		0.6594
0.5	0.5	0.7071		0.4901
0.5	0.7	0.8367		0.2984
0.5	0.9	0.9487		0.1000
1.0	0.1	1.0000		2.3026
1.0	0.3	1.0000		1.2040
1.0	0.5	1.0000		0.6931
1.0	0.7	1.0000		0.3567
1.0	0.9	1.0000		0.1054
2.0	0.1	1.0000		10.0000
2.0	0.3	1.0000		4.0132
2.0	0.5	1.0000		1.3863
2.0	0.7	1.0000		0.5095
2.0	0.9	1.0000		0.1171
=== Debug Timing Function Inputs ===
Hidden layer spike times: [1.43008574 1.43948677 1.51182876 1.4671

In [None]:
# Analysis of Hidden Layer Collapse in Your SNN

# The Problem: Why all hidden neurons spike at similar times
"""
In your layer_forward function:

scheduled_time = (sum/sr + layer)*ms

For layer 1:
- sum/sr is computed from: sum += spike_timing(w, global_clock, layer, sum, sr)
- If all weights are small and similar, sum/sr will be similar for all neurons
- Result: scheduled_time ≈ (similar_value + 1)*ms for all hidden neurons

This creates a "hidden layer collapse" where all neurons do the same thing.
"""

# Let's trace through what happens:
def analyze_hidden_collapse():
    """
    Analyze why all hidden neurons produce similar spike times
    """
    print("=== Hidden Layer Collapse Analysis ===")
    
    # Your input and typical weights
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    bias = 0.0
    aug_x0 = np.concatenate([x0, [bias]])
    
    print(f"Input times: {aug_x0}")
    
    # Let's see what happens with your actual weights
    W1 = np.load("W1.npy")
    print(f"W1 first few columns:\n{W1[:, :3]}")
    
    # For each hidden neuron, calculate what sum/sr would be
    print("\nHidden neuron analysis:")
    for j in range(3):  # First 3 hidden neurons
        weights = W1[:, j]
        print(f"\nHidden neuron {j}:")
        print(f"  Weights: {weights}")
        
        # Simulate the sum computation
        total_sum = 0
        for i, (input_time, weight) in enumerate(zip(aug_x0, weights)):
            # What spike_timing returns for this input
            st_value = spike_timing(weight, input_time, 1, 0, 1)
            print(f"    Input {i}: time={input_time:.3f}, w={weight:.3f}, spike_timing={st_value:.4f}")
            total_sum += st_value
        
        # Final scheduled time
        sr = len(weights)  # Number of synapses
        scheduled_time = (total_sum/sr + 1)  # +1 for layer
        print(f"  Total sum: {total_sum:.4f}, sum/sr: {total_sum/sr:.4f}")
        print(f"  Scheduled time: {scheduled_time:.4f}")

# Solutions to fix hidden layer collapse:

# Solution 1: Better weight initialization with more diversity
def initialize_diverse_weights():
    """
    Initialize weights to encourage diversity in hidden layer
    """
    np.random.seed(42)  # For reproducibility
    
    # Create weights with intentional diversity
    W1 = np.random.randn(5, 10) * 0.5  # Larger initial variance
    
    # Add some structure to encourage different behaviors
    # Some neurons prefer early spikes, others late spikes
    for j in range(10):
        if j < 3:  # Early spike neurons
            W1[:, j] += 0.5  # Positive bias
        elif j >= 7:  # Late spike neurons  
            W1[:, j] -= 0.5  # Negative bias
        # Middle neurons stay random
    
    W2 = np.random.randn(11, 3) * 0.3
    
    return W1, W2

# Solution 2: Add explicit diversity regularization
def train_with_diversity_regularization(X, Y, W1_init, W2_init, epochs=50):
    """
    Training with explicit diversity loss to prevent collapse
    """
    W1 = W1_init.copy()
    W2 = W2_init.copy()
    
    # Momentum
    beta = 0.9
    vW1 = np.zeros_like(W1)
    vW2 = np.zeros_like(W2)
    
    lr = 0.1
    diversity_weight = 0.1  # Weight for diversity loss
    
    for ep in range(epochs):
        epoch_loss = 0.0
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2)
        
        for xi, yi in zip(X, Y):
            # Forward pass
            h_times = layer_forward(xi, W1, 1)
            o_times = layer_forward(h_times, W2, 2)
            
            # Regular classification loss
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * 0.5 * sum([(o_times[j] - 2.05)**2 for j in non_ids])
            
            # DIVERSITY LOSS: Encourage hidden neurons to have different spike times
            h_variance = np.var(h_times)
            target_variance = 0.01  # We want at least this much variance
            L_diversity = diversity_weight * max(0, target_variance - h_variance)**2
            
            L = L_target + L_non + L_diversity
            epoch_loss += L
            
            # Regular gradients for classification
            delta_o = np.zeros_like(o_times)
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = 0.5 * (o_times[j] - 2.05)
            
            # Diversity gradients for hidden layer
            if h_variance < target_variance:
                # Encourage diversity by pushing hidden times away from mean
                h_mean = np.mean(h_times)
                delta_h_diversity = 2 * diversity_weight * (target_variance - h_variance) * (h_times - h_mean) / len(h_times)
            else:
                delta_h_diversity = np.zeros_like(h_times)
            
            # Backprop through W2 (your original code)
            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(W2[k, j], aug_h[k], 2, 0, 1)
            
            # Backprop into hidden layer (classification + diversity)
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                # Classification gradient
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], 2, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output
                
                # Add diversity gradient
                delta_h[k] += delta_h_diversity[k]
            
            # Gradients for W1
            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(W1[i, k], aug_xi[i], 1, 0, 1)
            
            acc_dW1 += dW1
            acc_dW2 += dW2
        
        # Update weights
        acc_dW1 /= len(X)
        acc_dW2 /= len(X)
        
        # Clip gradients
        acc_dW1 = np.clip(acc_dW1, -10, 10)
        acc_dW2 = np.clip(acc_dW2, -10, 10)
        
        # Momentum update
        vW1 = beta * vW1 + (1 - beta) * acc_dW1
        vW2 = beta * vW2 + (1 - beta) * acc_dW2
        
        # Apply updates
        W1 = np.clip(W1 - lr * vW1, -10, 10)
        W2 = np.clip(W2 - lr * vW2, -10, 10)
        
        # Monitor progress
        if ep % 5 == 0:
            # Check current diversity
            h_times_0 = layer_forward(X[0], W1, 1)
            h_times_1 = layer_forward(X[1], W1, 1)
            
            print(f"Epoch {ep+1}: loss={epoch_loss/len(X):.4f}")
            print(f"  Hidden variance: {np.var(h_times_0):.6f} (sample 0), {np.var(h_times_1):.6f} (sample 1)")
            print(f"  Hidden range: [{np.min(h_times_0):.3f}, {np.max(h_times_0):.3f}]")
    
    return W1, W2

# Solution 3: Architectural fix - Add noise or different activation patterns
def layer_forward_with_noise(inputs, W, layer_idx, noise_std=0.01):
    """
    Add small amount of noise to break symmetry
    """
    # Your original layer_forward code here, but add:
    # scheduled_time = (sum/sr + layer)*ms + noise
    
    # This is a band-aid solution - better to fix the root cause
    pass

# Solution 4: Different initialization strategy
def initialize_with_opposing_weights():
    """
    Initialize some weights to be strongly positive, others strongly negative
    """
    W1 = np.random.randn(5, 10) * 0.3
    
    # Make some neurons strongly favor early inputs
    W1[:2, :5] += 1.0  # First 2 inputs get positive boost for first 5 neurons
    
    # Make some neurons strongly favor late inputs  
    W1[2:4, 5:] += 1.0  # Last 2 inputs get positive boost for last 5 neurons
    
    W2 = np.random.randn(11, 3) * 0.3
    
    return W1, W2

# The key insight: Your architecture is sound, but you need initial weight diversity
def main_solution():
    """
    The main solution: Start with diverse weights and add diversity regularization
    """
    print("=== Solution: Diverse Initialization + Diversity Loss ===")
    
    # Your data
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
    
    # Use diverse initialization instead of collapsed weights
    W1_init, W2_init = initialize_diverse_weights()
    
    # Check initial diversity
    h_times_0 = layer_forward(X[0], W1_init, 1)
    h_times_1 = layer_forward(X[1], W1_init, 1)
    
    print(f"Initial hidden diversity:")
    print(f"  Sample 0 variance: {np.var(h_times_0):.6f}")
    print(f"  Sample 1 variance: {np.var(h_times_1):.6f}")
    print(f"  Sample 0 range: [{np.min(h_times_0):.3f}, {np.max(h_times_0):.3f}]")
    print(f"  Sample 1 range: [{np.min(h_times_1):.3f}, {np.max(h_times_1):.3f}]")
    
    if np.var(h_times_0) > 0.005 and np.var(h_times_1) > 0.005:
        print("✓ Good initial diversity!")
        
        # Train with diversity regularization
        W1_final, W2_final = train_with_diversity_regularization(X, Y, W1_init, W2_init, epochs=30)
        
        # Test final performance
        print("\n=== Final Test ===")
        for i, (xi, yi) in enumerate(zip(X[:2], Y[:2])):
            h_times = layer_forward(xi, W1_final, 1)
            o_times = layer_forward(h_times, W2_final, 2)
            
            pred_class = np.argmax(o_times)
            true_class = np.argmax(yi)
            
            print(f"Sample {i}: pred={pred_class}, true={true_class}")
            print(f"  Hidden variance: {np.var(h_times):.6f}")
            print(f"  Output times: {o_times}")
    else:
        print("✗ Still collapsed - need different initialization")

if __name__ == "__main__":
    # Run the analysis
    analyze_hidden_collapse()
    
    # Try the solution
    main_solution()

=== Hidden Layer Collapse Analysis ===
Input times: [0.9 0.7 0.3 0.4 0. ]
W1 first few columns:
[[-2.24802890e-01 -5.02760943e-05 -3.02786849e-02]
 [-1.11390344e-01 -1.80939746e-01 -1.34243925e-01]
 [-2.41631112e-01 -1.56714608e-01  3.69612932e-01]
 [ 5.13877625e-02  1.55790838e-02  3.39575082e-01]
 [-1.31178265e-01  6.46637541e-03 -9.75420633e-02]]

Hidden neuron analysis:

Hidden neuron 0:
  Weights: [-0.22480289 -0.11139034 -0.24163111  0.05138776 -0.13117826]
    Input 0: time=0.900, w=-0.225, spike_timing=0.8322
    Input 1: time=0.700, w=-0.111, spike_timing=0.6569
    Input 2: time=0.300, w=-0.242, spike_timing=0.2370
    Input 3: time=0.400, w=0.051, spike_timing=0.4193
    Input 4: time=0.000, w=-0.131, spike_timing=0.0000
  Total sum: 2.1454, sum/sr: 0.4291
  Scheduled time: 1.4291

Hidden neuron 1:
  Weights: [-5.02760943e-05 -1.80939746e-01 -1.56714608e-01  1.55790838e-02
  6.46637541e-03]
    Input 0: time=0.900, w=-0.000, spike_timing=0.9000
    Input 1: time=0.700, w=-0.

In [12]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    return(np.tanh(x*w))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L



            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            print("otimes ", o_times)
            print("delta_o ",delta_o)




            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = W1 - lr * vW1
        W2 = W2 - lr * vW2

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=10, lr=0.2)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    # print("\n=== Test predictions ===")
    # for xi, yi in zip(X, Y):
    #     # call layer_forward(positionally) rather than with layer1_idx=
    #     h_times = layer_forward(xi, W1_tr, 1)
    #     o_times = layer_forward(h_times, W2_tr, 2)

    #     pred_class = np.argmax(o_times)  
    #     true_class = np.argmax(yi)

    #     print(f"Input: {xi}")
    #     print(f" Spike times: {o_times}")
    #     print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



otimes  [2.261 1.536 2.202]
delta_o  [-0.689 -0.257  0.076]
otimes  [2.204 1.595 2.172]
delta_o  [ 0.077  -0.2275 -0.778 ]


Exception ignored in: <bound method InstanceTrackerSet.remove of InstanceTrackerSet({<weakref at 0x0000029EE47F2480; to 'CythonCodeObject' at 0x0000029EE46FB050>, <weakref at 0x0000029EE47F3D30; to 'CythonCodeObject' at 0x0000029EE47F6050>, <weakref at 0x0000029EE437CDB0; dead>, <weakref at 0x0000029EE47E8680; to 'Clock' at 0x0000029EE4532090>, <weakref at 0x0000029EE437C860; dead>, <weakref at 0x0000029EE47E8AE0; to 'CodeRunner' at 0x0000029EE47F6110>, <weakref at 0x0000029EE4618590; to 'Synapses' at 0x0000029EE4614110>, <weakref at 0x0000029EE4558310; to 'CythonCodeObject' at 0x0000029EE48D7150>, <weakref at 0x0000029EE47EBC90; to 'CythonCodeObject' at 0x0000029EE4670190>, <weakref at 0x0000029EE47F2570; to 'Resetter' at 0x0000029EE3CF9250>, <weakref at 0x0000029EE437D800; dead>, <weakref at 0x0000029EE47F25C0; to 'StateUpdater' at 0x0000029EE46F72D0>, <weakref at 0x0000029EE48054E0; to 'SynapticPathway' at 0x0000029EE47E2310>, <weakref at 0x0000029EE437C720; to 'Clock' at 0x0000029E

otimes  [2.261 1.536 2.202]
delta_o  [-0.689 -0.257  0.076]
otimes  [2.204 1.595 2.172]
delta_o  [ 0.077  -0.2275 -0.778 ]
otimes  [2.261 1.536 2.202]
delta_o  [-0.689 -0.257  0.076]
otimes  [2.204 1.595 2.172]
delta_o  [ 0.077  -0.2275 -0.778 ]
otimes  [2.261 1.536 2.202]
delta_o  [-0.689 -0.257  0.076]
otimes  [2.204 1.595 2.172]
delta_o  [ 0.077  -0.2275 -0.778 ]
Epoch 1/1 — avg loss=0.3348
             ‖W1‖=1.306, ‖W2‖=3.945

Trained W1: [[-2.24802890e-01 -5.02760943e-05 -3.02786849e-02 -7.89465047e-02
  -8.32850990e-02 -4.77545338e-02 -8.20297270e-02 -8.11190851e-02
  -1.19828297e-01 -1.44893684e-01]
 [-1.11390344e-01 -1.80939746e-01 -1.34243925e-01 -9.72715411e-02
  -1.90581208e-01 -2.27950846e-02 -1.44948526e-01 -1.37085700e-01
   1.53694720e-01  4.46186585e-02]
 [-2.41631112e-01 -1.56714608e-01  3.69612932e-01 -1.93654213e-01
  -2.64931721e-01  2.68245103e-01 -4.42425639e-01  1.83656054e-02
  -2.24950602e-01  2.78443194e-01]
 [ 5.13877625e-02  1.55790838e-02  3.39575082e-01  3.

In [None]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L



            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            print("otimes ", o_times)
            print("delta_o ",delta_o)




            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = np.clip(W1 - lr1 * vW1, w_min, w_max)
        W2 = np.clip(W2 - lr  * vW2, w_min, w_max)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=1, lr=0.0)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    # print("\n=== Test predictions ===")
    # for xi, yi in zip(X, Y):
    #     # call layer_forward(positionally) rather than with layer1_idx=
    #     h_times = layer_forward(xi, W1_tr, 1)
    #     o_times = layer_forward(h_times, W2_tr, 2)

    #     pred_class = np.argmax(o_times)  
    #     true_class = np.argmax(yi)

    #     print(f"Input: {xi}")
    #     print(f" Spike times: {o_times}")
    #     print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]
otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]


KeyboardInterrupt: 

In [None]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L



            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            print("otimes ", o_times)
            print("delta_o ",delta_o)




            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = np.clip(W1 - lr1 * vW1, w_min, w_max)
        W2 = np.clip(W2 - lr  * vW2, w_min, w_max)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=1, lr=0.0)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    # print("\n=== Test predictions ===")
    # for xi, yi in zip(X, Y):
    #     # call layer_forward(positionally) rather than with layer1_idx=
    #     h_times = layer_forward(xi, W1_tr, 1)
    #     o_times = layer_forward(h_times, W2_tr, 2)

    #     pred_class = np.argmax(o_times)  
    #     true_class = np.argmax(yi)

    #     print(f"Input: {xi}")
    #     print(f" Spike times: {o_times}")
    #     print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]
otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]


KeyboardInterrupt: 

In [None]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process





# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    if w >= 0:
        return np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x))
    else:
        return 1 - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x))

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.05,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # copy over the weights given 
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)  # changes how quckly it updates as we slow down or speed up   
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2 
    N = len(X)

    for ep in range(epochs):
        # Accumulators - stores the collecting gradients for each epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2) 
        epoch_loss = 0.0

        for xi, yi in zip(X, Y): # iterates in pairs at same time x1 and y1 x2 & y2 ect.
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            #print(h_times)
            o_times = layer_forward(h_times, W2, layer2_idx)  # this is sending in input array and outputs array for each layer


            # — Separation loss — # calcs to loss for a single sample will not touch for now 
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L



            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)  # a 3 element array for the 3 outputs
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            print("otimes ", o_times)
            print("delta_o ",delta_o)




            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = W1 - lr * vW1
        W2 = W2 - lr * vW2


        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2

            


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.0, 2.0])
    y1 = np.array([2.0, 2.0, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
  

    
    # W1_0 = np.random.randn(4+1, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(10+1, 3) * 0.1  # +1 for bias 


    W1_0 = np.load("W1.npy")
    W2_0 = np.load("W2.npy")


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=1, lr=0.0)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # # ── Now test on the same two patterns ──
    # print("\n=== Test predictions ===")
    # for xi, yi in zip(X, Y):
    #     # call layer_forward(positionally) rather than with layer1_idx=
    #     h_times = layer_forward(xi, W1_tr, 1)
    #     o_times = layer_forward(h_times, W2_tr, 2)

    #     pred_class = np.argmax(o_times)  
    #     true_class = np.argmax(yi)

    #     print(f"Input: {xi}")
    #     print(f" Spike times: {o_times}")
    #     print(f" Predicted class: {pred_class}, True class: {true_class}\n")

# np.save('W1.npy', W1_tr)
# np.save('W2.npy', W2_tr)
# print("weights saved")



otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]
otimes  [2.666 2.003 2.524]
delta_o  [-0.284  -0.0235  0.237 ]
otimes  [2.726 2.003 2.619]
delta_o  [ 0.338  -0.0235 -0.331 ]


KeyboardInterrupt: 

In [None]:
#testing on Iris
from sklearn import datasets
from sklearn.model_selection import train_test_split


if __name__ == "__main__":
    # Load Iris dataset and scale features to [0.05,0.95]
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    y_encoded = np.zeros((y.size, y.max()+1))
    y_encoded[np.arange(y.size), y] = 1
    
    def scale_features(x):
        mn, mx = x.min(axis=0), x.max(axis=0)
        x_norm = (x - mn) / (mx - mn)
        return x_norm * 0.9 + 0.05

    X_scaled = scale_features(X)
    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled, y_encoded, test_size=0.2,
        random_state=42, stratify=y
    )


    print("\n=== Test predictions ===")
    for x_test, y_test in zip(X, Y):
        # call layer_forward(positionally) rather than with layer1_idx=
        h_times = layer_forward(x_test, W1_tr, 1)
        o_times = layer_forward(h_times, W2_tr, 2)

        pred_class = np.argmax(o_times)  
        true_class = np.argmax(y_test)

        print(f"Input: {xi}")
        print(f" Spike times: {o_times}")
        print(f" Predicted class: {pred_class}, True class: {true_class}\n")



=== Test predictions ===
Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.228 5.    5.   ]
 Predicted class: 1, True class: 0

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.228 5.    5.   ]
 Predicted class: 1, True class: 2

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.359 5.    5.   ]
 Predicted class: 1, True class: 0

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.267 5.    5.   ]
 Predicted class: 1, True class: 2

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.228 5.    5.   ]
 Predicted class: 1, True class: 0

Input: [0.6 0.7 0.8 0.9]
 Spike times: [2.305 1.152 1.397]
 Predicted class: 0, True class: 2



KeyboardInterrupt: 