### Training Recurrent Spiking Neural Networks with JAX on TPU or multi-GPU Environments


**Author**: [Yigit Demirag](https://github.com/YigitDemirag/spikingTPU), ETH Zurich and University of Zurich, Switzerland
 
---

In [1]:
#@title 1. Import libraries

import os
if 'COLAB_TPU_ADDR' in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()

!pip install einops flax --quiet
import urllib.request
import gzip, shutil
import hashlib
import h5py
from six.moves.urllib.error import HTTPError 
from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlretrieve
import time
from functools import partial
from einops import repeat, rearrange
import jax
import jax.numpy as jnp
import jax.random as random
from jax import vmap, pmap, jit, value_and_grad, local_device_count
from jax.lax import scan
from jax.nn import log_softmax
from jax.example_libraries import optimizers
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
import numpy as np 
import tensorflow as tf
import tensorflow_datasets as tfds
from flax.jax_utils import prefetch_to_device

In [2]:
#@title 2. TFDS Data Pipeline for Spiking Heidelberg Digits (SHD)
"""
Taken from 
    - https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/
    - https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
"""
def get_audio_dataset(cache_dir, cache_subdir):
    base_url = "https://zenkelab.org/datasets"
    response = urllib.request.urlopen("%s/md5sums.txt"%base_url)
    data = response.read() 
    lines = data.decode('utf-8').split("\n")
    file_hashes = { line.split()[1]:line.split()[0] \
                    for line in lines if len(line.split())==2 }
    files = [ "shd_train.h5.gz", "shd_test.h5.gz"]
        
    for fn in files:
        origin = "%s/%s"%(base_url,fn)
        hdf5_file_path = get_and_gunzip(origin, 
                                        fn, 
                                        md5hash=file_hashes[fn],
                                        cache_dir=cache_dir,
                                        cache_subdir=cache_subdir)
        print("Available at: %s"%hdf5_file_path)

def get_and_gunzip(origin, filename, md5hash=None, cache_dir=None, 
                   cache_subdir=None):
    gz_file_path = get_file(filename, origin, md5_hash=md5hash,
                            cache_dir=cache_dir, cache_subdir=cache_subdir)
    hdf5_file_path = gz_file_path[:-3]
    if not os.path.isfile(hdf5_file_path) or os.path.getctime(gz_file_path) > os.path.getctime(hdf5_file_path):
        print("Decompressing %s"%gz_file_path)
        with gzip.open(gz_file_path, 'r') as f_in, open(hdf5_file_path, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    return hdf5_file_path

def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
    if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64):
        hasher = 'sha256'
    else:
        hasher = 'md5'
    if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash):
        return True
    else:
        return False

def _hash_file(fpath, algorithm='sha256', chunk_size=65535):
    if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64):
        hasher = hashlib.sha256()
    else:
        hasher = hashlib.md5()

    with open(fpath, 'rb') as fpath_file:
        for chunk in iter(lambda: fpath_file.read(chunk_size), b''):
            hasher.update(chunk)

    return hasher.hexdigest()

def get_file(fname,
             origin,
             md5_hash=None,
             file_hash=None,
             cache_subdir='datasets',
             hash_algorithm='auto',
             extract=False,
             archive_format='auto',
             cache_dir=None):
    if cache_dir is None:
        cache_dir = os.path.join(os.path.expanduser('~'), '.data-cache')
    if md5_hash is not None and file_hash is None:
        file_hash = md5_hash
        hash_algorithm = 'md5'
    datadir_base = os.path.expanduser(cache_dir)
    if not os.access(datadir_base, os.W_OK):
        datadir_base = os.path.join('/tmp', '.data-cache')
    datadir = os.path.join(datadir_base, cache_subdir)

    os.makedirs(cache_dir, exist_ok=True)
    os.makedirs(datadir, exist_ok=True)

    fpath = os.path.join(datadir, fname)

    download = False
    if os.path.exists(fpath):
        if file_hash is not None:
            if not validate_file(fpath, file_hash, algorithm=hash_algorithm):
                print('A local file was found, but it seems to be '
                      'incomplete or outdated because the ' + hash_algorithm +
                      ' file hash does not match the original value of ' + file_hash +
                      ' so we will re-download the data.')
                download = True
    else:
        download = True

    if download:
        print('Downloading data from', origin)

        error_msg = 'URL fetch failure on {}: {} -- {}'
        try:
            try:
                urlretrieve(origin, fpath)
            except HTTPError as e:
                raise Exception(error_msg.format(origin, e.code, e.msg))
            except URLError as e:
                raise Exception(error_msg.format(origin, e.errno, e.reason))
        except (Exception, KeyboardInterrupt) as e:
            if os.path.exists(fpath):
                os.remove(fpath)

    return fpath

