In [None]:
print("Starting program")
import os
from time import time as t

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm
from torch.profiler import profile, record_function, ProfilerActivity

from bindsnet.analysis.plotting import (
    plot_assignments,
    plot_input,
    plot_performance,
    plot_spikes,
    plot_voltages,
    plot_weights,
)
from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting
from bindsnet.network.monitors import Monitor
from bindsnet.utils import get_square_assignments, get_square_weights
from bindsnet.learning import PostPre
from bindsnet.network import Network
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
from bindsnet.network.topology import Connection, LocalConnection
from typing import Iterable, List, Optional, Sequence, Tuple, Union

seed = 0
n_epochs = 1
n_test = 10000
n_train = 60000
batch_size = 1
n_neurons = 100
padding = 0
time = 50
dt = 1.0
intensity = 128.0
progress_interval = 10
update_interval = 250
train = True
plot = False#True
gpu = True
n_classes = 10
n_workers = -1
exc = 22.5
inh = 120
theta_plus = 0.05


# Sets up Gpu use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if gpu and torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)
    device = "cpu"
    if gpu:
        gpu = False

torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)

# Determines number of workers to use
if n_workers == -1:
    n_workers = 0  # gpu * 4 * torch.cuda.device_count()

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity

class DiehlAndCook2015Standard(Network):
    # language=rst
    """
    Implements the spiking neural network architecture from `(Diehl & Cook 2015)
    <https://www.frontiersin.org/articles/10.3389/fncom.2015.00099/full>`_.
    Standard learning algorithm
    """

    def __init__(
        self,
        n_inpt: int,
        n_neurons: int = 100,
        exc: float = 22.5,
        inh: float = 17.5,
        dt: float = 1.0,
        nu: Optional[Union[float, Sequence[float]]] = (1e-4, 1e-2),
        reduction: Optional[callable] = None,
        wmin: float = 0.0,
        wmax: float = 1.0,
        norm: float = 78.4,
        theta_plus: float = 0.05,
        tc_theta_decay: float = 1e7,
        inpt_shape: Optional[Iterable[int]] = None,
        inh_thresh: float = -40.0,
        exc_thresh: float = -52.0,
    ) -> None:
        # language=rst
        """
        Constructor for class ``DiehlAndCook2015``.

        :param n_inpt: Number of input neurons. Matches the 1D size of the input data.
        :param n_neurons: Number of excitatory, inhibitory neurons.
        :param exc: Strength of synapse weights from excitatory to inhibitory layer.
        :param inh: Strength of synapse weights from inhibitory to excitatory layer.
        :param dt: Simulation time step.
        :param nu: Single or pair of learning rates for pre- and post-synaptic events,
            respectively.
        :param reduction: Method for reducing parameter updates along the minibatch
            dimension.
        :param wmin: Minimum allowed weight on input to excitatory synapses.
        :param wmax: Maximum allowed weight on input to excitatory synapses.
        :param norm: Input to excitatory layer connection weights normalization
            constant.
        :param theta_plus: On-spike increment of ``DiehlAndCookNodes`` membrane
            threshold potential.
        :param tc_theta_decay: Time constant of ``DiehlAndCookNodes`` threshold
            potential decay.
        :param inpt_shape: The dimensionality of the input layer.
        """
        super().__init__(dt=dt)

        self.n_inpt = n_inpt
        self.inpt_shape = inpt_shape
        self.n_neurons = n_neurons
        self.exc = exc
        self.inh = inh
        self.dt = dt

        # Layers
        input_layer = Input(
            n=self.n_inpt, shape=self.inpt_shape, traces=True, tc_trace=20.0
        )
        exc_layer = DiehlAndCookNodes(
            n=self.n_neurons,
            traces=True,
            rest=-65.0,
            reset=-60.0,
            thresh=exc_thresh,
            refrac=5,
            tc_decay=100.0,
            tc_trace=20.0,
            theta_plus=theta_plus,
            tc_theta_decay=tc_theta_decay,
        )
        inh_layer = LIFNodes(
            n=self.n_neurons,
            traces=False,
            rest=-60.0,
            reset=-45.0,
            thresh=inh_thresh,
            tc_decay=10.0,
            refrac=2,
            tc_trace=20.0,
        )

        # Connections
        w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
        input_exc_conn = Connection(
            source=input_layer,
            target=exc_layer,
            w=w,
            update_rule=PostPre,
            nu=nu,
            reduction=reduction,
            wmin=wmin,
            wmax=wmax,
            norm=norm,
        )
        w = self.exc * torch.diag(torch.ones(self.n_neurons))
        exc_inh_conn = Connection(
            source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
        )
        w = -self.inh * (
            torch.ones(self.n_neurons, self.n_neurons)
            - torch.diag(torch.ones(self.n_neurons))
        )
        inh_exc_conn = Connection(
            source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0
        )

        # Add to network
        self.add_layer(input_layer, name="X")
        self.add_layer(exc_layer, name="Ae")
        self.add_layer(inh_layer, name="Ai")
        self.add_connection(input_exc_conn, source="X", target="Ae")
        self.add_connection(exc_inh_conn, source="Ae", target="Ai")
        self.add_connection(inh_exc_conn, source="Ai", target="Ae")


