In [17]:
import numpy as np
import matplotlib.pyplot as plt
from brian2 import *
import os
import random
from entropy.entropy import *

In [18]:
training_data = []
test_data = []

# load train data
train_path = "/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/sr1000/train"
for filename in os.listdir(train_path):
    if filename.endswith(".npz"):
        data = np.load(os.path.join(train_path, filename))
        spike_matrix = data['spike_matrix']
        label = data['label'].item()
        training_data.append((spike_matrix, label))

# load test data
test_path = "/Users/minhhieunguyen/Documents/Projects/Dissertation/Code/spike_data/sr1000/test"
for filename in os.listdir(test_path):
    if filename.endswith(".npz"):
        data = np.load(os.path.join(test_path, filename))
        spike_matrix = data['spike_matrix']
        label = data['label'].item()
        test_data.append((spike_matrix, label))

random.shuffle(training_data)
random.shuffle(test_data)

print(f"Loaded {len(training_data)} training samples and {len(test_data)} test samples.")

Loaded 160 training samples and 40 test samples.


In [19]:
def counts_from_new_spikes(mon, start_idx, n_outputs):
    """
    Get spike counts for each output neuron from a spike monitor, starting from a specific index.
    Since spike monitor store all spikes across inputs, we need to filter them based on the start index.

    Args:
        mon (SpikeMonitor): The spike monitor to get spikes from.
        start_idx (int): The index to start counting from.
        n_outputs (int): The number of output neurons.
    """
    ii = np.array(mon.i[start_idx:])
    if ii.size == 0:
        return np.zeros(n_outputs, dtype=int)
    return np.bincount(ii, minlength=n_outputs)

def infer_mapping_mean(train_log_counts, n_output=2, labels_tuple=(0 , 1)):
    """
    Infer the mapping from spike counts to labels based on the mean spike counts for each label.

    Args:
        train_log_counts (np.ndarray): The training log counts.
        n_output (int): The number of output neurons.
        labels_tuple (tuple): A tuple containing the labels for the outputs.

    Returns:
        dict: A dictionary mapping each label to its corresponding output neuron index.
    """
    
    # mean spike per (label, output neuron)
    mean = {
        lbl: np.mean(np.vstack(train_log_counts[lbl]), axis=0) 
        if len(train_log_counts[lbl]) > 0 else np.zeros(n_output, dtype=int)
        for lbl in labels_tuple
    }
    
    # 2x2 assignment matrix
    # rows are labels, columns are output neurons
    M = np.vstack([mean[lbl] for lbl in labels_tuple])
    
    # try mapping: label0-> argmax row0, label1-> other
    a0 = int(np.argmax(M[0])); map1 = {labels_tuple[0]: a0, labels_tuple[1]: 1 - a0}
    score1 = M[0, map1[labels_tuple[0]]] + M[1, map1[labels_tuple[1]]]
    
    # try mapping label1-> argmax row1, label0-> other
    a1 = int(np.argmax(M[1])); map2 = {labels_tuple[1]: a1, labels_tuple[0]: 1 - a1}
    score2 = M[1, map2[labels_tuple[1]]] + M[0, map2[labels_tuple[0]]]
    
    chosen = map1 if score1 >= score2 else map2
    neuron_to_label = {neuron: lbl for lbl, neuron in chosen.items()}
    
    return neuron_to_label, M

In [20]:
entropy_calculator = EntropyCalculator()
def mi_matrix(output_spikes, input_spikes):
    """
    Calculate the mutual information of the model.
    Calculation includes the MI for each output neuron and the overall MI of the model.
    
    Formula: I(S;R) = ∑p(s,r) log(p(s,r)/(p(s)p(r)))
    
    where:
        - S is the input spikes
        - R is the output spikes
        - p(s,r) is the joint probability of input and output spikes
        - p(s) is the marginal probability of input spikes
        - p(r) is the marginal probability of output spikes

    Args:
        output_spikes (list): Spikes from the output layer.
        input_spikes (list): Spikes from the input layer.

    Returns:
        np.ndarray: List of MI (neuron0, neuron1, overall).
    """
    output_spikes = np.array(output_spikes)
    input_spikes = np.array(input_spikes)
    
    neuron0 = 0
    neuron1 = 0
    model_mi = 0

    for input_spike in input_spikes:
        if not isinstance(input_spike, np.ndarray):
            raise ValueError("Input spikes must be a numpy array.")
        
        neuron0 += entropy_calculator.mutual_information(input_spike, output_spikes[0])
        neuron1 += entropy_calculator.mutual_information(input_spike, output_spikes[1])
        model_mi = neuron0 + neuron1

    return neuron0, neuron1, model_mi

