# A few more jax tricks (compile and batches)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/FlorianMarquardt/machine-learning-for-physicists/blob/master/2024/03_MoreJaxTricks.ipynb)

Example code for the lecture series "Machine Learning for Physicists" by Florian Marquardt

Lecture 3

See https://machine-learning-for-physicists.org and the current course website linked there!

This notebook shows how to build a little network and evaluate the gradient of the cost function using jax, and how to apply one step of gradient descent.

MIT License.

In [31]:
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, vmap

In [8]:
# now define a simple neural network
# (arbitrary number of layers, but relu activation at each layer)
def network(parameters,x):
  """
  Evaluate network.

  parameters=[[weights1,biases1],[weights2,biases2],...]
  x=input vector
  """
  for weights,biases in parameters:
    # weights has shape (neurons_lower_layer,neurons_upper_layer),
    # biases has shape (neurons_upper_layer,)
    z=jnp.dot(x,weights)+biases
    x=(z>0)*z # relu activation
  return x

In [9]:
# our network has structure 2 (input) -- 3 -- 1 (output)
weights1=jnp.array([[0.1,0.3,0.5],[-0.4,0.2,0.8]]) # shape (2,3)
biases1=jnp.array([0.1,-0.2,0.3]) # shape (3,)
weights2=jnp.array([[0.2],[0.7],[-0.5]]) # shape (3,1)
biases2=jnp.array([0.2]) # shape (1,)

params=[[weights1,biases1],[weights2,biases2]]

In [10]:
# apply network to test input
x=jnp.array([0.9,-0.5])
print(network(params,x))

[0.10300001]


In [14]:
# define a cost function (here: quadratic deviation)
def cost(params,x,y_target):
  return jnp.sum( ( network(params,x) - y_target )**2 )
# note: we would divide by the batch size jnp.shape(x)[0] if we want to average
# over a batch (but right now we do not do batches)

In [12]:
cost(params,x,1.0)

Array(0.804609, dtype=float32)

In [17]:
%%timeit
cost(params,x,1.0)

194 µs ± 25 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Compile using jax "jit" (just-in-time compilation)

In [20]:
# define a cost function (here: quadratic deviation)
# compile it!
@jit
def compiled_cost(params,x,y_target):
  return jnp.sum( ( network(params,x) - y_target )**2 )
# note: we would divide by the batch size jnp.shape(x)[0] if we want to average
# over a batch (but right now we do not do batches)

In [21]:
%%timeit
compiled_cost(params,x,1.0)

10.8 µs ± 2.17 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [39]:
# now apply jax autodifferentiation to get the
# gradient of the cost function with respect to the params:
# note: it is enough to apply jit once at the very end!
grad_cost=jit(grad(cost,argnums=0)) # argnums=0 means first argument, i.e. params

In [40]:
%%timeit
grad_cost(params,x,1.0)

15.9 µs ± 2.69 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [28]:
# calculate the gradient of the cost function,
# at the current values of the parameters 'params',
# and for that given x input vector (and with y_target==1.0):
grad_value = grad_cost(params,x,1.0)

This output is just the same shape as params!
Each component is the gradient of the cost function with respect to that component!

In [59]:
# update parameters:
from jax.tree_util import tree_map

learning_rate=0.1
new_params = tree_map(lambda x,y: x - learning_rate * y, params, grad_value)

In [60]:
new_params

[[Array([[ 0.132292  ,  0.3       ,  0.41927   ],
         [-0.41794002,  0.2       ,  0.84485   ]], dtype=float32),
  Array([ 0.13588001, -0.2       ,  0.2103    ], dtype=float32)],
 [Array([[ 0.269966],
         [ 0.7     ],
         [-0.43721 ]], dtype=float32),
  Array([0.3794], dtype=float32)]]

Now we can train neural networks using jax and simple
stochastic gradient descent! Have fun!

(no flax or optax or anything else needed at this point!)

In [41]:
# even better: get both value and gradient at the same time (more efficient):
value_and_grad_cost=jit(value_and_grad(cost,argnums=0))

In [36]:
cost_value,grad_value = value_and_grad_cost(params,x,1.0)

In [37]:
cost_value

Array(0.804609, dtype=float32)

## Batch processing (using jax vmap)

Now also introduce batch processing! We now want shape of x to be (batchsize,num_input_neurons). We do not want to rewrite our network or cost function!

In [62]:
# vmap does the trick!
# It produces a new version of any function, but now with
# a batch index (vectorized processing)!

# in_axes says which index is the batch index, for each
# of the arguments of the function. 'params' (the first argument)
# does not have any batch index, therefore we write 'None':
batched_cost=vmap(cost,in_axes=(None,0,0))

In [63]:
# Apply this to a batch of three 'training samples':
x=jnp.array([[0.3,0.4],[0.9,0.8],[0.1,0.2]]) # batch size of 3, shape (3,2)
y_target=jnp.array([[1.0],[0.8],[0.7]]) # shape (3,1)

batched_cost(params,x,y_target)

Array([1.        , 0.64000005, 0.48999998], dtype=float32)

In [64]:
# want to average the cost over the batch:
def average_cost(params,x,y_target):
  return jnp.average(batched_cost(params,x,y_target))

In [65]:
average_cost(params,x,y_target)

Array(0.71000004, dtype=float32)

In [66]:
batched_value_and_grad_cost=jit(value_and_grad(average_cost))

In [67]:
%%timeit
batched_value_and_grad_cost(params,x,y_target)

18.2 µs ± 4.86 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
