# Training a MLP using backpropagation

* **part 1** we train using `MLP`. 
* **part 2** we train using `Sequential` and `Linear`. 
* **part 3** we train using `Sequential`, `Linear` and `Conv2D`. 

In [1]:
import jax
from jax import numpy as jnp
from jax import grad, jit
from jax_forward import ReLU, LogSoftmax, MLP, Module, Sequential, Linear, Conv2D, MaxPool2D, Flatten
from time import time

jax.config.update('jax_platform_name', 'cpu')

## Part 1
We first define useful parameters that we are going to use later. Note that our model is a MLP imported from `net.py`

In [2]:
# hyperparameters
layer_sizes = [784, 512, 256, 10]
step_size = 0.01
num_epochs = 5
batch_size = 128
n_targets = 10
key = jax.random.PRNGKey(0)
function = ReLU()
log_softmax = LogSoftmax()

# to create a model is as simple as that
model = MLP(layer_sizes, function, key)
params = model.params # we need to extract parameters

Helper functions, trainloader and other stuff...

In [7]:
# some helper functions
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(model(images,params), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = model(images, params)
  preds = log_softmax(preds)
  return -jnp.mean(preds * targets)

#@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]


import numpy as np
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))


class Cast(object):
  def __call__(self, pic):
    return np.array(pic, dtype=jnp.float32)

# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

Here we train using backpropagation

In [4]:
import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 5.78 sec
Training set accuracy 0.9044666886329651
Test set accuracy 0.911899983882904
Epoch 1 in 4.90 sec
Training set accuracy 0.9273666739463806
Test set accuracy 0.9305999875068665
Epoch 2 in 5.48 sec
Training set accuracy 0.9414166808128357
Test set accuracy 0.9399999976158142
Epoch 3 in 5.29 sec
Training set accuracy 0.9496833682060242
Test set accuracy 0.9474999904632568
Epoch 4 in 4.96 sec
Training set accuracy 0.956416666507721
Test set accuracy 0.9529999494552612


## Part 2
We first define useful parameters that we are going to use later. Note that our model is a costum model built using `Sequential` and `Linear`.

In [5]:
class myModel(Module):
    def __init__(self):
        self.layers = Sequential([
            Linear(28*28,512,key),
            function,
            Linear(512,256,key),
            function,
            Linear(256,10,key)
        ])
        self.params = self.layers.generate_parameters()

    def forward(self, data, params):
        return self.layers(data, params)


# we define the model as simple as that
model = myModel()
params = model.params

Let's train

In [6]:
for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 6.04 sec
Training set accuracy 0.9094333648681641
Test set accuracy 0.9138000011444092
Epoch 1 in 5.02 sec
Training set accuracy 0.9317333698272705
Test set accuracy 0.9314999580383301
Epoch 2 in 4.85 sec
Training set accuracy 0.9435499906539917
Test set accuracy 0.9408999681472778


## Part 3
We first define useful parameters that we are going to use later. Note that our model is a costum model built using `Sequential`, `Linear` and `Conv2D` base on this [PyTorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).

In [4]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
import torchvision

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), Cast()])

cifar_dataset = CIFAR10('/tmp/cifar/', download=True,train=True,
                        transform=transform)
training_generator = NumpyLoader(
    cifar_dataset, batch_size=batch_size, num_workers=0)

# Get the full train dataset (for checking accuracy while training)
train_images = np.array(cifar_dataset.data)
train_labels = one_hot(np.array(cifar_dataset.targets), n_targets)

# Get full test dataset
cifar_dataset_test = CIFAR10('/tmp/cifar/', download=True, train=False,
                             transform=Cast())
test_images = jnp.array(cifar_dataset_test.data)
test_labels = one_hot(np.array(cifar_dataset.targets), n_targets)


Files already downloaded and verified
Files already downloaded and verified


Let's define the model

In [5]:
class myModel(Module):
    def __init__(self):
        self.layers = Sequential([
            Conv2D(3, 6, 5, 1, 0, jax.random.PRNGKey(0)),
            function,
            MaxPool2D(2, 2),
            Conv2D(6, 16, 5, 1, 0, jax.random.PRNGKey(1)),
            function,
            MaxPool2D(2, 2),
            Flatten(),
            Linear(16 * 5 * 5, 120, jax.random.PRNGKey(2)),
            function,
            Linear(120, 84, jax.random.PRNGKey(3)),
            function,
            Linear(84, 10, jax.random.PRNGKey(3))
        ])
        self.params = self.layers.generate_parameters()

    def forward(self, data, params):
        return self.layers(data, params)


# we define the model as simple as that
model = myModel()
params = model.params


In [8]:
for epoch in range(num_epochs):
  start_time = time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

(Traced<ConcreteArray([[[[-4.48695989e-03 -7.11106742e-03 -8.55908077e-03 -4.69090883e-03
     6.87972875e-03]
   [ 3.79619142e-03  2.41013011e-03 -2.96275131e-03  1.24530196e-02
     3.84640484e-03]
   [-1.75121091e-02 -1.61187933e-03  5.70336916e-03  9.93743353e-03
     1.25568230e-02]
   [ 1.22481938e-02  5.47620980e-03  1.76855049e-03 -1.05922483e-02
    -9.52529069e-03]
   [-2.14272970e-03  7.26605626e-03 -4.15404001e-03 -1.46918548e-02
    -1.20740198e-03]]

  [[ 2.06230376e-02  4.40425985e-03  3.16884089e-03  4.46637487e-03
    -3.84080340e-03]
   [ 3.83325154e-03 -4.24770778e-03 -1.75783018e-04 -9.48355533e-03
    -5.74528612e-03]
   [ 1.09051736e-02  1.38865528e-03  2.97255628e-03  2.96859979e-03
    -7.41836103e-03]
   [-1.55019097e-03 -5.07828128e-03 -1.43137174e-02 -4.56925482e-03
     1.48569029e-02]
   [ 3.92207364e-03  8.27108975e-03  1.37943458e-02 -3.60407005e-03
    -1.04910368e-03]]

  [[-1.25242937e-02  1.29980557e-02 -3.20089384e-05 -5.15191071e-03
    -7.81618990e

TypeError: convolution requires lhs and rhs ndim to be equal, got 2 and 4.