In [1]:
from brian2 import *
import numpy as np
import logging
import warnings



from brian2 import prefs, set_device

# Tell Brian2 to use the Cython code generator:
prefs.codegen.target = 'cython'

# Optionally compile but keep Python interface:
set_device('runtime')  # default; compiles operations to .so but stays in Python process


# suppress overflow warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)
numpy.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative

start_scope()
defaultclock.dt = 0.001*ms

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    return np.tanh(w * x)

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = global_clock % 1
    eps = 1e-9
    if w >= 0:
        return - np.power(x, (1 - w), where=(x>0), out=np.zeros_like(x)) * np.log(x + eps)
    else:
        return - np.power((1 - x), (1 + w), where=(x<1), out=np.ones_like(x)) * np.log(1 - x + eps)

# ----------------------------------------------------------------------------
# Forward pass: 4->10->3 using two-stage mini_urd

def layer_forward(inputs, W, layer_idx):
    """
    inputs: array of spike times (ms) from previous layer (shape: n_in,)
    W: weight matrix shape (n_in+1, n_out)  ← note the extra bias row
    layer_idx: integer layer number
    returns: array of output spike times (ms)
    """
    # 1) augment inputs with bias spike @ t=0
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))  # shape (n_in+1,)

    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias

    out_times = []
    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        # single post‐synaptic neuron
        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        # init
        G.v = G.sum = G.sr = 0
        G.global_clock = 0
        G.scheduled_time = 1e9*second

        # stim: now includes bias spike at t=0
        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''w:1
            layer:1''',
            on_pre='''
            sr += 1
            sum += spike_timing(w, global_clock, layer, sum, sr)
            scheduled_time = (sum/sr + layer)*ms
        ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        t0 = float(ts[0]/ms) if len(ts)>0 else float(5.0)
        out_times.append(t0)

    return np.array(out_times)


# ----------------------------------------------------------------------------
# Training loop with backprop for 4-10-3
def train_snn_backprop(
    X, Y,                # lists of input arrays (4,) and target (3,)
    W1_init, W2_init,
    epochs=10, lr=0.1,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    non_target_time=2.0,
    λ=0.5                # non-target penalty weight
):
    """
    Trains a 4→10→3 spiking network with:
      • batched gradient updates
      • boosted hidden-layer learning rate
      • separate gradient clipping per layer
      • classical momentum smoothing
    """
    # Initialize weights
    W1 = W1_init.copy()      # shape (5,10) including bias row
    W2 = W2_init.copy()      # shape (11,3) including bias row

    # Momentum buffers
    beta = 0.9
    vW1 = np.zeros_like(W1)
    vW2 = np.zeros_like(W2)

    layer1_idx, layer2_idx = 1, 2
    N = len(X)

    for ep in range(epochs):
        # Accumulators for this epoch
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2)
        epoch_loss = 0.0

        for xi, yi in zip(X, Y):
            # — Forward pass —
            h_times = layer_forward(xi, W1, layer1_idx)
            o_times = layer_forward(h_times, W2, layer2_idx)

            # — Separation loss —
            target_idx = np.argmax(yi)
            L_target = 0.5 * (o_times[target_idx] - yi[target_idx])**2
            non_ids = [j for j in range(len(o_times)) if j != target_idx]
            L_non = 0.5 * λ * sum([(o_times[j] - non_target_time)**2 for j in non_ids])
            L = L_target + L_non
            epoch_loss += L

            # — Gradients for W2 —
            delta_o = np.zeros_like(o_times)
            delta_o[target_idx] = (o_times[target_idx] - yi[target_idx])
            for j in non_ids:
                delta_o[j] = λ * (o_times[j] - non_target_time)

            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(
                        W2[k, j], aug_h[k], layer2_idx, 0, 1)

            # — Backprop into hidden & gradients for W1 —
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    dt_dw_output = d_spike_timing_dw(W2[k, j], aug_h[k], layer2_idx, 0, 1)
                    delta_h[k] += delta_o[j] * dt_dw_output  # Remove the W2[k,j] multiplication

            aug_xi = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_xi[i], layer1_idx, 0, 1)

            # — Accumulate —
            acc_dW1 += dW1
            acc_dW2 += dW2

        # — Average & clip gradients —
        acc_dW1 /= N
        acc_dW2 /= N

        # Boost hidden-layer rate
        lr1 = lr

        # Separate clipping thresholds
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad,   max_grad)

        # — Momentum updates —
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2

        # — Apply weight updates & clamp —
        W1 = W1 - lr * vW1
        W2 = W2 - lr * vW2

        print("weights1:", W1)
        print("o_times:", o_times)

        print(f"Epoch {ep+1}/{epochs} — avg loss={epoch_loss/N:.4f}")
        print(f"             ‖W1‖={np.linalg.norm(W1):.3f}, ‖W2‖={np.linalg.norm(W2):.3f}\n")

    return W1, W2



        


