In [1]:
from functools import partial
import jax
import jax.numpy as jnp

In [2]:
!lscpu |grep 'Model name'

Model name:                      Intel(R) Xeon(R) CPU @ 2.20GHz


In [3]:
@jax.jit
def f(x):
    return 2 * x * x * jnp.cos(x) - 5 * x

In [4]:
@jax.jit
def df(x):
    return 4 * x * jnp.cos(x) - 2 * x * x * jnp.sin(x) - 5

In [5]:
@partial(jax.jit, static_argnames=['alpha', 'epochs'])
def gradient_descent(x, y, alpha: float, epochs: int):
    for i in range(epochs):
        x = jnp.concatenate([x, jnp.array([x[i] - alpha * df(x[i])])])
        y = jnp.concatenate([y, jnp.array([f(x[i + 1])])])
    return x, y

In [6]:
curve = jnp.linspace(-5, 5, 100)
alpha = 0.05
epochs = 20
x = jnp.array([-1.])
y = jnp.array([f(x[0])])
%timeit gradient_descent(x=x, y=y, alpha=alpha, epochs=epochs)



The slowest run took 6.52 times longer than the fastest. This could mean that an intermediate result is being cached.
14.4 µs ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
