# Training a MLP using forward and backward 
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)]()

In the following tutorial we look at the main features of **VENI**. In particular, we see how switch between forward automatic differentiation and backward automatic differentiation. For our aim we will train a simple MLP.

In [None]:
from torchvision.datasets import MNIST
import numpy as np
import sys

sys.path.append('../')

import jax
from jax import grad
import jax.numpy as jnp
from veni import ReLU, Softmax, Sequential, Linear
from veni.module import Module
from veni.utils import NumpyLoader, one_hot
from veni.optim import grad_fwd
from veni.functiontools import CrossEntropy

We now define a simple MLP and the network loss.

In [9]:
class MLP(Module):
    def __init__(self):
        self.layers = Sequential([
            Linear(28*28, 1024, jax.random.PRNGKey(111)),
            ReLU(),
            Linear(1024, 10, jax.random.PRNGKey(222)),
            Softmax()
        ])

        self.params = self.layers.generate_parameters()
        # eliminate the bias

    def forward(self, x, params):
        return self.layers(x, params)


model = MLP()
params = model.params

# loss + accuracy
def loss(params, x, y):
    y_hat = model(x, params)
    return CrossEntropy(y, y_hat)


def accuracy(y, y_hat):
    model_predictions = jnp.argmax(y_hat, axis=1)
    return jnp.mean(y == model_predictions)


Let's download the dataset

In [10]:
class tf(object):
    def __call__(self, pic):
        return np.array(np.ravel(pic), dtype=jnp.float32)/255

# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=tf(), train=True)
training_generator = NumpyLoader(
    mnist_dataset, batch_size=24, num_workers=0)

mnist_dataset = MNIST('/tmp/mnist/', download=True,
                      transform=tf(), train=False)
testing_generator = NumpyLoader(
    mnist_dataset, batch_size=24, num_workers=0)

We define an update function, which can handle both backward and forward AD

In [11]:
def grad_bwd(params, x, y, loss, key):
    grads = grad(loss)(params, x, y)
    return grads


def update(params, x, y, loss, optimizer, key, grad_type = 'fwd'):
    key = jax.random.split(key)
    if grad_type == 'fwd':
        grads = grad_fwd(params, x, y, loss, 1)
    elif grad_type == 'bwd':
        grads = grad_bwd(params, x, y, loss, key)
    else:
        raise ValueError(f"Invalid grad_type, expected 'fwd' or 'bwd' got {grad_type}")

    return optimizer(params, grads)


For this tutorial we are using vanilla SGD optimizer. Currently Adam and SGD optimizers are implemented in VENI: those optimizer are mainly used for simple benchmarking purpouses and are not meant to be efficient and fast.

Now we can train the newtwork.

In [13]:
#define the optimizer
def optimizer(params, grad, eta = 2e-4):
    return [(w - eta * dw, b - eta * db) for (w, b), (dw, db) in zip(params, grad)]

key = jax.random.PRNGKey(111)
for epoch in range(5):
    for x, y in training_generator:
        key = jax.random.split(key)
        one_hot_label = one_hot(y, 10)
        params = update(params, x, one_hot_label, loss, optimizer, key, grad_type='fwd')
    
    acc = 0
    count = 0
    for x, y in testing_generator:
        y_hat = model(x, params)
        acc += accuracy(y, y_hat)*x.shape[0]
        count += x.shape[0]
    
    print(f"Epoch [{epoch}] / Accuracy [{acc / count}]")


Epoch [0] / Accuracy [0.541700005531311]
Epoch [1] / Accuracy [0.6349999904632568]
Epoch [2] / Accuracy [0.6413000226020813]
Epoch [3] / Accuracy [0.666100025177002]
Epoch [4] / Accuracy [0.651199996471405]


You can now change freely from backward to forward or hybrid approaches.