if __name__ == "__main__":
    # example usage with fixed input/target pairs
    # 4 inputs per sample, constant across 8 samples
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    # 3-targets (network outputs 3 values): use desired spike times [2.1, 2.0, 1.0]
    y0 = np.array([2.95, 2.05, 2.05])
    y1 = np.array([2.05, 2.05, 2.95])
    Y = [y0 if i % 2 == 0 else y1 for i in range(8)]
    # X= []
    # Y = []
    # for _ in range(10):
    #     X.append(x0 + np.random.randn(4)*0.02);  Y.append(y0)
    #     X.append(x1 + np.random.randn(4)*0.02);  Y.append(y1)
    

    
     #W1_0 = np.array([[0.21958991, 0.16223261, 0.02545166, 0.18849804, 0.09521701, 0.22744421, 0.05556097, 0.33130229, 0.03974721, 0.1968464],
    #             [0.41958955, 0.47541312, 0.22287581, 0.69627866, 0.83639384, 0.79597959, 0.15029805, 0.126486, 0.18285382, 0.07470098],
    #             [0.69559509, 0.41228614, 0.06028855, 0.51098037, 0.33730611, 1.17605488, 0.15405119, 0.28079173, 0.17365651, 0.23041775],
    #             [0.79721356, 0.82210554, 0.15028745, 1.09421856, 0.68280376, 1.07577422, 0.16962136, 0.23838796, 0.0735181, 0.1719861],
    #             [0.0631449, 0.10618091, 0.05791614, 0.0260418, -0.01797577, -0.1209534, 0.18702474, -0.01662061, -0.0683026, 0.05468931]])
        #W1_0 = [[ 0.21958991,  0.16223261,  0.02545166  0.18849804  0.09521701  0.22744421, 0.05556097  0.33130229  0.03974721  0.1968464 ], [ 0.41958955  0.47541312  0.22287581  0.69627866  0.83639384  0.79597959, 0.15029805  0.126486    0.18285382  0.07470098], [ 0.69559509  0.41228614  0.06028855  0.51098037  0.33730611  1.17605488, 0.15405119  0.28079173  0.17365651  0.23041775], [ 0.79721356  0.82210554  0.15028745  1.09421856  0.68280376  1.07577422, 0.16962136  0.23838796  0.0735181   0.1719861 ], [ 0.0631449   0.10618091  0.05791614  0.0260418  -0.01797577 -0.1209534, 0.18702474 -0.01662061 -0.0683026   0.05468931]]
        
    
    # W1_0 = np.random.randn(5, 10) * 0.1  # +1 for bias
    # W2_0 = np.random.randn(11, 3) * 0.1

    W1_0 = np.load('W1.npy')
    W2_0 = np.load('W2.npy')

    # W2_0 = np.array([[ 0.38843549, -1.10085101,  0.38776897],
    #             [ 0.35841177, -1.06985881,  0.40542767],
    #             [ 0.40732835, -0.63317637,  0.59202003],
    #             [ 0.32589618, -0.9691458,   0.48481597],
    #             [ 0.26060079, -0.85267046,  0.75057083],
    #             [ 0.31287437, -1.34806606,  0.44407244],
    #             [ 0.1996638,  -0.65507816,  0.27473825],
    #             [ 0.40169137, -0.9979045,   0.09884702],
    #             [ 0.37579471, -0.62826323,  0.58035222],
    #             [ 0.2191151,  -0.84980247,  0.14767976],
    #             [ 0.11199583, -0.16498428,  0.00871235]])
    
    # will print out last times so DO NOT run the same the same expermeent to have a differnt outcoem


    # train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=10, lr=0.1)
    print("Trained W1:", W1_tr)
    print("Trained W2:", W2_tr) 
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))

    # ── Now test on the same two patterns ──
    print("\n=== Test predictions ===")
    for xi, yi in zip(X, Y):
        # call layer_forward(positionally) rather than with layer1_idx=
        h_times = layer_forward(xi, W1_tr, 1)
        o_times = layer_forward(h_times, W2_tr, 2)

        pred_class = np.argmax(o_times)  # changed to argmin WHY???
        true_class = np.argmax(yi)

        print(f"Input: {xi}")
        print(f" Spike times: {o_times}")
        print(f" Predicted class: {pred_class}, True class: {true_class}\n")

