In [1]:
import jax
import jax.numpy as jnp

## Array creation

In [None]:
x = jnp.arange(10)
print(x)

In [None]:
x

This code will run much faster in GPU

In [None]:
long_vector = jnp.arange(int(1e7))
%timeit jnp.dot(long_vector, long_vector).block_until_ready()

## Grad

In [None]:
def sum_of_squares(x):
  return jnp.sum(x**2)

In [None]:
sum_of_squares_dx = jax.grad(sum_of_squares)
print(type(sum_of_squares))
x = jnp.asarray([1.0, 2.0, 3.0, 4.0, 10.0])

print(sum_of_squares(x))
print(sum_of_squares_dx(x))

In [None]:
def sum_squared_error(x, y):
  return jnp.sum((x-y)**2)

sum_squared_error_dx = jax.grad(sum_squared_error)
y = jnp.asarray([1.1, 2.1, 3.1, 4.1, 10.1])

print(sum_squared_error_dx(x, y))

In [None]:
jax.grad(sum_squared_error, argnums=(0, 1))(x, y)  # Find gradient wrt both x & y

## Value and grad

Returns a tuple (value, grad)

In [None]:
jax.value_and_grad(sum_squared_error)(x, y)

## Auxiliary data

In [None]:
def squared_error_with_aux(x, y):
  return sum_squared_error(x, y), x-y

jax.grad(squared_error_with_aux)(x, y)
#Raises error

In [None]:
jax.grad(squared_error_with_aux, has_aux=True)(x, y)
#has_aux=True -> the function returns a pair (out,aux)


## Differences from NumPy

In [None]:
import numpy as np

In [None]:
x = np.array([1, 2, 3])

In [None]:
def in_place_modify(x):
  x[0] = 123
  return None

In [None]:
in_place_modify(x)
x

In [None]:
in_place_modify(jnp.array(x))  # Raises error when we cast input to jnp.ndarray

In [None]:
def jax_in_place_modify(x):
  return x.at[0].set(123)

In [None]:
y = jnp.array([1, 2, 3])
jax_in_place_modify(y)

In [None]:
y #The old array es untouched

## Your first JAX training loop

We will make a simple linear regression example

In [None]:
import matplotlib.pyplot as plt
import time

In [None]:
xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.2, size=(100,))
ys = xs *3 - 1 + noise

In [None]:
plt.scatter(xs, ys)

In [None]:
def model(theta, x):
  """Computes wx + b on a batch of input x."""
  w, b = theta
  return w * x + b

In [None]:
def loss_fn(theta, x, y):
  prediction = model(theta, x)
  return jnp.mean((prediction-y)**2)

We will add @jax.jit to train much faster (see next notebook to see how it works)

In [None]:
@jax.jit
def update(theta, x, y, lr=0.05):
  return theta - lr * jax.grad(loss_fn)(theta, x, y)

In [None]:
theta = jnp.array([1.,1.])

start_time = time.time()
for _ in range(1000): #1000 epochs
  theta = update(theta, xs, ys)
print("--- %s seconds ---" % (time.time() - start_time))

## More than 8s of dif using @jax.jit in update()

In [None]:
#the orignial data
plt.scatter(xs, ys)
#plot the line generated by the model
plt.plot(xs, model(theta, xs))

In [None]:
w, b = theta
print(f"w: {w:.2f}, b: {b:.2f}")