# Build network.
network = DiehlAndCook2015Standard(
    n_inpt=784,
    n_neurons=100 ,
    exc=exc,
    inh=inh,
    dt=dt,
    norm=78.4,
    theta_plus=theta_plus,
    inpt_shape=(1, 28, 28),
)

# Directs network to GPU
if gpu:
    network.to("cuda")

# Load MNIST data.
train_dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    "cluster/home/thombruf/MNIST",
    download=True,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
    ),
)

# Record spikes during the simulation.
spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device)

# Neuron assignments and spike proportions.
n_classes = 10
assignments = -torch.ones(n_neurons, device=device)
proportions = torch.zeros((n_neurons, n_classes), device=device)
rates = torch.zeros((n_neurons, n_classes), device=device)

# Sequence of accuracy estimates.
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(
    network.layers["Ae"], ["v"], time=int(time / dt), device=device
)
inh_voltage_monitor = Monitor(
    network.layers["Ai"], ["v"], time=int(time / dt), device=device
)
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
    spikes[layer] = Monitor(
        network.layers[layer], state_vars=["s"], time=int(time / dt), device=device
    )
    network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
    voltages[layer] = Monitor(
        network.layers[layer], state_vars=["v"], time=int(time / dt), device=device
    )
    network.add_monitor(voltages[layer], name="%s_voltages" % layer)

inpt_ims, inpt_axes = None, None
spike_ims, spike_axes = None, None
weights_im = None
assigns_im = None
perf_ax = None
voltage_axes, voltage_ims = None, None

# Train the network.
print("\nBegin training.\n")

start = t()
for epoch in range(n_epochs):
    labels = []

    if epoch % progress_interval == 0:
        print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start))
        start = t()

    # Create a dataloader to iterate and batch data
    dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=True, num_workers=n_workers, pin_memory=gpu
    )

    for step, batch in enumerate(dataloader):
        if step > n_train:
            break
        # Get next input sample.
        inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)}
        if gpu:
            inputs = {k: v.cuda() for k, v in inputs.items()}

        if step % update_interval == 0 and step > 0:
            # Convert the array of labels into a tensor
            label_tensor = torch.tensor(labels, device=device)

            # Get network predictions.
            all_activity_pred = all_activity(
                spikes=spike_record, assignments=assignments, n_labels=n_classes
            )
            proportion_pred = proportion_weighting(
                spikes=spike_record,
                assignments=assignments,
                proportions=proportions,
                n_labels=n_classes,
            )

            # Compute network accuracy according to available classification strategies.
            accuracy["all"].append(
                100
                * torch.sum(label_tensor.long() == all_activity_pred).item()
                / len(label_tensor)
            )
            
            accuracy["proportion"].append(
                100
                * torch.sum(label_tensor.long() == proportion_pred).item()
                / len(label_tensor)
            )

            print(
                "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
                % (
                    accuracy["all"][-1],
                    np.mean(accuracy["all"]),
                    np.max(accuracy["all"]),
                )
            )
            print(
                "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f"
                " (best)\n"
                % (
                    accuracy["proportion"][-1],
                    np.mean(accuracy["proportion"]),
                    np.max(accuracy["proportion"]),
                )
            )

            # Assign labels to excitatory layer neurons.
            assignments, proportions, rates = assign_labels(
                spikes=spike_record,
                labels=label_tensor,
                n_labels=n_classes,
                rates=rates,
            )

            labels = []

        labels.append(batch["label"])

        # Run the network on the input.
        network.run(inputs=inputs, time=time)

        # Get voltage recording.
        exc_voltages = exc_voltage_monitor.get("v")
        inh_voltages = inh_voltage_monitor.get("v")

        # Add to spikes recording.
        spike_record[step % update_interval] = spikes["Ae"].get("s").squeeze()

        # Optionally plot various simulation information.
        if plot:
            voltages = {"Ae": exc_voltages, "Ai": inh_voltages}
            voltage_ims, voltage_axes = plot_voltages(
                voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line"
            )

            plt.pause(1e-8)

        network.reset_state_variables()  # Reset state variables.

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")



