# Training a simple neural network with tensorflow/ datasets Data Loading
https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## Hyperparameters

In [2]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer

def random_layer_params(m,n,key,scale=1e-2):
    w_key, b_key = random.split(key)
    return scale*random.normal(w_key,(n,m)) , scale * random.normal(b_key,(n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m,n,k) for m,n,k in zip(sizes[:-1],sizes[1:],keys)]

layer_sizes = [784,512,512,10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Auto-batching predictions

In [3]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0,x)

def predict(params, image):
    # per-example predictions
    activations = image
    for w,b in params[:-1]:
        outputs = jnp.dot(w,activations)+b
        activations = relu(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w,activations) +  final_b
    return logits - logsumexp(logits)

In [9]:
# This works on single examples
random_flattened_image =  random.normal(random.PRNGKey(1),(28*28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [10]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1),(10,28*28))
try:
    preds = predict(params, random_flattened_images)
except TypeError:
    print("Invalid shapes!")

Invalid shapes!


In [11]:
# Let's upgrade it to handle batches using vmap

# make a batched version of the predict function
batched_predict = vmap(predict, in_axes = (None,0))

# batched_predict has the same call signature as predict
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


At this point, we have all the ingredients we need to define our neural network and train it. We’ve built an auto-batched version of predict, which we should be able to use in a loss function. We should be able to use grad to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use jit to speed up everything.