def hc_matrix(output_spikes, input_spikes):
    """
    Calculate the conditional entropy of the model.
    Calculation includes the conditional entropy for each output neuron and the overall conditional entropy of the model.
    
    Formula: H(R|S) = -∑p(s)∑p(r|s) log(p(r|s))
    
    Args:
        output_spikes (list): Spikes from the output layer.
        input_spikes (list): Spikes from the input layer.

    Returns:
        np.ndarray: List of entropies (neuron0, neuron1, overall).
    """
    output_spikes = np.array(output_spikes)
    input_spikes = np.array(input_spikes)
    
    neuron0 = 0
    neuron1 = 0
    model_hc = 0
    
    for input_spike in input_spikes:
        if not isinstance(input_spike, np.ndarray):
            raise ValueError("Input spikes must be a numpy array.")
        
        neuron0 += entropy_calculator.conditional_entropy(output_spikes[0], input_spike)
        neuron1 += entropy_calculator.conditional_entropy(output_spikes[1], input_spike)
        model_hc = neuron0 + neuron1

    return neuron0, neuron1, model_hc

In [50]:
prefs.codegen.target = 'numpy'

In [60]:
defaultclock.dt = 1 * ms
start_scope()

n_input = 8
n_output = 2
n_epoch = 5
sim_duration = 1000 * ms
T = int(sim_duration / defaultclock.dt)

taupre = taupost = 50 * ms
taum = 20 * ms
taue = 5 * ms
gmax = 0.8
dApre = 3e-3
dApost = dApre
# v_thresh = 0.2
is_inhibitated = False
wmin = 0.01

eqs = '''
        dv/dt = (ge - v) / taum + I_ext : 1
        dge/dt = -ge / taue : 1
        I_ext : Hz
        v_th : 1
        '''

# input layer
input_group = SpikeGeneratorGroup(
    n_input,
    [],
    [] * ms
)

# output layer
output_group = NeuronGroup(
    n_output,
    eqs,
    threshold='v > v_th',
    reset='v = 0.0',
    refractory=5*ms,
    method='exact'
)
output_group.v = 0
output_group.v_th = 0.25
output_group.I_ext = 0*Hz

#synapse
syn = Synapses(
    input_group,
    output_group,
    '''
    w : 1
    plastic : 1
    dApre/dt = -Apre / taupre : 1 (event-driven)
    dApost/dt = -Apost / taupost : 1 (event-driven)
    ''',
    on_pre='''
    ge_post += w
    Apre += dApre
    w = clip(w - Apost * plastic * w, wmin, gmax)
    ''',
    on_post='''
    Apost += dApost
    w = clip(w + Apre * plastic * (gmax - w), wmin, gmax)
    '''
)

syn.connect()  # Full connectivity
seed(42)  # For reproducibility
syn.w = '0.1 + 0.04 * rand()'
syn.plastic = 1

output_mon = SpikeMonitor(output_group)
input_mon = SpikeMonitor(input_group)
net = Network(collect())

logs = []

def rate_homeostasis_update(counts, 
                            output_group,
                            T, 
                            eta=0.001, 
                            r_target=50.0):
    
    rates = counts.astype(float)
    delta = eta * (rates - r_target)

    new_th = output_group.v_th[:] + delta
    new_th = np.clip(new_th, 0.22, 0.38)
    output_group.v_th = new_th

