In [1]:
import numpy as np
import matplotlib.pyplot as plt
from brian2 import *
import os
import random
from entropy.entropy import *

In [2]:
training_data = []
test_data = []

# load train data
train_path = "/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/sr1000/train"
for filename in os.listdir(train_path):
    if filename.endswith(".npz"):
        data = np.load(os.path.join(train_path, filename))
        spike_matrix = data['spike_matrix']
        label = data['label'].item()
        training_data.append((spike_matrix, label))

# load test data
test_path = "/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/sr1000/test"
for filename in os.listdir(test_path):
    if filename.endswith(".npz"):
        data = np.load(os.path.join(test_path, filename))
        spike_matrix = data['spike_matrix']
        label = data['label'].item()
        test_data.append((spike_matrix, label))

random.shuffle(training_data)
random.shuffle(test_data)

print(f"Loaded {len(training_data)} training samples and {len(test_data)} test samples.")

Loaded 160 training samples and 40 test samples.


In [15]:
print(test_data[0])

(array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8), 0)


In [3]:
entropy_calculator = EntropyCalculator()
def mi_matrix(output_spikes, input_spikes):
    """
    Calculate the mutual information of the model.
    Calculation includes the MI for each output neuron and the overall MI of the model.
    
    Formula: I(S;R) = ∑p(s,r) log(p(s,r)/(p(s)p(r)))
    
    where:
        - S is the input spikes
        - R is the output spikes
        - p(s,r) is the joint probability of input and output spikes
        - p(s) is the marginal probability of input spikes
        - p(r) is the marginal probability of output spikes

    Args:
        output_spikes (list): Spikes from the output layer.
        input_spikes (list): Spikes from the input layer.

    Returns:
        np.ndarray: List of MI (neuron0, neuron1, overall).
    """
    output_spikes = np.array(output_spikes)
    input_spikes = np.array(input_spikes)
    
    neuron0 = 0
    neuron1 = 0
    model_mi = 0

    for input_spike in input_spikes:
        if not isinstance(input_spike, np.ndarray):
            raise ValueError("Input spikes must be a numpy array.")
        
        neuron0 += entropy_calculator.mutual_information(input_spike, output_spikes[0])
        neuron1 += entropy_calculator.mutual_information(input_spike, output_spikes[1])
        model_mi = neuron0 + neuron1

    return neuron0, neuron1, model_mi

def hc_matrix(output_spikes, input_spikes):
    """
    Calculate the conditional entropy of the model.
    Calculation includes the conditional entropy for each output neuron and the overall conditional entropy of the model.
    
    Formula: H(R|S) = -∑p(s)∑p(r|s) log(p(r|s))
    
    Args:
        output_spikes (list): Spikes from the output layer.
        input_spikes (list): Spikes from the input layer.

    Returns:
        np.ndarray: List of entropies (neuron0, neuron1, overall).
    """
    output_spikes = np.array(output_spikes)
    input_spikes = np.array(input_spikes)
    
    neuron0 = 0
    neuron1 = 0
    model_hc = 0
    
    for input_spike in input_spikes:
        if not isinstance(input_spike, np.ndarray):
            raise ValueError("Input spikes must be a numpy array.")
        
        neuron0 += entropy_calculator.conditional_entropy(output_spikes[0], input_spike)
        neuron1 += entropy_calculator.conditional_entropy(output_spikes[1], input_spike)
        model_hc = neuron0 + neuron1

    return neuron0, neuron1, model_hc

In [4]:
prefs.codegen.target = 'numpy'

In [16]:
dt_ms = 10
defaultclock.dt = dt_ms * ms
start_scope()

n_input = 32
n_output = 2
n_epoch = 20
refrac_ms = 5

# STDP multiplicative soft bound
taupre = taupost = 20 * ms
taum = 20
taue = 10
vth_init = 0.20
vth_min, vth_max = 0.12, 0.60

eta_plus = 4e-3
eta_minus = 0.9 * eta_plus
wmin, gmax = 0.02, 0.6
w_init_mean = 0.18
w_init_spread = 0.06
is_inhibitated = False

# teacher forcing
bias_Hz = 12 * Hz
bias_dur_ms = 100

