[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width="400">](https://github.com/jeshraghian/snntorch/)

# snnTorch - Population Coding in Spiking Neural Nets
## By Jason K. Eshraghian (www.jasoneshraghian.com)

<a href="https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_pop.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width="28">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width="80">](https://github.com/jeshraghian/snntorch/)

For a comprehensive overview on how SNNs work, and what is going on under the hood, [then you might be interested in the snnTorch tutorial series available here.](https://snntorch.readthedocs.io/en/latest/tutorials/index.html)
The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:

> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. "Training Spiking Neural Networks Using Lessons From Deep Learning". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>

# Introduction
It is thought that rate codes alone cannot be the dominant encoding mechanism in the primary cortex. One of several reasons is because the average neuronal firing rate is roughly $0.1-1$ Hz, which is far slower than the reaction response time of animals and humans.

But if we pool together multiple neurons and count their spikes together, then it becomes possible to measure a firing rate for a population of neurons in a very short window of time. Population coding adds some credibility to the plausibility of rate-encoding mechanisms.

<center>
<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial_pop/pop.png?raw=true' width="300">
</center>


In this tutorial, you will:
* Learn how to train a population coded network. Instead of assigning one neuron per class, we will extend this to multiple neurons per class, and aggregate their spikes together.

If running in Google Colab:
* You may connect to GPU by checking `Runtime` > `Change runtime type` > `Hardware accelerator: GPU`
* Next, install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`.

In [1]:
!pip install snntorch --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/104.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m102.4/104.7 kB[0m [31m3.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.7/104.7 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [6]:
import torch, torch.nn as nn
import snntorch as snn

### Useful functions

In [7]:
import csv
import numpy as np
import scipy.stats
from sklearn.model_selection import train_test_split

In [8]:
def extract_label(file_name, verbose=False):
    data = {}
    label = []
    with open(file_name, "r") as fin:
        reader = csv.reader(fin, delimiter=',')
        first = True
        for row in reader:
            lbl = row[2]
            if first or "TARGET" in lbl:
                first = False
                continue
            lbl = lbl.replace("TCGA-","")

            label.append(lbl)
            if lbl in data.keys():
                data[lbl] += 1 
            else:
                data[lbl] = 1
    if verbose:
        print(f"Number of classes in the dataset = {len(data)}")
        pprint.pprint(data, indent=4)

    return label

In [9]:
def create_dictionary(labels):
    dictionary = {}
    class_names = np.unique(labels)
    for i, name in enumerate(class_names):
        dictionary[name] = i
    return dictionary

In [10]:
def label_processing(labels):
    new_miRna_label = []
    dictionary = create_dictionary(labels)
    for i in labels:
        new_miRna_label.append(dictionary[i])
    return new_miRna_label

### 1.2 Download Dataset

In [1]:
import os
mir_dataset = "https://drive.google.com/drive/folders/1oWWeord8YYvtxIo2Pq2peyx7xOI-1Tmb?usp=sharing"
if "MLinApp_course_data" not in os.listdir("./"):
  ! gdown $mir_dataset -O ./MLinApp_course_data --folder

"gdown" non Š riconosciuto come comando interno o esterno,
 un programma eseguibile o un file batch.


In [11]:
# Remove the first row and the last column from the feature
miR_label = extract_label("./MLinApp_course_data/tcga_mir_label.csv")
miR_data = np.genfromtxt('./MLinApp_course_data/tcga_mir_rpm.csv', delimiter=',')[1:,0:-1]

In [12]:
number_to_delete = abs(len(miR_label) - miR_data.shape[0])
miR_data = miR_data[number_to_delete:,:]
# Convert labels in number 
num_miR_label = label_processing(miR_label)

In [13]:
# Z-score normalization
miR_data = scipy.stats.zscore(miR_data, axis=1)

assert np.isnan(miR_data).sum() == 0

In [14]:
print(miR_data[0], np.min(miR_data))

[ 1.68703834  1.67910068  1.71667838 ... -0.05112508 -0.01854857
  3.38106288] -0.13941802539632334


In [37]:
# log2 normalization <Optional>

miR_data = miR_data + abs(np.min(miR_data)) + 0.001

miR_data = np.log2(miR_data)

In [38]:
# normalization between [0, 255] <Optional>
miR_data = (miR_data - np.min(miR_data)) / (np.max(miR_data) - np.min(miR_data)) * 255

In [127]:
n_classes = np.unique(miR_label).size

print(n_classes)
print(miR_label)

print(num_miR_label)
print(miR_data)

33
['READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'READ', 'REA

# DataLoading
Define variables for dataloading.

In [128]:
batch_size = 128
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
padded_data = False

Define function to add padding to our data \<Optional\> 

In [129]:
import math

def add_pad_data(data):
  miR_data = data
  c_int = math.ceil(np.sqrt(len(miR_data[0])))
  pad = c_int ** 2 - len(miR_data[0])
  pad_width = (0, pad)

  padded_miR_data = np.zeros((miR_data.shape[0], miR_data.shape[1] + pad_width[1]))

  for i in range(len(miR_data)):
    padded_miR_data[i] = np.pad(miR_data[i], pad_width, mode='constant')

  # reshape shape[1] into (c_int, c_int)

  dim = int(np.sqrt(len(padded_miR_data[0])))
  padded_miR_data = padded_miR_data.reshape((padded_miR_data.shape[0],1, dim, dim))

  return padded_miR_data

## TODO: Generate subset based on top N most frequent labels \<Optional\>

From the dataset extract the 10 most frequent classes

In [130]:
# N = 10
N = 33
print(len(miR_data))
print(len(miR_label))

def elementi_piu_frequenti(lista, n):
    unici, conteggi = np.unique(lista, return_counts=True)
    indici_ordinati = np.argsort(conteggi)[::-1]
    elementi_frequenti = unici[indici_ordinati][:n]
    return list(elementi_frequenti)

top_label_set = elementi_piu_frequenti(miR_label, N)

11082
11082


In [131]:
# TODO: Write here your the code for identifying the most frequent classes
# N = 10
N = 33
labels = np.unique(miR_label, return_counts=True)
lab = []
for i in range(len(labels[0])):
  lab.append((labels[0][i], labels[1][i]))
lab.sort(key=lambda i: i[1])
top_lab = lab[-N:]
new_dataset  =[]
new_label_set = []

for idx in range(len(miR_label)):
  if miR_label[idx] in top_label_set:
    new_dataset.append(miR_data[idx])
    new_label_set.append(miR_label[idx])
    
n_classes = N
miR_data = np.array(new_dataset)
num_miR_label = label_processing(np.array(new_label_set))

## TODO: Dimensionality analysis and reduction using Principal Component Analysis  \<Optional\>

---

(PCA) on train_data. 

Keep only features that preserve 99% of the variance.

For further information, please look at the documentation available at https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html

In [132]:
from sklearn.decomposition import PCA

def apply_PCA(data):
  pca = PCA()
  return pca.fit_transform(data)



## Create DataLoader

In [133]:
# usefull if you want to represent your data as images <Optional>

miR_data = add_pad_data(miR_data)
padded_data = True

In [134]:
train_data, val_data, train_label, val_label = train_test_split(miR_data, num_miR_label, test_size=0.20, random_state=42)

In [135]:
from torch.utils.data import TensorDataset, DataLoader
miR_train = torch.Tensor(train_data)
miR_train_label = torch.Tensor(train_label)
miR_dataset_train = TensorDataset(miR_train, miR_train_label)

miR_val = torch.Tensor(val_data)
miR_val_label = torch.Tensor(val_label)
miR_dataset_val = TensorDataset(miR_val, miR_val_label)

train_loader = DataLoader(miR_dataset_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(miR_dataset_val, batch_size=batch_size)

# Define Network
Let's compare the performance of a pair of networks both with and without population coding.


Each group should try the assigned values.


In [145]:
#TODO: Select a random configuration of parameters
from snntorch import surrogate

# network parameters
if padded_data:
  num_inputs = train_data.shape[2] ** 2
else:
  num_inputs = train_data.shape[1]

num_hidden = 128 #GROUPS : A: [64], B: [128], C: [256], D: [512]
num_outputs = n_classes

# temporal dynamics
num_steps = 20 #GROUPS : A: [5], B: [10], C: [20], D: [50]

# spiking neuron parameters
beta = 0.7   # neuron decay rate  #GROUPS : A: [0.7], B: [0.8], C: [0.85], D: [0.9 - 1]
grad = surrogate.fast_sigmoid()

## Without population coding

In [146]:
first_layer_neuron = snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True)
second_layer_neuron =  snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)

In [138]:
#TODO: change neuron type es. Lapique https://snntorch.readthedocs.io/en/latest/snn.neurons_lapicque.html <Optional>
first_layer_neuron = snn.Lapicque(beta=beta, spike_grad=grad, init_hidden=True, threshold=.4)
second_layer_neuron =  snn.Lapicque(beta=beta, spike_grad=grad, init_hidden=True, output=True, threshold=.4)

In [77]:
#TODO: # change neuron type es. RLeaky https://snntorch.readthedocs.io/en/latest/snn.neurons_rleaky.html <Optional>
first_layer_neuron = snn.RLeaky(beta=beta, spike_grad=grad, init_hidden=True, linear_features=num_hidden ,threshold=.4)
second_layer_neuron = snn.RLeaky(beta=beta, spike_grad=grad, init_hidden=True, output=True, linear_features=num_outputs, threshold=.1)

In [147]:
# standard network
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(num_inputs, num_hidden),
                    first_layer_neuron,
                    nn.Linear(num_hidden, num_outputs),
                    second_layer_neuron
                    ).to(device)

## Next Step: define your own network

In [115]:
# TODO: define your own network (If you need another layer you need to define it. You can change the type of neurons between different layers)
# A full list of neurons type is available at https://snntorch.readthedocs.io/en/latest/snntorch.html#neuron-list
# Execute this step after reporting the results of the previous standard network

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(num_inputs, num_hidden),
                    first_layer_neuron,
                    nn.Linear(num_hidden, num_hidden),
                    nn.Linear(num_hidden, num_outputs),
                    second_layer_neuron
                    ).to(device)

## With population coding


In [106]:
#TODO: Select a random configuration of parameters
neurons_per_classes = 25  #GROUPS : A: [25], B: [50], C: [75], D: [100]
pop_outputs = n_classes * neurons_per_classes

In [107]:
first_layer_neuron =  snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True)
second_layer_neuron =  snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)

In [66]:
first_layer_neuron =  snn.Lapicque(beta=beta, spike_grad=grad, init_hidden=True, threshold=.4)
second_layer_neuron =  snn.Lapicque(beta=beta, spike_grad=grad, init_hidden=True, output=True, threshold=.4)

In [67]:
pop_outputs = n_classes * 50
first_layer_neuron =  snn.RLeaky(beta=beta, spike_grad=grad, init_hidden=True, linear_features=num_hidden)
second_layer_neuron =  snn.RLeaky(beta=beta, spike_grad=grad, init_hidden=True, output=True, linear_features=pop_outputs)

In [108]:
# standard network with population coding

net_pop = nn.Sequential(nn.Flatten(),
                        nn.Linear(num_inputs, num_hidden),
                        first_layer_neuron,   
                        nn.Linear(num_hidden, pop_outputs),
                        second_layer_neuron                        
                        ).to(device)

## Next Step: Define your own network with population coding

In [119]:
# TODO: define your own network (If you need another layer you need to define it. You can change the type of neurons between different layers)
# A full list of neurons type is available at https://snntorch.readthedocs.io/en/latest/snntorch.html#neuron-list
# Execute this step after reporting the results of the previous standard network with population coding

net_pop = nn.Sequential(nn.Flatten(),
                        nn.Linear(num_inputs, num_hidden),
                        first_layer_neuron,   
                        nn.Linear(num_hidden, pop_outputs),
                        nn.Linear(pop_outputs, pop_outputs),
                        second_layer_neuron                        
                        ).to(device)

# Training
## Without population coding
Define the optimizer and loss function. Here, we use the MSE Count Loss, which counts up the total number of output spikes at the end of the simulation run. 

The correct class has a target firing probability of 100%, and incorrect classes are set to 0%. 

In [148]:
#TODO: Select a random configuration of parameters
import snntorch.functional as SF

learning_rate = 1e-3 #GROUPS : A: [1e-3], B: [1.5e-3], C: [2e-3], D: [2.5e-3]

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0)

We will also define a simple test accuracy function that predicts the correct class based on the neuron with the highest spike count. 

In [142]:
from snntorch import utils

def test_accuracy(data_loader, net, num_steps, population_code=False, num_classes=False):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    data_loader = iter(data_loader)
    for data, targets in data_loader:
      data = data.to(device)
      targets = targets.to(device)
      utils.reset(net)
      spk_rec, _ = net(data)

      if population_code:
        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets, population_code=True, num_classes=n_classes) * spk_rec.size(1)
      else:
        acc += SF.accuracy_rate(spk_rec.unsqueeze(0), targets) * spk_rec.size(1)
        
      total += spk_rec.size(1)

  return acc/total

Let's run the training loop.

In [149]:
from snntorch import backprop

num_epochs = 20

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net, train_loader, num_steps=num_steps,
                          optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)
    
    print(f"Epoch: {epoch}")
    print(f"Test set accuracy: {test_accuracy(test_loader, net, num_steps)*100:.3f}%\n")

Epoch: 0
Test set accuracy: 12.006%

Epoch: 1
Test set accuracy: 14.707%

Epoch: 2
Test set accuracy: 23.192%

Epoch: 3
Test set accuracy: 19.676%

Epoch: 4
Test set accuracy: 23.577%

Epoch: 5
Test set accuracy: 23.886%

Epoch: 6
Test set accuracy: 25.850%

Epoch: 7
Test set accuracy: 26.105%

Epoch: 8
Test set accuracy: 26.029%

Epoch: 9
Test set accuracy: 27.109%

Epoch: 10
Test set accuracy: 26.838%

Epoch: 11
Test set accuracy: 27.760%

Epoch: 12
Test set accuracy: 32.859%

Epoch: 13
Test set accuracy: 30.142%

Epoch: 14
Test set accuracy: 29.810%

Epoch: 15
Test set accuracy: 32.414%

Epoch: 16
Test set accuracy: 33.841%

Epoch: 17
Test set accuracy: 32.322%

Epoch: 18
Test set accuracy: 33.538%

Epoch: 19
Test set accuracy: 34.161%



## With population coding

In [120]:
# TODO: Select a random configuration of parameters
learning_rate = 1e-3  #GROUPS : A: [1e-3], B: [1.5e-3], C: [2e-3], D: [2.5e-3]

loss_fn = SF.mse_count_loss(correct_rate=1.0, incorrect_rate=0.0, population_code=True, num_classes=n_classes)
optimizer = torch.optim.Adam(net_pop.parameters(), lr=learning_rate, betas=(0.9, 0.999))

In [121]:
num_epochs = 20

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net_pop, train_loader, num_steps=num_steps,
                            optimizer=optimizer, criterion=loss_fn, time_var=False, device=device)

    print(f"Epoch: {epoch}")
    print(f"Test set accuracy: {test_accuracy(test_loader, net_pop, num_steps, population_code=True, num_classes=n_classes)*100:.3f}%\n")
    

Epoch: 0
Test set accuracy: 46.502%

Epoch: 1
Test set accuracy: 53.417%

Epoch: 2
Test set accuracy: 58.058%

Epoch: 3
Test set accuracy: 62.248%

Epoch: 4
Test set accuracy: 61.093%

Epoch: 5
Test set accuracy: 64.750%

Epoch: 6
Test set accuracy: 66.532%

Epoch: 7
Test set accuracy: 66.408%

Epoch: 8
Test set accuracy: 68.801%

Epoch: 9
Test set accuracy: 67.690%

Epoch: 10
Test set accuracy: 68.110%

Epoch: 11
Test set accuracy: 70.236%

Epoch: 12
Test set accuracy: 70.314%

Epoch: 13
Test set accuracy: 71.908%

Epoch: 14
Test set accuracy: 70.878%

Epoch: 15
Test set accuracy: 71.017%

Epoch: 16
Test set accuracy: 73.659%

Epoch: 17
Test set accuracy: 71.127%

Epoch: 18
Test set accuracy: 71.862%

Epoch: 19
Test set accuracy: 70.783%



## TODO: Report here the parameters of your configurations

Consider that:

*   with 10 classes:
    *   we achieved ~51% of accuracy with the standard network without population coding and ~72% of accuracy with the standard population coding network
    *    we achieved ~56% of accuracy with the  custom network and ~70% of accuracy with the custom population coding network


*   with 33 classes:
    *   we achieved ~15% of accuracy with the standard network without population coding and ~30% of accuracy with the standard population coding network
    *   we achieved ~29% of accuracy with the  custom network and ~53% of accuracy with the custom population coding network



In [144]:
# 10 classes - without population coding
# num_hidden: B
# num_steps: C
# beta: C
# LR: A

# BEST ACCURACY = 67.375%



# 10 classes - with population coding
# num_hidden: B
# num_steps: C
# beta: A
# LR: A
# neurons_per_classes A

# BEST ACCURACY = 74.504%



# 33 classes - without population coding
# num_hidden: B
# num_steps: C
# beta: A
# LR: A

# BEST ACCURACY = 34.161%



# 33 classes - with population coding
# /

# Conclusion
The performance boost from population coding may start to fade as the number of time steps increases. But it may also be preferable to increasing time steps as PyTorch is optimized for handling matrix-vector products, rather than sequential, step-by-step operations over time. 

* For a detailed tutorial of spiking neurons, neural nets, encoding, and training using neuromorphic datasets, check out the
[snnTorch tutorial series](https://snntorch.readthedocs.io/en/latest/tutorials/index.html).
* For more information on the features of snnTorch, check out the [documentation at this link](https://snntorch.readthedocs.io/en/latest/).
* If you have ideas, suggestions or would like to find ways to get involved, then [check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)