<a href="https://colab.research.google.com/github/RandomAnass/Data-Analysis-Course/blob/main/JAX_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [None]:
!pip install --upgrade jax

Collecting jax
  Downloading jax-0.7.1-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.7.1,>=0.7.1 (from jax)
  Downloading jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.whl.metadata (1.3 kB)
Downloading jax-0.7.1-py3-none-any.whl (2.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.8 MB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.whl (81.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.2/81.2 MB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.5.3
    Uninstalling jaxlib-0.5.3:
      Successfully uninstalled jaxlib-0.5.3
  Attempting uninstall: jax
    Found existing installation: jax 0.5.3
    Uninstalling jax-0.5.3:
      Successfully uninstalled jax-0.5.3


# JAX 1. Numpy Wrapper

In [None]:
import numpy as np

x = np.ones((5000, 5000))
y = np.arange(5000)

%timeit z = np.sin(x) + np.cos(y)

1 loop, best of 3: 401 ms per loop


In [None]:
import jax.numpy as jnp
x = jnp.ones((5000, 5000))
y = jnp.arange(5000)

%timeit z = jnp.sin(x) + jnp.cos(y)

100 loops, best of 3: 2.15 ms per loop


# JAX 2. JIT Compiler

In [None]:
from jax import jit
import tensorflow as tf

def fn(x, y):
  z = np.sin(x)
  w = np.cos(y)
  return z + w

@jit
def fn_jit(x, y):
  z = jnp.sin(x)
  w = jnp.cos(y)
  return z + w

@tf.function
def fn_tf2(x, y):
  z = tf.sin(x)
  w = tf.cos(y)
  return z + w

In [None]:
x = np.ones((5000, 5000))
y = np.ones((5000, 5000))
%timeit fn(x, y)

1 loop, best of 3: 780 ms per loop


In [None]:
jx = jnp.ones((5000, 5000))
jy = jnp.ones((5000, 5000))
%timeit fn_jit(jx, jy)

100 loops, best of 3: 2.12 ms per loop


In [None]:
tx = tf.ones((5000, 5000))
ty = tf.ones((5000, 5000))
%timeit fn_tf2(tx, ty)

The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 3.36 ms per loop


# JAX 3. grad

In [None]:
from jax import grad

@jit
def simple_fun(x):
  return jnp.sin(x) / x

In [None]:
grad_simple_fun = grad(simple_fun)

In [None]:
%timeit grad_simple_fun(1.0)

1000 loops, best of 3: 1.22 ms per loop


In [None]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_simple_fun(xi) for xi in x_range]

[DeviceArray(nan, dtype=float32),
 DeviceArray(-0.30116874, dtype=float32),
 DeviceArray(-0.43539774, dtype=float32),
 DeviceArray(-0.3456775, dtype=float32),
 DeviceArray(-0.11611074, dtype=float32),
 DeviceArray(0.09508941, dtype=float32),
 DeviceArray(0.16778992, dtype=float32),
 DeviceArray(0.09429243, dtype=float32),
 DeviceArray(-0.03364623, dtype=float32),
 DeviceArray(-0.10632458, dtype=float32)]

In [None]:
grad_grad_simple_fun = grad(grad(simple_fun))

In [None]:
%timeit grad_grad_simple_fun(1.0)

The slowest run took 93.35 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 3.19 ms per loop


In [None]:
grad_grad_simple_fun(1.0)

DeviceArray(-0.23913354, dtype=float32)

In [None]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_grad_simple_fun(xi) for xi in x_range]

[DeviceArray(nan, dtype=float32),
 DeviceArray(-0.23913354, dtype=float32),
 DeviceArray(-0.01925094, dtype=float32),
 DeviceArray(0.18341166, dtype=float32),
 DeviceArray(0.247256, dtype=float32),
 DeviceArray(0.1537491, dtype=float32),
 DeviceArray(-0.00936072, dtype=float32),
 DeviceArray(-0.12079593, dtype=float32),
 DeviceArray(-0.11525822, dtype=float32),
 DeviceArray(-0.02216326, dtype=float32)]

# Intro


hands-on coverage of: JAX basics (jit/grad/vmap), pytrees, PRNG best practices; autodiff theory & APIs (jvp/vjp, jacfwd/jacrev, custom rules); control flow; LAX primitives; performance (tracing, static args, donation, remat, timing); vectorization & parallelism (vmap, pmap, intro to pjit & sharding); numerics & mixed precision; end-to-end models (MLP, CNN, RNN with scan, compact Transformer block); optimization with Optax; saving/loading; debugging/profiling; interop