In [None]:
print(train_dataset)

In [None]:
from pathlib import Path
#network.save(file_name="Standard_network_Diehl_and_Cook_100_neurons")
network.save(str(Path.home()) + '/Standard_network_Diehl_and_Cook_100_neurons.pt')

In [None]:

print("Trained weights:")
print("\n")
print("Weights between input and excitatory layer:")
print(network.connections[("X", "Ae")].w)
print("\n")
print("Weights between excitatory and inhibitory layer:")
print(network.connections[("Ae", "Ai")].w)
print("\n")
print("Weights between inhibitory and excitatory layer:")
print(network.connections[("Ai", "Ae")].w)


In [None]:
np.savetxt('conn_X_Ae.txt', network.connections[("X", "Ae")].w.numpy())
np.savetxt('conn_Ae_Ai.txt', network.connections[("Ae", "Ai")].w.numpy())
np.savetxt('conn_Ai_Ae.txt', network.connections[("Ai", "Ae")].w.numpy())

In [None]:
# Below is code to extract neuron assignments to know which neurons correspond to which label

In [None]:
print("Assignments: ", assignments)


In [None]:
print("Proportions: ", proportions)


In [None]:
print("Rates: ", rates)

In [None]:
print("Trained weights lengths:")
print("\n")
print("Length of weights between input and excitatory layer:")
print(network.connections[("X", "Ae")].w.numpy().shape)
print("\n")
print("Length of weights between excitatory and inhibitory layer:")
print(network.connections[("Ae", "Ai")].w.numpy().shape)
print("\n")
print("Length of weights between inhibitory and excitatory layer:")
print(network.connections[("Ai", "Ae")].w.numpy().shape)

In [None]:
# I need to get the threshold from all the neurons extracted so that I can set them in hardware from file.
print(network.layers["Ae"].thresh.numpy())
print(network.layers["Ae"].thresh.numpy().shape)
print(network.layers["Ae"].theta.numpy())
print(network.layers["Ae"].theta.numpy().shape)
for theta in network.layers["Ae"].theta.numpy():
    print(theta)


In [None]:
for theta in network.layers["Ae"].theta.numpy():
    new_thresh = -52.0+theta
    print(new_thresh)


In [None]:

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")

# Load MNIST data.
test_dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    "cluster/home/thombruf/MNIST",
    download=True,
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
    ),
)

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(
    network.layers["Ae"], ["v"], time=int(time / dt), device=device
)
inh_voltage_monitor = Monitor(
    network.layers["Ai"], ["v"], time=int(time / dt), device=device
)

exc_spikes_monitor = Monitor(
    network.layers["Ae"], state_vars=["s"], time=int(time / dt), device=device
)

network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")
network.add_monitor(exc_spikes_monitor, name="exc_spikes")


# Sequence of accuracy estimates.
accuracy = {"all": 0, "proportion": 0}

# Record spikes during the simulation.
spike_record = torch.zeros((1, int(time / dt), n_neurons), device=device)
spikes_sum = torch.zeros((1, n_neurons))

# Test the network.
print("\nBegin testing\n")
network.train(mode=False)
start = t()

counter = 0
spike_counter = 0
max_voltage_prints = 3
max_spike_prints = 3
k = 10000
input_spike_trains_list = []

