# Training a MLP using backpropagation and forward differentiation

In **part 1** we train using backpropagation, while in **part 2** we train using forward ad.

In [1]:
import jax.numpy as jnp
import jax
jax.config.update('jax_platform_name', 'cpu')
import os, sys
sys.path.append(os.path.join(os.getcwd(),'jax_forward'))
from net import MLP
from functional import relu, log_softmax
from jax import grad, jit

## Part 1
We first define useful parameters that we are going to use later. Note that our model is a MLP imported from `net.py`

In [2]:
# hyperparameters
layer_sizes = [784, 512, 256, 10]
step_size = 0.01
num_epochs = 1
batch_size = 128
n_targets = 10
key = jax.random.PRNGKey(0)

# to create a model is as simple as that
model = MLP(layer_sizes, relu, key)
params = model.params # we need to extract parameters

Helper functions, trainloader and other stuff...

In [3]:
# some helper functions
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(model(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = model(params, images)
  preds = log_softmax(preds)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]


import numpy as np
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)



Here we train using backpropagation

In [4]:
import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 7.63 sec
Training set accuracy 0.9044666886329651
Test set accuracy 0.911899983882904


## Part 2

Now lets's train using forward AD. Note that we need to change a bit the function `update`

In [5]:
from jax import jvp
model = MLP(layer_sizes,relu,key)

def get_vector(key, params):
  v_shaped = []
  for w,b in params:
    key, subkey = jax.random.split(key)
    v_w = jax.random.normal(key, shape=w.shape)
    v_b = jax.random.normal(subkey, shape=b.shape)
    v_shaped.append((v_w, v_b))
  return v_shaped

@jit
def update(params, key, x, y):
  v = get_vector(key, params)
  _, dd = jvp(lambda params: loss(params, x, y), (params,), (v,))
  step = step_size * dd
  return [(w - step * dw, b - step * db)
          for (w, b), (dw, db) in zip(params, v)]

Let's train again

In [6]:
step_size = 2e-5

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    key, _ = jax.random.split(key)
    y = one_hot(y, n_targets)
    params = update(params, key, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))


Epoch 0 in 15.56 sec
Training set accuracy 0.9088000059127808
Test set accuracy 0.9150999784469604


In [7]:
#test on custom module class

In [34]:
from net import *
from module import *
from function import *

class myModel(Module):
    def __init__(self):
        self.layers = Sequential([
            Linear(28*28,512,key),
            ReLU(),
            Linear(512,256,key),
            ReLU(),
            Linear(256,10,key)

        ])
        self.params = self.layers.generate_parameters()

    def forward(self, params, data):
        return self.layers(params,data)


In [35]:
model = myModel()

In [67]:
step_size = 2e-3
params = model.params

for epoch in range(5):
  start_time = time.time()
  for x, y in training_generator:
    key, _ = jax.random.split(key)
    y = one_hot(y, n_targets)
    params = update(params, key, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 10.23 sec
Training set accuracy 0.3422499895095825
Test set accuracy 0.33889999985694885
Epoch 1 in 9.94 sec
Training set accuracy 0.5627999901771545
Test set accuracy 0.5673999786376953
Epoch 2 in 13.68 sec
Training set accuracy 0.6536666750907898
Test set accuracy 0.6548999547958374
Epoch 3 in 9.81 sec
Training set accuracy 0.7006833553314209
Test set accuracy 0.708899974822998
Epoch 4 in 9.34 sec
Training set accuracy 0.7205833196640015
Test set accuracy 0.7286999821662903