weights1: [[-0.22137975  0.00167169 -0.02890704 -0.07835327 -0.0818913  -0.04473316
  -0.07977414 -0.07876425 -0.11839853 -0.14203129]
 [-0.10793411 -0.1782907  -0.13217254 -0.096198   -0.18868302 -0.01842692
  -0.14197238 -0.13404006  0.15476579  0.04679777]
 [-0.23875195 -0.15475638  0.3718874  -0.19314677 -0.26335707  0.27412208
  -0.43937395  0.02040501 -0.22349327  0.28245337]
 [ 0.0537452   0.01703331  0.34149903  0.32587405  0.12686899 -0.05327587
   0.04133178 -0.3077049  -0.24370899  0.42517325]
 [-0.13117826  0.00646638 -0.09754206  0.00469225  0.11569048 -0.21067804
   0.06001542  0.10219516  0.0855197  -0.09519465]]
o_times: [2.204 1.595 2.172]
Epoch 1/10 — avg loss=0.3277
             ‖W1‖=1.305, ‖W2‖=3.934

weights1: [[-0.21493659  0.00433183 -0.02628947 -0.07719417 -0.07926001 -0.03877565
  -0.07552159 -0.07432399 -0.11569948 -0.13656268]
 [-0.10142247 -0.17330226 -0.12821799 -0.09413382 -0.18510016 -0.0098191
  -0.13636231 -0.12829904  0.15679013  0.05097681]
 [-0.23332

KeyboardInterrupt: 

In [10]:
from brian2 import *
import numpy as np
import logging, warnings
import random

# Brian2 config
prefs.codegen.target = 'cython'
set_device('runtime')
warnings.filterwarnings('ignore', category=RuntimeWarning)
np.seterr(over='ignore', under='ignore')
logging.getLogger('brian2').setLevel(logging.ERROR)

# ----------------------------------------------------------------------------
# Spike timing and derivative with smoother sigmoid
@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def spike_timing(w, global_clock, layer, sum, spikes_received):
    x = (global_clock % 1)
    z = 5.0 * (x - 0.5)  # Smoother sigmoid
    sigmoid_val = 1.0 / (1.0 + np.exp(-w * z))
    return sigmoid_val

@implementation('numpy', discard_units=True)
@check_units(w=1, global_clock=1, layer=1, sum=1, spikes_received=1, result=1)
def d_spike_timing_dw(w, global_clock, layer, sum, spikes_received):
    x = (global_clock % 1)
    z = 5.0 * (x - 0.5)
    sig = 1.0 / (1.0 + np.exp(-w * z))
    return sig * (1.0 - sig) * z

# ----------------------------------------------------------------------------
# Forward pass
def layer_forward(inputs, W, layer_idx):
    bias_time = 0.0
    aug_inputs = np.concatenate((inputs, [bias_time]))
    n_in_plus_bias, n_out = W.shape
    assert aug_inputs.size == n_in_plus_bias
    out_times = []

    for j in range(n_out):
        start_scope()
        defaultclock.dt = 0.001*ms

        G = NeuronGroup(1, '''
            v : 1
            sum : 1
            sr : 1
            scheduled_time : second
            global_clock : 1
        ''', threshold='v>1', reset='v=0', method='exact')

        G.global_clock = np.random.rand()
        G.v = G.sum = G.sr = 0
        G.scheduled_time = 1e9*second

        stim = SpikeGeneratorGroup(n_in_plus_bias,
                                   indices=list(range(n_in_plus_bias)),
                                   times=aug_inputs*ms)

        S = Synapses(stim, G, '''
                     w:1
                     layer:1''',
                     on_pre='''
                       sr += 1
                       sum += spike_timing(w, global_clock, layer, sum, sr)
                       scheduled_time = (layer + sum/sr)*ms
                     ''')
        S.connect(True)
        S.w = W[:, j]
        S.layer = layer_idx

        G.run_regularly('''
            v = int(abs(t - scheduled_time) < 0.001*ms) * 1.2
            global_clock += 0.001
        ''', dt=0.001*ms)

        mon = SpikeMonitor(G)
        run(5*ms)

        ts = mon.spike_trains()[0]
        if len(ts) > 0:
            t0 = float(ts[0]/ms)
            t0 = max(layer_idx, min(layer_idx + 1, t0))
        else:
            t0 = layer_idx + 0.5
        out_times.append(t0)

    return np.array(out_times)

# ----------------------------------------------------------------------------
# Training loop
def train_snn_backprop(X, Y, W1_init, W2_init, epochs=50, lr=0.5, max_grad=5.0, λ=0.5):
    W1, W2 = W1_init.copy(), W2_init.copy()
    beta = 0.95
    vW1 = np.zeros_like(W1)
    vW2 = np.zeros_like(W2)
    N = len(X)

    for ep in range(epochs):
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2)
        epoch_loss = 0.0

        for xi, yi in zip(X, Y):
            # Forward
            h_times = layer_forward(xi, W1, 1)
            o_times = layer_forward(h_times, W2, 2)

            # Debug
            print(f"Epoch {ep+1}: Hidden times range: [{min(h_times):.4f}, {max(h_times):.4f}], std: {np.std(h_times):.4f}")
            print(f"Output times: {o_times}, separation: {max(o_times) - min(o_times):.4f}")

            # Loss & output deltas
            t_idx = np.argmax(yi)
            delta_o = np.zeros_like(o_times)
            delta_o[t_idx] = (o_times[t_idx] - yi[t_idx])
            for j in range(len(o_times)):
                if j != t_idx:
                    delta_o[j] = λ * (o_times[j] - yi[j])

            # Accumulate loss with separation penalty
            L_t = 5.0 * (o_times[t_idx] - yi[t_idx])**2  # Scaled target loss
            L_n = 5.0 * λ * np.sum(np.fromiter(((o_times[j] - yi[j])**2 for j in range(len(o_times)) if j != t_idx), dtype=float))  # Scaled non-target loss
            separation = max(o_times) - min(o_times)
            desired_separation = 0.517
            L_sep = 5.0 * (desired_separation - separation)**2 if separation < desired_separation else 0.0  # Separation penalty
            epoch_loss += (L_t + L_n + L_sep)

            # Debug loss components
            print(f"Sample loss: L_t={L_t:.4f}, L_n={L_n:.4f}, L_sep={L_sep:.4f}, Total={L_t + L_n + L_sep:.4f}")

            # Gradients for W2
            aug_h = np.concatenate((h_times, [0.0]))
            dW2 = np.zeros_like(W2)
            for k in range(W2.shape[0]):
                for j in range(W2.shape[1]):
                    dW2[k, j] = delta_o[j] * d_spike_timing_dw(W2[k, j], aug_h[k], 2, 0, 1)

            # Backprop to hidden
            delta_h = np.zeros(len(h_times))
            for k in range(len(h_times)):
                for j in range(W2.shape[1]):
                    delta_h[k] += delta_o[j] * d_spike_timing_dw(W2[k, j], aug_h[k], 2, 0, 1)
            if np.sum(np.abs(delta_h)) > 0:
                delta_h /= np.sum(np.abs(delta_h))  # Normalize

            # Gradients for W1
            aug_x = np.concatenate((xi, [0.0]))
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):
                for k in range(W1.shape[1]):
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(W1[i, k], aug_x[i], 1, 0, 1)

            # Debug gradients
            print(f"dW1 range: [{np.min(dW1):.4f}, {np.max(dW1):.4f}], std: {np.std(dW1):.4f}")
            print(f"dW2 range: [{np.min(dW2):.4f}, {np.max(dW2):.4f}], std: {np.std(dW2):.4f}")

            acc_dW1 += dW1
            acc_dW2 += dW2

        # Average & clip
        acc_dW1 /= N
        acc_dW2 /= N
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad, max_grad)

        # Debug: gradient norms and accuracy
        print(f"Epoch {ep+1}: ||g1||={np.linalg.norm(g1):.4f}, ||g2||={np.linalg.norm(g2):.4f}, acc={np.mean([np.argmax(layer_forward(layer_forward(xi, W1, 1), W2, 2)) == np.argmax(yi) for xi, yi in zip(X, Y)]):.3f}")
        print(f"  avg loss={(epoch_loss/N):.4f}")

        # Momentum
        vW1 = beta * vW1 + (1 - beta) * g1
        vW2 = beta * vW2 + (1 - beta) * g2
        upd1, upd2 = vW1, vW2

        # Update & clamp
        W1 -= lr * upd1
        W2 -= lr * upd2
        W1 = np.clip(W1, -20.0, 20.0)
        W2 = np.clip(W2, -20.0, 20.0)

    return W1, W2