def run_sample(spike_matrix,
               net,
               input_group,
               output_mon,
               n_output=2,
               sim_duration=1000 * ms,
               label=None,
               bias=5*Hz,
               t_bias=200*ms,
               enable_plasticity=True):
    syn.plastic = 1 if enable_plasticity else 0
    T = int(sim_duration / defaultclock.dt)
    t_warmup = 200*ms

    # adjust spike time to match the simulation time step
    t_start = net.t
    t_stop = t_start + sim_duration
    
    # set input spikes
    # move the spike times to the current simulation time
    idx, tt = np.where(spike_matrix == 1)
    if idx.size > 0:
        times = (tt * defaultclock.dt) + t_start
        input_group.set_spikes(idx, times)
    else:
        input_group.set_spikes([], []*ms)
    
    output_group.v = 0
    output_group.ge = 0
    if enable_plasticity and (label is not None):
        output_group.I_ext = 0*Hz
        net.run(t_warmup)
        if label == 0:
            output_group.I_ext = [bias, -bias]
        else:
            output_group.I_ext = [-bias, bias]
        net.run(t_bias)
        output_group.I_ext = 0*Hz
        net.run(sim_duration - t_warmup - t_bias, report=None)
    else:
        output_group.I_ext = 0*Hz
        net.run(sim_duration, report=None)

    # get spike counts
    # counts = counts_from_new_spikes(output_mon, start_idx, n_output)
    mask = (output_mon.t >= t_start) & (output_mon.t < t_stop)
    i_sel = np.asarray(output_mon.i[mask])
    t_sel = output_mon.t[mask]
    
    counts = np.bincount(i_sel, minlength=n_output) if i_sel.size > 0 else np.zeros(n_output, dtype=int)
    
    out_spike = np.zeros((n_output, T), dtype=int)
    # new_i = np.array(output_mon.i[start_idx:])
    # new_t = np.array(output_mon.t[start_idx:])
    if i_sel.size > 0:
        bins = np.floor((t_sel - t_start) / defaultclock.dt).astype(int)
        bins = np.clip(bins, 0, T - 1)
        out_spike[i_sel, bins] = 1
    return counts, out_spike

def normalize_incoming(syn,
                       n_output=2,
                       w_target=1.2):
    j_all = syn.j[:]
    w_all = syn.w[:]
    for j in range(n_output):
        idx = np.where(j_all == j)[0]
        if idx.size == 0:
            continue

        s = float(np.sum(w_all[idx]))
        if s > 1e-12:
            syn.w[idx] = w_all[idx] * w_target / s

bias0 = 30*Hz
bias_decay = 0.9

