# Imports

In [64]:
import jax 
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
px = 1/plt.rcParams["figure.dpi"]

import os

# set how much memory in your local CPU/GPU will be pre-allocated for JAX
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".75"

import optax

# Data Preprocessing

In [65]:
# MNIST dataset, already flat and normalized
X_train_1 = jnp.load('../mnist_np/X_train_1.npy') 
X_train_2 = jnp.load('../mnist_np/X_train_2.npy') 
X_train_3 = jnp.load('../mnist_np/X_train_3.npy') 
X_train_4 = jnp.load('../mnist_np/X_train_4.npy') 
# create X_train out of 4 X_trains
X_train = jnp.concatenate([X_train_1, X_train_2, X_train_3, X_train_4], axis=0)
y_train = jnp.load('../mnist_np/y_train.npy')
X_test = jnp.load('../mnist_np/X_test.npy')
y_test = jnp.load('../mnist_np/y_test.npy')
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

(60000, 784) (60000,) (10000, 784) (10000,)


In [66]:
def rate_encoding(key, X, sim_len=100):
    
    def bernoulli_encoding(key, spike_trains, sim_len):
        key, subkey = jax.random.split(key)
        return key, jax.random.bernoulli(key, spike_trains, (sim_len, spike_trains.shape[0], spike_trains.shape[1]))
    
    print('Encoding the data in batches of 2000 (going above take more time)')
    X_encoded = []
    batch_size = 2000
    for i in range(X.shape[0]//batch_size):
        key, X_encoded_ = bernoulli_encoding(key, X[i*batch_size:(i+1)*batch_size], sim_len=100)
        print(X_encoded_.shape)
        X_encoded.append(X_encoded_)

    return key, jnp.concatenate(X_encoded, axis=1)

# do rate encoding on X_test
key = jax.random.PRNGKey(9)
key, X_test_encoded = rate_encoding(key, X_test, sim_len=100)

Encoding the data in batches of 2000 (going above take more time)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)


In [67]:
# do rate encoding on Xtrain
key, X_train_encoded = rate_encoding(key, X_train, sim_len=100)

Encoding the data in batches of 2000 (going above take more time)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)


# Function Definition

In [69]:
@jax.custom_jvp
def gr_than(x, thr):
    return (x > thr).astype(jnp.float32)

@gr_than.defjvp
def gr_jvp(primals, tangents):
    x, thr = primals
    x_dot, thr_dot = tangents
    primal_out = gr_than(x, thr)
    tangent_out = x_dot * 1 / (jnp.absolute(x-thr)+1)**2
    return primal_out, tangent_out

# TODO 1: Implement the neuron model
# NOTE : The neuron model is a LIF neuron model, its input are a state object and input spikes
#        The state object contains the following:
#        - state[0]: w, out_spikes, I_in, V_mem
#        - state[1]: tau_mem, V_th, timestep
#        The neuron model should return the updated state object and the output values (I_in, V_mem, out_spikes)
# NOTE: V_mem should be constrained to be positive
def lif_forward(state, input_spikes):
    '''
    TODO 1.1: complete description
    :param state: 
        - state[0]: w, out_spikes, I_in, V_mem
            - w: learnable weights of the layer, shape = (784, layer_dim), dtype = float32
						- out_spike: output of the LIF neuron, shape = (784,), dtype = int
						- I_in: input current into the LIF neuron, shape = (784,), dtype = float32
						- V_mem: (non-neg) membrane potential of the LIF neuron, shape = (784,), dtype = float32
				- state[1]: tau_mem, V_th, timestep
						- tau_mem: time constant of the neuron model, shape = (1,1), dtype = (float32)
						- V_th: threshold membrane potential, shape = (1,1), dtype = float32
						- timestep: time delta, shape = (1,1), dtype = float32
    :return: 
		- state:
			- state[0]: w. out_spikes, I_in, V_mem
					updated values for w, out_spikes depending on I_in, V_mem and input_spikes
			- state[1]: tau_mem, V_th, timestep
					auxiliary, fixed variables of the neuron model
		- output: printable output values
			- I_in
			- V_mem
			- out_spikes
					
    '''
    w, out_spikes, I_in, V_mem = state[0]
    tau_mem, Vth, timestep = state[1]
    # TODO 1.2: Implement the LIF neuron model using jnp functions where possible
    I_in = jnp.dot(w, input_spikes)
    V_mem = (1 - timestep/tau_mem) * V_mem + I_in - out_spikes * Vth
    # constraining V_mem to be non-negative
    V_mem = jnp.maximum(0, V_mem)
    out_spikes = gr_than(V_mem, Vth)

    # TODO 1.3: return the updated state and the output values
    return ((w, out_spikes, I_in, V_mem), state[1]), (I_in, V_mem, out_spikes)


In [70]:
def randomWeightInit(parent_key, scale, in_width, out_width):
    # TODO 2: implement the random weight initialization function using rax.random.split() and keys
    in_width = in_width
    out_width = out_width
    weight_key, bias_key = jax.random.split(parent_key)
    W = scale*jax.random.normal(weight_key, shape=(out_width, in_width)) # random init of [weights, biases] tuple for each layer
    return W