# homoeostasis + normalization
target_count = 12.0  # target firing rate in Hz
eta_th = 0.01
target_sum = 1.0  # target sum of incoming weights

eqs = '''
        dv/dt = (ge - v) / tau_m + I_ext : 1
        dge/dt = -ge / tau_e : 1
        I_ext : Hz
        v_th : 1
        tau_m : second (constant)
        tau_e : second (constant)
        learn : 1
        '''

# input layer
input_group = SpikeGeneratorGroup(
    n_input,
    [],
    [] * ms
)

# output layer
output_group = NeuronGroup(
    n_output,
    eqs,
    threshold='v > v_th',
    reset='v = 0.0',
    refractory=refrac_ms*ms,
    method='euler'
)
output_group.v = 0
output_group.ge = 0
output_group.I_ext = 0*Hz
output_group.v_th = vth_init
output_group.tau_m = taum*ms
output_group.tau_e = taue*ms

#synapse
syn = Synapses(
    input_group,
    output_group,
    '''
    w : 1
    plastic : boolean (shared)
    dApre/dt = -Apre / taupre : 1 (event-driven)
    dApost/dt = -Apost / taupost : 1 (event-driven)
    ''',
    on_pre=f'''
    ge_post += w
    Apre += 1
    w = clip(w + Apost * plastic * {eta_plus} * learn_post * ({gmax} - w), {wmin}, {gmax})
    ''',
    on_post=f'''
    Apost += 1
    w = clip(w - Apre * plastic * {eta_minus} * (1 - learn_post) * (w - {wmin}), {wmin}, {gmax})
    ''',
    method='euler'
)

syn.connect()  # Full connectivity
np.random.seed(42)  # For reproducibility
rand = np.random.uniform(-1.0, 1.0, size=len(syn.w[:]))
syn.w = np.clip(w_init_mean + w_init_spread * rand, wmin, gmax)
syn.plastic = True

inh = Synapses(
    output_group,
    output_group,
    '''
    beta: 1 (shared)
    ''',
    on_pre='v_post -= beta',
    method='euler'
)
inh.connect(condition='j != i')  # Full connectivity
inh.beta = 0.03

output_mon = SpikeMonitor(output_group)
net = Network(collect())

logs = []

def set_spikes_from_matrix(S, t_start, input_group):
    """
    Set spikes in the input group from a spike matrix.
    
    Args:
        S (np.ndarray): The spike matrix where rows are neurons and columns are time steps.
        t_start (float): The start time of the simulation.
        input_group (SpikeGeneratorGroup): The input group to set spikes for.
    """
    n_in, T = S.shape
    i, t = np.where(S == 1)
    if i.size == 0:
        input_group.set_spikes([], []*ms)
        return T
    times = t_start + (t * defaultclock.dt)
    input_group.set_spikes(i, times)
    return T

def run_one(S,
            label,
            net,
            input_group,
            output_group,
            output_mon,
            syn,
            enable_plasticity=True):
    syn.plastic = enable_plasticity
    output_group.v = 0
    output_group.ge = 0
    output_group.learn = 0
    
    S = np.asarray(S, dtype=np.uint8)
    
    T = S.shape[1]
    sim_duration = T * defaultclock.dt
    t0 = net.t
    set_spikes_from_matrix(S, t0, input_group)
    
    if enable_plasticity and (label is not None):
        inh.beta = 0.0
        if label == 0:
            output_group.I_ext = [bias_Hz, 0*Hz]
            output_group.learn = [1.0, 0.0]
        else:
            output_group.I_ext = [0*Hz, bias_Hz]
            output_group.learn = [0.0, 1.0]
        net.run(bias_dur_ms * ms, report=None)
        
        output_group.I_ext = 0*Hz
        output_group.learn = 0.0
        inh.beta = 0.03
        net.run(sim_duration - bias_dur_ms * ms, report=None)
    else:
        output_group.learn = 0.0
        output_group.I_ext = 0*Hz
        net.run(sim_duration, report=None)

    # get spike counts
    mask = (output_mon.t >= t0) & (output_mon.t < t0 + sim_duration)
    i_sel = np.asarray(output_mon.i[mask])

    counts = np.bincount(i_sel, minlength=n_output) if i_sel.size else np.zeros(n_output, dtype=int)
    
    t_steps = np.asarray(output_mon.t[mask] / defaultclock.dt)
    t0_step = float(t0 / defaultclock.dt)
    
    out_spike = np.zeros((n_output, T), dtype=np.uint8)
    
    if i_sel.size:
        bins = np.floor(t_steps - t0_step).astype(int)
        bins = np.clip(bins, 0, T - 1)
        out_spike[i_sel, bins] = 1
    return counts, out_spike