for epoch in range(n_epoch):
    print(f'Epoch {epoch + 1}/{n_epoch}')
    log = {}
    log['epoch'] = epoch + 1
    
    if epoch >= 1:
        # lateral inhibition
        if not is_inhibitated:
            inh_w = 0.02
            inh = Synapses(
                output_group,
                output_group,
                on_pre='ge_post -= inh_w'
            )
            inh.connect(condition='j != i')  # Full connectivity
            net.add(inh)
            is_inhibitated = True

    # Training
    train_log_counts = {0: [], 1: []}
    curr_bias = bias0 * (bias_decay ** epoch)
    for spike_matrix, label in training_data:
        counts, _ = run_sample(spike_matrix,
                               net=net,
                               input_group=input_group,
                               output_mon=output_mon,
                               n_output=n_output,
                               sim_duration=sim_duration,
                               label=label,
                               bias=curr_bias,
                               t_bias=150*ms,
                               enable_plasticity=True)
        train_log_counts[label].append(counts)
        rate_homeostasis_update(
            counts,
            output_group,
            T
        )
    normalize_incoming(syn, n_output=n_output, w_target=0.9)

    # log['weights'] = syn.w
    print("w mean/min/max:",
      float(np.mean(syn.w[:])),
      float(np.min(syn.w[:])),
      float(np.max(syn.w[:])))
    print("v_th:", output_group.v_th[:])

    # deprive mapping neuron-to-label
    neuron_to_label, M = infer_mapping_mean(train_log_counts, n_output=n_output, labels_tuple=(0, 1))
    log['mapping'] = neuron_to_label.copy()

    # print(f"Train log: {train_log_counts}")
    # print(f"Mean spike (rows= labels, columns=output neurons):\n{M}")
    print(f"Neuron to label mapping: {neuron_to_label}")
    
    # Testing
    correct = 0
    total = 0
    
    neuron0_mi = 0
    neuron1_mi = 0
    neuron0_hc = 0
    neuron1_hc = 0
    neuron0_entropy = 0
    neuron1_entropy = 0
    neuron0_entropy_rate = 0
    neuron1_entropy_rate = 0
    model_mi = 0
    model_hc = 0
    model_entropy = 0
    model_entropy_rate = 0
    
    neuron0_spikecount = []
    neuron1_spikecount = []

    for spikes, label in test_data:
        counts, out_spike = run_sample(spikes,
                                       net=net,
                                       input_group=input_group,
                                       output_mon=output_mon,
                                       n_output=n_output,
                                       sim_duration=sim_duration,
                                       enable_plasticity=False)

        pred_neuron = int(np.argmax(counts))
        pred_label = neuron_to_label.get(pred_neuron, 0)
        correct += int(pred_label == label)
        total += 1
        
        neuron0_spikecount.append(counts[0])
        neuron1_spikecount.append(counts[1])
        
        # print(out_spike.shape)
        # print(spikes.shape)

        neuron0_entropy += entropy_calculator.shannon_entropy(out_spike[0])
        neuron1_entropy += entropy_calculator.shannon_entropy(out_spike[1])
        model_entropy += entropy_calculator.shannon_entropy(out_spike)
        
        neuron0_entropy_rate += entropy_calculator.entropy_rate(out_spike[0])
        neuron1_entropy_rate += entropy_calculator.entropy_rate(out_spike[1])
        model_entropy_rate += entropy_calculator.entropy_rate(out_spike)

        _neuron0_hc, _neuron1_hc, _model_hc = hc_matrix(out_spike, spikes)
        neuron0_hc += _neuron0_hc
        neuron1_hc += _neuron1_hc
        model_hc += _model_hc
        
        _neuron0_mi, _neuron1_mi, _model_mi = mi_matrix(out_spike, spikes)
        neuron0_mi += _neuron0_mi
        neuron1_mi += _neuron1_mi
        model_mi += _model_mi

    accuracy = correct / max(total, 1)
    log['accuracy'] = accuracy
    print(f"Test accuracy: {accuracy:.2f}")
    
    log['shannon_entropy'] = model_entropy / len(test_data)
    log['entropy_rate'] = model_entropy_rate / len(test_data)
    log['conditional_entropy'] = model_hc / len(test_data)
    log['mutual_information'] = model_mi / len(test_data)
    
    log['neuron0'] = {
        'shannon_entropy': neuron0_entropy / len(test_data),
        'entropy_rate': neuron0_entropy_rate / len(test_data),
        'conditional_entropy': neuron0_hc / len(test_data),
        'mutual_information': neuron0_mi / len(test_data),
        'spike_count': neuron0_spikecount
    }
    
    log['neuron1'] = {
        'shannon_entropy': neuron1_entropy / len(test_data),
        'entropy_rate': neuron1_entropy_rate / len(test_data),
        'conditional_entropy': neuron1_hc / len(test_data),
        'mutual_information': neuron1_mi / len(test_data),
        'spike_count': neuron1_spikecount
    }
    
    print(log)
    
    logs.append(log)

Epoch 1/5
w mean/min/max: 0.1125 0.10899363186605823 0.11409520472741518
v_th: [0.38  0.379]
Neuron to label mapping: {0: 0, 1: 1}
Test accuracy: 0.53
{'epoch': 1, 'mapping': {0: 0, 1: 1}, 'accuracy': 0.525, 'shannon_entropy': 0.10764026010123746, 'entropy_rate': 0.10643325550141494, 'conditional_entropy': 1.6811883371947842, 'mutual_information': 0.04105582442501558, 'neuron0': {'shannon_entropy': 0.10725700697009306, 'entropy_rate': 0.10603410220654523, 'conditional_entropy': 0.8380331995030511, 'mutual_information': 0.020022856257693643, 'spike_count': [11, 28, 14, 33, 13, 21, 17, 10, 10, 9, 1, 10, 10, 15, 16, 16, 13, 16, 8, 17, 7, 11, 12, 16, 11, 10, 16, 11, 9, 15, 7, 20, 21, 35, 11, 10, 37, 15, 12, 9]}, 'neuron1': {'shannon_entropy': 0.10802351323238188, 'entropy_rate': 0.10683240879628467, 'conditional_entropy': 0.843155137691733, 'mutual_information': 0.021032968167321957, 'spike_count': [11, 28, 14, 34, 13, 21, 17, 10, 10, 9, 1, 10, 11, 15, 16, 17, 13, 16, 8, 17, 8, 11, 13, 16,