def get_h5py_files():
    cache_dir = os.path.expanduser("/content/")
    cache_subdir = "audiospikes"
    get_audio_dataset(cache_dir, cache_subdir)
    train_shd_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_train.h5'), 'r') #r
    test_shd_file  = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_test.h5'), 'r')
    return train_shd_file, test_shd_file

def preprocess_h5py_files(h5py_file):
    nb_steps = 100
    nb_units = 700
    max_time = 1.4 
    num_samples = h5py_file['spikes']['times'].shape[0]

    firing_times = h5py_file['spikes']['times']
    units_fired  = h5py_file['spikes']['units']    
    labels       = h5py_file['labels']

    time_bins = np.linspace(0, max_time, num=nb_steps)
    input  = np.zeros((num_samples, nb_steps, nb_units), dtype=np.uint8)
    output = np.array(labels, dtype=np.uint8)

    for idx in range(num_samples):
        times = np.digitize(firing_times[idx], time_bins)
        units = units_fired[idx] 
        input[idx, times, units] = 1

    return input, output

num_devices = jax.local_device_count()

def shard(data):
    data['spikes'] = rearrange(data['spikes'], '(d b) t u -> d b t u', d=num_devices)
    data['labels'] = rearrange(data['labels'], '(d b) -> d b', d=num_devices)
    return data

def create_tf_dataset(input, output, batch_size, training):
    dataset = tf.data.Dataset.from_tensor_slices({'spikes':input, 
                                                  'labels':output})
    if training:
        dataset = dataset.shuffle(input.shape[0])

    dataset = dataset.cache()
    if training:
        dataset = dataset.batch(batch_size * num_devices, drop_remainder=True)
    else:
        dataset = dataset.batch(batch_size * num_devices, drop_remainder=False)

    dataset = dataset.map(shard, tf.data.AUTOTUNE)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return tfds.as_numpy(dataset)

def prefetch(dataset, n_prefetch):
    ds_iter = iter(dataset)
    ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
                  ds_iter)
    if n_prefetch:
        ds_iter = prefetch_to_device(ds_iter, n_prefetch)
    return ds_iter

In [3]:
#@title 3. Spiking Neural Network Models (with straight-through estimators)

@jax.custom_jvp
def gr_than(x, thr):
    return (x > thr).astype(jnp.float32)
 
@gr_than.defjvp
def gr_jvp(primals, tangents):
    x, thr = primals
    x_dot, y_dot = tangents
    primal_out = gr_than(x, thr)
    tangent_out = x_dot * 1 / (jnp.absolute(x-thr)+1)**2
    return primal_out, tangent_out

def lif_forward(state, x):
    ''' Leaky Integrate and Fire (LIF) neuron model
    '''
    inp_weight, rec_weight, out_weight = state[0]     # Static weights
    thr_rec, thr_out, alpha, kappa = state[1]         # Static neuron states
    v, z, vo, zo = state[2]                           # Dynamic neuron states

    v  = alpha * v + jnp.matmul(x, inp_weight) + jnp.matmul(z, rec_weight) - z * thr_rec
    z = gr_than(v, thr_rec)
    vo = kappa * vo + jnp.matmul(z, out_weight) - zo * thr_out
    zo = gr_than(vo, thr_out)

    return [[inp_weight, rec_weight, out_weight], [thr_rec, thr_out, alpha, kappa], [v, z, vo, zo]], [z, zo]

