# Imports

In [3]:
import jax 
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
px = 1/plt.rcParams["figure.dpi"]

import os

# set how much memory in your local CPU/GPU will be pre-allocated for JAX
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".75"

import optax

# Data Preprocessing

In [23]:
# MNIST dataset, already flat and normalized
X_train_1 = jnp.load('../mnist_np/X_train_1.npy') 
X_train_2 = jnp.load('../mnist_np/X_train_2.npy') 
X_train_3 = jnp.load('../mnist_np/X_train_3.npy') 
X_train_4 = jnp.load('../mnist_np/X_train_4.npy') 
# create X_train out of 4 X_trains
X_train = jnp.concatenate([X_train_1, X_train_2, X_train_3, X_train_4], axis=0)
y_train = jnp.load('../mnist_np/y_train.npy')
X_test = jnp.load('../mnist_np/X_test.npy')
y_test = jnp.load('../mnist_np/y_test.npy')
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

(60000, 784) (60000,) (10000, 784) (10000,)


In [24]:
def rate_encoding(key, X, sim_len=100):
    
    def bernoulli_encoding(key, spike_trains, sim_len):
        key, subkey = jax.random.split(key)
        return key, jax.random.bernoulli(key, spike_trains, (sim_len, spike_trains.shape[0], spike_trains.shape[1]))
    
    print('Encoding the data in batches of 2000 (going above take more time)')
    X_encoded = []
    batch_size = 2000
    for i in range(X.shape[0]//batch_size):
        key, X_encoded_ = bernoulli_encoding(key, X[i*batch_size:(i+1)*batch_size], sim_len=100)
        print(X_encoded_.shape)
        X_encoded.append(X_encoded_)

    return key, jnp.concatenate(X_encoded, axis=1)

# do rate encoding on X_test
key = jax.random.PRNGKey(9)
key, X_test_encoded = rate_encoding(key, X_test, sim_len=100)

Encoding the data in batches of 2000 (going above take more time)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)


In [25]:
# do rate encoding on Xtrain
key, X_train_encoded = rate_encoding(key, X_train, sim_len=100)

Encoding the data in batches of 2000 (going above take more time)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)
(100, 2000, 784)


# Function Definition

In [7]:
@jax.custom_jvp
def gr_than(x, thr):
    return (x > thr).astype(jnp.float32)

@gr_than.defjvp
def gr_jvp(primals, tangents):
    x, thr = primals
    x_dot, thr_dot = tangents
    primal_out = gr_than(x, thr)
    tangent_out = x_dot * 1 / (jnp.absolute(x-thr)+1)**2
    return primal_out, tangent_out

# TODO 1: Implement the neuron model
# NOTE : The neuron model is a LIF neuron model, its input are a state object and input spikes
#        The state object contains the following:
#        - state[0]: w, out_spikes, I_in, V_mem
#        - state[1]: tau_mem, V_th, timestep
#        The neuron model should return the updated state object and the output values (I_in, V_mem, out_spikes)
# NOTE: V_mem should be constrained to be positive
def lif_forward(state, input_spikes):
    '''
    TODO 1.1: complete description
    :param state: 

    :return: 	

    '''
    w, out_spikes, I_in, V_mem = state[0]
    tau_mem, Vth, timestep = state[1]
    # TODO 1.2: Implement the LIF neuron model using jnp functions where possible

    # TODO 1.3: return the updated state and the output values
    return None


In [None]:
def randomWeightInit(parent_key, scale, in_width, out_width):
    # TODO 2: implement the random weight initialization function using rax.random.split() and keys
    return None

