# A simple neural network and cost function gradient via jax

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/FlorianMarquardt/machine-learning-for-physicists/blob/master/2024/03_NN_CostFunctionGradientFromJax.ipynb)

Example code for the lecture series "Machine Learning for Physicists" by Florian Marquardt

Lecture 3

See https://machine-learning-for-physicists.org and the current course website linked there!

This notebook shows how to build a little network and evaluate the gradient of the cost function using jax, and how to apply one step of gradient descent.

MIT License.

In [18]:
import jax.numpy as jnp
from jax import grad

In [2]:
# a simple function, to show the principles of jax grad:
def f(x):
  return jnp.sum(x**2)

In [3]:
# use jax autodifferentiation to get gradient of f
# with respect to the input vector x
grad_f=grad(f)

In [4]:
# show what happens in a numerical example:
x=jnp.array([0.1,0.2,0.3])

print("value of f: ", f(x))
print("value of grad f:", grad_f(x))

value of f:  0.14000002
value of grad f: [0.2 0.4 0.6]


In [5]:
# now define a simple neural network
# (arbitrary number of layers, but relu activation at each layer)
def network(parameters,x):
  """
  Evaluate network.

  parameters=[[weights1,biases1],[weights2,biases2],...]
  x=input vector
  """
  for weights,biases in parameters:
    # weights has shape (neurons_lower_layer,neurons_upper_layer),
    # biases has shape (neurons_upper_layer,)
    z=jnp.dot(x,weights)+biases
    x=(z>0)*z # relu activation
  return x

In [6]:
# our network has structure 2 (input) -- 3 -- 1 (output)
weights1=jnp.array([[0.1,0.3,0.5],[-0.4,0.2,0.8]]) # shape (2,3)
biases1=jnp.array([0.1,-0.2,0.3]) # shape (3,)
weights2=jnp.array([[0.2],[0.7],[-0.5]]) # shape (3,1)
biases2=jnp.array([0.2]) # shape (1,)

params=[[weights1,biases1],[weights2,biases2]]

In [8]:
# apply network to test input
x=jnp.array([0.3,-0.5])
print(network(params,x))

[0.241]


In [9]:
# define a cost function (here: quadratic deviation)
def cost(params,x,y_target):
  return jnp.sum( ( network(params,x) - y_target )**2 )
# note: we would divide by the batch size jnp.shape(x)[0] if we want to average
# over a batch (but right now we do not do batches)

In [10]:
cost(params,x,1.0)

Array(0.576081, dtype=float32)

In [11]:
# now apply jax autodifferentiation to get the
# gradient of the cost function with respect to the params:
grad_cost=grad(cost,argnums=0) # argnums=0 means first argument, i.e. params

In [13]:
params

[[Array([[ 0.1,  0.3,  0.5],
         [-0.4,  0.2,  0.8]], dtype=float32),
  Array([ 0.1, -0.2,  0.3], dtype=float32)],
 [Array([[ 0.2],
         [ 0.7],
         [-0.5]], dtype=float32),
  Array([0.2], dtype=float32)]]

In [12]:
# calculate the gradient of the cost function,
# at the current values of the parameters 'params',
# and for that given x input vector (and with y_target==1.0):
grad_cost(params,x,1.0)

[[Array([[-0.09108001, -0.        ,  0.22770001],
         [ 0.1518    ,  0.        , -0.3795    ]], dtype=float32),
  Array([-0.3036, -0.    ,  0.759 ], dtype=float32)],
 [Array([[-0.50094   ],
         [-0.        ],
         [-0.07590002]], dtype=float32),
  Array([-1.518], dtype=float32)]]

This output is just the same shape as params!
Each component is the gradient of the cost function with respect to that component!

In [14]:
# store this whole nested list:
grad_value = grad_cost(params,x,1.0)

In [15]:
# now update all parameters according to the negative gradient!
learning_rate=0.1

new_params=[(weights-learning_rate*grad_weights,biases-learning_rate*grad_biases) for
            (weights,biases),(grad_weights,grad_biases) in zip(params,grad_value)]

In [16]:
# print new parameters:
new_params

[(Array([[ 0.109108,  0.3     ,  0.47723 ],
         [-0.41518 ,  0.2     ,  0.83795 ]], dtype=float32),
  Array([ 0.13036   , -0.2       ,  0.22410001], dtype=float32)),
 (Array([[ 0.250094],
         [ 0.7     ],
         [-0.49241 ]], dtype=float32),
  Array([0.35180002], dtype=float32))]

In [17]:
# cost should have gone down!
cost(new_params,x,1.0)

Array(0.3085742, dtype=float32)

In [19]:
# jax makes such things more convenient:
from jax.tree_util import tree_map

new_params = tree_map(lambda x,y: x - learning_rate * y, params, grad_value)

In [20]:
# should be the same result as above:
new_params

[[Array([[ 0.109108,  0.3     ,  0.47723 ],
         [-0.41518 ,  0.2     ,  0.83795 ]], dtype=float32),
  Array([ 0.13036   , -0.2       ,  0.22410001], dtype=float32)],
 [Array([[ 0.250094],
         [ 0.7     ],
         [-0.49241 ]], dtype=float32),
  Array([0.35180002], dtype=float32)]]

Now we can train neural networks using jax and simple
stochastic gradient descent! Have fun!

(no flax or optax or anything else needed at this point!)