def update_homeostasis(counts, output_group):
    delta = eta_th * (counts.astype(float) - target_count)
    v_new = np.clip(output_group.v_th[:] + delta, vth_min, vth_max)
    output_group.v_th[:] = v_new

def renorm_incoming(syn,
                   n_output=n_output,
                   w_sum=target_sum):
    j_all = syn.j[:]
    for j in range(n_output):
        idx = np.where(j_all == j)[0]
        if idx.size == 0:
            continue
        s = float(np.sum(syn.w[idx]))
        if s > 0:
            syn.w[idx] = syn.w[idx] * (w_sum / s)
            
def print_incoming_sum(tag=""):
    sums = np.bincount(syn.j[:], weights=syn.w[:], minlength=n_output)
    print(f"{tag} Incoming weights sum: {np.round(sums, 3)}")

def infer_mapping_mean(train_counts_dict):
    labels = sorted(train_counts_dict.keys())
    M = []
    for lb in labels:
        X = np.vstack(train_counts_dict[lb]) if len(train_counts_dict[lb]) else np.zeros((1, n_output))
        M.append(np.mean(X, axis=0))
    M = np.vstack(M)
    
    mapping = {}
    for j in range(n_output):
        idx = int(np.argmax(M[:, j]))
        mapping[j] = labels[idx]
    return mapping, M

# training and testing
random.seed(42)  # For reproducibility
np.random.seed(42)  # For reproducibility

prev_w = syn.w[:].copy()
renorm_incoming(syn, w_sum=2.5)

