# Imports

In [None]:
import jax 
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
import os
import optax

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".5"
px = 1/plt.rcParams["figure.dpi"]

# Data Preprocessing

In [None]:
# MNIST dataset, already flat and normalized
X_train_1 = jnp.load('../mnist_np/X_train_1.npy') 
X_train_2 = jnp.load('../mnist_np/X_train_2.npy') 
X_train_3 = jnp.load('../mnist_np/X_train_3.npy') 
X_train_4 = jnp.load('../mnist_np/X_train_4.npy') 
# create X_train out of 4 X_trains
X_train = jnp.concatenate([X_train_1, X_train_2, X_train_3, X_train_4], axis=0)
y_train = jnp.load('../mnist_np/y_train.npy')
X_test = jnp.load('../mnist_np/X_test.npy')
y_test = jnp.load('../mnist_np/y_test.npy')
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)


In [None]:
def rate_encoding(key, X, sim_len=100):
    
    def bernoulli_encoding(key, spike_trains, sim_len):
        key, subkey = jax.random.split(key)
        return key, jax.random.bernoulli(key, spike_trains, (sim_len, spike_trains.shape[0], spike_trains.shape[1]))
    
    print('Encoding the data in batches of 2000 (going above take more time)')
    X_encoded = []
    batch_size = 2000
    for i in range(X.shape[0]//batch_size):
        key, X_encoded_ = bernoulli_encoding(key, X[i*batch_size:(i+1)*batch_size], sim_len=100)
        print(X_encoded_.shape)
        X_encoded.append(X_encoded_)

    return key, jnp.concatenate(X_encoded, axis=1)

# do rate encoding on X_test
key = jax.random.PRNGKey(9)
key, X_test_encoded = rate_encoding(key, X_test, sim_len=100)

In [None]:
# do rate encoding on Xtrain
key, X_train_encoded = rate_encoding(key, X_train, sim_len=100)

# Function Definition

In [None]:
@jax.custom_jvp
def gr_than(x, thr):
    return (x > thr).astype(jnp.float32)

@gr_than.defjvp
def gr_jvp(primals, tangents):
    x, thr = primals
    x_dot, thr_dot = tangents
    primal_out = gr_than(x, thr)
    tangent_out = x_dot * 1 / (jnp.absolute(x-thr)+1)**2
    return primal_out, tangent_out

# TODO 1: extend the LIF model to have two layers
# NOTE: nb of inputs and outputs are the same, the hidden layer dimension has to be defined
# NOTE: a recurrent layer is defined at the hidden layer
# NOTE: we want to return besides the state also the input current, V_mem of the hidden layer, V_mem and the output spikes of the output layer
def lif_forward(state, input_spikes):
    
    return ((None, None, None, None, None, None), state[1]), (I_in, V_mem1, V_mem2, out_spikes2)


In [None]:
def randomWeightInit(parent_key, scale, in_width, out_width):
    in_width = in_width
    out_width = out_width
    weight_key, bias_key = jax.random.split(parent_key)
    W = scale*jax.random.normal(weight_key, shape=(out_width, in_width))
    return W

In [None]:
# TODO 3: implement a multi-layer loss function
def mini_loss(params, static_params, img, lbl):
    # TODO 3.1: try to infer these values from your parameters
    num_classes = None
    num_hidden = None

    # TODO 3.2: initialize the multi-layer state variable
    state = None

    state, plot_values = jax.lax.scan(lif_forward,state,img)

    # TODO 3.3: implement the loss calulcation accordingly
    V_mem_data = None
    max_per_class = None
    prediction = None
    logits = None
    loss = None
    acc = None

    return loss, acc

In [None]:
# loss function
def loss_fn_vmap(params, static_params, img_b, lbl_b):
    batch_size = img_b.shape[0]
    
    local_loss = jnp.zeros(batch_size)
    local_acc = jnp.zeros(batch_size)

    local_loss, local_acc = jax.vmap(mini_loss, in_axes=(None, None, 0, 0))(params, static_params, img_b, lbl_b)
    return local_loss.mean(), local_acc.mean()

## Dataset class

In [None]:
class MNISTDataset():
    def __init__(self, X_images,Y_labels):
        self.imgs = X_images
        self.lbls = Y_labels
    def __len__(self): 
        return self.imgs.shape[1]

    def __getitem__(self, idx):
        return self.imgs[:,idx,:], self.lbls[idx]

In [None]:
# test functionality of MNISTDataset() class 
train_dataset = MNISTDataset(X_train_encoded, y_train)
test_dataset = MNISTDataset(X_test_encoded, y_test)

## Dataloader

In [None]:
from torch.utils.data import DataLoader

batch_size=128

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))
    labels= np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])
    return imgs, labels

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

