In [None]:
# Run this cell, if you are running the notebook on Google Colab:
!git clone https://github.com/CRC183-summer-school/school_2021.git

In [176]:
import numpy as np
# Everything we need from JAX
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random

## Activation functions 

In [20]:
def relu(x):
    return jnp.maximum(0, x)

def sigmoid(x):
    return 1/(1+jnp.exp(-x))

## Neural network

In [178]:
# The key for our random number generator, we need this because JAX will distribute the calculation
key = random.PRNGKey(0)

def initialize_network_params(layer_sizes):
    # Get a random key for each layer
    keys = random.split(key, len(layer_sizes))
    
    parameters = []
    for i in range(len(layer_sizes)-1):
        parameters.append( [random.normal(keys[i], (layer_sizes[i+1], layer_sizes[i])), # The weight matrix
                            random.normal(keys[i], (layer_sizes[i+1],)) ]) # The bias

    return parameters

In [179]:
layer_sizes = [50, 10, 2]
neural_network_parameters = initialize_network_params(layer_sizes)

In [262]:
def predict_sample(params, sample):
    activations = sample
    
    # Feed through the network, layer by layer
    for weights, biases in params[:-1]:
        outputs = jnp.dot(weights, activations) + biases
        activations = sigmoid(outputs)

    # For the final layer, we don't want the relu
    final_weight, final_bias = params[-1]
    logits = jnp.dot(final_weight, activations) + final_bias
    
    # We want a softmax output
    explogits = jnp.exp(logits)
    return explogits/jnp.sum(explogits)

In [250]:
predict_sample(neural_network_parameters, np.zeros(50))

DeviceArray([0.11788198, 0.88211805], dtype=float32)

## The Loss function

In [251]:
def loss_for_sample(params, sample, target):
    pred = predict_sample(params, sample)
    return jnp.mean( (pred - target)**2 )

In [252]:
loss_for_sample(neural_network_parameters, np.zeros(50), np.array([1,0]))

DeviceArray(0.7781322, dtype=float32)

## The optimization

In [253]:
step_size = 1e-2
def update_for_sample(params, x, y):
    grads = grad(loss_for_sample)(params, x, y)
    return [(w - step_size * dw, b - step_size *db) for (w,b),(dw,db) in zip(params, grads)]

In [254]:
neural_network_parameters = update_for_sample(neural_network_parameters, np.zeros(50), np.array([1,0]))

## The training loop

In [255]:
layer_sizes = [50, 10, 2]
neural_network_parameters = initialize_network_params(layer_sizes)

print("Initial prediction: ")
print(predict_sample(neural_network_parameters, np.zeros(50)))

for epoch in range(500):
    # Update the neural network parameters, so that for input 'np.zeros(50)' the output becomes '[0,1]'
    neural_network_parameters = update_for_sample(neural_network_parameters, np.zeros(50), np.array([0,1]))
    
print("After training: ")
print(predict_sample(neural_network_parameters, np.zeros(50)))

Initial prediction: 
[0.49242866 0.5075713 ]
After training: 
[0.08945558 0.9105444 ]


### If we now want to predict for many samples, we **could** do this:

In [256]:
# Generate 10 random samples of size 50
all_samples = np.random.rand(10,50)
# For demonstration purposes, let's have the same target for all samples
target = np.array([1,0])

total_loss = 0
for sample in all_samples:
    total_loss += loss_for_sample(neural_network_parameters, sample, target)
    
print(total_loss)

7.9115486


#### But JAX can do this much faster! 

In [270]:
# Make a 'vectorized' version of predict_sample, using 'vmap'
# Here we tell JAX that the index to be vectorized over is the first (in_axis = 0), and that it 
#   should also put that vectorized result as the first index in the output
predict = vmap(predict_sample, in_axes=(None,0), out_axes=0)

# Now we (re-)define the loss function to use this new vectorized predict function
def loss(params, samples, targets):
    pred = predict(params, samples)
    return jnp.mean( (pred - targets)**2 )

# And the same for the update function
step_size = 1e-2
@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size *db) for (w,b),(dw,db) in zip(params, grads)]

In [258]:
all_samples = np.random.rand(10,50)
targets = np.array([[0,1] for i in range(10)])

layer_sizes = [50, 10, 2]
neural_network_parameters = initialize_network_params(layer_sizes)

print("Initial prediction: ")
print(predict(neural_network_parameters, all_samples))

for epoch in range(500):
    neural_network_parameters = update(neural_network_parameters, all_samples, targets)
    
print("After training: ")
print(predict(neural_network_parameters, all_samples))

Initial prediction: 
[[0.26412225 0.73587775]
 [0.24111044 0.75888956]
 [0.13203025 0.86796975]
 [0.38874215 0.61125785]
 [0.396364   0.60363597]
 [0.18825619 0.81174374]
 [0.39966217 0.6003378 ]
 [0.07824132 0.9217587 ]
 [0.2613949  0.7386051 ]
 [0.26985535 0.7301446 ]]
After training: 
[[0.07209282 0.92790717]
 [0.05102606 0.9489739 ]
 [0.04807254 0.9519275 ]
 [0.08966338 0.9103367 ]
 [0.10765581 0.89234424]
 [0.05388276 0.9461173 ]
 [0.11235292 0.8876471 ]
 [0.01764767 0.9823523 ]
 [0.06969136 0.93030864]
 [0.06187989 0.9381201 ]]


# Improvements

Instead of defining the model parameters for this network ourselves, the module `Flax` is very useful for defining neural networks. Especially once you would like to build more complex ones (convolutional layers, for example), it is a useful module to have on top.

Additionally, `Flax` has the `flax.optim` module that conveniently has all kinds of optimizers. As an alternative, `jax.experimental` has a bunch of optimizers too!