In [1]:
import jax, jax.numpy as jnp, numpy as np
jax.config.update("jax_enable_x64", True)

key = jax.random.PRNGKey(seed = 0)

## My JAX Test

### Basic Operations

In [2]:
A = jax.random.uniform(key, (1000, 1000))
B = jax.random.uniform(key, (1000, 1000))

In [3]:
C = jnp.einsum('ij, jk -> ik', A, B)

Asynchronous dispatch

In [4]:
%timeit A @ B

5.15 ms ± 849 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
%timeit (A @ B).block_until_ready()

5.43 ms ± 34.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


`jax.lax.scan`

In [6]:
def scan_fn(x, _):
    return x + 1, x + 1

res, traj = jax.lax.scan(
    lambda x, _: (x+1, x+1), jnp.zeros(2), None, length = 200
)
traj.shape

(200, 2)

### Automatic Differentiation

In [19]:
def mse_loss(x, x_hat):
    return jnp.mean((x - x_hat)**2)

mse_grad_fn = jax.grad(mse_loss, argnums = (0, 1))
x, x_hat = jnp.array([1.0, ]), jnp.array([2.0, ])
mse = mse_loss(x, x_hat)
mse_grad = mse_grad_fn(x, x_hat)
mse_grad, x, x_hat

((Array([-2.], dtype=float64), Array([2.], dtype=float64)),
 Array([1.], dtype=float64),
 Array([2.], dtype=float64))

In [20]:
jax.value_and_grad(mse_loss)(x, x_hat)

(Array(1., dtype=float64), Array([-2.], dtype=float64))