In [None]:
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from brian2 import *
from tqdm import tqdm

In [None]:
# download the MNIST dataset as tensors
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
])
mnist_train = datasets.MNIST(root='./data', train=True, download=False, transform=mnist_transform)
#mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=mnist_transform)


In [None]:
# encode images to spike trains
# using poisson encoding
def poisson_encode(images, duration):
    # images: [B, 1, 28, 28]
    images = images.unsqueeze(-1)  # [B, 1, 28, 28, 1]
    rand_tensor = torch.rand((*images.shape[:-1], duration))
    spikes = (rand_tensor < images).float()
    return spikes  # [B, 1, 28, 28, T]

In [None]:
# function for calculating entropy
def shannon_entropy(spike_train):
    p = np.mean(spike_train, axis=1)
    p = np.clip(p, 1e-10, 1 - 1e-10)
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

def entropy_rate(spike_train):
    diffs = np.diff(spike_train, axis=1)
    transitions = np.abs(diffs)
    p_trans = np.mean(transitions, axis=1)
    p_trans = np.clip(p_trans, 1e-10, 1 - 1e-10)
    return -p_trans * np.log2(p_trans) - (1 - p_trans) * np.log2(1 - p_trans)

def conditional_entropy(X, Y):
    H_XY = shannon_entropy(np.logical_and(X, Y))
    H_Y = shannon_entropy(Y)
    return H_XY - H_Y

def mutual_information(X, Y):
    H_X = shannon_entropy(X)
    H_Y = shannon_entropy(Y)
    H_joint = shannon_entropy(np.logical_and(X, Y))
    return H_X + H_Y - H_joint

In [None]:
# build the SNN model with STDP
# input layer with 784 neurons (28x28 pixels)
# hidden layer with 128 neurons
# output layer with 10 neurons (for digits 0-9)
def build_snn(input_spikes, duration):
    start_scope()

    # Convert spike train to Brian2 format
    indices, times = np.where(input_spikes == 1)
    spike_times = times * ms
    input_group = SpikeGeneratorGroup(784, indices, spike_times)

    # LIF model
    eqs = 'dv/dt = -v / (10*ms) : 1'

    # Layers
    hidden = NeuronGroup(128, eqs, threshold='v>1', reset='v=0', method='exact')
    output = NeuronGroup(10, eqs, threshold='v>1', reset='v=0', method='exact')
    hidden.v = 0
    output.v = 0

    # STDP Synapses: input → hidden
    S1 = Synapses(input_group, hidden,
                  '''
                  w : 1
                  dApre/dt = -Apre / (20*ms) : 1 (event-driven)
                  dApost/dt = -Apost / (20*ms) : 1 (event-driven)
                  ''',
                  on_pre='''
                  v_post += w
                  Apre += 0.01
                  w = clip(w + Apost, 0, 1)
                  ''',
                  on_post='''
                  Apost += 0.01
                  w = clip(w + Apre, 0, 1)
                  ''')
    S1.connect()
    S1.w = '0.2 + 0.1*rand()'

    # STDP Synapses: hidden → output
    S2 = Synapses(hidden, output,
                  '''
                  w : 1
                  dApre/dt = -Apre / (20*ms) : 1 (event-driven)
                  dApost/dt = -Apost / (20*ms) : 1 (event-driven)
                  ''',
                  on_pre='''
                  v_post += w
                  Apre += 0.01
                  w = clip(w + Apost, 0, 1)
                  ''',
                  on_post='''
                  Apost += 0.01
                  w = clip(w + Apre, 0, 1)
                  ''')
    S2.connect()
    S2.w = '0.2 + 0.1*rand()'

    # Monitor output
    mon = SpikeMonitor(output)

    # Run simulation
    run(duration)

    # Convert output spikes to matrix [n_output, T]
    T = int(duration/ms)
    spike_mat = np.zeros((10, T))
    for i, t in zip(mon.i, mon.t):
        spike_mat[i, int(t/ms)] = 1

    return spike_mat


In [None]:
# running the SNN on batches of 8 images
# calculate entropy after each epoch
entropy_log = []

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=8, shuffle=True)

for batch_idx, (images, labels) in tqdm(enumerate(train_loader)):
    if batch_idx >= 10:
        break  # chỉ chạy 10 batch demo

    # Poisson encode
    spike_batch = poisson_encode(images, duration=100)
    spike_batch = spike_batch.view(8, 784, 100).numpy()

    # Chạy từng ảnh trong batch
    for i in range(8):
        output_spikes = build_snn(spike_batch[i], duration=100*ms)
        print(f'Batch {batch_idx}, Image {i}, Output Spikes:\n{output_spikes}')

        # Tính entropy các loại
        H = np.mean(shannon_entropy(output_spikes))
        Hr = np.mean(entropy_rate(output_spikes))
        Hc = np.mean(conditional_entropy(output_spikes, output_spikes))  # demo self
        MI = np.mean(mutual_information(output_spikes, output_spikes))   # demo self

        entropy_log.append({
            'batch': batch_idx,
            'img': i,
            'shannon': H,
            'rate': Hr,
            'cond': Hc,
            'mi': MI
        })


In [None]:
# present the entropy results
import pandas as pd
import seaborn as sns

df = pd.DataFrame(entropy_log)
df_avg = df.groupby('batch').mean()

plt.figure(figsize=(10, 5))
sns.lineplot(data=df_avg[['shannon', 'rate', 'cond', 'mi']])
plt.title("Entropy Metrics Across Batches")
plt.xlabel("Batch")
plt.ylabel("Entropy Value")
plt.grid()
plt.show()


In [None]:
import brian2
import numpy
print("Brian2:", brian2.__version__)
print("NumPy:", numpy.__version__)