#pbar = tqdm(total=n_test)
for step, batch in enumerate(test_dataset):
    if step >= n_test:
        break
    # Get next input sample.
    inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)}
    # Iterate over the dataset and print the encoded spike trains
    if step < k:  # Change this number to print more or fewer samples
        spike_trains = batch["encoded_image"]  # Access the encoded spike trains
    
        label = batch['label']
    
        # Initialize recorded_spikes with the correct dimensions
        recorded_spikes = [[0] * 50 for _ in range(28 * 28)]  # 28*28 = 784 for MNIST images
    
        for time_step in range(spike_trains.shape[0]):  # 50 time steps
            # Flatten the 28x28 image to a 784-length vector for the current time step
            flat_image = spike_trains[time_step].flatten()
    
            # Store the spikes for each pixel at the current time step
            for pixel_index in range(len(flat_image)):

                recorded_spikes[pixel_index][time_step] = flat_image[pixel_index].item()
        input_spike_trains_list.append(recorded_spikes)
    
    if gpu:
        inputs = {k: v.cuda() for k, v in inputs.items()}

    # Run the network on the input.
    network.run(inputs=inputs, time=time)
    
    # Get voltage recording.
    exc_voltages = exc_voltage_monitor.get("v")
    inh_voltages = inh_voltage_monitor.get("v")
    exc_spikes = exc_spikes_monitor.get("s")

    # Add to spikes recording.
    spike_record[0] = spikes["Ae"].get("s").squeeze()

    spikes_sum += spike_record.sum(1)
    
    # Convert the array of labels into a tensor
    label_tensor = torch.tensor(batch["label"], device=device)

    # Get network predictions.
    all_activity_pred = all_activity(
        spikes=spike_record, assignments=assignments, n_labels=n_classes
    )
    proportion_pred = proportion_weighting(
        spikes=spike_record,
        assignments=assignments,
        proportions=proportions,
        n_labels=n_classes,
    )
    

    # Compute network accuracy according to available classification strategies.
    accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item())
    accuracy["proportion"] += float(
        torch.sum(label_tensor.long() == proportion_pred).item()
    )
    

    network.reset_state_variables()  # Reset state variables.

print("\nAll activity accuracy: %.2f" % (accuracy["all"] / n_test))
print("Proportion weighting accuracy: %.2f \n" % (accuracy["proportion"] / n_test))

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")


In [None]:
print(len(exc_spikes))
print((exc_spikes.shape))
print(exc_spikes[0][0])

I Add this here so that I can analyze the last image sent in - useful for just seeing the behaviour of the top level SNN.

In [None]:
def store_input_spikes_for_one_image(input_spike_trains_list, output_file_path):
    try:
            with open(output_file_path, 'w') as f:
                for recorded_spikes in input_spike_trains_list:
                    for recorded_pixel_spike_train in recorded_spikes:
                        f.write("[")
                        f.write(",".join(map(str, recorded_pixel_spike_train)))
                        f.write("]")
                        f.write("\n")
    except:
        print("Could not save input spike trains to file")

In [None]:
store_input_spikes_for_one_image(input_spike_trains_list=input_spike_trains_list, output_file_path="input_spikes_full_dataset.txt")

In [None]:
for key in exc_voltage_monitor:
    print(exc_voltage_monitor[key])
for key in inh_voltage_monitor:
    print(inh_voltage_monitor[key])

    

In [None]:
for t in range(time):
    print("Time step: ", t)
    print(exc_voltages[t])

In [None]:
for theta in network.layers["Ae"].theta.numpy():
    print(theta)
    print(theta-52.0)


In [None]:
print(test_dataset.image_encoder)

In [None]:
print(test_dataset.image_encoder.enc)

