In [1]:
%tensorflow_version 2.x

Colab only includes TensorFlow 2.x; %tensorflow_version has no effect.


# JAX 1. Numpy Wrapper

In [2]:
import numpy as np

x = np.ones((5000, 5000))
y = np.arange(5000)
print(x)
print(y)

%timeit z = np.sin(x) + np.cos(y)

[[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 ...
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]]
[   0    1    2 ... 4997 4998 4999]
The slowest run took 4.32 times longer than the fastest. This could mean that an intermediate result is being cached.
654 ms ± 365 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
import jax.numpy as jnp
x = jnp.ones((5000, 5000))
y = jnp.arange(5000)

%timeit z = jnp.sin(x) + jnp.cos(y)

171 µs ± 76.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# JAX 2. JIT Compiler

In [4]:
from jax import jit
import tensorflow as tf

def fn(x, y):
  z = np.sin(x)
  w = np.cos(y)
  return z + w

@jit
def fn_jit(x, y):
  z = jnp.sin(x)
  w = jnp.cos(y)
  return z + w

@tf.function
def fn_tf2(x, y):
  z = tf.sin(x)
  w = tf.cos(y)
  return z + w

In [5]:
x = np.ones((5000, 5000))
y = np.ones((5000, 5000))
%timeit fn(x, y)

651 ms ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
jx = jnp.ones((5000, 5000))
jy = jnp.ones((5000, 5000))
%timeit fn_jit(jx, jy)

1.23 ms ± 327 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [7]:
tx = tf.ones((5000, 5000))
ty = tf.ones((5000, 5000))
%timeit fn_tf2(tx, ty)

2.84 ms ± 1.04 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# JAX 3. grad

In [8]:
from jax import grad

@jit
def simple_fun(x):
  return jnp.sin(x) / x

In [9]:
grad_simple_fun = grad(simple_fun)

In [10]:
%timeit grad_simple_fun(1.0)

1.95 ms ± 233 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [11]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_simple_fun(xi) for xi in x_range]

[Array(nan, dtype=float32),
 Array(-0.30116874, dtype=float32),
 Array(-0.43539777, dtype=float32),
 Array(-0.3456775, dtype=float32),
 Array(-0.11611076, dtype=float32),
 Array(0.09508941, dtype=float32),
 Array(0.16778992, dtype=float32),
 Array(0.09429243, dtype=float32),
 Array(-0.03364623, dtype=float32),
 Array(-0.10632458, dtype=float32)]

In [12]:
grad_grad_simple_fun = grad(grad(simple_fun))

In [13]:
%timeit grad_grad_simple_fun(1.0)

4.35 ms ± 235 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
grad_grad_simple_fun(1.0)

Array(-0.23913354, dtype=float32, weak_type=True)

In [15]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_grad_simple_fun(xi) for xi in x_range]

[Array(nan, dtype=float32),
 Array(-0.23913354, dtype=float32),
 Array(-0.01925096, dtype=float32),
 Array(0.18341166, dtype=float32),
 Array(0.247256, dtype=float32),
 Array(0.1537491, dtype=float32),
 Array(-0.00936072, dtype=float32),
 Array(-0.12079593, dtype=float32),
 Array(-0.11525822, dtype=float32),
 Array(-0.02216326, dtype=float32)]