if __name__ == "__main__":
    random.seed(13)
    # Example usage with fixed input/target pairs
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 3 == 0 else x1 if i % 3 == 1 else x0 for i in range(8)]
    # Updated targets to match diagnostics
    y0 = np.array([2.8194, 2.3026, 2.3026])  # Class 0
    y1 = np.array([2.3026, 2.8194, 2.3026])  # Class 1
    y2 = np.array([2.3026, 2.3026, 2.8194])  # Class 2
    Y = [y0 if i % 3 == 0 else y1 if i % 3 == 1 else y2 for i in range(8)]

    # Initialize weights with larger variance
    W1_0 = np.random.randn(5, 10) * 0.5
    W2_0 = np.random.randn(11, 3) * 0.5

    # Train
    W1_tr, W2_tr = train_snn_backprop(X, Y, W1_0, W2_0,
                                      epochs=50, lr=0.5, max_grad=5.0, λ=0.5)

    print("Final trained weights:")
    print("W1 shape:", W1_tr.shape)
    print("W2 shape:", W2_tr.shape)

    print("\nHidden layer outputs:")
    print("Hidden times for x0:", layer_forward(x0, W1_tr, 1))
    print("Hidden times for x1:", layer_forward(x1, W1_tr, 1))
    print("Output times for x0:", layer_forward(layer_forward(x0, W1_tr, 1), W2_tr, 2))

    # Test predictions
    print("\n=== Test predictions ===")
    for i, (xi, yi) in enumerate(zip(X[:8], Y[:8])):
        h_times = layer_forward(xi, W1_tr, 1)
        o_times = layer_forward(h_times, W2_tr, 2)

        pred_class = np.argmax(o_times)
        true_class = np.argmax(yi)

        print(f"Sample {i+1} (true class {true_class}):")
        print(f"  Input: {xi}")
        print(f"  Output spike times: {o_times}")
        print(f"  Target spike times: {yi}")
        print(f"  Predicted class: {pred_class}, Correct: {pred_class == true_class}")
        print(f"  Output separation: {max(o_times) - min(o_times):.4f}\n")

    # Overall accuracy
    print(f"Overall accuracy: {np.sum(np.fromiter((np.argmax(layer_forward(layer_forward(xi, W1_tr, 1), W2_tr, 2)) == np.argmax(yi) for xi, yi in zip(X, Y)), dtype=int)) / len(X):.3f}")