In [None]:
# TODO 3: implement loss function for a single image
# cross entropy loss of a batch of input, returns loss and accuracy, also calculates prediction
def loss_fn_per_image(W, static_params, img, lbl):
    # TODO 3.1: define V_mem, I_in, out_spikes, state
    num_classes = W.shape[0]
    V_mem = None
    I_in = None
    out_spikes = None
    state = None
    
    # TODO 3.2: define V_mem_data using jnp.zeros()
    V_mem_data = None
    # TODO 3.3: iterate through the timesteps and call lif_forward() for each timestep, save V_mem_data
    for i in range(img.shape[0]):
        state, plot_values = lif_forward(state, img[i])
        # TODO 3.4: use JAX's .at[].set() to update V_mem_data

    # TODO LATER: replace for-loop with jax.lax.scan() for speed-up

    # TODO 3.5: calculate prediction using V_mem_data
    # we define prediction for MNIST to be the highest occurring membrane voltage across all timesteps
    max_per_class = None
    prediction = None

    # TODO 3.6: define logits and loss, use softmax() function for loss
    logits = None
    loss = None

    # TODO 3.7: define accuracy, and use jnp.where() for it
    acc = None

    return loss, acc

In [None]:
# TODO 4: implement the training function
def train_one_epoch(state, input):
    # TODO 4.1: unpack input to img_batch, lbl_batch, (input comes in as a tuple of (img_batch, lbl_batch))
    img_batch, lbl_batch = None, None
    
    # TODO 4.2: iterate through all batches in the dataset, call loss_fn_per_image() for each image in the batch, calculate loss and accuracy
    batch_loss = []
    batch_acc = []

    for i in range(img_batch.shape[0]):
        # TODO 4.2.1: call loss_fn_per_image() for each image in the batch
        loss, acc = None, None
        # TODO 4.2.2: append loss, acc to batched_loss, batched_acc
    
    # # TODO LATER: replace for-loop with jax.vmap for speed-up
    # batch_loss, batch_acc = jax.vmap(loss_fn_per_image, in_axes=(None, None, 0, 0))(W, static_params, img_batch, lbl_batch)
    

    # TODO 4.3: stack the batched loss and accuracy
    batch_loss = None
    batch_acc = None

    # TODO 4.4: calculate the mean loss and accuracy for the batch
    epoch_loss = None
    epoch_acc = None
    
    return epoch_acc, epoch_loss

## Dataset class

In [15]:
class MNISTDataset():
    def __init__(self, X_images,Y_labels):
        self.imgs = X_images
        self.lbls = Y_labels
    def __len__(self): # return length of dataset, i.e. nb of MNIST pictures
        return self.imgs.shape[1]

    def __getitem__(self, idx):
        return self.imgs[:,idx,:], self.lbls[idx]

In [17]:
# test functionality of MNISTDataset() class 
train_dataset = MNISTDataset(X_train_encoded, y_train)
test_dataset = MNISTDataset(X_test_encoded, y_test)

## Dataloader

In [19]:
from torch.utils.data import DataLoader

batch_size=128

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch)) # *() unpacks data !

    labels= np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return imgs, labels

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

# Training SNN

## Initialization of loop

In [20]:
# TODO 5: Implement the forward loop: 
# 5.1 Initialize all necessary state variables and parameters
# 5.2 Iterate over the training data for one epoch (for now)


num_classes = 10

# TODO 5.1: initialize state variables of LIF neuron for every time step

# initialize dynamic params
seed = 9
parent_key = jax.random.PRNGKey(seed)
# TODO 5.1.1: initialize weights, think about dimensions, scale with 0.03
W = randomWeightInit() 
tau_mem = 10e-3
V_th = 1.0
timestep = 1e-3

# TODO 5.1.2: define static_params
static_params = None

Weight: 0.02396264672279358


## Training Loop 

In [None]:
num_epochs = 1

for epoch in range(num_epochs):
    # TODO 5.2: use train_loader to iterate over the dataset
    for i, data in enumerate(train_loader):
        # TODO 5.2.1: get the data from the DataLoader, save in img_batch, lbl_batch
        img_batch, lbl_batch = None, None
        # TODO 5.2.2: call train_one_epoch() for each batch
        acc, loss = None, None
        if i % 10 == 0: 
            print(f'Epoch: {epoch}, Batch: {i}, Loss: {loss}, Accuracy: {acc}')
        # stop after 50 batches (non-JAX version will take a while!)
        if i == 50:
            break

print('DONE')