## Problème 4: Fonction de ReLu $\alpha = 0.01$

#### Packages et imports

In [1]:
# pip install torch quantities sparse==0.11.0 > /dev/null

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets, model_selection, utils
import torch
import quantities as units
from sparse import COO

#### Configuration

In [2]:
# Reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Use the GPU unless there is none available.
# If you don't have a CUDA enabled GPU, I recommned using Google Colab,
# available at https://colab.research.google.com. Create a new notebook
# and then go to Runtime -> Change runtime type -> Hardware accelerator -> GPU
# Colab gives you access to up to 12 free continuous hours of a fairly recent GPU.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


#### Préparation des données

In [3]:
# Let's download the MNIST dataset, available at https://www.openml.org/d/554
# You can edit the argument data_home to the directory of your choice.
# The dataset will be downloaded there; the default directory is ~/scikit_learn_data/
X, y = datasets.fetch_openml('mnist_784', version=1, return_X_y=True, data_home=None, as_frame=False)
nb_of_samples, nb_of_features = X.shape
# X = 70k samples, 28*28 features; y = 70k samples, 1 label (string)

# Shuffle the dataset
X, y = utils.shuffle(X, y)

# Convert the labels (string) to integers for convenience
y = np.array(y, dtype=int)
nb_of_ouputs = np.max(y) + 1