Epoch 1: Hidden times range: [1.3470, 1.6060], std: 0.0787
Output times: [2.541 2.497 2.496], separation: 0.0450
Sample loss: L_t=0.3875, L_n=0.1880, L_sep=1.1139, Total=1.6894
dW1 range: [-0.1089, 0.0965], std: 0.0391
dW2 range: [-0.0532, 0.1740], std: 0.0367
Epoch 1: Hidden times range: [1.4320, 1.6090], std: 0.0621
Output times: [2.527 2.521 2.511], separation: 0.0160
Sample loss: L_t=0.4452, L_n=0.2345, L_sep=1.2550, Total=1.9347
dW1 range: [-0.0872, 0.1170], std: 0.0413
dW2 range: [-0.0701, 0.1465], std: 0.0333
Epoch 1: Hidden times range: [1.4440, 1.5950], std: 0.0518
Output times: [2.462 2.506 2.517], separation: 0.0550
Sample loss: L_t=0.4572, L_n=0.1669, L_sep=1.0672, Total=1.6914
dW1 range: [-0.0902, 0.1055], std: 0.0379
dW2 range: [-0.0499, 0.1663], std: 0.0337
Epoch 1: Hidden times range: [1.4530, 1.6410], std: 0.0508
Output times: [2.454 2.491 2.478], separation: 0.0370
Sample loss: L_t=0.6676, L_n=0.1656, L_sep=1.1520, Total=1.9852
dW1 range: [-0.0780, 0.0772], std: 0.035

