# Imports

In [2]:
import jax 
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
px = 1/plt.rcParams["figure.dpi"]

import os

# set how much memory in your local CPU/GPU will be pre-allocated for JAX
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".75"

import optax

# Data Preprocessing

In [3]:
# MNIST dataset, already flat and normalized
X_train_1 = jnp.load('../mnist_np/X_train_1.npy') 
X_train_2 = jnp.load('../mnist_np/X_train_2.npy') 
X_train_3 = jnp.load('../mnist_np/X_train_3.npy') 
X_train_4 = jnp.load('../mnist_np/X_train_4.npy') 
# create X_train out of 4 X_trains
X_train = jnp.concatenate([X_train_1, X_train_2, X_train_3, X_train_4], axis=0)
y_train = jnp.load('../mnist_np/y_train.npy')
X_test = jnp.load('../mnist_np/X_test.npy')
y_test = jnp.load('../mnist_np/y_test.npy')
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

(60000, 784) (60000,) (10000, 784) (10000,)


In [4]:
def rate_encoding(key, X, sim_len=100):
    
    def bernoulli_encoding(key, spike_trains, sim_len):
        key, subkey = jax.random.split(key)
        return key, jax.random.bernoulli(key, spike_trains, (sim_len, spike_trains.shape[0], spike_trains.shape[1]))
    
    print('Encoding the data in batches of 2000 (going above take more time)')
    X_encoded = []
    batch_size = 2000
    for i in range(X.shape[0]//batch_size):
        key, X_encoded_ = bernoulli_encoding(key, X[i*batch_size:(i+1)*batch_size], sim_len=100)
        print(X_encoded_.shape)
        X_encoded.append(X_encoded_)

    return key, jnp.concatenate(X_encoded, axis=1)

# do rate encoding on X_test
key = jax.random.PRNGKey(9)
key, X_test_encoded = rate_encoding(key, X_test, sim_len=100)

Encoding the data in batches of 2000 (going above take more time)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)


In [5]:
# do rate encoding on Xtrain
key, X_train_encoded = rate_encoding(key, X_train, sim_len=100)

Encoding the data in batches of 2000 (going above take more time)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)


# Function Definition

In [6]:
@jax.custom_jvp
def gr_than(x, thr):
    return (x > thr).astype(jnp.float32)

@gr_than.defjvp
def gr_jvp(primals, tangents):
    x, thr = primals
    x_dot, thr_dot = tangents
    primal_out = gr_than(x, thr)
    tangent_out = x_dot * 1 / (jnp.absolute(x-thr)+1)**2
    return primal_out, tangent_out

def lif_forward(state, input_spikes):
    w, out_spikes, I_in, V_mem = state[0]
    tau_mem, Vth, timestep = state[1]
    I_in = jnp.dot(w, input_spikes)
    V_mem = (1 - timestep/tau_mem) * V_mem + I_in - out_spikes * Vth
    # constraining V_mem to be non-negative
    V_mem = jnp.maximum(0, V_mem)
    out_spikes = gr_than(V_mem, Vth)

    # return state
    return ((w, out_spikes, I_in, V_mem), state[1]), (I_in, V_mem, out_spikes)


In [7]:
def randomWeightInit(parent_key, scale, in_width, out_width):
    in_width = in_width
    out_width = out_width
    weight_key, bias_key = jax.random.split(parent_key)
    W = scale*jax.random.normal(weight_key, shape=(out_width, in_width)) # random init of [weights, biases] tuple for each layer
    return W

In [8]:
def mini_loss(W, static_params, img, lbl):
    num_classes = W.shape[0]
    V_mem = jnp.zeros((num_classes,), dtype='float32')
    I_in = jnp.zeros((num_classes,), dtype='float32')
    out_spikes = jnp.zeros((num_classes,), dtype='float32')
    state = ((W, out_spikes, I_in, V_mem), static_params)
    state, plot_values = jax.lax.scan(lif_forward,state,img)
    V_mem_data = plot_values[1]
    max_per_class = V_mem_data.max(axis=0)
    prediction = max_per_class.argmax()
    logits = jax.nn.softmax(max_per_class)
    loss = -jnp.mean(jnp.log(logits[lbl.astype(jnp.uint8)]))
    # accuracy
    acc = jnp.where(prediction == lbl, 1.0, 0.0)

    return loss, acc

