# Coding Schemes

#### Neural coding schemes are used to convert input pixels into spikes that are transmitted to the excitatory neurons. In this notebook we are going to describe and implement the coding scheme called Time-to-First-Spikes.

## TTFS

#### Time to first spike coding was discovered to encode information for fast responses within a few milliseconds, like tactile stimulus, by using the first spikes. In this notrbook we are going to implement a fast and energy-efficient TTFS coding scheme that uses an exponential-decaying dynamic threshold to convert input pixels to the first-spike patterns. The larger an input pixel is, the more information it carries, and the earlier it emits a spike.

#### An exponential function is used to compute the threshold $P_{th}$, described by:

$$
P_{th}(t) = θ_{0}exp(−t/τ_{th})
$$

#### where $θ_0$ is a threshold constant and set as 1, and $τ_{th}$ is the time constant. A spike is generated when the input pixel exceeds the threshold, and the input is inhibited from generating more spikes. In this scheme, the input pixels are translated into the exact timing of the first spikes.

## Implementation

### Import Neccesary Libraries

In [2]:
import torch
import torchvision
import torchvision.datasets as datasets
import plotly.express as px
import plotly.graph_objects as go

#### You can check the implementation in the cell below.

In [3]:
class TTFS:

    @staticmethod
    def normalize(image):
        return image / 255.
    
    @staticmethod
    def process_spike_times(spike_times):
        inf_mask = torch.isinf(spike_times)
        spike_times[inf_mask] = -1
        spike_times = torch.round(spike_times).long()
        return spike_times
    
    def __init__(self, tau, theta=1):
        self.theta = theta
        self.tau = tau

    def get_spike_times(self, image):
        return -1 * self.tau * torch.log(image)
    
    def get_spike_train(self, time_window, spike_times):
        dim_0, dim_1, dim_2 = spike_times.shape[0], spike_times.shape[1], time_window
        spike_train = torch.zeros(size=(dim_0, dim_1, dim_2))
        for iteration in range(time_window):
            current_spikes = spike_times == iteration
            spike_train[:, :, iteration] = current_spikes
        return spike_train
    

    def encode(self, image, time_window, normalize=False):
        if normalize:
            image = self.normalize(image)
        spike_times = self.get_spike_times(image)
        spike_times = self.process_spike_times(spike_times)
        return self.get_spike_train(time_window, spike_times)
    

#### The `encode` method normalizes the input image when user wants, computes spike times as described, and converts them into a spike train based on the given time window.
#### Initially, spike times are continuous quantities which need to be discretized. In order to do that we have implemented the `process_spike_times` method. It uses `torch.round` function to convert spike times to integers. It also takes care of very large numbers using `torch.isinf`. After that the method `get_spike_train` is called and it returns a binary tensor with shape of $(H_{input}, W_{input}, T)$ in which $T$ is the time window.  

## Example

In [4]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
img = mnist_trainset[0][0].reshape((28, 28))

In [5]:
px.imshow(img, binary_string=True)

In [6]:
ttfs = TTFS(tau=10.)
spike_train = ttfs.encode(img, 50)

## 3D Scatter Plot

In [7]:
x, y, z = torch.where(spike_train == 1)
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z,
                                   mode='markers')])
fig.update_layout(
    width=1000,  # Set the width to 800 pixels
    height=800  # Set the height to 600 pixels
)
fig.show()

## Raster Plot

In [9]:
spike_train_2d = torch.reshape(spike_train, (-1, 50))
x, y = torch.where(spike_train_2d == 1)
fig = px.scatter(x, y, title="Raster of plot TTFS output")
fig.update_layout(xaxis_title="iteration")

## Refrence

### Guo, W., Fouda, M. E., Eltawil, A. M., & Salama, K. N. (2021). Neural Coding in Spiking Neural Networks: A Comparative Study for Robust Neuromorphic Systems. Frontiers in Neuroscience, 15. https://doi.org/10.3389/fnins.2021.638474