KeyboardInterrupt: 

In [8]:
# Improved training function with better loss function and target alignment
def train_snn_backprop_improved(
    X, Y,
    W1_init, W2_init,
    epochs=30, lr=0.2,
    max_grad=20.0, w_min=-20.0, w_max=20.0,
    λ=0.5, separation_margin=0.1
):
    W1, W2 = W1_init.copy(), W2_init.copy()
    beta = 0.9
    vW1 = np.zeros_like(W1)
    vW2 = np.zeros_like(W2)
    N = len(X)

    for ep in range(epochs):
        acc_dW1 = np.zeros_like(W1)
        acc_dW2 = np.zeros_like(W2)
        epoch_loss = 0.0
        correct_predictions = 0

        for xi, yi in zip(X, Y):
            # Forward
            h_times = layer_forward(xi, W1, 1)
            o_times = layer_forward(h_times, W2, 2)

            # Find target and predicted classes
            t_idx = np.argmax(yi)
            p_idx = np.argmax(o_times)
            
            # Improved loss function: encourage separation
            delta_o = np.zeros_like(o_times)
            
            # For the target class: encourage it to have the latest spike time
            target_time = np.max(o_times) + separation_margin
            delta_o[t_idx] = (o_times[t_idx] - target_time)
            
            # For non-target classes: encourage them to have earlier spike times
            non_target_time = np.min(o_times) - separation_margin
            for j in range(len(o_times)):
                if j != t_idx:
                    delta_o[j] = λ * (o_times[j] - non_target_time)

            # Compute loss
            loss_target = 0.5 * (o_times[t_idx] - target_time)**2
            loss_non_target = 0.5 * λ * sum((o_times[j] - non_target_time)**2 
                                          for j in range(len(o_times)) if j != t_idx)
            epoch_loss += (loss_target + loss_non_target)
            
            # Track accuracy
            if p_idx == t_idx:
                correct_predictions += 1

            # Gradients for W2 (output layer)
            aug_h = np.concatenate((h_times, [0.0]))  # Add bias
            dW2 = np.zeros_like(W2)
            
            # Store derivatives for backprop
            dspike_dw_hidden = np.zeros((len(aug_h), len(o_times)))
            
            for k in range(W2.shape[0]):  # For each hidden unit (+ bias)
                for j in range(W2.shape[1]):  # For each output unit
                    # Compute derivative of spike timing w.r.t. weight
                    deriv = d_spike_timing_dw(W2[k, j], aug_h[k], 2, 0, 1)
                    dW2[k, j] = delta_o[j] * deriv
                    # Store for backprop
                    dspike_dw_hidden[k, j] = deriv

            # Backprop to hidden layer
            delta_h = np.zeros_like(h_times)
            for k in range(len(h_times)):  # For each hidden unit
                delta_h[k] = 0.0
                for j in range(W2.shape[1]):  # Sum over all output units
                    # Gradient flows back through the weight W2[k,j]
                    delta_h[k] += delta_o[j] * W2[k, j] * dspike_dw_hidden[k, j]

            # Gradients for W1 (hidden layer)
            aug_x = np.concatenate((xi, [0.0]))  # Add bias
            dW1 = np.zeros_like(W1)
            for i in range(W1.shape[0]):  # For each input (+ bias)
                for k in range(W1.shape[1]):  # For each hidden unit
                    dW1[i, k] = delta_h[k] * d_spike_timing_dw(
                        W1[i, k], aug_x[i], 1, 0, 1)

            acc_dW1 += dW1
            acc_dW2 += dW2

        # Average & clip
        acc_dW1 /= N
        acc_dW2 /= N
        g1 = np.clip(acc_dW1, -max_grad, max_grad)
        g2 = np.clip(acc_dW2, -max_grad, max_grad)
        
        # Calculate accuracy
        accuracy = correct_predictions / N
        
        # Debug: gradient norms
        print(f"Epoch {ep+1}: ||g1||={np.linalg.norm(g1):.4f}, ||g2||={np.linalg.norm(g2):.4f}, acc={accuracy:.3f}")

        # Momentum (skip first 5 iters)
        if ep >= 5:
            vW1 = beta*vW1 + (1-beta)*g1
            vW2 = beta*vW2 + (1-beta)*g2
            upd1, upd2 = vW1, vW2
        else:
            upd1, upd2 = g1, g2

        # Update & clamp
        W1 -= lr * upd1
        W2 -= lr * upd2
        W1 = np.clip(W1, w_min, w_max)
        W2 = np.clip(W2, w_min, w_max)

        print(f"  avg loss={(epoch_loss/N):.4f}")
        
        # Print sample outputs to monitor progress
        if ep % 5 == 0:
            h_sample = layer_forward(X[0], W1, 1)
            o_sample = layer_forward(h_sample, W2, 2)
            print(f"  Sample output: {o_sample}, separation: {np.max(o_sample) - np.min(o_sample):.4f}")

    return W1, W2

