<a href="https://colab.research.google.com/github/Singular-Brain/bindsnet/blob/master/lc_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Notebook setups

In [1]:
!pip install -q git+https://github.com/Singular-Brain/bindsnet

[K     |████████████████████████████████| 120 kB 8.4 MB/s 
[K     |████████████████████████████████| 76 kB 5.7 MB/s 
[K     |████████████████████████████████| 72 kB 1.2 MB/s 
[K     |████████████████████████████████| 280 kB 59.9 MB/s 
[K     |████████████████████████████████| 28.5 MB 46 kB/s 
[?25h  Building wheel for BindsNET (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.[0m


In [2]:
!wget https://data.deepai.org/mnist.zip
!mkdir -p ../data/MNIST/TorchvisionDatasetWrapper/raw
!unzip mnist.zip -d ../data/MNIST/TorchvisionDatasetWrapper/raw/

--2021-08-07 08:46:55--  https://data.deepai.org/mnist.zip
Resolving data.deepai.org (data.deepai.org)... 138.201.36.183
Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11597176 (11M) [application/x-zip-compressed]
Saving to: ‘mnist.zip’


2021-08-07 08:46:57 (8.23 MB/s) - ‘mnist.zip’ saved [11597176/11597176]

Archive:  mnist.zip
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/train-labels-idx1-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/train-images-idx3-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-images-idx3-ubyte.gz  
  inflating: ../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-labels-idx1-ubyte.gz  


In [3]:
!git clone https://github.com/Singular-Brain/bindsnet/

Cloning into 'bindsnet'...
remote: Enumerating objects: 9354, done.[K
remote: Counting objects: 100% (314/314), done.[K
remote: Compressing objects: 100% (230/230), done.[K
remote: Total 9354 (delta 210), reused 151 (delta 83), pack-reused 9040[K
Receiving objects: 100% (9354/9354), 40.20 MiB | 31.64 MiB/s, done.
Resolving deltas: 100% (5936/5936), done.


In [4]:
from bindsnet.network.nodes import Nodes
import os
import copy
import time
import gzip, pickle
import torch
import random
import torchvision
import numpy as np
import argparse
import matplotlib.pyplot as plt
import collections
from bindsnet import manual_seed
from torchvision import transforms
from tqdm.notebook import tqdm

from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.network import Network
from bindsnet.network.nodes import Input, LIFNodes, AdaptiveLIFNodes
from bindsnet.network.topology import LocalConnection, Connection
from bindsnet.network.monitors import Monitor, AbstractMonitor, TensorBoardMonitor
from bindsnet.learning import PostPre, MSTDP, MSTDPET 
from bindsnet.learning.reward import DynamicDopamineInjection
from bindsnet.utils import get_square_assignments, get_square_weights
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels

from bindsnet.analysis.plotting import (
    plot_input,
    plot_assignments,
    plot_performance,
    plot_weights,
    plot_spikes,
    plot_voltages,
)

## Sets up Gpu use


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu = True
train = True
def manual_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
            
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)

if not train:
    update_interval = n_test
seed = 2045 # The Singularity is Near!
manual_seed(seed)

Running on Device =  cuda


## Set up hyper-parameters

In [13]:
seed = 2045 # The Singularity is Near!
manual_seed(seed)
# training hyperparameters

n_train = 1000
n_test = 200
n_val = 100
val_interval = 100
running_window_length = 100

time = 250
dt = 1
train = True
gpu = False
device_id = 0

# Dataset Hyperparameters
target_classes = (0,1)
if target_classes:
    # npz_file = np.load(f'mask_{"_".join([str(i) for i in target_classes])}.npz')
    npz_file = np.load('bindsnet/mask_0.npz') ##### KESAFAT KARI !!!
    mask, mask_test = torch.from_numpy(npz_file['arr_0']), torch.from_numpy(npz_file['arr_1'])
    n_classes = len(target_classes)
else:
    mask = None
    mask_test = None
    n_classes = 10
crop_size = 20
intensity = 64

# Network Architecture Hyperparameters 
n_neurons = 100
C = 50
K = 12
S = 4
theta_plus = 0.05   ## Adaptive LIF
inh_factor = 0.5
wmin = -1.0
wmax = 1.0
neuron_per_class = int(n_neurons/n_classes)

train_hparams = {
    'n_train' : n_train,
    'n_test' : n_test,
    'n_val' : n_val,
    'val_interval' : val_interval,
    'running_window_length': running_window_length,
}

data_hparams = { 
    'intensity': intensity,
    'crop_size': crop_size,
}

network_hparams = {
    'n_classes': n_classes,
    'dt' : dt,
    'time' : time, 
    'nu': 1e-2,
    'n_neurons' : n_neurons,
    'n_channels': C,
    'filter_size': K,
    'stride': S,
    'theta_plus': theta_plus,
    'inh_factor': inh_factor,
    'wmin': wmin,
    'wmax': wmax,
    'crop_size': crop_size,
    'neuron_per_class': neuron_per_class,
    'update_rule': MSTDPET,
    'online': False,
    'tc_trace': 20,
}

# reward Hyperparameters 
reward_hparams = {
    # 'dopaminergic_layer': 'output', 
    'n_labels': n_classes,
    'neuron_per_class': neuron_per_class,
    'dopamine_per_spike': 0.001,
    'dopamine_for_correct_pred': 0,
    'tc_reward': 20,
    'give_reward': True,
    'dopamine_base': 0.00,
    'variant': 'pure_per_spike'
}

# Design network

## Reward Monitor

In [7]:
class RewardMonitor(AbstractMonitor):
    # language=rst
    """
    Records state variables of interest.
    """

    def __init__(
        self,
        time: None,
        batch_size: int = 1,
        device: str = "cpu",
    ):
        # language=rst
        """
        Constructs a ``Monitor`` object.

        :param obj: An object to record state variables from during network simulation.
        :param state_vars: Iterable of strings indicating names of state variables to record.
        :param time: If not ``None``, pre-allocate memory for state variable recording.
        :param device: Allow the monitor to be on different device separate from Network device
        """
        super().__init__()

        self.time = time
        self.batch_size = batch_size
        self.device = device

        # if time is not specified the monitor variable accumulate the logs
        if self.time is None:
            self.device = "cpu"

        self.recording = []
        self.reset_state_variables()

    def get(self,) -> torch.Tensor:
        # language=rst
        """
        Return recording to user.

        :return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is the shape of the recorded state
        variable.
        Note, if time == `None`, get return the logs and empty the monitor variable

        """
        # return_logs = torch.as_tensor(self.recording)
        # if self.time is None:
        #     self.recording = []
        return self.recording

    def record(self, **kwargs) -> None:
        # language=rst
        """
        Appends the current value of the recorded state variables to the recording.
        """
        self.recording.append(kwargs["reward"])
        # remove the oldest element (first in the list)
        # if self.time is not None:
        #     self.recording.pop(0)

    def reset_state_variables(self) -> None:
        # language=rst
        """
        Resets recordings to empty ``List``s.
        """
        self.recording = []


## LCNET

In [17]:
compute_size = lambda inp_size, k, s: int((inp_size-k)/s) + 1


class LCNet(Network):
    def __init__(
        self,
        n_classes: int,
        neuron_per_class: int,
        n_channels:int,
        filter_size: int,
        stride: int,
        online: bool,
        time: int,
        reward_fn,
        dt: float = 1.0,
        crop_size:int = 20,
        update_rule = MSTDPET,
        nu = 1e-2,
        wmin: float = 0.0,
        wmax: float = 1.0,
        norm: float = 78.4,
        theta_plus: float = 0.05,
        tc_theta_decay: float = 1e7,
        tc_trace:int = 20,
        **kwargs,
    ) -> None:
        # language=rst
        """
        Constructor for class ``DiehlAndCook2015``.

        :param n_inpt: Number of input neurons. Matches the 1D size of the input data.
        :param n_neurons: Number of excitatory, inhibitory neurons.
        :param exc: Strength of synapse weights from excitatory to inhibitory layer.
        :param inh: Strength of synapse weights from inhibitory to excitatory layer.
        :param dt: Simulation time step.
        :param nu: Single or pair of learning rates for pre- and post-synaptic events,
            respectively.
        :param reduction: Method for reducing parameter updates along the minibatch
            dimension.
        :param wmin: Minimum allowed weight on input to excitatory synapses.
        :param wmax: Maximum allowed weight on input to excitatory synapses.
        :param norm: Input to excitatory layer connection weights normalization
            constant.
        :param theta_plus: On-spike increment of ``DiehlAndCookNodes`` membrane
            threshold potential.
        :param tc_theta_decay: Time constant of ``DiehlAndCookNodes`` threshold
            potential decay.
        :param inpt_shape: The dimensionality of the input layer.
        """
        super().__init__(dt=dt, reward_fn = reward_fn, online=online)

        self.n_classes = n_classes
        self.neuron_per_class = neuron_per_class
        self.dt = dt
        self.time = time
        ### nodes
        inp = Input(shape= [1,20,20], traces=True, tc_trace=tc_trace)
        self.add_layer(inp, name="input")
        main = LIFNodes(shape= [C, compute_size(crop_size, K, S), compute_size(crop_size, K, S)], traces=True, tc_trace=tc_trace)
        self.add_layer(main, name="main")
        ### connections 
        LC = LocalConnection(inp, main, K, S, C, nu = nu, update_rule = update_rule, wmin = wmin, wmax= wmax,)
        self.add_connection(LC, "input", "main")
        ### main to output
        for c in range(n_classes):
            self.add_layer(
                LIFNodes(n= neuron_per_class, traces=True, tc_trace=tc_trace),
                name=f"output_{c}",
            )

            self.add_connection(
                Connection(
                    main,
                    self.layers[f"output_{c}"],
                    nu = nu,
                ),
                "main",
                f"output_{c}",
            )

        for source in range(n_classes):
            for target in range(n_classes):
                if source == target:
                    continue
                self.add_connection(
                    Connection(
                        self.layers[f"output_{source}"],
                        self.layers[f"output_{target}"],
                        nu = nu,
                        wmin=-inh_factor,
                        wmax=0,
                        w= torch.ones(
                            self.layers[f"output_{source}"].n,
                            self.layers[f"output_{target}"].n,
                        ) * -inh_factor
                    ),
                    f"output_{source}",
                    f"output_{target}",
                )

        # Directs network to GPU
        if gpu:
            self.to("cuda")


    def fit(
        self,
        dataloader,
        val_loader,
        online_validate = True,
        n_train = 200,
        n_test = 100,
        n_val = 50,
        val_interval = 50,
        running_window_length = 50,
    ):

        # add Monitors
        main_monitor = Monitor(self.layers["main"], ["v"], time=self.time, device=device)
        reward_monitor = RewardMonitor(time =self.time)
        tensorboard = TensorBoardMonitor(self, time = self.time)
        self.add_monitor(main_monitor, name="main")
        self.add_monitor(reward_monitor, name="reward")
        self.add_monitor(tensorboard, name="tensorboard")

        manual_seed(seed)
        print("Begin training.\n")
        acc_hist = collections.deque([], running_window_length)

        output_layers = set([layer for layer in self.layers if layer.startswith('output')])
        self.output_spikes = {}
        for layer in output_layers:
            self.output_spikes[layer] = Monitor(self.layers[layer], state_vars=["s"], time=self.time)
            self.add_monitor(self.output_spikes[layer], name="%s_spikes" % layer)

        val_acc = 0.0

        reward_history = []
        pbar = tqdm(total=n_train)
        self.reset_state_variables()
        for (i, datum) in enumerate(dataloader):
            if i > n_train:
                break

            image = datum["encoded_image"]
            label = datum["label"]

            # Run the network on the input.
            if gpu:
                inputs = {"input": image.cuda().view(self.time, 1, 1, 20, 20)}
            else:
                inputs = {"input": image.view(self.time, 1, 1, 20, 20)}

            dopaminergic_layers = {name: layer for name, layer in self.layers.items() if name.startswith('output')}

            self.run(inputs=inputs, 
                    time=self.time, 
                    **reward_hparams, 
                    labels = label.int().item(),
                    dopaminergic_layers= dopaminergic_layers,
                    train=True)


            # Get voltage recording.
            main_voltage = main_monitor.get("v")
            reward_history.append(reward_monitor.get())
            tensorboard.update(step= i)

            # Add to spikes recording.
            spikes_record = torch.zeros(self.n_classes, self.time, self.neuron_per_class)
            for c in range(self.n_classes):
                spikes_record[c] = self.output_spikes[f"output_{c}"].get("s").squeeze(1)

            predicted_label = torch.argmax(spikes_record.sum(1).sum(1))

            if predicted_label == label:
                reward_hparams['dopamine_for_correct_pred'] = 0.5
                self.run(inputs=inputs, time=self.time, **reward_hparams, labels =  label.int().item(),dopaminergic_layers= dopaminergic_layers, train=True)
                acc_hist.append(1)
            else:
                # reward_hparams['dopamine_for_correct_pred'] = -0.01
                # self.run(inputs=inputs, time=time, **reward_hparams, labels = label, train=True)
                acc_hist.append(0)

            reward_hparams['dopamine_for_correct_pred'] = 0.0

            print("\routput", spikes_record.sum(1).sum(1), 'predicted_label:',
                predicted_label.item(), 'GT:', label.item(), ' Reward:',
                sum(reward_monitor.get()),
                end = '')

            if online_validate and i % val_interval == 0 and i!=0:
                val_acc = self.evaluate(val_loader, n_val)

            # if  i % val_interval == 0 and i!=0:
            #     fig = create_plot(self.output_spikes, reward_monitor.get(), label)
            #     tensorboard.writer.add_figure('reward', fig, i)
                
            acc = 100 * sum(acc_hist)/len(acc_hist)
            self.reset_state_variables()  # Reset state variables.
            
            pbar.set_description_str("Running accuracy: " + "{:.2f}".format(acc) + "%, " + "Current val accuracy: " + "{:.2f}".format(val_acc) + "%, ")
            pbar.update()

        result_metrics = {'train_acc': acc, 'val_acc': val_acc}
        tensorboard.writer.add_hparams(
            {**train_hparams, **data_hparams, **network_hparams, **reward_hparams},
            result_metrics
        )



    def evaluate(self, val_loader, n_val):
        acc_hist_val = collections.deque([], running_window_length)

        spikes_val = {}

        self.train = False
        for (i, datum) in enumerate(val_loader):
            if i > n_val:
                break

            image = datum["encoded_image"]
            label = datum["label"]

            # Run the network on the input.
            if gpu:
                inputs = {"input": image.cuda().view(self.time, 1, 1, 20, 20)}
            else:
                inputs = {"input": image.view(self.time, 1, 1, 20, 20)}
            self.run(inputs=inputs, time=self.time, **reward_hparams, labels=None, train=False)

            # Add to spikes recording.
            spikes_record = torch.zeros(self.n_classes, self.time, self.neuron_per_class)
            for c in range(self.n_classes):
                spikes_record[c] = self.output_spikes[f"output_{c}"].get("s").squeeze(1)

            predicted_label = torch.argmax(spikes_record.sum(1).sum(1))

            if predicted_label == label:
                acc_hist_val.append(1)
            else:
                acc_hist_val.append(0)
            

            self.reset_state_variables()  # Reset state variables.

        self.train = True
        val_acc = 100 * sum(acc_hist_val)/len(acc_hist_val)
        return val_acc

# Load Dataset

In [9]:
class ClassSelector(torch.utils.data.sampler.Sampler):
    """Select target classes from the dataset"""
    def __init__(self, target_classes, data_source, mask = None):
        if mask is not None:
            self.mask = mask
        else:
            self.mask = torch.tensor([1 if data_source[i]['label'] in target_classes else 0 for i in range(len(data_source))])
        self.data_source = data_source

    def __iter__(self):
        return iter([i.item() for i in torch.nonzero(self.mask)])

    def __len__(self):
        return len(self.data_source)

In [10]:
# Load MNIST data.
manual_seed(seed)
dataset = MNIST(
    PoissonEncoder(time=time, dt=dt,),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
        transforms.Lambda(lambda x: x * intensity),
        transforms.CenterCrop(crop_size)]
    ),
)

# Create a dataloader to iterate and batch data
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1,
                                         sampler = ClassSelector(
                                                target_classes = target_classes,
                                                data_source = dataset,
                                                mask = mask,
                                                ) if target_classes else None
                                         )

# Load test dataset
test_dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(),
        transforms.Lambda(lambda x: x * intensity),
        transforms.CenterCrop(crop_size)]
    ),
)

val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
                                         sampler = ClassSelector(
                                                target_classes = target_classes,
                                                data_source = test_dataset,
                                                mask = mask_test,
                                                ) if target_classes else None
                                         )

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Using downloaded and verified file: ../../data/MNIST/TorchvisionDatasetWrapper/raw/train-images-idx3-ubyte.gz
Extracting ../../data/MNIST/TorchvisionDatasetWrapper/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/TorchvisionDatasetWrapper/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Using downloaded and verified file: ../../data/MNIST/TorchvisionDatasetWrapper/raw/train-labels-idx1-ubyte.gz
Extracting ../../data/MNIST/TorchvisionDatasetWrapper/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/TorchvisionDatasetWrapper/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Using downloaded and verified file: ../../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-images-idx3-ubyte.gz
Extracting ../../data/MNIST/TorchvisionDatasetWrapper/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/TorchvisionDatasetWrapper/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k

# Train

In [11]:
def create_plot(spikes, reward, label):
    fig = plt.figure()
    ax = fig.gca()
    for i in range(spikes.shape[1]):
        spikes_sum = spikes.sum(-1)[:,i]
        spike_timepoints = np.where(spikes_sum)[0]
        spike_values = spikes_sum[spike_timepoints]
        if i == label:
            kwargs = {'s':10, 'marker' : '*', 'c' : 'r'}
        else:
            kwargs = {'s':5, 'marker': '*'}
        
        ax.scatter(spike_timepoints, spike_values, **kwargs)
    
    ax.plot(reward)
    return fig

## Variant 1 (2 passes with per prediction)

In [None]:
net = LCNet(**network_hparams, reward_fn = DynamicDopamineInjection)
net.fit(dataloader = dataloader, val_loader = val_loader, **train_hparams)

Begin training.



  0%|          | 0/1000 [00:00<?, ?it/s]

output tensor([770.,   0.]) predicted_label: 0 GT: 0  Reward: 125.0

## Variant 4 (per spike)

In [30]:
net = LCNet(**network_hparams, reward_fn = DynamicDopamineInjection)
net.fit(dataloader = dataloader, val_loader = val_loader, **train_hparams)

Begin training.



  0%|          | 0/1000 [00:00<?, ?it/s]

output tensor([722., 757.]) predicted_label: 1 GT: 0  Reward: tensor(14.0578)

KeyboardInterrupt: ignored

In [None]:
net

LCNet(
  (input): Input()
  (main): LIFNodes()
  (input_to_main): LocalConnection(
    (source): Input()
    (target): LIFNodes()
  )
  (output_0): LIFNodes()
  (main_to_output_0): Connection(
    (source): LIFNodes()
    (target): LIFNodes()
  )
  (output_1): LIFNodes()
  (main_to_output_1): Connection(
    (source): LIFNodes()
    (target): LIFNodes()
  )
  (output_0_to_output_1): Connection(
    (source): LIFNodes()
    (target): LIFNodes()
  )
  (output_1_to_output_0): Connection(
    (source): LIFNodes()
    (target): LIFNodes()
  )
)

In [None]:
net.connections

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [None]:
%tensorboard --logdir runs

# Kernel 

In [None]:
from abc import ABC, abstractmethod
from typing import Union, Tuple, Optional, Sequence
from torch.nn.modules.utils import _pair

In [None]:
class AbstractKernel(ABC):
	r"""Base class for generating image filter kernels such as Gabor, DoG, etc. Each subclass should override :attr:`__call__` function.
	"""
	def __init__(self, kernel_size: Union[int, Tuple[int, int]]):
    """
    Instantiates a ``Filter Kernel`` object.

    :param kernel_size: Horizontal and vertical size of convolutional kernels.
    """
		self.kernel_size = _pair(kenel_size)

	def __call__(self):
		pass


In [None]:
class DoG_Kernel(AbstractKernel):
	r"""Generates DoG filter kernels.
	"""
	def __init__(self, 
                 kernel_size: Union[int, Tuple[int, int]], 
                 sigma1 : float,
                 sigma2 : float):
        """
		:param kernel_size: Horizontal and vertical size of DOG kernels.(If pass int, we consider it as a square filter) 
		:param sigma1 : The sigma parameter for the first Gaussian function.
		:param sigma2 : The sigma parameter for the second Gaussian function.
		"""
        super(DoG_Kernel, self).__init__(kernel_size)
		self.sigma1 = sigma1
		self.sigma2 = sigma2

	# returns a 2d tensor corresponding to the requested DoG filter
	def __call__(self):
		k = self.kernel_size//2
		x, y = np.mgrid[-k:k+1:1, -k:k+1:1]
		a = 1.0 / (2 * math.pi)
		prod = x*x + y*y
		f1 = (1/(self.sigma1*self.sigma1)) * np.exp(-0.5 * (1/(self.sigma1*self.sigma1)) * (prod))
		f2 = (1/(self.sigma2*self.sigma2)) * np.exp(-0.5 * (1/(self.sigma2*self.sigma2)) * (prod))
		dog = a * (f1-f2)
		dog_mean = np.mean(dog)
		dog = dog - dog_mean
		dog_max = np.max(dog)
		dog = dog / dog_max
		dog_tensor = torch.from_numpy(dog)
		return dog_tensor.float()