# Event-based training of NNs in Jax

In this notebook, we will explore how to write event based software to do gradient-based learning with neural networks (NN) in [jax](https://github.com/google/jax/).

In [15]:
import jax 
import jax.numpy as jnp
from jax import vmap, grad, jit
from jax.scipy.special import logsumexp
import matplotlib.pyplot as plt
from functools import partial
import dataclasses
from typing import Generic, Any, Union, TypeVar, Tuple
import tree_math
import numpy as np

from jax import config 
config.update("jax_debug_nans", True)
print("Using jax", jax.__version__)
print(jax.devices())

Using jax 0.5.1
[CpuDevice(id=0)]


In [19]:
a = jnp.array([[5, 4, 2, 3, 1],
              [4.8870957e-03, 6.7979717e-03, 5.1129041e-03, 3.2020281e-03, 4.9999999e-03]])

print(jnp.arange(a.shape[0])[:, None])
sort_idx = jnp.argsort(a, axis=-1)
print(sort_idx)
input = a[jnp.arange(a.shape[0])[:, None], sort_idx]
print(input)

print(input.devices())
print(input.sharding)
jax.debug.visualize_array_sharding(input)

[[0]
 [1]]
[[4 2 3 1 0]
 [3 0 4 2 1]]
[[1.0000000e+00 2.0000000e+00 3.0000000e+00 4.0000000e+00 5.0000000e+00]
 [3.2020281e-03 4.8870957e-03 4.9999999e-03 5.1129041e-03 6.7979717e-03]]
{CpuDevice(id=0)}
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)


In [6]:
@dataclasses.dataclass
@tree_math.struct
class Neuron:
    value: jnp.ndarray
    

@dataclasses.dataclass
@tree_math.struct
class InputQueue:
    neuron: Neuron
    head: int = 0

    @property
    def is_empty(self) -> bool:
        return self.head == len(self.spikes.time)

    def peek(self) -> Neuron:
        return self.spikes[self.head]

    def pop(self) -> Neuron:
        spike = self.spikes[self.head]
        self.head += 1
        return spike

In [7]:
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10

def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = jax.random.split(key)
  return scale * jax.random.normal(w_key, (n, m)), scale * jax.random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = jax.random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

params = init_network_params(layer_sizes, jax.random.key(0))
[print(len(w), len(b)) for w, b in params[:]]

512 512
512 512
10 10


[None, None, None]

In [8]:
def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
    
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

# This works on single examples
random_flattened_image = jax.random.normal(jax.random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
random_flattened_images = jax.random.normal(jax.random.key(1), (10, 28 * 28))
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10,)
(10, 10)


In [9]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [10]:
!pip install torch torchvision



In [11]:
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

In [13]:
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

  train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
  train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)
  test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)


In [14]:
import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 3.40 sec
Training set accuracy 0.9134333729743958
Test set accuracy 0.9165999889373779
Epoch 1 in 3.18 sec
Training set accuracy 0.9359000325202942
Test set accuracy 0.9353999495506287
Epoch 2 in 3.24 sec
Training set accuracy 0.9474666714668274
Test set accuracy 0.9460999965667725
Epoch 3 in 3.17 sec
Training set accuracy 0.9559000134468079
Test set accuracy 0.9527999758720398
Epoch 4 in 3.24 sec
Training set accuracy 0.962066650390625
Test set accuracy 0.9575999975204468
Epoch 5 in 3.32 sec
Training set accuracy 0.9663166999816895
Test set accuracy 0.9606999754905701
Epoch 6 in 3.06 sec
Training set accuracy 0.9700333476066589
Test set accuracy 0.9633999466896057
Epoch 7 in 3.03 sec
Training set accuracy 0.9731500148773193
Test set accuracy 0.9662999510765076