In [4]:
#@title 4. Training 

def train(key, batch_size, n_inp, n_rec, n_out, thr_rec, thr_out, tau_rec, 
          lr, tau_out, w_gain, n_epochs, num_prefetch):
    
    key, key_model = random.split(key, 2)
    n_devices = local_device_count()

    def net_step(net_params, x_t):
        ''' Single time step network inference (x_t => yhat_t)
        '''
        net_params, [z_rec, z_out] = lif_forward(net_params, x_t)
        return net_params, [z_rec, z_out]
    
    def predict(weights, X):
        _, net_const, net_dyn = param_initializer(key, n_inp, n_rec, n_out, thr_rec,
                                                  thr_out, tau_rec, tau_out, w_gain)
        _, [z_rec, z_out] = scan(net_step, [weights, net_const, net_dyn], X, length=100) 
        Yhat = log_softmax(jnp.sum(z_out.T, 1))
        return Yhat, [z_rec, z_out]

    v_predict = vmap(predict, in_axes=(None, 0))

    def loss(weight, X, Y):
        ''' Scan over time and return predictions
        '''
        Yhat, [z_rec, z_out] = v_predict(weight, X)
        Y = one_hot(Y, n_out)
        num_correct = jnp.sum(jnp.equal(jnp.argmax(Yhat, 1), jnp.argmax(Y, 1)))
        loss_ce = -jnp.mean(jnp.sum(Yhat * Y, axis=1, dtype=jnp.float32))
        return loss_ce, num_correct

    def param_initializer(key, n_inp, n_rec, n_out, thr_rec, thr_out, tau_rec, tau_out, w_gain):
        ''' Initialize parameters
        '''
        key_inp, key_rec, key_out, key = random.split(key, 4)
        alpha = jnp.exp(-1e-3/tau_rec) 
        kappa = jnp.exp(-1e-3/tau_out)

        inp_weight = random.normal(key_inp, [n_inp, n_rec]) * w_gain
        rec_weight = random.normal(key_rec, [n_rec, n_rec]) * w_gain
        out_weight = random.normal(key_out, [n_rec, n_out]) * w_gain

        neuron_dyn = [jnp.zeros(n_rec), jnp.zeros(n_rec), jnp.zeros(n_out), jnp.zeros(n_out)]
        net_params = [[inp_weight, rec_weight, out_weight], [thr_rec, thr_out, alpha, kappa], neuron_dyn]
        return net_params
 
    @partial(pmap, axis_name='num_devices', donate_argnums=(0))   
    def update(opt_state, X, Y):
        weight = get_params(opt_state)
        value, grads = value_and_grad(loss, has_aux=True)(weight, X, Y)
        grads = jax.lax.pmean(grads, axis_name='num_devices')
        L = jax.lax.pmean(value[0], axis_name='num_devices')
        tot_corr = jax.lax.psum(value[1], axis_name='num_devices')
        opt_state = opt_update(0, grads, opt_state)
        return opt_state, (L, tot_corr)

    def one_hot(x, n_class):
        return jnp.array(x[:, None] == jnp.arange(n_class), dtype=jnp.float32)

    @partial(pmap, axis_name='num_devices')   
    def total_correct(opt_state, X, Y):
        weight = get_params(opt_state)
        L, tot_corr = loss(weight, X, Y)
        p_tot_corr = jax.lax.psum(tot_corr, axis_name='num_devices')
        return p_tot_corr

    piecewise_lr = optimizers.piecewise_constant([1000], [lr, lr/10])
    opt_init, opt_update, get_params = optimizers.adam(step_size=piecewise_lr)
    weight, _, _ = param_initializer(key, n_inp, n_rec, n_out, thr_rec, \
                                     thr_out, tau_rec, tau_out, w_gain)
    opt_state = opt_init(weight)
    opt_state = jax.device_put_replicated(opt_state, jax.local_devices())

    # Preprocess data
    train_shd_file, test_shd_file = get_h5py_files()
    train_x, train_y = preprocess_h5py_files(train_shd_file)
    test_x, test_y = preprocess_h5py_files(test_shd_file)
    train_ds = create_tf_dataset(train_x, train_y, batch_size=batch_size, training=True)
    test_ds = create_tf_dataset(test_x, test_y, batch_size=batch_size, training=False)
    del train_x, train_y, test_x, test_y

    # Training loop
    train_loss = []; t = time.time(); 
    for epoch in range(n_epochs):
        acc = 0
        for batch_idx, b in enumerate(prefetch(train_ds, num_prefetch)):
            opt_state, (L, tot_correct) = update(opt_state, b['spikes'], b['labels'])
            train_loss.append(L)
            acc += tot_correct
        
        # Logs
        if epoch % 10 == 0:
            train_acc = 100*acc/(((batch_idx)+1)*batch_size*jax.device_count())
            print(f'Epoch: {epoch}/{n_epochs}' + 
                  f' - Loss: {jnp.mean(L):.2f}' +
                  f' - Training acc: {jnp.mean(train_acc):.2f}')
    
    t_end = time.time()
    print(f'Training completed in {(t_end-t):.2f} seconds ({(n_epochs/(t_end-t)):.2f} epoch/s)')

    # Testing loop
    test_acc_shd = 0; tot_inp = 0; tot_corr = 0
    for batch_idx, b in enumerate(test_ds):
        tot_inp  += b['spikes'].shape[0] * b['spikes'].shape[1]
        tot_corr += total_correct(opt_state, b['spikes'], b['labels'])[0]
    test_acc_shd = 100*tot_corr/tot_inp
    print(f'SHD Test Accuracy: {test_acc_shd:.1f}%')

    return train_loss, test_acc_shd, weight

