In [1]:
!pip install -U jax jaxlib # install required libraries

Collecting jax
  Downloading jax-0.4.30-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
Collecting jaxlib
  Downloading jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl (79.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.26+cuda12.cudnn89
    Uninstalling jaxlib-0.4.26+cuda12.cudnn89:
      Successfully uninstalled jaxlib-0.4.26+cuda12.cudnn89
  Attempting uninstall: jax
    Found existing installation: jax 0.4.26
    Uninstalling jax-0.4.26:
      Successfully uninstalled jax-0.4.26
Successfully installed jax-0.4.30 jaxlib-0.4.30


In [32]:
#import libraries
import jax
import jax.numpy as jnp
from jax import grad
import numpy as np
import tqdm

Let us consider the function $f(x) = 2x^2 + 3x - 4$. If we differentiate this function, we get $f^\prime(x) = \frac{df}{dx} = 4x + 3$.

In [3]:
# implement f from above
def f(x):
  return 2 * x * x + 3 * x + 4

# implement derivative
def f_prime_analytical(x):
  return 4 * x + 3

# we can also implement a function that computes the derivative numerically
def f_prime_numerical(x, eps=0.0000001):
  diff = f(x + eps) - f(x)
  denom = eps
  return diff / denom

#using Jax, generate a function that computes the derivative of f instead
f_prime_automatic = grad(f)

In [5]:
print(f_prime_analytical(10.0))
print(f_prime_numerical(10.0))
print(f_prime_automatic(10.0))

43.0
43.000000005122274
43.0


As we can see from above, all three methods achieve more or less the same result, but using Jax to automatically generate a function to compute the derivative was the easiest approach. This makes it easy to implement gradient descent

In [34]:
learning_rate = 0.01 # set learning rate
num_iterations = 1000 # set maximum number of iterations
current_guess = np.random.random() # set initial guess
print('Initial guess ', current_guess)
for i in tqdm.tqdm(range(num_iterations)): #tqdm is used to help us monitor progress
  current_gradient = f_prime_automatic(current_guess)
  update = learning_rate * current_gradient
  current_guess = current_guess - update

print('')
print('Minimum at ', current_guess)

Initial guess  0.5967380372999491


100%|██████████| 1000/1000 [00:05<00:00, 172.54it/s]


-0.7499993





We can also easily derive the gradients for multivariate functions. To illustrate, let us consider the function $f(x, y) = (x - 2)^2 + (y + 3)^2$. This function has a minimum at $\left( 2, -3 \right)$. Let us implement gradient descent to find this minimum

In [28]:
def f(x, y):
  return (x - 2) ** 2 + (y + 3) ** 2

f_grad_x = grad(f, 0) # derivative wrt to the first argument, i.e. x. Remember we count from 0!
f_grad_y = grad(f, 1) # derivative wrt to the second argument, i.e. y
print(f_grad_x(2.0, -3.0))
print(f_grad_y(2.0, -3.0))

0.0
0.0


In [36]:
learning_rate = 0.1 # set learning rate
num_iterations = 1000 # set maximum number of iterations
x_guess = np.random.random() # set initial guess for x
y_guess = np.random.random() # set initial guess for y
print('Initial guess ', (x_guess, y_guess))
for i in tqdm.tqdm(range(num_iterations)):
  x_grad, y_grad = f_grad_x(x_guess, y_guess), f_grad_y(x_guess, y_guess)
  x_guess = x_guess - learning_rate * x_grad
  y_guess = y_guess - learning_rate * y_grad

print()
print(f'Minimum at {(float(x_guess), float(y_guess))}')

Initial guess  (0.8675157920045229, 0.4547974086534248)


100%|██████████| 1000/1000 [00:08<00:00, 124.27it/s]


Minimum at (1.999999761581421, -2.999999523162842)





As an exercise, use gradient descent to minimize the following function:

$$
f(x, y) = 0.26(x^2 + y^2) - 0.48xy
$$

2.6666666666666665