# Training SNN

## Initialization of loop

In [None]:
num_classes = 10
# TODO LATER: play around, try out different hidden layer dimensions and check out the resulting accuracy
# NOTE: this one works sufficiently well
num_hidden = 200

# initialize state variables of LIF neuron for every time step

# initialize dynamic params

# TODO 2: modify/add parameters necessary for a multi-layer SNN

seed = 9
parent_key = jax.random.PRNGKey(seed)

# TODO 2.1: initialize weights
# NOTE: weights can be added to a list/tuple of weights (called params for instance), and passed as an argument
# NOTE: weight scaling factor of 0.03 can be used
# NOTE: refer to the JAX tutorial on weight updates with JAX using Pytrees to see why we are using it this way
params = None
params_init = params

tau_mem = 10e-3
V_th = 1.0
timestep = 1e-3

static_params = (tau_mem, V_th, timestep)

In [None]:
# trying new optimizer with scheduler

start_learning_rate = 1e-3
n_epochs = 1
hp = static_params

n_batches = len(train_loader)
n_updates = n_epochs * n_batches
n_updates_lr = 15
transition_steps = np.floor(n_updates / n_updates_lr)
print(f'n_updates: {n_updates}, n_updates_lr: {n_updates_lr}, transition_steps: {transition_steps}')

# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=start_learning_rate,
    transition_steps=transition_steps,
    decay_rate=0.99)

# Combining gradient transforms using `optax.chain`.
gradient_transform = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
)

opt_state = gradient_transform.init(params)

## Training Loop 

In [None]:
num_epochs = 2
best_acc = 0.5

for epoch in range(num_epochs):
    epoch_loss = jnp.zeros(len(train_loader))
    epoch_acc = jnp.zeros(len(train_loader))
    # continue training for one whole epoch using mini loss
    for batch_cnt, (img_batch, lbl_batch) in enumerate(train_loader):
      (batch_loss, batch_acc), weight_grad = jax.value_and_grad(loss_fn_vmap, has_aux=True)(params, static_params, img_batch, lbl_batch)
      updates, opt_state = gradient_transform.update(weight_grad, opt_state)
      params = optax.apply_updates(params, updates)

      # logging
      if batch_cnt % 25 == 0:
          print('   batch_cnt ', batch_cnt, ', b loss ', batch_loss, ', b accuracy ', batch_acc)

      epoch_loss = epoch_loss.at[batch_cnt].set(batch_loss)
      epoch_acc = epoch_acc.at[batch_cnt].set(batch_acc)

    epoch_loss = epoch_loss.mean()
    epoch_acc = epoch_acc.mean()

    print('')
    print('epoch ', epoch, ', e loss ', epoch_loss, ', e acc', epoch_acc)
    print('')

    # save best performing weight (per epoch)
    if epoch_acc > best_acc:
      # TODO 2.3: save the best performing weights into params_final
      params_final = None
      best_acc = epoch_acc
      print('params saved')
      print('')


print('DONE')

# Testing SNN on test data

In [None]:
num_epochs = 1

batch_img, batch_lbl = next(iter(test_loader))

for epoch in range(num_epochs):
    epoch_loss = jnp.zeros(len(test_loader))
    epoch_acc = jnp.zeros(len(test_loader))
    # continue training for one whole epoch using mini loss
    for batch_cnt, (img_batch, lbl_batch) in enumerate(test_loader):

      # simple inference using W_final
      batch_loss, batch_acc = loss_fn_vmap(params_final, static_params, img_batch, lbl_batch)

      # logging
      if batch_cnt % 25 == 0:
          print('   batch_cnt ', batch_cnt, ', b loss ', batch_loss, ', b accuracy ', batch_acc)

      epoch_loss = epoch_loss.at[batch_cnt].set(batch_loss)
      epoch_acc = epoch_acc.at[batch_cnt].set(batch_acc)

    epoch_loss = epoch_loss.mean()
    epoch_acc = epoch_acc.mean()

    print('')
    print('epoch ', epoch, ', e loss ', epoch_loss, ', e acc', epoch_acc)
    print('')


print('DONE')