In [10]:
# Hyperparameters 
seed = 42
batch_size = 1024
n_epochs = 250
lr = 2e-3
n_inp = 700
n_rec = 256
n_out = 20
thr_rec = 1
thr_out = 1
tau_rec = 20e-3
tau_out = 20e-3
w_gain = 1e-1
num_prefetch = 4

train_loss, test_acc_shd, weights = train(key=random.PRNGKey(seed), 
                                          batch_size=batch_size, 
                                          n_inp=n_inp,
                                          n_rec=n_rec,
                                          n_out=n_out,
                                          thr_rec=thr_rec,
                                          thr_out=thr_out,
                                          tau_rec=tau_rec,
                                          lr=lr,
                                          tau_out=tau_out,
                                          w_gain=w_gain,
                                          n_epochs=n_epochs,
                                          num_prefetch=num_prefetch)

Available at: /content/audiospikes/shd_train.h5
Available at: /content/audiospikes/shd_test.h5
Epoch: 0/250 - Loss: 17.07 - Training acc: 5.96
Epoch: 10/250 - Loss: 2.90 - Training acc: 5.65
Epoch: 20/250 - Loss: 2.42 - Training acc: 21.16
Epoch: 30/250 - Loss: 1.56 - Training acc: 46.11
Epoch: 40/250 - Loss: 1.03 - Training acc: 63.17
Epoch: 50/250 - Loss: 0.72 - Training acc: 73.54
Epoch: 60/250 - Loss: 0.52 - Training acc: 80.41
Epoch: 70/250 - Loss: 0.70 - Training acc: 77.75
Epoch: 80/250 - Loss: 0.42 - Training acc: 84.32
Epoch: 90/250 - Loss: 0.30 - Training acc: 88.67
Epoch: 100/250 - Loss: 0.28 - Training acc: 90.65
Epoch: 110/250 - Loss: 0.27 - Training acc: 90.54
Epoch: 120/250 - Loss: 0.19 - Training acc: 92.69
Epoch: 130/250 - Loss: 0.16 - Training acc: 94.46
Epoch: 140/250 - Loss: 0.13 - Training acc: 95.76
Epoch: 150/250 - Loss: 0.12 - Training acc: 95.44
Epoch: 160/250 - Loss: 0.13 - Training acc: 96.48
Epoch: 170/250 - Loss: 0.15 - Training acc: 96.61
Epoch: 180/250 - 