# Function to create better targets based on actual output ranges
def create_adaptive_targets(X, W1, W2, num_classes=3):
    """Create targets that are achievable given the current network outputs"""
    
    # Get typical output range
    all_outputs = []
    for xi in X:
        h_times = layer_forward(xi, W1, 1)
        o_times = layer_forward(h_times, W2, 2)
        all_outputs.append(o_times)
    
    all_outputs = np.array(all_outputs)
    min_out = np.min(all_outputs)
    max_out = np.max(all_outputs)
    range_out = max_out - min_out
    
    # Create targets with good separation
    target_range = max(range_out * 1.5, 0.2)  # At least 0.2ms separation
    
    # For 3 classes: early, middle, late
    if num_classes == 3:
        early_time = min_out - target_range * 0.3
        middle_time = (min_out + max_out) / 2
        late_time = max_out + target_range * 0.3
        
        targets = {
            0: np.array([late_time, early_time, early_time]),    # Class 0: first output latest
            1: np.array([early_time, late_time, early_time]),    # Class 1: second output latest  
            2: np.array([early_time, early_time, late_time])     # Class 2: third output latest
        }
    
    return targets

# Enhanced main function with adaptive targets
def main_improved():
    np.random.seed(42)
    
    # Your data
    x0 = np.array([0.9, 0.7, 0.3, 0.4])
    x1 = np.array([0.6, 0.7, 0.8, 0.9])
    X = [x0 if i % 2 == 0 else x1 for i in range(8)]
    
    # Class labels (0 or 2)
    class_labels = [2 if i % 2 == 0 else 0 for i in range(8)]
    
    # Better weight initialization
    W1_0 = np.random.randn(5, 10) * 0.5  # Slightly larger initialization
    W2_0 = np.random.randn(11, 3) * 0.5

    print("=== Creating Adaptive Targets ===")
    target_templates = create_adaptive_targets(X, W1_0, W2_0)
    print("Target templates:")
    for class_idx, target in target_templates.items():
        print(f"  Class {class_idx}: {target}")
    
    # Create Y based on class labels
    Y = [target_templates[class_labels[i]] for i in range(len(X))]
    
    print("\n=== Initial Analysis ===")
    analyze_weights_and_outputs(X, Y, W1_0, W2_0)

    # Train with improved loss function
    W1_tr, W2_tr = train_snn_backprop_improved(
        X, Y, W1_0, W2_0,
        epochs=30, lr=0.25, max_grad=15.0, 
        separation_margin=0.05
    )
    
    print("\n=== Final Analysis ===")
    analyze_weights_and_outputs(X, Y, W1_tr, W2_tr)
    
    # Test predictions
    print("\n=== Test Predictions ===")
    all_correct = True
    for i, (xi, yi) in enumerate(zip(X, Y)):
        h_times = layer_forward(xi, W1_tr, 1)
        o_times = layer_forward(h_times, W2_tr, 2)

        pred_class = np.argmax(o_times)
        true_class = np.argmax(yi)
        is_correct = pred_class == true_class
        all_correct &= is_correct

        print(f"Sample {i+1} (true class {true_class}):")
        print(f"  Input: {xi}")
        print(f"  Output spike times: {o_times}")
        print(f"  Target spike times: {yi}")
        print(f"  Predicted class: {pred_class}, Correct: {is_correct}")
        print(f"  Output separation: {np.max(o_times) - np.min(o_times):.4f}")
        print()
    
    print(f"Overall accuracy: {np.sum(np.fromiter((np.argmax(layer_forward(layer_forward(xi, W1_tr, 1), W2_tr, 2)) == np.argmax(yi) for xi, yi in zip(X, Y)), dtype=int)) / len(X):.3f}")


    return W1_tr, W2_tr

