# About BindsNet image to spike encoders

This notebook roughly present the BindsNet API behind the conversion of image datasets to spiking datasets. <br />
It's called is to build confidence in the development of new encoding methods for Spiking Neural Network input layers creation.

## BindsNet Dataset Wrapper

To encode input pixels to temporal data, BindsNet uses a ```TorchvisionDatasetWrapper``` class that creates **Custom** PyTorch dataset (read more about custom datasets and dataloaders -> [Here](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)). <br/>

The wrapper extends the ```torch.utils.data.Dataset``` class by adding 2 new parameters - ```image_encoder``` and ```label_encoder```. These encoders are instances of 
```bindsnet.encoding``` classes (PoissonEncoder, Bernoulli, RankOrderEncoder ...). The encoder classes are then called inside the custom dataset in the ```__getitem__``` function to encode the data once the dataset class is instanciated : 


```python
...
def __getitem__(self, ind: int) -> Dict[str, torch.Tensor]:
    image, label = super().__getitem__(ind)

    output = {
        "image": image,
        "label": label,
        "encoded_image": self.image_encoder(image),
        "encoded_label": self.label_encoder(label),
    }

    return output
...

```


The signature of the encoding functions is as follows : 

```python
def poisson(
    datum: torch.Tensor,
    time: Optional[int] = None,
    dt: float = 1.0,
    device="cpu",
    **kwargs,
) -> torch.Tensor:
```

They take an input tensor (from ```__getitem__```), the ```time``` window from encoding (250 for 250 ms for example) and ```dt``` as the simulation timestep. So if the encoding was done with a precision of 0.5ms and the time window is 200 then the returned tensor would have a shape of ```(time/dt, 1, image_size, image_size)```.

> **_Existing encoders :_** 
> The framework already provides encoders for poisson rate encoding, time-to-first-spike encoding, rank-order encoding and bernoulli encoding.


## Example of a new encoder class

### Burst Coding

In [burst coding](https://www.frontiersin.org/journals/neuroscience/articles/10.3389/fnins.2021.638474/full), pixels are normalized in the range (0, 1). For an input pixel P, the **number** of spikes in a burst is calculated as $N_{s}(P) = \lceil N_{max}P \rceil$, where $N_{max}$ is the maximum number of spikes and $\lceil . \rceil$ is the ceiling function. To create the spike train, the inter-spikes interval (ISI) also has to be calculated : 

$$
ISI(P) = 
\begin{cases} 
    \lceil - (T_{max} - T_{min}) P + T_{max} \rceil, & N_{s} > 1 \\ 
    T_{max}, & otherwise
\end{cases}
$$

where $T_{max}$ and $T_{min}$ are the maximum and minimum intervals, respectively. The ISI is confined in [$T_{min}$,$T_{max}$]. A larger input pixel produces a burst with a smaller ISI and more spikes inside. In the reference paper, the parameters are configured in a biological range. $N_{max}$ is chosen as 5 spikes for the optimal classification and computational performance (on biological base). $T_{max}$ was chosen as the time window for processing one image. $T_{min}$ was taken as 2 ms.

In [None]:
import torch, math

def burst(
    datum: torch.Tensor,
    time: int,
    dt: float = 1.0,
    tmin: int = 2,
    nmax: int = 5,
    device="cpu",
    **kwargs,
) -> torch.Tensor:
    """
    Generates burst spike trains based on input intensity.

    :param datum: Tensor of shape ``[n_1, ..., n_k]``.
    :param time: Length of Poisson spike train per input variable.
    :param dt: Simulation time step.
    :param tmin: int: Minimum spike timing.
    :param nmax: int: Maximum number of spike per neuron.
    :param device: target destination
    :return: Tensor of shape ``[time, n_1, ..., n_k]`` of burst spikes.
    """
    result = datum * nmax
    n_spikes = torch.ceil(result).int()
    # if n_spikes > 1:
    #     ISI = torch.ceil(-(T_max - T_min) * P + T_max)
    # else:
    #     # Set ISI to T_max if N_s <= 1
    #     ISI = T_max
    return torch.Tensor(n_spikes).byte()

### Poisson Encoder

In [35]:
import torch

dt = 1 
device = 'cuda'
datum = torch.rand(1, 28, 28) * 128
shape, size = datum.shape, datum.numel()
datum = datum.flatten().to(device)

time = int(200 / dt)

rate = torch.zeros(size, device=device)
rate[datum != 0] = 1 / datum[datum != 0] * (1000 / dt)

dist = torch.distributions.Poisson(rate=rate, validate_args=False)
intervals = dist.sample(sample_shape=torch.Size([time + 1]))
intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float()

times = torch.cumsum(intervals, dim=0).long()
times[times >= time + 1] = 0

spikes = torch.zeros(time + 1, size, device=device).byte()
spikes[times, torch.arange(size)] = 1


spikes = spikes[1:]

spikes = spikes.view(time, *shape)

### Time-to-first-spike

In [None]:
datum = torch.rand(1, 28, 28)
sparsity = 0.5
time = int(time / dt)
shape = list(datum.shape)
datum = torch.tensor(datum)
quantile = torch.quantile(datum, 1 - sparsity)
s = torch.zeros([time, *shape], device=device)
s[0] = torch.where(datum > quantile, torch.ones(shape), torch.zeros(shape))
spikes = torch.Tensor(s).byte()

In [None]:
import sys, os
sys.path.append(os.path.abspath('..'))
from src.utils.dataloaders import load_image_folder_dataloader

root = "/home/nvidia/datasets/vindr-mammo-pngs"
image_size = 128
intensity = 128
time = 100
dt = 1

t, v = load_image_folder_dataloader(root, 128, 8, time, dt, intensity, True)


# Available labels in dataset : {'abnormal': 0, 'normal': 1}


: 