In [16]:
# TODO 1: implement loss_fn_vmap using jax.vmap to call mini_loss in parallel
# NOTE: in_axes might be useful here
def loss_fn_vmap(W, static_params, img_b, lbl_b):
    # TODO 1.1: initialize variables
    batch_size = img_b.shape[0]
    local_loss = jnp.zeros(batch_size)
    local_acc = jnp.zeros(batch_size)
    # TODO 1.2: use vmap here, calling mini_loss
    local_loss, local_acc = jax.vmap(mini_loss, in_axes=(None, None, 0, 0))(W, static_params, img_b, lbl_b)
    # TODO 1.3: return average loss and accuracy
    return local_loss.mean(), local_acc.mean()

## Dataset class

In [10]:
class MNISTDataset():
    def __init__(self, X_images,Y_labels):
        self.imgs = X_images
        self.lbls = Y_labels
    def __len__(self): # return length of dataset, i.e. nb of MNIST pictures
        return self.imgs.shape[1]
    def __getitem__(self, idx):
        return self.imgs[:,idx,:], self.lbls[idx]

In [11]:
# define train and test datasets
train_dataset = MNISTDataset(X_train_encoded, y_train)
test_dataset = MNISTDataset(X_test_encoded, y_test)

## Dataloader

In [12]:
from torch.utils.data import DataLoader

batch_size=128

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch)) # *() unpacks data !

    labels= np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return imgs, labels
 
# data loader from pytorch
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

# Training SNN

## Initialization of loop

In [13]:
num_classes = 10

# initialize state variables of LIF neuron for every time step

# initialize dynamic params
seed = 9
parent_key = jax.random.PRNGKey(seed)
W = randomWeightInit(parent_key, 0.03, 784, 10)
W_init = W

# initalize static params
tau_mem = 10e-3
V_th = 1.0
timestep = 1e-3

static_params = (tau_mem, V_th, timestep)

In [14]:
# optimizer with scheduler

start_learning_rate = 1e-3
n_epochs = 1
hp = static_params
w = W

n_batches = len(train_loader)
n_updates = n_epochs * n_batches
n_updates_lr = 15
transition_steps = np.floor(n_updates / n_updates_lr)
print(f'n_updates: {n_updates}, n_updates_lr: {n_updates_lr}, transition_steps: {transition_steps}')

# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=start_learning_rate,
    transition_steps=transition_steps,
    decay_rate=0.99)

# Combining gradient transforms using `optax.chain`.
gradient_transform = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
)

opt_state = gradient_transform.init(W)

n_updates: 468, n_updates_lr: 15, transition_steps: 31.0


## Training Loop 

In [None]:
num_epochs = 2

for epoch in range(num_epochs):
    # TODO 2.1: initialize epoch_loss and epoch_acc using jnp.zeros
    epoch_loss = jnp.zeros(len(train_loader))
    epoch_acc = jnp.zeros(len(train_loader))
    # continue training for one whole epoch using mini loss
    for batch_cnt, (img_batch, lbl_batch) in enumerate(train_loader):
      # TODO 2.2: use jax.value_and_grad to get loss and gradient
      (batch_loss, batch_acc), weight_grad = jax.value_and_grad(loss_fn_vmap, has_aux=True)(W, static_params, img_batch, lbl_batch)
      
      # TODO 2.3: use gradient_transform to update weights
      updates, opt_state = gradient_transform.update(weight_grad, opt_state)
      W = optax.apply_updates(W, updates)

      # logging
      if batch_cnt % 25 == 0:
          print('   batch_cnt ', batch_cnt, ', b loss ', batch_loss, ', b accuracy ', batch_acc)
      
      # TODO 2.4: update epoch_loss and epoch_acc using at and set
      epoch_loss = epoch_loss.at[batch_cnt].set(batch_loss)
      epoch_acc = epoch_acc.at[batch_cnt].set(batch_acc)

    # TODO 2.5: average epoch_loss and epoch_acc
    epoch_loss = epoch_loss.mean()
    epoch_acc = epoch_acc.mean()

    print('')
    print('epoch ', epoch, ', e loss ', epoch_loss, ', e acc', epoch_acc)
    print('')

print('DONE')