diff --git a/.gitignore b/.gitignore index 02d41b4b..e4ab00e4 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ logs/* .pytest_cache/* .vscode/* data/* +/examples/mnist/*.pt +/examples/mnist/draft* + diff --git a/bindsnet/analysis/plotting.py b/bindsnet/analysis/plotting.py index 89007b08..d5e1d66c 100644 --- a/bindsnet/analysis/plotting.py +++ b/bindsnet/analysis/plotting.py @@ -842,3 +842,230 @@ def plot_voltages( plt.tight_layout() return ims, axes + +# I added this plot_traces which is completely based on voltage_plot, just changed the word voltage +def plot_traces( + traces: Dict[str, torch.Tensor], + ims: Optional[List[AxesImage]] = None, + axes: Optional[List[Axes]] = None, + time: Tuple[int, int] = None, + n_neurons: Optional[Dict[str, Tuple[int, int]]] = None, + cmap: Optional[str] = "jet", + plot_type: str = "color", + thresholds: Dict[str, torch.Tensor] = None, + figsize: Tuple[float, float] = (8.0, 4.5), +) -> Tuple[List[AxesImage], List[Axes]]: + # language=rst + """ + Plot traces for any group(s) of neurons. + + :param traces: Contains trace data by neuron layers. + :param ims: Used for re-drawing the plots. + :param axes: Used for re-drawing the plots. + :param time: Plot traces of neurons in given time range. Default is entire + simulation time. + :param n_neurons: Plot traces of neurons in given range of neurons. Default is all + neurons. + :param cmap: Matplotlib colormap to use. + :param figsize: Horizontal, vertical figure size in inches. + :param plot_type: The way how to draw graph. 'color' for pcolormesh, 'line' for + curved lines. + :param thresholds: Thresholds of the neurons in each layer. + :return: ``ims, axes``: Used for re-drawing the plots. + """ + n_subplots = len(traces.keys()) + + # for key in traces.keys(): + # traces[key] = traces[key].view(-1, traces[key].size(-1)) + traces = {k: v.view(v.size(0), -1) for (k, v) in traces.items()} + + if time is None: + for key in traces.keys(): + time = (0, traces[key].size(0)) + break + + if n_neurons is None: + n_neurons = {} + + for key, val in traces.items(): + if key not in n_neurons.keys(): + n_neurons[key] = (0, val.size(1)) + + if not ims: + fig, axes = plt.subplots(n_subplots, 1, figsize=figsize) + ims = [] + if n_subplots == 1: # Plotting only one image + for v in traces.items(): + if plot_type == "line": + ims.append( + axes.plot( + v[1] + .detach() + .clone() + .cpu() + .numpy()[ + time[0] : time[1], + n_neurons[v[0]][0] : n_neurons[v[0]][1], + ] + ) + ) + + if thresholds is not None and thresholds[v[0]].size() == torch.Size( + [] + ): + ims.append( + axes.axhline( + y=thresholds[v[0]].item(), c="r", linestyle="--" + ) + ) + else: + ims.append( + axes.pcolormesh( + v[1] + .cpu() + .numpy()[ + time[0] : time[1], + n_neurons[v[0]][0] : n_neurons[v[0]][1], + ] + .T, + cmap=cmap, + ) + ) + + args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1]) + plt.title("%s traces for neurons (%d - %d) from t = %d to %d " % args) + plt.xlabel("Time (ms)") + + if plot_type == "line": + plt.ylabel("trace") + else: + plt.ylabel("Neuron index") + + axes.set_aspect("auto") + + else: # Plot each layer at a time + for i, v in enumerate(traces.items()): + if plot_type == "line": + ims.append( + axes[i].plot( + v[1] + .cpu() + .numpy()[ + time[0] : time[1], + n_neurons[v[0]][0] : n_neurons[v[0]][1], + ] + ) + ) + if thresholds is not None and thresholds[v[0]].size() == torch.Size( + [] + ): + ims.append( + axes[i].axhline( + y=thresholds[v[0]].item(), c="r", linestyle="--" + ) + ) + else: + ims.append( + axes[i].matshow( + v[1] + .cpu() + .numpy()[ + time[0] : time[1], + n_neurons[v[0]][0] : n_neurons[v[0]][1], + ] + .T, + cmap=cmap, + ) + ) + args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1]) + axes[i].set_title( + "%s traces for neurons (%d - %d) from t = %d to %d " % args + ) + + for ax in axes: + ax.set_aspect("auto") + + if plot_type == "color": + plt.setp(axes, xlabel="Simulation time", ylabel="Neuron index") + elif plot_type == "line": + plt.setp(axes, xlabel="Simulation time", ylabel="trace") + + plt.tight_layout() + + else: + # Plotting figure given + if n_subplots == 1: # Plotting only one image + for v in traces.items(): + axes.clear() + if plot_type == "line": + axes.plot( + v[1] + .cpu() + .numpy()[ + time[0] : time[1], n_neurons[v[0]][0] : n_neurons[v[0]][1] + ] + ) + if thresholds is not None and thresholds[v[0]].size() == torch.Size( + [] + ): + axes.axhline(y=thresholds[v[0]].item(), c="r", linestyle="--") + else: + axes.matshow( + v[1] + .cpu() + .numpy()[ + time[0] : time[1], n_neurons[v[0]][0] : n_neurons[v[0]][1] + ] + .T, + cmap=cmap, + ) + args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1]) + axes.set_title( + "%s traces for neurons (%d - %d) from t = %d to %d " % args + ) + axes.set_aspect("auto") + + else: + # Plot each layer at a time + for i, v in enumerate(traces.items()): + axes[i].clear() + if plot_type == "line": + axes[i].plot( + v[1] + .cpu() + .numpy()[ + time[0] : time[1], n_neurons[v[0]][0] : n_neurons[v[0]][1] + ] + ) + if thresholds is not None and thresholds[v[0]].size() == torch.Size( + [] + ): + axes[i].axhline( + y=thresholds[v[0]].item(), c="r", linestyle="--" + ) + else: + axes[i].matshow( + v[1] + .cpu() + .numpy()[ + time[0] : time[1], n_neurons[v[0]][0] : n_neurons[v[0]][1] + ] + .T, + cmap=cmap, + ) + args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1]) + axes[i].set_title( + "%s traces for neurons (%d - %d) from t = %d to %d " % args + ) + + for ax in axes: + ax.set_aspect("auto") + + if plot_type == "color": + plt.setp(axes, xlabel="Simulation time", ylabel="Neuron index") + elif plot_type == "line": + plt.setp(axes, xlabel="Simulation time", ylabel="trace") + + plt.tight_layout() + + return ims, axes diff --git a/bindsnet/evaluation/evaluation.py b/bindsnet/evaluation/evaluation.py index 5271d762..83f15a70 100644 --- a/bindsnet/evaluation/evaluation.py +++ b/bindsnet/evaluation/evaluation.py @@ -118,7 +118,7 @@ def all_activity( n_assigns = torch.sum(assignments == i).float() if n_assigns > 0: - # Get indices of samples with this label. + # Get indices of samples with this label. # correcting : get the number of neurons with this label indices = torch.nonzero(assignments == i).view(-1) # Compute layer-wise firing rate for this label. diff --git a/bindsnet/learning/__init__.py b/bindsnet/learning/__init__.py index 5a733783..e9fb5a1a 100644 --- a/bindsnet/learning/__init__.py +++ b/bindsnet/learning/__init__.py @@ -5,6 +5,7 @@ LearningRule, NoOp, PostPre, + Bi_sigmoid, Rmax, WeightDependentPostPre, ) @@ -13,6 +14,7 @@ "LearningRule", "NoOp", "PostPre", + "Bi_sigmoid", "WeightDependentPostPre", "Hebbian", "MSTDP", diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index e2c171cd..ac6354bc 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -389,6 +389,11 @@ def _connection_update(self, **kwargs) -> None: """ Post-pre learning rule for ``Connection`` subclass of ``AbstractConnection`` class. + + self.source.s : 28 *28 array of 0 and 1 source_s : array converted to 1D vector (784*1) + self.target.x : array of 100 values (~1e-5) target_x : araray (20*5) (values ~1e-9) + source : first layer, target = second layer. + s: spike occurances (0 or 1) for each neuron; x : exp decaying trace """ batch_size = self.source.batch_size @@ -549,6 +554,80 @@ def _conv3d_connection_update(self, **kwargs) -> None: super().update() +#=================================================== +class Bi_sigmoid(LearningRule): + # language=rst + """ + Bi_sigmoid STDP rule involving only post-synaptic spiking activity. The weight update + quantity is poisitive if the post-synaptic spike occures shortly after the presynatpic spike, + and negative otherwise. + """ + + def __init__( + self, + connection: AbstractConnection, + nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None, + reduction: Optional[callable] = None, + weight_decay: float = 0.0, + **kwargs, + ) -> None: + # language=rst + """ + Constructor for ``Bi_sigmoid`` learning rule. + + :param connection: An ``AbstractConnection`` object whose weights the + ``Bi_sigmoid`` learning rule will modify. + :param nu: Single or pair of learning rates for pre- and post-synaptic events. It also + accepts a pair of tensors to individualize learning rates of each neuron. + In this case, their shape should be the same size as the connection weights. + :param reduction: Method for reducing parameter updates along the batch + dimension. + :param weight_decay: Coefficient controlling rate of decay of the weights each iteration. + """ + super().__init__( + connection=connection, + nu=nu, + reduction=reduction, + weight_decay=weight_decay, + **kwargs, + ) + + assert ( + self.source.traces and self.target.traces + ), "Both pre- and post-synaptic nodes must record spike traces." + + if isinstance(connection, (Connection, LocalConnection)): # added: Bi_sigmoid will work only fore these 2 connections + self.update = self._connection_update # rewrites the update rule defined in the base class + else: + raise NotImplementedError( + "This learning rule is not supported for this Connection type." + ) + + def _connection_update(self, **kwargs) -> None: + # language=rst + """ + Bi_sigmoid learning rule for ``Connection`` subclass of ``AbstractConnection`` + class. + + self.source.s : 28 *28 array of 0 and 1 source_s : array converted to 1D vector (784*1) + self.target.x2 : array of 100 values (~1e-5) target_x2 : araray (20*5) (values ~1e-9) + source : first layer, target = second layer. + s: spike occurances (0 or 1) for each neuron; x2 : bi_sigmoid decaying trace + In this rule we only use the spiking of post (target_s) and the bi_sigmoid trace of pre (source_x2) + """ + batch_size = self.source.batch_size + + # Post-synaptic update. + if self.nu[1].any(): + target_s = (self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]) # 100 values of 0&1 + source_x2 = self.source.x2.view(batch_size, -1).unsqueeze(2) # 784 value 1D ( values between -1 and 1) + self.connection.w += self.reduction(torch.bmm(source_x2, target_s), dim=0) + del source_x2, target_s + + super().update() + + +#=================================================== class WeightDependentPostPre(LearningRule): # language=rst diff --git a/bindsnet/models/__init__.py b/bindsnet/models/__init__.py index b57121bc..8a13b8d8 100644 --- a/bindsnet/models/__init__.py +++ b/bindsnet/models/__init__.py @@ -1,5 +1,6 @@ from bindsnet.models.models import ( DiehlAndCook2015, + Salah_model, DiehlAndCook2015v2, IncreasingInhibitionNetwork, LocallyConnectedNetwork, @@ -9,6 +10,7 @@ __all__ = [ "TwoLayerNetwork", "DiehlAndCook2015v2", + "Salah_model", "DiehlAndCook2015", "IncreasingInhibitionNetwork", "LocallyConnectedNetwork", diff --git a/bindsnet/models/models.py b/bindsnet/models/models.py index 50f27872..6a523dfc 100644 --- a/bindsnet/models/models.py +++ b/bindsnet/models/models.py @@ -5,7 +5,7 @@ from scipy.spatial.distance import euclidean from torch.nn.modules.utils import _pair -from bindsnet.learning import PostPre +from bindsnet.learning import PostPre, Bi_sigmoid from bindsnet.network import Network from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes from bindsnet.network.topology import Connection, LocalConnection @@ -199,6 +199,128 @@ def __init__( self.add_connection(exc_inh_conn, source="Ae", target="Ai") self.add_connection(inh_exc_conn, source="Ai", target="Ae") +#============================================================================= + +class Salah_model(Network): + # language=rst + """ + It implements the same network architecture model used by `(Diehl & Cook 2015)`, + that has input, excitatoy and inhebetory neurons layers. But this model uses + a Bi-sigmoid learning rule (Bi_sigmoid), which is hardware friendly. The Bi-sigmoid rule + describes the learning behavior of the MTJs-based synapses. + More details here (DADDINOUOU & VATAJELU 2024) : + `_. + """ + + def __init__( + self, + n_inpt: int, + n_neurons: int = 100, + exc: float = 22.5, + inh: float = 17.5, + dt: float = 1.0, + nu: Optional[Union[float, Sequence[float]]] = (1e-4, 1e-2), + reduction: Optional[callable] = None, + wmin: float = 0.0, + wmax: float = 1.0, + norm: float = 78.4, + theta_plus: float = 0.05, + tc_theta_decay: float = 1e7, + inpt_shape: Optional[Iterable[int]] = None, + ) -> None: + # language=rst + """ + Constructor for class ``DiehlAndCook2015``. + + :param n_inpt: Number of input neurons. Matches the 1D size of the input data. + :param n_neurons: Number of excitatory, inhibitory neurons. + :param exc: Strength of synapse weights from excitatory to inhibitory layer. + :param inh: Strength of synapse weights from inhibitory to excitatory layer. + :param dt: Simulation time step. + :param nu: Single or pair of learning rates for pre- and post-synaptic events, + respectively. + :param reduction: Method for reducing parameter updates along the minibatch + dimension. + :param wmin: Minimum allowed weight on input to excitatory synapses. + :param wmax: Maximum allowed weight on input to excitatory synapses. + :param norm: Input to excitatory layer connection weights normalization + constant. + :param theta_plus: On-spike increment of ``DiehlAndCookNodes`` membrane + threshold potential. + :param tc_theta_decay: Time constant of ``DiehlAndCookNodes`` threshold + potential decay. + :param inpt_shape: The dimensionality of the input layer. + """ + super().__init__(dt=dt) + + self.n_inpt = n_inpt + self.inpt_shape = inpt_shape + self.n_neurons = n_neurons + self.exc = exc + self.inh = inh + self.dt = dt + + # Layers + input_layer = Input( + n=self.n_inpt, shape=self.inpt_shape, traces=True, tc_trace=20.0 + ) + exc_layer = DiehlAndCookNodes( + n=self.n_neurons, + traces=True, + rest= -65.0, + reset= -60.0, + thresh= -52.0, + refrac=5, + tc_decay=100.0, + tc_trace=20.0, + theta_plus=theta_plus, + tc_theta_decay=tc_theta_decay, + ) + inh_layer = LIFNodes( + n=self.n_neurons, + traces=False, + rest=-60.0, + reset=-45.0, + thresh=-40.0, + refrac=2, + tc_decay=10.0, + tc_trace=20.0, + ) + + # Connections + w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) + input_exc_conn = Connection( # Modied_PostPre is only working for Connection and LocalConnection) + source=input_layer, + target=exc_layer, + w=w, + update_rule=Bi_sigmoid, + nu=nu, + reduction=reduction, + wmin=wmin, + wmax=wmax, + norm=norm, + ) + w = self.exc * torch.diag(torch.ones(self.n_neurons)) + exc_inh_conn = Connection( + source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc + ) + w = -self.inh * ( + torch.ones(self.n_neurons, self.n_neurons) + - torch.diag(torch.ones(self.n_neurons)) + ) + inh_exc_conn = Connection( + source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0 + ) + + # Add to network + self.add_layer(input_layer, name="X") + self.add_layer(exc_layer, name="Ae") + self.add_layer(inh_layer, name="Ai") + self.add_connection(input_exc_conn, source="X", target="Ae") + self.add_connection(exc_inh_conn, source="Ae", target="Ai") + self.add_connection(inh_exc_conn, source="Ai", target="Ae") + +#============================================================================= class DiehlAndCook2015v2(Network): # language=rst diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index cf8b709c..16026d3d 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -2,7 +2,6 @@ from functools import reduce from operator import mul from typing import Iterable, Optional, Union - import torch @@ -67,6 +66,8 @@ def __init__( if self.traces: self.register_buffer("x", torch.Tensor()) # Firing traces. + self.register_buffer("t_", torch.Tensor()) # added delta_t is registred + self.register_buffer("x2", torch.Tensor()) # added x2 bisimoid trace is registred self.register_buffer( "tc_trace", torch.tensor(tc_trace) ) # Time constant of spike trace decay. @@ -93,19 +94,29 @@ def forward(self, x: torch.Tensor) -> None: :param x: Inputs to the layer. """ + def bisigmoid_trace(t_): #added + A = 7.95E+05; k0 = 0.474723045; x0 = 20.77893753; k1 = 0.757072031; x1 = 48.93860322 + result = (-A / (1 + torch.exp(-k0*(t_-x0))) - A / (1 + torch.exp(-k1*(t_-x1))) + A) / A + result.masked_fill_(t_ > 60, torch.tensor(0)) + return result + if self.traces: # Decay and set spike traces. self.x *= self.trace_decay + + self.t_ += torch.tensor(self.dt) #added + self.x2 = bisigmoid_trace(self.t_) #added if self.traces_additive: self.x += self.trace_scale * self.s.float() else: self.x.masked_fill_(self.s.bool(), self.trace_scale) + self.t_.masked_fill_(self.s.bool(), torch.tensor(0)) #added, reset t_ to 0 after spiking if self.sum_input: # Add current input to running sum. self.summed += x.float() - + def reset_state_variables(self) -> None: # language=rst """ @@ -115,6 +126,8 @@ def reset_state_variables(self) -> None: if self.traces: self.x.zero_() # Spike traces. + self.t_.zero_() # added + self.x2.zero_() # added if self.sum_input: self.summed.zero_() # Summed inputs. @@ -144,6 +157,8 @@ def set_batch_size(self, batch_size) -> None: if self.traces: self.x = torch.zeros(batch_size, *self.shape, device=self.x.device) + self.t_ = torch.zeros(batch_size, *self.shape, device=self.t_.device) #added + self.x2 = torch.zeros(batch_size, *self.shape, device=self.x2.device) #added if self.sum_input: self.summed = torch.zeros( @@ -207,7 +222,6 @@ def __init__( trace_scale=trace_scale, sum_input=sum_input, ) - def forward(self, x: torch.Tensor) -> None: # language=rst """ @@ -217,16 +231,16 @@ def forward(self, x: torch.Tensor) -> None: """ # Set spike occurrences to input values. self.s = x - super().forward(x) - + + def reset_state_variables(self) -> None: # language=rst """ Resets relevant state variables. """ super().reset_state_variables() - + class McCullochPitts(Nodes): # language=rst @@ -1083,9 +1097,9 @@ def forward(self, x: torch.Tensor) -> None: # Decrement refractory counters. self.refrac_count -= self.dt - + # Check for spiking neurons. - self.s = self.v >= self.thresh + self.theta + self.s = self.v >= self.thresh + self.theta # spikes whenever it exceds threshold # Refractoriness, voltage reset, and adaptive thresholds. self.refrac_count.masked_fill_(self.s, self.refrac) @@ -1107,7 +1121,7 @@ def forward(self, x: torch.Tensor) -> None: # Voltage clipping to lower bound. if self.lbound is not None: self.v.masked_fill_(self.v < self.lbound, self.lbound) - + super().forward(x) def reset_state_variables(self) -> None: @@ -1131,6 +1145,7 @@ def compute_decays(self, dt) -> None: self.theta_decay = torch.exp( -self.dt / self.tc_theta_decay ) # Adaptive threshold decay (per timestep). + def set_batch_size(self, batch_size) -> None: # language=rst diff --git a/examples/mnist/eth_mnist.py b/examples/mnist/eth_mnist.py index 74f18f9c..d0250305 100644 --- a/examples/mnist/eth_mnist.py +++ b/examples/mnist/eth_mnist.py @@ -42,10 +42,11 @@ parser.add_argument("--test", dest="train", action="store_false") parser.add_argument("--plot", dest="plot", action="store_true") parser.add_argument("--gpu", dest="gpu", action="store_true") -parser.set_defaults(plot=True, gpu=True) +parser.set_defaults(plot=False, gpu=False, train = True) args = parser.parse_args() +save_as = "eth_test" seed = args.seed n_neurons = args.n_neurons n_epochs = args.n_epochs @@ -278,6 +279,7 @@ print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start)) print("Training complete.\n") +""" # Load MNIST data. test_dataset = MNIST( PoissonEncoder(time=time, dt=dt), @@ -296,7 +298,7 @@ # Record spikes during the simulation. spike_record = torch.zeros((1, int(time / dt), n_neurons), device=device) -# Train the network. +# Testing the network. print("\nBegin testing\n") network.train(mode=False) start = t() @@ -345,3 +347,5 @@ print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start)) print("Testing complete.\n") + +""" diff --git a/examples/mnist/evaluate_plot.py b/examples/mnist/evaluate_plot.py new file mode 100644 index 00000000..d07eebae --- /dev/null +++ b/examples/mnist/evaluate_plot.py @@ -0,0 +1,214 @@ + +import argparse +import os +from time import time as t + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torchvision import transforms +from tqdm import tqdm + +from bindsnet.analysis.plotting import ( + plot_assignments, + plot_input, + plot_performance, + plot_spikes, + plot_voltages, + plot_weights, + plot_traces, # added +) +from bindsnet.datasets import MNIST +from bindsnet.encoding import PoissonEncoder +from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting +from bindsnet.models import Salah_model +from bindsnet.network.monitors import Monitor +from bindsnet.utils import get_square_assignments, get_square_weights +from bindsnet.network import Network, load + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--n_neurons", type=int, default=1000) +parser.add_argument("--n_epochs", type=int, default=1) +parser.add_argument("--n_test", type=int, default=10000) +parser.add_argument("--n_train", type=int, default=60000) +parser.add_argument("--exc", type=float, default=22.5) +parser.add_argument("--inh", type=float, default=120) +parser.add_argument("--theta_plus", type=float, default=0.05) +parser.add_argument("--time", type=int, default=250) +parser.add_argument("--dt", type=int, default=1.0) +parser.add_argument("--intensity", type=float, default=128) +parser.add_argument("--progress_interval", type=int, default=10) +parser.add_argument("--update_interval", type=int, default=250) +parser.add_argument("--train", dest="train", action="store_true") +parser.add_argument("--test", dest="train", action="store_false") +parser.add_argument("--plot", dest="plot", action="store_true") +parser.add_argument("--gpu", dest="gpu", action="store_true") +parser.add_argument("--saved_as") +parser.set_defaults(plot=True, gpu=False, train="False", n_test=10000 ) +args = parser.parse_args() + +seed = args.seed +n_neurons = args.n_neurons +n_epochs = args.n_epochs +n_test = args.n_test +n_train = args.n_train +exc = args.exc +inh = args.inh +theta_plus = args.theta_plus +time = args.time +dt = args.dt +intensity = args.intensity +progress_interval = args.progress_interval +update_interval = args.update_interval +train = args.train +plot = args.plot +gpu = args.gpu +saved_as = args.saved_as # "exp_26" + +#================================================ +# Sets up Gpu use (not used) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if gpu and torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) +else: + torch.manual_seed(seed) + device = "cpu" + if gpu: + gpu = False + +torch.set_num_threads(os.cpu_count() - 1) +print("Running on Device = ", device) + +#================================================ + +n_classes = 10 +n_sqrt = int(np.ceil(np.sqrt(n_neurons))) +start_intensity = intensity + +# Load pre-trained network +network = load(f"net_{saved_as}.pt") +network.train(mode=False) + +if gpu: + network.to("cuda") + +# Load assignments, obtained while training +train_details = torch.load(f"./details_{saved_as}.pt", map_location=torch.device(device)) + +# Assign the variables from the loaded dictionary +assignments = train_details["assignments"] +proportions = train_details["proportions"] +rates = train_details["rates"] + +#================================================ +spikes = {} +for layer in set(network.layers): + spikes[layer] = Monitor( + network.layers[layer], state_vars=["s"], time=int(time / dt), device=device + ) + network.add_monitor(spikes[layer], name="%s_spikes" % layer) + +inpt_ims, inpt_axes = None, None +spike_ims, spike_axes = None, None +weights_im = None +assigns_im = None +perf_ax = None +voltage_axes, voltage_ims = None, None +trace_axes, trace_ims = None, None + +#================================================ +# Load MNIST data. +test_dataset = MNIST( + PoissonEncoder(time=time, dt=dt), + None, + root=os.path.join("..", "..", "data", "MNIST"), + download=True, + train=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] + ),) + +# Sequence of accuracy estimates. +accuracy = {"all": 0, "proportion": 0} + +# Record spikes during the simulation. +spike_record = torch.zeros((1, int(time / dt), n_neurons), device=device) + +#================================================================================== +#************************* Testing the network *********************************** + +print("\nBegin testing\n") +start = t() + +pbar = tqdm(total=n_test) +for step, batch in enumerate(test_dataset): + if step >= n_test: + break + # Get next input sample. + inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)} + if gpu: + inputs = {k: v.cuda() for k, v in inputs.items()} + + # Run the network on the input. + network.run(inputs=inputs, time=time) + + # Add to spikes recording. + spike_record[0] = spikes["Ae"].get("s").squeeze() + #spike_record[0] = spikes + + # Convert the array of labels into a tensor + label_tensor = torch.tensor(batch["label"], device=device) + + # Get network predictions. + all_activity_pred = all_activity( + spikes=spike_record, assignments=assignments, n_labels=n_classes + ) + + proportion_pred = proportion_weighting( + spikes=spike_record, + assignments=assignments, + proportions=proportions, + n_labels=n_classes, + ) + + # Compute network accuracy according to available classification strategies. + accuracy["all"] += float(torch.sum(label_tensor.long() == all_activity_pred).item()) + accuracy["proportion"] += float(torch.sum(label_tensor.long() == proportion_pred).item()) + + network.reset_state_variables() # Reset state variables. + pbar.set_description_str("Test progress: ") + pbar.update() + +print("\nAll activity accuracy: %.2f \n" % (100*accuracy["all"] / n_test)) +print("Proportion weighting accuracy: %.2f \n" % (100*accuracy["proportion"] / n_test)) +print("Testing complete after: %.4f seconds \n" % (t() - start)) + + + +''' +#================================================================================== +#********************** Ploting terained weights ******************************** +After training, If you want to plot the weights without testing the network, +uncomment this bloc and comment the precedent testing code. + +print("#of evaluation steps: \n",len(train_details["train_accur"]["all"])) +print("training time: \n", train_details['train_time']) +#print("training accur: \n", train_details["train_accur"]) +print("average of last 10 accuracies (all): \n", np.mean(train_details["train_accur"]["all"][-10:])) + +#extract weights +input_exc_weights = network.connections[("X", "Ae")].w +square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28 ) +square_assignments = get_square_assignments(assignments, n_sqrt) +train_accur = train_details["train_accur"] +train_accur_prop = {"Accuracy": train_accur["proportion"]} # creat a dict to plot only "proportion" + + +#plot +weights_im = plot_weights(square_weights, im=None) +assigns_im = plot_assignments(square_assignments, im=None) +#perf_ax = plot_performance(train_accur, x_scale=update_interval, ax=None) # plot both accrucies +perf_ax = plot_performance(train_accur_prop, x_scale=update_interval, ax=None) # plot only proportion +plt.pause(300) +''' diff --git a/examples/mnist/salah_example.py b/examples/mnist/salah_example.py new file mode 100644 index 00000000..06db69c3 --- /dev/null +++ b/examples/mnist/salah_example.py @@ -0,0 +1,324 @@ +import argparse +import os +from time import time as t + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torchvision import transforms +from tqdm import tqdm + +from bindsnet.analysis.plotting import ( + plot_assignments, + plot_input, + plot_performance, + plot_spikes, + plot_voltages, + plot_weights, + plot_traces, # added +) +from bindsnet.datasets import MNIST +from bindsnet.encoding import PoissonEncoder +from bindsnet.evaluation import all_activity, assign_labels, proportion_weighting +from bindsnet.models import Salah_model # import model +from bindsnet.network.monitors import Monitor +from bindsnet.utils import get_square_assignments, get_square_weights + +parser = argparse.ArgumentParser() +parser.add_argument("--seed", type=int, default=0) +parser.add_argument("--n_neurons", type=int, default=100) +parser.add_argument("--n_epochs", type=int, default=1) +parser.add_argument("--n_test", type=int, default=10000) +parser.add_argument("--n_train", type=int, default=60000) +parser.add_argument("--n_workers", type=int, default=-1) +parser.add_argument("--exc", type=float, default=22.5) +parser.add_argument("--inh", type=float, default=120) +parser.add_argument("--theta_plus", type=float, default=0.05) +parser.add_argument("--time", type=int, default=250) +parser.add_argument("--dt", type=int, default=1.0) +parser.add_argument("--intensity", type=float, default=128) +parser.add_argument("--progress_interval", type=int, default=10) +parser.add_argument("--update_interval", type=int, default=250) +# adding/not adding --train to CL makes args.train true/false +parser.add_argument("--train", dest="train", action="store_true") +parser.add_argument("--test", dest="train", action="store_false") +parser.add_argument("--plot", dest="plot", action="store_true") +parser.add_argument("--gpu", dest="gpu", action="store_true") +parser.add_argument("--save_as") +# But if none of the four is added, these are the default ones: +parser.set_defaults(plot=False, gpu=False, train="True") +args = parser.parse_args() + +print(args) +save_as = args.save_as + +seed = args.seed +n_neurons = args.n_neurons +n_epochs = args.n_epochs +n_test = args.n_test +n_train = args.n_train +n_workers = args.n_workers +exc = args.exc +inh = args.inh +theta_plus = args.theta_plus +time = args.time +dt = args.dt +intensity = args.intensity +progress_interval = args.progress_interval +update_interval = args.update_interval +train = args.train +plot = args.plot +gpu = args.gpu + + +# Sets up Gpu use +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if gpu and torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) +else: + torch.manual_seed(seed) + device = "cpu" + if gpu: + gpu = False + +torch.set_num_threads(os.cpu_count() - 1) +print("Running on Device = ", device) + +# Determines number of workers to use +if n_workers == -1: + n_workers = 0 #gpu * 4 * torch.cuda.device_count() + +n_sqrt = int(np.ceil(np.sqrt(n_neurons))) +start_intensity = intensity + +# Build network. +network = Salah_model( + n_inpt=784, + n_neurons=n_neurons, + exc=exc, + inh=inh, + dt=dt, + norm=78.4, + theta_plus=theta_plus, + inpt_shape=(1, 28, 28), +) + +# Directs network to GPU +if gpu: + network.to("cuda") + +# Load MNIST data. +train_dataset = MNIST( + PoissonEncoder(time=time, dt=dt), + None, + root=os.path.join("..", "..", "data", "MNIST"), + download=True, + train=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] # intensity = 128 + ), +) + +# Record spikes during the simulation. +spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device) + +# Neuron assignments and spike proportions. +n_classes = 10 +assignments = -torch.ones(n_neurons, device=device) +proportions = torch.zeros((n_neurons, n_classes), device=device) +rates = torch.zeros((n_neurons, n_classes), device=device) + +# Sequence of accuracy estimates. +accuracy = {"all": [], "proportion": []} + +# Voltage recording for excitatory and inhibitory layers. +exc_voltage_monitor = Monitor( + network.layers["Ae"], ["v"], time=int(time / dt), device=device +) +inh_voltage_monitor = Monitor( + network.layers["Ai"], ["v"], time=int(time / dt), device=device +) +network.add_monitor(exc_voltage_monitor, name="exc_voltage") +network.add_monitor(inh_voltage_monitor, name="inh_voltage") + +#============================== +#added +# trace recording for input and excitatory layers. +inp_trace_monitor = Monitor( + network.layers["X"], ["x2"], time=int(time / dt), device=device +) +exc_trace_monitor = Monitor( + network.layers["Ae"], ["x2"], time=int(time / dt), device=device +) +network.add_monitor(inp_trace_monitor, name="inp_trace") +network.add_monitor(exc_trace_monitor, name="exc_trace") + +# Set up monitors for traces +traces = {} +for layer in set(network.layers) - {"Ai"}: + traces[layer] = Monitor( + network.layers[layer], state_vars=["x"], time=int(time / dt), device=device + ) + + network.add_monitor(traces[layer], name="%s_traces" % layer) + +#============================== + +# Set up monitors for spikes and voltages +spikes = {} +for layer in set(network.layers): + spikes[layer] = Monitor( + network.layers[layer], state_vars=["s"], time=int(time / dt), device=device + ) + network.add_monitor(spikes[layer], name="%s_spikes" % layer) + + +voltages = {} +for layer in set(network.layers) - {"X"}: + voltages[layer] = Monitor( + network.layers[layer], state_vars=["v"], time=int(time / dt), device=device + ) + network.add_monitor(voltages[layer], name="%s_voltages" % layer) + +inpt_ims, inpt_axes = None, None +spike_ims, spike_axes = None, None +weights_im = None +assigns_im = None +perf_ax = None +voltage_axes, voltage_ims = None, None +trace_axes, trace_ims = None, None + +# Train the network. +print("\nBegin training.\n") +start = t() +for epoch in range(n_epochs): + labels = [] + + if epoch % progress_interval == 0: + print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) + start = t() + + # Create a dataloader to iterate and batch data + dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=1, shuffle=True, num_workers=n_workers, pin_memory=gpu) + + for step, batch in enumerate(tqdm(dataloader)): + if step > n_train: + break + # Get next input sample. + inputs = {"X": batch["encoded_image"].view(int(time / dt), 1, 1, 28, 28)} + if gpu: + inputs = {k: v.cuda() for k, v in inputs.items()} + + if step % update_interval == 0 and step > 0: + # Convert the array of labels into a tensor + label_tensor = torch.tensor(labels, device=device) + + # Get network predictions. + all_activity_pred = all_activity( + spikes=spike_record, + assignments=assignments, + n_labels=n_classes, + ) + proportion_pred = proportion_weighting( + spikes=spike_record, + assignments=assignments, + proportions=proportions, + n_labels=n_classes, + ) + + # Compute network accuracy according to available classification strategies. + accuracy["all"].append( + 100 + * torch.sum(label_tensor.long() == all_activity_pred).item() + / len(label_tensor) + ) + accuracy["proportion"].append( + 100 + * torch.sum(label_tensor.long() == proportion_pred).item() + / len(label_tensor) + ) + + print( + "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)" + % ( + accuracy["all"][-1], + np.mean(accuracy["all"]), + np.max(accuracy["all"]), + ) + ) + print( + "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f" + " (best)\n" + % ( + accuracy["proportion"][-1], + np.mean(accuracy["proportion"]), + np.max(accuracy["proportion"]), + ) + ) + + # Assign labels to excitatory layer neurons. + assignments, proportions, rates = assign_labels( + spikes=spike_record, + labels=label_tensor, + n_labels=n_classes, + rates=rates, + ) + + labels = [] + + labels.append(batch["label"]) + + # Run the network on the input. + network.run(inputs=inputs, time=time) + + # Get voltage recording. + exc_voltages = exc_voltage_monitor.get("v") + inh_voltages = inh_voltage_monitor.get("v") + + #added + # Get trace recording. + inp_traces = inp_trace_monitor.get("x2") + exc_traces = exc_trace_monitor.get("x2") + + # Add to spikes recording. + spike_record[step % update_interval] = spikes["Ae"].get("s").squeeze() + + # Optionally plot various simulation information. + if plot: + image = batch["image"].view(28, 28) + inpt = inputs["X"].view(time, 784).sum(0).view(28, 28) + input_exc_weights = network.connections[("X", "Ae")].w + square_weights = get_square_weights( + input_exc_weights.view(784, n_neurons), n_sqrt, 28 + ) + square_assignments = get_square_assignments(assignments, n_sqrt) + spikes_ = {layer: spikes[layer].get("s") for layer in spikes} + voltages = {"Ae": exc_voltages, "Ai": inh_voltages} + traces = {"X": inp_traces, "Ae": exc_traces} # added + inpt_axes, inpt_ims = plot_input( + image, inpt, label=batch["label"], axes=inpt_axes, ims=inpt_ims + ) + spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) + weights_im = plot_weights(square_weights, im=weights_im) + assigns_im = plot_assignments(square_assignments, im=assigns_im) + perf_ax = plot_performance(accuracy, x_scale=update_interval, ax=perf_ax) + voltage_ims, voltage_axes = plot_voltages( + voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line") + + #added + trace_ims, trace_axes = plot_traces( + traces, n_neurons = {"X": (250, 280)}, ims=trace_ims, axes=trace_axes, plot_type="line") + + plt.pause(1e-8) + + network.reset_state_variables() # Reset state variables. + +train_time = t()-start +network.save(f"./net_{save_as}.pt") # added +train_details = {"assignments": assignments, "proportions": proportions, "rates": rates,"train_accur":accuracy, "train_time": train_time, "spike_record": spike_record} +torch.save(train_details, f"./details_{save_as}.pt") + +print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, train_time)) +print("Training complete.\n") +