# We'll normalize our input data in the range [0, 1[.
X = X / pow(2, 8)

  warn(


#### Conversion en décharges

In [4]:
# And convert the data to a spike train using TTFS encoding
dt = 1*units.ms
duration_per_image = 100*units.ms
absolute_duration = int(duration_per_image / dt)

time_of_spike = (1 - X) * absolute_duration  # The brighter the pixel, the earlier the spike
time_of_spike[X < .25] = 0  # "Remove" the spikes associated with darker pixels, which presumably carry less information

sample_id, neuron_idx = np.nonzero(time_of_spike)

# We use a sparse COO array to store the spikes for memory requirements
# You can use the spike_train variable as if it were a tensor of shape (nb_of_samples, nb_of_features, absolute_duration)
spike_train = COO((sample_id, neuron_idx, time_of_spike[sample_id, neuron_idx]),
                  np.ones_like(sample_id), shape=(nb_of_samples, nb_of_features, absolute_duration))


#### Split entrainement/test

In [5]:
# Split in train/test
nb_of_train_samples = int(nb_of_samples * 0.85)  # Keep 15% of the dataset for testing
train_indices = np.arange(nb_of_train_samples)
test_indices = np.arange(nb_of_train_samples, nb_of_samples)

#### Création du réseau

In [6]:
# We create a 3 layer network (2 hidden, 1 output)
nb_hidden = 128  # Number of hidden neurons

# Hidden layer 1
w1 = torch.empty((nb_of_features, nb_hidden), device=device, dtype=torch.float, requires_grad=True)
torch.nn.init.normal_(w1, mean=0., std=.1)

# Hidden layer 2
w2 = torch.empty((nb_hidden, nb_hidden), device=device, dtype=torch.float, requires_grad=True)
torch.nn.init.normal_(w2, mean=0., std=.1)

# Output layer
w3 = torch.empty((nb_hidden, nb_of_ouputs), device=device, dtype=torch.float, requires_grad=True)
torch.nn.init.normal_(w3, mean=0., std=.1)


tensor([[-0.1144, -0.0802, -0.0531,  ...,  0.0470,  0.0470,  0.1064],
        [-0.1416,  0.1032, -0.0467,  ...,  0.0532,  0.1733, -0.2062],
        [ 0.0021, -0.0387, -0.0462,  ...,  0.0792,  0.0299, -0.0544],
        ...,
        [-0.1386,  0.0744, -0.0074,  ..., -0.0649, -0.1381,  0.1604],
        [ 0.0838,  0.0204, -0.0326,  ...,  0.1160, -0.1548,  0.0551],
        [-0.0816,  0.0538,  0.0429,  ..., -0.0057, -0.1306,  0.0350]],
       requires_grad=True)

In [7]:
"""
Cette class permet de calculer la sortie d'une fonction lors de la propagation avant et de personaliser la derivée lors de la retropropagation de l'erreur.
Voir cet exemple pour plus de détails : https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
"""
class SpikeFunction(torch.autograd.Function):
    """
    Dans la passe avant, nous recevons un tenseur contenant l'entrée (potential-threshold).
    Nous appliquons la fonction Heaviside et renvoyons un tenseur contenant la sortie.
    """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0 # On génère une décharge quand (potential-threshold) > 0
        return out

    """
    Dans la passe arrière, nous recevons un tenseur contenant le gradient de l'erreur par rapport à la sortie.
    Nous calculons le gradient de l'erreur par rapport à l'entrée en utilisant la dérivée de la fonction ReLu.
    """
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_relu = torch.ones_like(input) # The derivativen ReLU function
        grad_relu[input < 0] = 0.01        # Apply alpha = 0.01 to negative input
        return grad_output.clone()*grad_relu

In [8]:
def run_spiking_layer(input_spike_train, layer_weights, tau_v=20*units.ms, tau_i=5*units.ms, v_threshold=1.0):
    """Here we implement a current-LIF dynamic in PyTorch"""

    # First, we multiply the input spike train by the weights of the current layer to get the current that will be added
    # We can calculate this beforehand because the weights are constant in the forward pass (no plasticity)
    input_current = torch.einsum("abc,bd->adc", (input_spike_train, layer_weights))  # Equivalent to a matrix multiplication for tensors of dim > 2 using Einstein's Notation

    recorded_spikes = []  # Array of the output spikes at each time t
    membrane_potential_at_t = torch.zeros((input_spike_train.shape[0], layer_weights.shape[-1]), device=device, dtype=torch.float)
    membrane_current_at_t = torch.zeros((input_spike_train.shape[0], layer_weights.shape[-1]), device=device, dtype=torch.float)

    const_a = 1*units.ms / tau_i
    alpha = np.exp(-(const_a.item()))

    const_b = 1*units.ms / tau_v
    beta = np.exp(-(const_b.item()))

    for t in range(absolute_duration):  # For every timestep
        # Apply the leak
        membrane_potential_at_t = torch.mul(membrane_potential_at_t, beta) # Using tau_v with euler or exact method
        membrane_current_at_t = torch.mul(membrane_current_at_t, alpha) # Using tau_i with euler or exact method

        # Select the input current at time t
        input_at_t = input_current[:, :, t]

        # Integrate the input current
        membrane_current_at_t += input_at_t

        # Integrate the input to the membrane potential
        membrane_potential_at_t += membrane_current_at_t

        # Apply the non-differentiable function
        recorded_spikes_at_t = SpikeFunction.apply(membrane_potential_at_t - v_threshold)
        recorded_spikes.append(recorded_spikes_at_t)

        # Reset the spiked neurons
        membrane_potential_at_t[membrane_potential_at_t > v_threshold] = 0

    recorded_spikes = torch.stack(recorded_spikes, dim=2) # Stack over time axis (Array -> Tensor)
    return recorded_spikes


#### Entrainement

In [9]:
# Set-up training
nb_of_epochs = 4
batch_size = 256  # The backpropagation is done after every batch, but a batch here is also used for memory requirements
number_of_batches = len(train_indices) // batch_size

params = [w1, w2, w3]  # Trainable parameters
optimizer = torch.optim.Adam(params, lr=0.01, amsgrad=True)
loss_fn = torch.nn.MSELoss(reduction='mean')

for e in range(nb_of_epochs):
    epoch_loss = 0
    i = 0
    for batch in np.array_split(train_indices, number_of_batches):
        i += 1
        # Select batch and convert to tensors
        batch_spike_train = torch.FloatTensor(spike_train[batch].todense()).to(device)
        batch_labels = torch.LongTensor(y[batch, np.newaxis]).to(device)

        # Here we create a target spike count (10 spikes for wrong label, 100 spikes for true label) in a one-hot fashion
        # This approach is seen in Shrestha & Orchard (2018) https://arxiv.org/pdf/1810.08646.pdf
        # Code available at https://github.com/bamsumit/slayerPytorch
        min_spike_count = 10 * torch.ones((batch.shape[0], 10), device=device, dtype=torch.float)
        target_output = min_spike_count.scatter_(1, batch_labels, 100.0)

        # Forward propagation
        layer_1_spikes = run_spiking_layer(batch_spike_train, w1)
        layer_2_spikes = run_spiking_layer(layer_1_spikes, w2)
        layer_3_spikes = run_spiking_layer(layer_2_spikes, w3)
        network_output = torch.sum(layer_3_spikes, 2)  # Count the spikes over time axis
        loss = loss_fn(network_output, target_output)

        # Backward propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        print("Batch %i out of %i in Epoch %i -- loss : %.4f" %(i, number_of_batches, e+1, loss.item()))
    
    print("Epoch %i -- loss : %.4f" %(e+1, epoch_loss / number_of_batches))

Batch 1 out of 232 in Epoch 1 -- loss : 1223.7490
Batch 2 out of 232 in Epoch 1 -- loss : 1056.6108
Batch 3 out of 232 in Epoch 1 -- loss : 1196.1494
Batch 4 out of 232 in Epoch 1 -- loss : 1027.8273
Batch 5 out of 232 in Epoch 1 -- loss : 937.3517
Batch 6 out of 232 in Epoch 1 -- loss : 855.2891
Batch 7 out of 232 in Epoch 1 -- loss : 800.9844
Batch 8 out of 232 in Epoch 1 -- loss : 783.6202
Batch 9 out of 232 in Epoch 1 -- loss : 777.7300
Batch 10 out of 232 in Epoch 1 -- loss : 752.6265
Batch 11 out of 232 in Epoch 1 -- loss : 739.7805
Batch 12 out of 232 in Epoch 1 -- loss : 712.1934
Batch 13 out of 232 in Epoch 1 -- loss : 736.2023
Batch 14 out of 232 in Epoch 1 -- loss : 654.7023
Batch 15 out of 232 in Epoch 1 -- loss : 675.5693
Batch 16 out of 232 in Epoch 1 -- loss : 660.1615
Batch 17 out of 232 in Epoch 1 -- loss : 579.2623
Batch 18 out of 232 in Epoch 1 -- loss : 626.0895
Batch 19 out of 232 in Epoch 1 -- loss : 581.3276
Batch 20 out of 232 in Epoch 1 -- loss : 564.6755
Batch

# Test

In [10]:
# Test the accuracy of the model
correct_label_count = 0

# We only need to batchify the test set for memory requirements
for batch in np.array_split(test_indices,  len(test_indices) // batch_size):
    test_spike_train = torch.FloatTensor(spike_train[batch].todense()).to(device)

    # Same forward propagation as before
    layer_1_spikes = run_spiking_layer(test_spike_train, w1)
    layer_2_spikes = run_spiking_layer(layer_1_spikes, w2)
    layer_3_spikes = run_spiking_layer(layer_2_spikes, w3)
    network_output = torch.sum(layer_3_spikes, 2)  # Count the spikes over time axis
    
    # Do the prediction by selecting the output neuron with the most number of spikes
    _, am = torch.max(network_output, 1)
    correct_label_count += np.sum(am.detach().cpu().numpy() == y[batch])

print("Model accuracy on test set: %.3f" % (correct_label_count / len(test_indices)))

Model accuracy on test set: 0.908
