In [None]:
# Run this cell, if you are running the notebook on Google Colab:
!git clone https://github.com/CRC183-summer-school/school_2021.git

In [176]:
# Everything we need from JAX
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, random

## Activation functions 

In [20]:
def relu(x):
    return jnp.maximum(0, x)

def sigmoid(x):
    return 1/(1+jnp.exp(-x))

## Neural network

In [178]:
# The key for our random number generator, we need this because JAX will distribute the calculation
key = random.PRNGKey(0)

def initialize_network_params(layer_sizes):
    # Get a random key for each layer
    keys = random.split(key, len(layer_sizes))
    
    parameters = []
    for i in range(len(layer_sizes)-1):
        parameters.append( [random.normal(keys[i], (layer_sizes[i+1], layer_sizes[i])), # The weight matrix
                            random.normal(keys[i], (layer_sizes[i+1],)) ]) # The bias

    return parameters

In [179]:
layer_sizes = [50, 10, 2]
neural_network_parameters = initialize_network_params(layer_sizes)

In [182]:
def predict_sample(params, sample):
    activations = sample
    
    # Feed through the network, layer by layer
    for weights, biases in params:
        outputs = jnp.dot(weights, activations) + biases
        activations = sigmoid(outputs)
        
    return activations

In [183]:
predict_sample(neural_network_parameters, np.zeros(50))

DeviceArray([0.7025618, 0.7088518], dtype=float32)

## The Loss function

In [184]:
def loss_for_sample(params, sample, target):
    pred = predict_sample(params, sample)
    return jnp.mean( (pred - target)**2 )

In [185]:
loss_for_sample(neural_network_parameters, np.zeros(50), np.array([1,0]))

DeviceArray(0.2954702, dtype=float32)

## The optimization

In [186]:
step_size = 1e-2
def update_for_sample(params, x, y):
    grads = grad(loss_for_sample)(params, x, y)
    return [(w - step_size * dw, b - step_size *db) for (w,b),(dw,db) in zip(params, grads)]

In [187]:
neural_network_parameters = update_for_sample(neural_network_parameters, np.zeros(50), np.array([1,0]))

## The training loop

In [192]:
layer_sizes = [50, 10, 2]
neural_network_parameters = initialize_network_params(layer_sizes)

print("Initial prediction: ")
print(predict_sample(neural_network_parameters, np.zeros(50)))

for epoch in range(500):
    # Update the neural network parameters, so that for input 'np.zeros(50)' the output becomes '[0,1]'
    neural_network_parameters = update_for_sample(neural_network_parameters, np.zeros(50), np.array([0,1]))
    
print("After training: ")
print(predict_sample(neural_network_parameters, np.zeros(50)))

Initial prediction: 
[0.7025618 0.7088518]
After training: 
[0.2398114 0.8432371]


### If we now want to predict for many samples, we **could** do this:

In [189]:
# Generate 10 random samples of size 50
all_samples = np.random.rand(10,50)
# For demonstration purposes, let's have the same target for all samples
target = np.array([1,0])

total_loss = 0
for sample in all_samples:
    total_loss += loss_for_sample(neural_network_parameters, sample, target)
    
print(total_loss)

5.765141


#### But JAX can do this much faster! 

In [190]:
# Make a 'vectorized' version of predict_sample, using 'vmap'
# Here we tell JAX that the index to be vectorized over is the first (in_axis = 0), and that it 
#   should also put that vectorized result as the first index in the output
predict = vmap(predict_sample, in_axes=(None,0), out_axes=0)

# Now we (re-)define the loss function to use this new vectorized predict function
def loss(params, samples, targets):
    pred = predict(params, samples)
    return jnp.mean( (pred - targets)**2 )

# And the same for the update function
step_size = 1e-2
#@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size *db) for (w,b),(dw,db) in zip(params, grads)]

In [193]:
all_samples = np.random.rand(10,50)
targets = np.array([[0,1] for i in range(10)])

layer_sizes = [50, 10, 2]
neural_network_parameters = initialize_network_params(layer_sizes)

print("Initial prediction: ")
print(predict(neural_network_parameters, all_samples))

for epoch in range(500):
    neural_network_parameters = update(neural_network_parameters, all_samples, targets)
    
print("After training: ")
print(predict(neural_network_parameters, all_samples))

Initial prediction: 
[[0.7331598  0.9061516 ]
 [0.8618607  0.79652697]
 [0.62596804 0.81051296]
 [0.8639208  0.88325244]
 [0.7379976  0.9311364 ]
 [0.5626366  0.7117155 ]
 [0.8775297  0.8445624 ]
 [0.85660243 0.82235706]
 [0.86117375 0.8429931 ]
 [0.85192144 0.91011536]]
After training: 
[[0.14645067 0.9381032 ]
 [0.30671868 0.8825794 ]
 [0.12380596 0.8912052 ]
 [0.22043103 0.9405608 ]
 [0.16109806 0.95768106]
 [0.10426618 0.8038362 ]
 [0.31655475 0.9157887 ]
 [0.23097222 0.90807724]
 [0.24089806 0.9044562 ]
 [0.25836465 0.94957983]]


# Improvements

Instead of defining the model parameters for this network ourselves, the module `Flax` is very useful for defining neural networks. Especially once you would like to build more complex ones (convolutional layers, for example), it is a useful module to have on top.