if __name__ == "__main__":
    W1_final, W2_final = main_improved()

=== Creating Adaptive Targets ===
Target templates:
  Class 0: [2.8194 2.3026 2.3026]
  Class 1: [2.3026 2.8194 2.3026]
  Class 2: [2.3026 2.3026 2.8194]

=== Initial Analysis ===

=== Weight Analysis ===
W1 stats: mean=-0.1578, std=0.6470, range=[-1.3718, 1.2966]
W2 stats: mean=0.0633, std=0.6728, range=[-1.8338, 1.0953]

=== Hidden Layer Outputs ===
Input 1: [0.9 0.7 0.3 0.4]
Hidden times: [1.741 1.506 1.499 1.265 1.726 1.516 1.425 1.7   1.573 1.53 ]
Hidden range: [1.2650, 1.7410], std: 0.1391
Output times: [2.486 2.585 2.635]
Output range: [2.4860, 2.6350], std: 0.0619

Input 2: [0.6 0.7 0.8 0.9]
Hidden times: [1.399 1.573 1.545 1.247 1.561 1.381 1.538 1.433 1.719 1.476]
Hidden range: [1.2470, 1.7190], std: 0.1234
Output times: [2.491 2.389 2.48 ]
Output range: [2.3890, 2.4910], std: 0.0457



  loss_non_target = 0.5 * λ * sum((o_times[j] - non_target_time)**2


Epoch 1: ||g1||=0.0176, ||g2||=0.0656, acc=0.125
  avg loss=0.0211
  Sample output: [2.516 2.49  2.538], separation: 0.0480
Epoch 2: ||g1||=0.0123, ||g2||=0.0475, acc=0.250
  avg loss=0.0064
Epoch 3: ||g1||=0.0233, ||g2||=0.0657, acc=0.250
  avg loss=0.0178
Epoch 4: ||g1||=0.0145, ||g2||=0.0508, acc=0.625
  avg loss=0.0101
Epoch 5: ||g1||=0.0298, ||g2||=0.0687, acc=0.250
  avg loss=0.0204
Epoch 6: ||g1||=0.0172, ||g2||=0.0968, acc=0.250
  avg loss=0.0201
  Sample output: [2.53  2.512 2.558], separation: 0.0460
Epoch 7: ||g1||=0.0236, ||g2||=0.0720, acc=0.625
  avg loss=0.0144
Epoch 8: ||g1||=0.0335, ||g2||=0.0639, acc=0.125
  avg loss=0.0186
Epoch 9: ||g1||=0.0190, ||g2||=0.0677, acc=0.125
  avg loss=0.0138
Epoch 10: ||g1||=0.0258, ||g2||=0.0613, acc=0.500
  avg loss=0.0192
Epoch 11: ||g1||=0.0171, ||g2||=0.0479, acc=0.375
  avg loss=0.0117
  Sample output: [2.426 2.514 2.476], separation: 0.0880
Epoch 12: ||g1||=0.0199, ||g2||=0.0608, acc=0.125
  avg loss=0.0163
Epoch 13: ||g1||=0.019

  print(f"Overall accuracy: {sum(np.argmax(layer_forward(layer_forward(xi, W1_tr, 1), W2_tr, 2)) == np.argmax(yi) for xi, yi in zip(X, Y)) / len(X):.3f}")


Sample 8 (true class 0):
  Input: [0.6 0.7 0.8 0.9]
  Output spike times: [2.557 2.628 2.441]
  Target spike times: [2.8194 2.3026 2.3026]
  Predicted class: 1, Correct: False
  Output separation: 0.1870

Overall accuracy: 0.500