for epoch in range(n_epoch):
    print(f'Epoch {epoch + 1}/{n_epoch}')
    log = {}
    log['epoch'] = epoch + 1
    
    # Training
    train_log_counts = {0: [], 1: []}
    epoch_counts = np.zeros(n_output, dtype=float)
    total_seconds = 0.0
    for S, label in training_data:
        counts, _ = run_one(S,
                            label,
                            net=net,
                            input_group=input_group,
                            output_group=output_group,
                            output_mon=output_mon,
                            syn=syn,
                            enable_plasticity=True)
        train_log_counts[label].append(counts)
        epoch_counts += counts
        total_seconds += (S.shape[1] * float(defaultclock.dt / second))
        # update_homeostasis(counts, output_group)
    mean_rate = epoch_counts / (total_seconds / n_output)
    delta = eta_th * (mean_rate - target_count)
    output_group.v_th[:] = np.clip(output_group.v_th[:] + delta, vth_min, vth_max)
    renorm_incoming(syn, w_sum=2.5)
    print_incoming_sum("END EPOCH")
        
    dw = float(np.mean(np.abs(syn.w[:] - prev_w)))
    prev_w = syn.w[:].copy()
    print(f"w mean/min/max: {np.mean(prev_w):.3f} {np.min(prev_w):.3f} {np.max(prev_w):.3f} | (dw={dw:.3e})")
    print("v_th:", output_group.v_th[:])

    # deprive mapping neuron-to-label
    neuron_to_label, M = infer_mapping_mean(train_log_counts)
    #log['mapping'] = neuron_to_label.copy()
    print(f"Neuron to label mapping: {neuron_to_label}")
    print(f"Mean spike (rows= labels, columns=output neurons):\n{np.round(M, 2)}")
    
    # Testing
    correct = 0
    total = 0
    
    neuron0_mi = 0
    neuron1_mi = 0
    neuron0_hc = 0
    neuron1_hc = 0
    neuron0_entropy = 0
    neuron1_entropy = 0
    neuron0_entropy_rate = 0
    neuron1_entropy_rate = 0
    model_mi = 0
    model_hc = 0
    model_entropy = 0
    model_entropy_rate = 0
    
    neuron0_spikecount = []
    neuron1_spikecount = []

    for spikes, label in test_data:
        counts, out_spike = run_one(spikes,
                                    label,
                                    net=net,
                                    input_group=input_group,
                                    output_group=output_group,
                                    output_mon=output_mon,
                                    syn=syn,
                                    enable_plasticity=False)

        pred_neuron = int(np.argmax(counts))
        pred_label = neuron_to_label.get(pred_neuron, 0)
        correct += int(pred_label == label)
        total += 1
        
        neuron0_spikecount.append(counts[0])
        neuron1_spikecount.append(counts[1])
        
        # print(out_spike.shape)
        # print(spikes.shape)

        neuron0_entropy += entropy_calculator.shannon_entropy(out_spike[0])
        neuron1_entropy += entropy_calculator.shannon_entropy(out_spike[1])
        model_entropy += entropy_calculator.shannon_entropy(out_spike)
        
        neuron0_entropy_rate += entropy_calculator.entropy_rate(out_spike[0])
        neuron1_entropy_rate += entropy_calculator.entropy_rate(out_spike[1])
        model_entropy_rate += entropy_calculator.entropy_rate(out_spike)

        _neuron0_hc, _neuron1_hc, _model_hc = hc_matrix(out_spike, spikes)
        neuron0_hc += _neuron0_hc
        neuron1_hc += _neuron1_hc
        model_hc += _model_hc
        
        _neuron0_mi, _neuron1_mi, _model_mi = mi_matrix(out_spike, spikes)
        neuron0_mi += _neuron0_mi
        neuron1_mi += _neuron1_mi
        model_mi += _model_mi

    accuracy = correct / max(total, 1)
    log['accuracy'] = accuracy
    print(f"Test accuracy: {accuracy:.2f}")
    
    log['shannon_entropy'] = model_entropy / len(test_data)
    log['entropy_rate'] = model_entropy_rate / len(test_data)
    log['conditional_entropy'] = model_hc / len(test_data)
    log['mutual_information'] = model_mi / len(test_data)
    
    log['neuron0'] = {
        'shannon_entropy': neuron0_entropy / len(test_data),
        'entropy_rate': neuron0_entropy_rate / len(test_data),
        'conditional_entropy': neuron0_hc / len(test_data),
        'mutual_information': neuron0_mi / len(test_data),
        'spike_count': neuron0_spikecount
    }
    
    log['neuron1'] = {
        'shannon_entropy': neuron1_entropy / len(test_data),
        'entropy_rate': neuron1_entropy_rate / len(test_data),
        'conditional_entropy': neuron1_hc / len(test_data),
        'mutual_information': neuron1_mi / len(test_data),
        'spike_count': neuron1_spikecount
    }
    
    print(log)
    
    logs.append(log)

Epoch 1/20
END EPOCH Incoming weights sum: [2.5 2.5]
w mean/min/max: 0.078 0.062 0.095 | (dw=9.793e-02)
v_th: [0.33475 0.35575]
Neuron to label mapping: {0: 0, 1: 1}
Mean spike (rows= labels, columns=output neurons):
[[13.88 10.89]
 [11.6  16.69]]
Test accuracy: 0.50
{'epoch': 1, 'accuracy': 0.5, 'shannon_entropy': 0.5217066750131049, 'entropy_rate': 0.3115704006576322, 'conditional_entropy': 31.674331071470583, 'mutual_information': 1.7148961293681382, 'neuron0': {'shannon_entropy': 0.5392537406031499, 'entropy_rate': 0.31946514582306873, 'conditional_entropy': 16.386018008588685, 'mutual_information': 0.870101690712113, 'spike_count': [8, 12, 10, 11, 9, 19, 10, 7, 15, 22, 39, 12, 9, 6, 13, 12, 12, 14, 21, 8, 9, 5, 9, 25, 5, 14, 14, 9, 15, 6, 8, 12, 30, 21, 6, 11, 8, 10, 37, 20]}, 'neuron1': {'shannon_entropy': 0.5041596094230595, 'entropy_rate': 0.3036756554921958, 'conditional_entropy': 15.288313062881889, 'mutual_information': 0.8447944386560247, 'spike_count': [6, 10, 6, 9, 9, 15,

KeyboardInterrupt: 