# Train the Network
This notebook demonstrates how to use the Network class with different activation functions from jax.nn.

In [6]:
from functools import partial

# Import Network and JAX libraries
import jax
import jax.numpy as jnp
jnp.set_printoptions(precision=4, suppress=True)

from loader import load_data, load_data_onehot
from network import Network, init_network_params, cross_entropy_loss, mse_loss

In [22]:
# Load train and test data using loader.py
# Use loader.load_data_onehot for one-hot encoded labels.
train_iter, test_iter = load_data_onehot(flatten=True)
X_train, Y_train = next(train_iter)
X_test, Y_test = next(test_iter)

# Define network architecture
layer_sizes = [784, 30, 10]    # Gets 97% with mse loss
# layer_sizes = [784, 30, 10]
key = jax.random.PRNGKey(0)
net = Network(layer_sizes, loss_fn=mse_loss, activation=jax.nn.swish)
init_params = init_network_params(layer_sizes, key)

In [19]:
# Untrained performance
# jnp.sum(jax.vmap(net.evaluate, in_axes=(None, 0, 0))(init_params, X_test, Y_test)) / X_test.shape[0]
jnp.mean(jax.vmap(cross_entropy_loss, in_axes=(None, 0, 0, None))(init_params, X_test, Y_test, net))

Array(2.3408, dtype=float32)

In [20]:
# Untrained expected cross-entropy
-jnp.log(1/10)

Array(2.3026, dtype=float32, weak_type=True)

## Test differentiation

In [15]:
h = net.backward(init_params, X_train[2], Y_train[2])
aval, adiff = jax.value_and_grad(mse_loss, argnums=0)(init_params, X_train[2], Y_train[2], net)
hval, hdiff = h[1], h[0]

In [16]:
jax.tree.map(lambda x, y: jnp.allclose(x, y), hdiff, adiff)

{'b': [Array(True, dtype=bool),
  Array(True, dtype=bool),
  Array(True, dtype=bool)],
 'w': [Array(True, dtype=bool),
  Array(True, dtype=bool),
  Array(True, dtype=bool)]}

## Train and test

In [None]:
params = net.sgd(init_params, X_train, Y_train, X_test, Y_test, batch_size=128, lr=3, epochs=100)

Epoch 1, Loss: 0.2469, Train Acc: 92.55%, Test Acc: 92.77%
Epoch 2, Loss: 0.1801, Train Acc: 94.65%, Test Acc: 94.17%
Epoch 3, Loss: 0.1438, Train Acc: 95.80%, Test Acc: 95.23%
Epoch 4, Loss: 0.1175, Train Acc: 96.58%, Test Acc: 95.82%
Epoch 5, Loss: 0.0998, Train Acc: 97.14%, Test Acc: 96.08%
Epoch 6, Loss: 0.0898, Train Acc: 97.42%, Test Acc: 96.00%
Epoch 7, Loss: 0.0817, Train Acc: 97.65%, Test Acc: 96.23%
Epoch 8, Loss: 0.0742, Train Acc: 97.83%, Test Acc: 96.34%
Epoch 9, Loss: 0.0679, Train Acc: 98.02%, Test Acc: 96.40%
Epoch 10, Loss: 0.0626, Train Acc: 98.19%, Test Acc: 96.46%
Epoch 11, Loss: 0.0580, Train Acc: 98.30%, Test Acc: 96.55%
Epoch 12, Loss: 0.0539, Train Acc: 98.42%, Test Acc: 96.60%
Epoch 13, Loss: 0.0503, Train Acc: 98.53%, Test Acc: 96.61%
Epoch 14, Loss: 0.0472, Train Acc: 98.60%, Test Acc: 96.62%
Epoch 15, Loss: 0.0444, Train Acc: 98.69%, Test Acc: 96.65%
Epoch 16, Loss: 0.0418, Train Acc: 98.77%, Test Acc: 96.59%
Epoch 17, Loss: 0.0395, Train Acc: 98.87%, Test A

In [8]:
# sgd = partial(net.sgd, X=X_train, Y=Y_train, X_test=X_test, Y_test=Y_test, 
#               batch_size=128, lr=1.0)
# jit_sgd = partial(jax.jit, static_argnames=("epochs",))(sgd)
# n_epochs = 30

In [9]:
# params = sgd(init_params, epochs=n_epochs)

In [10]:
# params = jit_sgd(init_params, epochs=n_epochs)