In [None]:
# Iterate over the dataset and print the encoded spike trains
for i in range(len(test_dataset)):
    data = test_dataset[i]  # Get the i-th sample
    spike_trains = data['encoded_image']  # Access the encoded spike trains

    recorded_spikes = [[0]*250 for i in range(784)]
    print(recorded_spikes)
    pixel_at_index = []
    total_sum = 0
    for image_matrix in spike_trains:
        for k in range(len(image_matrix)):
            flat_image = image_matrix[k].flatten()
            for i in range(len(flat_image)):
                recorded_spikes[i][k] = flat_image[i].item()

            if (sum(image_matrix[k].flatten()) > 0):
                print(sum(image_matrix[k].flatten()).item())
                total_sum += sum(image_matrix[k].flatten()).item()
    print(total_sum) 
    print(pixel_at_index)
    print(len(recorded_spikes))
    print(len(recorded_spikes[0]))
    print(recorded_spikes)
    for index, array in enumerate(recorded_spikes):
        if (sum(array)):
            print("The number of spikes at this pixel ", index, " was: ", sum(array))
            print("This spike happened at the time step: ", np.argmax(np.array(array)))


    # Optional: break after a few iterations to avoid too much output
    if i >= 0:  # Change this number to print more or fewer samples
        break

In [None]:
# Load MNIST data.
test_dataset = MNIST(
    PoissonEncoder(time=50, dt=dt),
    None,
    "cluster/home/thombruf/MNIST",
    download=True,
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
    ),
)


In [None]:
# Iterate over the dataset and print the encoded spike trains
k = 0#100
input_spike_trains_list = []
for i in range(0,len(test_dataset)):
    data = test_dataset[i]  # Get the i-th sample
    spike_trains = data['encoded_image']  # Access the encoded spike trains
    label = data['label']
    print("Label for this image: ", label)

    # Initialize recorded_spikes with the correct dimensions
    recorded_spikes = [[0] * 50 for _ in range(28 * 28)]  # 28*28 = 784 for MNIST images
   
    for t in range(spike_trains.shape[0]):  # 50 time steps
        # Flatten the 28x28 image to a 784-length vector for the current time step
        flat_image = spike_trains[t].flatten()
        if (t == 0):
            print(flat_image)
        
        # Store the spikes for each pixel at the current time step
        for pixel_index in range(len(flat_image)):
            if (flat_image[pixel_index]):
                print("Index in flat image with spike: ", pixel_index)
                print("Timestep: ", t)
            recorded_spikes[pixel_index][t] = flat_image[pixel_index].item()
    input_spike_trains_list.append(recorded_spikes)


    total_spike_count_for_this_image = 0
    for index, spike_recording in enumerate(recorded_spikes):
        if (sum(spike_recording)):
            total_spike_count_for_this_image += sum(spike_recording)

    # Optional: break after a few iterations to avoid too much output
    if i >= k:  # Change this number to print more or fewer samples
        break

In [None]:
def store_input_spikes_for_one_image(input_spike_trains_list, output_file_path):
    try:
            with open(output_file_path, 'w') as f:
                for recorded_spikes in input_spike_trains_list:
                    for recorded_pixel_spike_train in recorded_spikes:
                        f.write("[")
                        f.write(",".join(map(str, recorded_pixel_spike_train)))
                        f.write("]")
                        f.write("\n")
    except:
        print("Could not save input spike trains to file")

In [None]:
store_input_spikes_for_one_image(input_spike_trains_list=input_spike_trains_list, output_file_path="input_spike_trains_dense.txt")

In [None]:
new_assignments = [9, 2, 7, 6, 9, 4, 0, 5, 8, 4, 9, 8, 0, 2, 2, 8, 9, 8, 3, 5, 0, 4, 6, 9,
        0, 7, 9, 0, 3, 8, 2, 2, 6, 9, 6, 9, 8, 8, 0, 2, 6, 2, 8, 8, 3, 8, 3, 3,
        4, 9, 0, 0, 4, 0, 5, 8, 3, 0, 4, 5, 6, 0, 3, 0, 7, 6, 3, 1, 7, 0, 6, 8,
        3, 4, 4, 3, 3, 9, 5, 8, 3, 6, 1, 3, 5, 8, 1, 9, 6, 8, 7, 5, 3, 2, 0, 1,
        7, 7, 5, 8]
print(new_assignments.count(9))
print(new_assignments.count(8))
print(new_assignments.count(7))
print(new_assignments.count(6))
print(new_assignments.count(5))
print(new_assignments.count(4))
print(new_assignments.count(3))
print(new_assignments.count(2))
print(new_assignments.count(1))
print(new_assignments.count(0))