In [71]:
# TODO 3: implement loss function for a single image
# cross entropy loss of a batch of input, returns loss and accuracy, also calculates prediction
def loss_fn_per_image(W, static_params, img, lbl):
    # TODO 3.1: define V_mem, I_in, out_spikes, state
    num_classes = W.shape[0]
    V_mem = jnp.zeros((num_classes,), dtype='float32')
    I_in = jnp.zeros((num_classes,), dtype='float32')
    out_spikes = jnp.zeros((num_classes,), dtype='float32')

    state = ((W, out_spikes, I_in, V_mem), static_params)
    
    # # TODO 3.2: define V_mem_data using jnp.zeros()
    # V_mem_data = jnp.zeros((img.shape[0], num_classes), dtype='float32')
    # # TODO 3.3: iterate through the timesteps and call lif_forward() for each timestep, save V_mem_data
    # for i in range(img.shape[0]):
    #     state, plot_values = lif_forward(state, img[i])
    #     # TODO 3.4: use JAX's .at[].set() to update V_mem_data
    #     V_mem_data.at[i, :].set(plot_values[1])

    # TODO LATER: replace for-loop with jax.lax.scan() to speed up
    state, plot_values = jax.lax.scan(lif_forward,state,img)   
    V_mem_data = plot_values[1]


    # TODO 3.5: calculate prediction using V_mem_data
    # we define prediction for MNIST to be the highest occurring membrane voltage across all timesteps
    max_per_class = V_mem_data.max(axis=0)
    prediction = max_per_class.argmax()

    # TODO 3.6: define logits and loss, use softmax() function for loss
    logits = jax.nn.softmax(max_per_class)
    loss = -jnp.mean(jnp.log(logits[prediction]))

    # TODO 3.7: define accuracy, and use jnp.where() for it
    acc = jnp.where(prediction == lbl, 1.0, 0.0)

    return loss, acc

In [72]:
# TODO 4: implement the training function
def train_one_epoch(W, static_params, input):
    # TODO 4.1: unpack input to img_batch, lbl_batch, (input comes in as a tuple of (img_batch, lbl_batch))
    img_batch, lbl_batch = input
    
    # TODO 4.2: iterate through all batches in the dataset, call loss_fn_per_image() for each image in the batch, calculate loss and accuracy
    # batch_loss = []
    # batch_acc = []

    # for i in range(img_batch.shape[0]):
    #     # TODO 4.2.1: call loss_fn_per_image() for each image in the batch
    #     loss, acc = loss_fn_per_image(W, static_params, img_batch[i], lbl_batch[i])
    #     # TODO 4.2.2: append loss, acc to batched_loss, batched_acc
    #     batch_loss.append(loss)
    #     batch_acc.append(acc)
    
    # TODO LATER: replace for-loop with jax.vmap for speed-up
    batch_loss, batch_acc = jax.vmap(loss_fn_per_image, in_axes=(None, None, 0, 0))(W, static_params, img_batch, lbl_batch)
    
    
    # TODO 4.3: stack the batched loss and accuracy
    batch_loss = jnp.stack(batch_loss)
    batch_acc = jnp.stack(batch_acc)

    # TODO 4.4: calculate the mean loss and accuracy for the batch
    epoch_loss = jnp.mean(batch_loss)
    epoch_acc = jnp.mean(batch_acc)
    
    return epoch_loss, epoch_acc

## Dataset class

In [73]:
class MNISTDataset():
    def __init__(self, X_images,Y_labels):
        self.imgs = X_images
        self.lbls = Y_labels
    def __len__(self): # return length of dataset, i.e. nb of MNIST pictures
        return self.imgs.shape[1]

    def __getitem__(self, idx):
        return self.imgs[:,idx,:], self.lbls[idx]

In [74]:
# test functionality of MNISTDataset() class 
train_dataset = MNISTDataset(X_train_encoded, y_train)
test_dataset = MNISTDataset(X_test_encoded, y_test)

## Dataloader

In [75]:
from torch.utils.data import DataLoader

batch_size=128

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch)) # *() unpacks data !

    labels= np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return imgs, labels

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

# Training SNN

## Initialization of loop

In [76]:
# TODO 5: Implement the forward loop: 
# 5.1 Initialize all necessary state variables and parameters
# 5.2 Iterate over the training data for one epoch (for now)


num_classes = 10

# TODO 5.1: initialize state variables of LIF neuron for every time step

# initialize dynamic params
seed = 9
parent_key = jax.random.PRNGKey(seed)
# TODO 5.1.1: initialize weights, think about dimensions, scale with 0.03
W = randomWeightInit(parent_key, 0.03, 784, 10) 
tau_mem = 10e-3
V_th = 1.0
timestep = 1e-3

# TODO 5.2: define static_params
static_params = (tau_mem, V_th, timestep)

## Training Loop 

In [77]:
num_epochs = 1

for epoch in range(num_epochs):
    # TODO 5.3: use train_loader to iterate over the dataset
    for i, data in enumerate(train_loader):
        # TODO 5.3.1: get the data from the DataLoader, save in img_batch, lbl_batch
        img_batch, lbl_batch = data
        loss, acc = train_one_epoch(W, static_params, (img_batch, lbl_batch))
        if i % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss}, Accuracy: {acc}')
        # stop after 50 batches (non-JAX version will take a while!)
        if i == 50:
            break

print('DONE')

Epoch: 0, Batch: 0, Loss: 1.615514874458313, Accuracy: 0.0546875
Epoch: 0, Batch: 10, Loss: 1.638960599899292, Accuracy: 0.09375
Epoch: 0, Batch: 20, Loss: 1.6376497745513916, Accuracy: 0.109375
Epoch: 0, Batch: 30, Loss: 1.62709641456604, Accuracy: 0.109375
Epoch: 0, Batch: 40, Loss: 1.645958662033081, Accuracy: 0.0625
Epoch: 0, Batch: 50, Loss: 1.6312494277954102, Accuracy: 0.09375
DONE
