<a href="https://colab.research.google.com/github/RandomAnass/Data-Analysis-Course/blob/main/JAX_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [1]:
!pip install --upgrade jax

Collecting jax
  Downloading jax-0.7.1-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.7.1,>=0.7.1 (from jax)
  Downloading jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.whl.metadata (1.3 kB)
Downloading jax-0.7.1-py3-none-any.whl (2.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.8/2.8 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.whl (81.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.2/81.2 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.5.3
    Uninstalling jaxlib-0.5.3:
      Successfully uninstalled jaxlib-0.5.3
  Attempting uninstall: jax
    Found existing installation: jax 0.5.3
    Uninstalling jax-0.5.3:
      Successfully uninstalled jax-0.5.3
Successfully installed jax-0.7.1 jaxlib-0.7.1


In [4]:
!pip uninstall -y jax_cuda12_plugin

Found existing installation: jax-cuda12-plugin 0.5.3
Uninstalling jax-cuda12-plugin-0.5.3:
  Successfully uninstalled jax-cuda12-plugin-0.5.3


# JAX 1. Numpy Wrapper

In [5]:
import numpy as np

x = np.ones((5000, 5000))
y = np.arange(5000)

%timeit z = np.sin(x) + np.cos(y)

497 ms ± 199 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
import jax.numpy as jnp
x = jnp.ones((5000, 5000))
y = jnp.arange(5000)

%timeit z = jnp.sin(x) + jnp.cos(y)

The slowest run took 11.54 times longer than the fastest. This could mean that an intermediate result is being cached.
94.4 µs ± 135 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# JAX 2. JIT Compiler

In [7]:
from jax import jit
import tensorflow as tf

def fn(x, y):
  z = np.sin(x)
  w = np.cos(y)
  return z + w

@jit
def fn_jit(x, y):
  z = jnp.sin(x)
  w = jnp.cos(y)
  return z + w

@tf.function
def fn_tf2(x, y):
  z = tf.sin(x)
  w = tf.cos(y)
  return z + w

In [8]:
x = np.ones((5000, 5000))
y = np.ones((5000, 5000))
%timeit fn(x, y)

644 ms ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
jx = jnp.ones((5000, 5000))
jy = jnp.ones((5000, 5000))
%timeit fn_jit(jx, jy)

287 ms ± 17.5 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
tx = tf.ones((5000, 5000))
ty = tf.ones((5000, 5000))
%timeit fn_tf2(tx, ty)

2.83 ms ± 3.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# JAX 3. grad

In [None]:
from jax import grad

@jit
def simple_fun(x):
  return jnp.sin(x) / x

In [None]:
grad_simple_fun = grad(simple_fun)

In [None]:
%timeit grad_simple_fun(1.0)

1000 loops, best of 3: 1.22 ms per loop


In [None]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_simple_fun(xi) for xi in x_range]

[DeviceArray(nan, dtype=float32),
 DeviceArray(-0.30116874, dtype=float32),
 DeviceArray(-0.43539774, dtype=float32),
 DeviceArray(-0.3456775, dtype=float32),
 DeviceArray(-0.11611074, dtype=float32),
 DeviceArray(0.09508941, dtype=float32),
 DeviceArray(0.16778992, dtype=float32),
 DeviceArray(0.09429243, dtype=float32),
 DeviceArray(-0.03364623, dtype=float32),
 DeviceArray(-0.10632458, dtype=float32)]

In [None]:
grad_grad_simple_fun = grad(grad(simple_fun))

In [None]:
%timeit grad_grad_simple_fun(1.0)

The slowest run took 93.35 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 3.19 ms per loop


In [None]:
grad_grad_simple_fun(1.0)

DeviceArray(-0.23913354, dtype=float32)

In [None]:
x_range = jnp.arange(10, dtype=jnp.float32)
[grad_grad_simple_fun(xi) for xi in x_range]

[DeviceArray(nan, dtype=float32),
 DeviceArray(-0.23913354, dtype=float32),
 DeviceArray(-0.01925094, dtype=float32),
 DeviceArray(0.18341166, dtype=float32),
 DeviceArray(0.247256, dtype=float32),
 DeviceArray(0.1537491, dtype=float32),
 DeviceArray(-0.00936072, dtype=float32),
 DeviceArray(-0.12079593, dtype=float32),
 DeviceArray(-0.11525822, dtype=float32),
 DeviceArray(-0.02216326, dtype=float32)]

# Intro


hands-on coverage of: JAX basics (jit/grad/vmap), pytrees, PRNG best practices; autodiff theory & APIs (jvp/vjp, jacfwd/jacrev, custom rules); control flow; LAX primitives; performance (tracing, static args, donation, remat, timing); vectorization & parallelism (vmap, pmap, intro to pjit & sharding); numerics & mixed precision; end-to-end models (MLP, CNN, RNN with scan, compact Transformer block); optimization with Optax; saving/loading; debugging/profiling; interop

In [6]:
#!pip uninstall -y jax_cuda12_plugin
!pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
#!pip uninstall jaxlib -y
#!pip install "jax[cuda12]==0.7.1" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax-cuda12-plugin<=0.7.1,>=0.7.1 (from jax-cuda12-plugin[with-cuda]<=0.7.1,>=0.7.1; extra == "cuda12"->jax[cuda12])
  Downloading jax_cuda12_plugin-0.7.1-cp312-cp312-manylinux_2_27_x86_64.whl.metadata (2.0 kB)
Collecting jax-cuda12-pjrt==0.7.1 (from jax-cuda12-plugin<=0.7.1,>=0.7.1->jax-cuda12-plugin[with-cuda]<=0.7.1,>=0.7.1; extra == "cuda12"->jax[cuda12])
  Downloading jax_cuda12_pjrt-0.7.1-py3-none-manylinux_2_27_x86_64.whl.metadata (579 bytes)
Collecting nvidia-cuda-nvcc-cu12>=12.6.85 (from jax-cuda12-plugin[with-cuda]<=0.7.1,>=0.7.1; extra == "cuda12"->jax[cuda12])
  Downloading nvidia_cuda_nvcc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-nvshmem-cu12>=3.2.5 (from jax-cuda12-plugin[with-cuda]<=0.7.1,>=0.7.1; extra == "cuda12"->jax[cuda12])
  Downloading nvidia_nvshmem_cu12-3.3.24-py3-none-manylinux2014_x86_64.manylinux_2_17_x

In [11]:
%pip install -U optax flax orbax-checkpoint matplotlib

Collecting flax
  Downloading flax-0.11.2-py3-none-any.whl.metadata (11 kB)
Collecting orbax-checkpoint
  Downloading orbax_checkpoint-0.11.24-py3-none-any.whl.metadata (2.3 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Downloading flax-0.11.2-py3-none-any.whl (458 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m458.1/458.1 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading orbax_checkpoint-0.11.24-py3-none-any.whl (529 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m529.3/529.3 kB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading matplotlib-3.10.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.7/8.7 MB[0m [31m111.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: matplotlib, orbax-checkpoint, flax
  Attempting uninstall: matplotlib
  

In [1]:
import os, sys, math, time, functools, itertools, pickle
from dataclasses import dataclass
from typing import Any, Tuple, Dict

import jax
import jax.numpy as jnp
from jax import jit, grad, value_and_grad, random, lax, vmap, jacrev, jacfwd
import numpy as np

In [2]:
print('JAX version:', jax.__version__)
print('Backend:', jax.default_backend())
print('Devices:', jax.devices())

JAX version: 0.7.1
Backend: gpu
Devices: [CudaDevice(id=0)]


In [3]:
key = random.PRNGKey(0)

# JAX basics

(arrays, no in-place updates, jit, grad, vmap)
Key ideas: functional transforms (jit, grad, vmap) act on functions; no in-place mutation (use .at[...]); first jit call compiles; timings should block with .block_until_ready().

In [8]:
x = jnp.ones((5, 5))
y = jnp.arange(25).reshape(5,5)

In [10]:
print(y)
print(x+y)

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]
[[ 1.  2.  3.  4.  5.]
 [ 6.  7.  8.  9. 10.]
 [11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20.]
 [21. 22. 23. 24. 25.]]


In [None]:
# Arrays & broadcasting
x = jnp.arange(6).reshape(2, 3)
y = jnp.ones((3,))
print('x:\n', x)
print('x + y:\n', x + y)  # broadcasting

# No in-place mutation -> use .at
z = x.at[:, 1].set(42)
print('z (col 1 set to 42):\n', z)

# JIT compilation: first call compiles, subsequent calls are fast
@jit
def poly(a, b, x):
    return a * x**2 + b * x + 1.0

xv = jnp.linspace(-3, 3, 100_000)
a, b = 2.0, -1.0

t0 = time.time()
out = poly(a, b, xv).block_until_ready()  # compile + run
t1 = time.time()
out2 = poly(a, b, xv).block_until_ready() # run-only
t2 = time.time()
print(f'First call (compile+run): {(t1 - t0)*1e3:.2f} ms')
print(f'Second call (run only):   {(t2 - t1)*1e3:.2f} ms')

# Autodiff with grad
def loss_fn(w, x, y):
    pred = w[0] * x + w[1]
    return jnp.mean((pred - y)**2)

w = jnp.array([2.0, -1.0])
x = jnp.linspace(-1, 1, 1000)
y = 3.0 * x + 0.5  # true slope=3, intercept=0.5

dloss_dw = grad(loss_fn)(w, x, y)
print('Gradients:', dloss_dw)

# vmap to batch without Python loops
def square_plus_one(t):
    return t*t + 1

batched_square = vmap(square_plus_one)  # vectorize over leading axis
print('vmap:', batched_square(jnp.arange(5.)))


# 2) Pytrees (structuring parameters)

Pytree = nested structure of leaves (arrays/scalars). Many JAX APIs accept pytrees directly.

In [None]:
from jax import tree_util

params = {'W': jnp.ones((3, 3)), 'b': jnp.zeros((3,))}
print('Leaves:', tree_util.tree_leaves(params))

# Map over leaves
scaled = jax.tree_map(lambda x: 0.5 * x, params)
print('Scaled b:', scaled['b'])

# dataclass as pytree
@dataclass
class MLPParams:
    W1: jnp.ndarray
    b1: jnp.ndarray
    W2: jnp.ndarray
    b2: jnp.ndarray

p = MLPParams(jnp.ones((4, 8)), jnp.zeros((8,)), jnp.ones((8, 2)), jnp.zeros((2,)))
print('Num leaves:', len(tree_util.tree_leaves(p)))


# 3) PRNG best practices (functional randomness)

Split keys; treat keys as inputs/outputs; use fold_in(step) for reproducible per-step seeding.

In [None]:
key = random.PRNGKey(42)
key, k1, k2 = random.split(key, 3)
print('Normal:', random.normal(k1, (3,)))
print('Uniform:', random.uniform(k2, (3,)))

# fold_in for deterministic per-step randomness
base = random.PRNGKey(123)
for step in range(3):
    k = random.fold_in(base, step)
    print('step', step, '->', random.normal(k, ()).item())


# 4) Autodiff theory + APIs (jvp/vjp, jacfwd/jacrev, custom_vjp/jvp)

Forward-mode AD (jvp, jacfwd) is efficient when input dim ≤ output dim.

Reverse-mode AD (grad, vjp, jacrev) is efficient when output dim ≤ input dim.

You can nest them for higher-order derivatives.

Custom rules (custom_vjp / custom_jvp) = numerical stability or non-standard ops.

In [None]:
# jvp/vjp example
def f(u):
    return jnp.array([u[0]*jnp.sin(u[1]), u[0]*jnp.cos(u[1])])  # R^2 -> R^2

u0 = jnp.array([2.0, 0.3])
tangent = jnp.array([1.0, -0.2])

y, jvp_val = jax.jvp(f, (u0,), (tangent,))
print('f(u0)=', y)
print('JVP(f; u0, tangent)=', jvp_val)

y, pullback = jax.vjp(f, u0)
bar = jnp.array([1.0, 0.0])  # seed cotangent
vjp_val = pullback(bar)[0]
print('VJP(f; u0, bar)=', vjp_val)

# Jacobians
def g(u):
    return jnp.array([u[0]**2 + u[1], jnp.sin(u[0]*u[1])])

J_fwd = jacfwd(g)(u0)
J_rev = jacrev(g)(u0)
print('Jacobian (fwd):\n', J_fwd)
print('Jacobian (rev):\n', J_rev)

# custom_vjp for a stable log-sum-exp
@jax.custom_vjp
def stable_logsumexp(x):
    m = jnp.max(x)
    return m + jnp.log(jnp.sum(jnp.exp(x - m)))

def slse_fwd(x):
    y = stable_logsumexp(x)
    return y, (x, y)

def slse_bwd(res, g):
    x, y = res
    p = jnp.exp(x - y)  # softmax probabilities
    return (g * p,)

stable_logsumexp.defvjp(slse_fwd, slse_bwd)

xx = jnp.array([1000.0, 999.0, 998.0])
print('stable LSE:', stable_logsumexp(xx))
print('grad check:', grad(lambda t: stable_logsumexp(t))(xx))


# 5) Control flow (lax.cond, lax.scan, while_loop)

In [None]:
# cond
def piecewise(x):
    return lax.cond(x > 0, lambda t: t*t, lambda t: -t, x)

print(piecewise(2.), piecewise(-3.))

# scan: sum over a sequence (with state "carry")
def scan_body(carry, x):
    s = carry + x
    return s, s  # (new_carry, stacked_output)

xs = jnp.arange(5.0)
carry0 = 0.0
final_carry, s_hist = lax.scan(scan_body, carry0, xs)
print('final sum:', final_carry)
print('s_hist:', s_hist)

# while_loop: simple countdown
def cond_fun(state):
    i, acc = state
    return i > 0

def body_fun(state):
    i, acc = state
    return (i-1, acc + i)

init = (5, 0)
i_end, acc = lax.while_loop(cond_fun, body_fun, init)
print('while result:', i_end, acc)


# 6) LAX primitives, shape polymorphism, eval_shape

In [None]:
# eval_shape to probe shapes
def my_net(params, x):
    for W, b in params:
        x = jnp.tanh(x @ W + b)
    return x

params = [(random.normal(random.PRNGKey(0), (32, 64)), jnp.zeros((64,))),
          (random.normal(random.PRNGKey(1), (64, 10)), jnp.zeros((10,)))]
sig = jax.eval_shape(my_net, params, jax.ShapeDtypeStruct((8, 32), jnp.float32))
print(sig)

# A tiny LAX convolution (N,H,W,C) * (KH,KW,C,OC)
N,H,W,C = 1, 8, 8, 1
x = random.normal(random.PRNGKey(0), (N,H,W,C))
Wk = random.normal(random.PRNGKey(1), (3,3,C,4))
y = lax.conv_general_dilated(x, Wk, window_strides=(1,1), padding='SAME')
print('conv out shape:', y.shape)


# 7) Performance: tracing, static args, donation, timing, lower/HLO

Keep shapes/dtypes steady to avoid retracing.

Mark Python values as static when they change control-flow/structure.

Use buffer donation to reduce copies.

Always time with .block_until_ready().

Peek at compiler IR via lower(...).as_text().

In [None]:
# Static args (e.g., a small Python "mode" flag)
@jit
def reg_loss(w, x, y, penalty: str = 'l2', lam: float = 1e-2):
    logits = x @ w
    nll = jnp.mean(jnp.logaddexp(0.0, -y * logits))
    if penalty == 'l2':
        reg = 0.5 * lam * jnp.sum(w * w)
    elif penalty == 'l1':
        reg = lam * jnp.sum(jnp.abs(w))
    else:
        raise ValueError('unknown penalty')
    return nll + reg

w = jnp.zeros((10,))
x = random.normal(random.PRNGKey(0), (128, 10))
y = jnp.sign(random.normal(random.PRNGKey(1), (128,)))
print(reg_loss(w, x, y, 'l2', 1e-2))

# Donation example (donate_argnums=0 donates first arg "arr")
@jit(donate_argnums=(0,))
def in_placey_add(arr, val):
    return arr + val

arr = jnp.ones((1_000_000,))
t0 = time.time(); arr = in_placey_add(arr, 1).block_until_ready(); t1 = time.time()
print('donated add ms:', (t1-t0)*1e3)

# Inspect lowered IR/HLO (API may vary slightly across versions)
def foo_to_lower(a, b):
    return a @ b

print(
    jit(foo_to_lower).lower(jnp.ones((64,64), jnp.float32), jnp.ones((64,64), jnp.float32)).as_text()[:500]
)


#Vectorization & parallelism (vmap, pmap (legacy), intro to pjit & sharding)

vmap: auto-batch on one device.

pmap: SPMD across devices (still useful, esp. TPUs).

pjit + explicit sharding with Mesh/PartitionSpec (recommended for new multi-device work).

In [None]:
# vmap: pairwise dot without Python loops
def dot(a, b):  # [..., d] x [..., d] -> [...]
    return jnp.sum(a * b, axis=-1)

A = random.normal(random.PRNGKey(0), (64, 128))
B = random.normal(random.PRNGKey(1), (64, 128))
batched_dot = vmap(dot)
print('batched dot shape:', batched_dot(A, B).shape)

# pmap demo (will degenerate to single-device behavior if you only have 1 device)
@jax.pmap
def add_one(x):
    return x + 1

xs = jnp.arange(jax.device_count())
print('device_count =', jax.device_count())
print('pmap result:', add_one(xs))

# pjit-style sharding (pattern only; single-device becomes no-op)
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

devices = np.array(jax.devices())
mesh = Mesh(devices[:max(1, min(2, devices.size))], ('data',))
print('Mesh axes:', mesh.axis_names)


# 9) Numerics & mixed precision (stable softmax, gradient clipping)

Use numerically stable forms; consider mixed precision (fp16/bfloat16 for matmuls/conv, fp32 for accumulators).

In [None]:
def softmax(x, axis=-1):
    z = x - jnp.max(x, axis=axis, keepdims=True)
    num = jnp.exp(z)
    return num / jnp.sum(num, axis=axis, keepdims=True)

x = jnp.array([[1000., 0., -1000.]])
print('stable softmax:', softmax(x))

# Simple global-norm clip util
def clip_by_global_norm(tree, max_norm=1.0):
    gsq = sum([jnp.sum(jnp.square(g)) for g in jax.tree_util.tree_leaves(tree)])
    scale = jnp.minimum(1.0, max_norm / (jnp.sqrt(gsq) + 1e-8))
    return jax.tree_map(lambda g: g * scale, tree)


# 10) End-to-end MLP (from scratch)

Synthetic 2-class data, MLP, training step with jit.

In [None]:
import matplotlib.pyplot as plt

def make_toy_data(k, n_per_class=512, spread=0.7, seed=0):
    key = random.PRNGKey(seed)
    key1, key2 = random.split(key)
    c0 = random.normal(key1, (n_per_class, 2)) * spread + jnp.array([0., -k])
    c1 = random.normal(key2, (n_per_class, 2)) * spread + jnp.array([0., k])
    X = jnp.concatenate([c0, c1], axis=0)
    y = jnp.concatenate([-jnp.ones((n_per_class,)), jnp.ones((n_per_class,))], axis=0)
    return X, y

X, y = make_toy_data(2.5)

def init_mlp(key, sizes):
    keys = random.split(key, len(sizes)-1)
    params = []
    for k, (m, n) in zip(keys, zip(sizes[:-1], sizes[1:])):
        W = random.normal(k, (m, n)) / jnp.sqrt(m)
        b = jnp.zeros((n,))
        params.append((W, b))
    return params

def mlp_apply(params, x):
    for (W, b) in params[:-1]:
        x = jnp.tanh(x @ W + b)
    W, b = params[-1]
    return x @ W + b  # logits

def binary_loss(params, x, y):
    logits = mlp_apply(params, x).squeeze(-1)
    return jnp.mean(jnp.logaddexp(0.0, -y * logits))  # logistic loss

@jit
def train_step(params, x, y, lr=1e-2):
    loss, grads = value_and_grad(binary_loss)(params, x, y)
    # grads = clip_by_global_norm(grads, 1.0)  # optional
    params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return params, loss

key = random.PRNGKey(0)
params = init_mlp(key, [2, 64, 64, 1])

losses = []
for step in range(400):
    params, l = train_step(params, X, y)
    losses.append(float(l))

plt.figure()
plt.plot(losses)
plt.title('MLP Training Loss')
plt.xlabel('step'); plt.ylabel('loss')
plt.show()


# 11) CNN (from scratch, synthetic images)

In [None]:
# Fake images (N, H, W, C)
N, H, W, C = 64, 28, 28, 1
key = random.PRNGKey(1)
imgs = random.normal(key, (N, H, W, C))
labels = jnp.where(jnp.mean(imgs, axis=(1,2,3)) > 0, 1., -1.)  # separable synthetic target

def init_cnn(key):
    k1, k2, k3 = random.split(key, 3)
    W1 = random.normal(k1, (3, 3, C, 8)) / jnp.sqrt(3*3*C)
    b1 = jnp.zeros((8,))
    W2 = random.normal(k2, (3, 3, 8, 16)) / jnp.sqrt(3*3*8)
    b2 = jnp.zeros((16,))
    W3 = random.normal(k3, (16*7*7, 1)) / jnp.sqrt(16*7*7)
    b3 = jnp.zeros((1,))
    return (W1,b1,W2,b2,W3,b3)

def cnn_apply(params, x):
    W1,b1,W2,b2,W3,b3 = params
    y = lax.conv_general_dilated(x, W1, (1,1), 'SAME')
    y = jnp.tanh(y + b1)
    y = lax.conv_general_dilated(y, W2, (2,2), 'SAME')  # stride 2
    y = jnp.tanh(y + b2)
    y = y.reshape((y.shape[0], -1))
    y = y @ W3 + b3
    return y.squeeze(-1)

@jit
def cnn_loss(params, x, y):
    logits = cnn_apply(params, x)
    return jnp.mean(jnp.logaddexp(0.0, -y * logits))

@jit
def cnn_step(params, x, y, lr=1e-2):
    l, g = value_and_grad(cnn_loss)(params, x, y)
    params = jax.tree_map(lambda p, gg: p - lr * gg, params, g)
    return params, l

params_cnn = init_cnn(random.PRNGKey(2))
for step in range(200):
    params_cnn, l = cnn_step(params_cnn, imgs, labels)
print('Final CNN loss:', float(l))


# 12) RNN via lax.scan (sequence modeling)

In [None]:
T, B, D_in, D_hid = 50, 32, 16, 32
key = random.PRNGKey(3)
xs = random.normal(key, (T, B, D_in))
ys = random.normal(key, (T, B, D_hid))

def init_rnn(key):
    k1,k2,k3 = random.split(key, 3)
    params = dict(
        Wx = random.normal(k1, (D_in, D_hid)) / jnp.sqrt(D_in),
        Wh = random.normal(k2, (D_hid, D_hid)) / jnp.sqrt(D_hid),
        b  = jnp.zeros((D_hid,))
    )
    return params

def rnn_step(params, h, x):
    h_new = jnp.tanh(x @ params['Wx'] + h @ params['Wh'] + params['b'])
    return h_new, h_new

def rnn_apply(params, x_seq, h0):
    hT, hs = lax.scan(lambda h, x: rnn_step(params, h, x), h0, x_seq)
    return hT, hs

params_rnn = init_rnn(key)
h0 = jnp.zeros((B, D_hid))
_, hs = rnn_apply(params_rnn, xs, h0)
print('RNN hidden seq shape:', hs.shape)  # (T, B, D_hid)


# 13) Compact Transformer encoder block (no Flax)

In [None]:
def layer_norm(x, eps=1e-5):
    m = jnp.mean(x, axis=-1, keepdims=True)
    v = jnp.mean((x - m)**2, axis=-1, keepdims=True)
    return (x - m) / jnp.sqrt(v + eps)

def init_transformer_block(key, d_model=128, n_heads=4, d_ff=256):
    kq, kk, kv, ko, k1, k2 = random.split(key, 6)
    params = dict(
        Wq = random.normal(kq, (d_model, d_model)) / jnp.sqrt(d_model),
        Wk = random.normal(kk, (d_model, d_model)) / jnp.sqrt(d_model),
        Wv = random.normal(kv, (d_model, d_model)) / jnp.sqrt(d_model),
        Wo = random.normal(ko, (d_model, d_model)) / jnp.sqrt(d_model),
        W1 = random.normal(k1, (d_model, d_ff)) / jnp.sqrt(d_model),
        W2 = random.normal(k2, (d_ff, d_model)) / jnp.sqrt(d_ff),
        b1 = jnp.zeros((d_ff,)),
        b2 = jnp.zeros((d_model,)),
    )
    head_dim = d_model // n_heads
    return params, head_dim, n_heads

def split_heads(x, n_heads):
    B,T,D = x.shape
    return x.reshape(B, T, n_heads, D//n_heads).transpose(0,2,1,3)  # (B,H,T,dh)

def combine_heads(x):
    B,H,T,dh = x.shape
    return x.transpose(0,2,1,3).reshape(B, T, H*dh)

def mhsa(params, x, head_dim, n_heads, mask=None):
    Q = x @ params['Wq']; K = x @ params['Wk']; V = x @ params['Wv']
    Qh, Kh, Vh = split_heads(Q, n_heads), split_heads(K, n_heads), split_heads(V, n_heads)
    scale = 1.0 / jnp.sqrt(head_dim)
    attn = jnp.einsum('bhtd,bhTd->bhtT', Qh, Kh) * scale  # (B,H,T,T)
    if mask is not None:
        attn = jnp.where(mask, attn, -1e9)
    probs = softmax(attn, axis=-1)
    out = jnp.einsum('bhtT,bhTd->bhtd', probs, Vh)
    out = combine_heads(out) @ params['Wo']
    return out

def transformer_block(params, x, head_dim, n_heads, mask=None):
    x = layer_norm(x + mhsa(params, x, head_dim, n_heads, mask))
    y = jnp.tanh(x @ params['W1'] + params['b1'])
    y = y @ params['W2'] + params['b2']
    x = layer_norm(x + y)
    return x

# Demo
B, T, D = 8, 32, 128
key = random.PRNGKey(0)
x = random.normal(key, (B, T, D))
params_tx, head_dim, n_heads = init_transformer_block(key, d_model=D, n_heads=4, d_ff=256)
y = transformer_block(params_tx, x, head_dim, n_heads)
print('Transformer out:', y.shape)


# 14) Optimization with Optax (adamw, schedules, clipping)

In [None]:
try:
    import optax

    def make_optimizer(lr=1e-3, wd=1e-4):
        return optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adamw(learning_rate=lr, weight_decay=wd)
        )

    @dataclass
    class TrainState:
        params: Any
        opt_state: Any

    def init_train_state(params, opt):
        return TrainState(params=params, opt_state=opt.init(params))

    def apply_updates(params, updates):
        return optax.apply_updates(params, updates)

    @jit
    def train_step_optax(state, x, y, opt):
        loss, grads = value_and_grad(binary_loss)(state.params, x, y)
        updates, new_opt_state = opt.update(grads, state.opt_state, state.params)
        new_params = apply_updates(state.params, updates)
        return TrainState(new_params, new_opt_state), loss

    # Demo on MLP data
    key = random.PRNGKey(0)
    params0 = init_mlp(key, [2, 64, 64, 1])
    opt = make_optimizer(1e-2, 1e-4)
    state = init_train_state(params0, opt)
    for step in range(200):
        state, l = train_step_optax(state, X, y, opt)
    print('Optax demo loss:', float(l))

except Exception as e:
    print('Optax not available:', e)


# Memory: rematerialization (remat/checkpoint), scans vs vmaps, donation

In [None]:
# Rematerialization: recompute activations during backward to save memory
from jax import checkpoint as remat

def big_chain(x, depth=10):
    for _ in range(depth):
        x = jnp.tanh(x @ x.T)
    return x

@jit
def f_no_remat(x): return big_chain(x)

@jit
def f_with_remat(x): return remat(big_chain)(x)

Xbig = random.normal(random.PRNGKey(0), (128, 128))
_ = f_no_remat(Xbig).block_until_ready()
_ = f_with_remat(Xbig).block_until_ready()
print('Remat demo complete (profile on your hardware for real gains).')


# Saving / loading params (simple & robust)

In [None]:
# 1) Pickle (simple)
def tree_save_pickle(path, pytree):
    with open(path, 'wb') as f:
        pickle.dump(pytree, f)

def tree_load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

tree_save_pickle('params_demo.pkl', params)
params_loaded = tree_load_pickle('params_demo.pkl')
print('pickle load ok:', jax.tree_util.tree_structure(params) == jax.tree_util.tree_structure(params_loaded))

# 2) NPZ with treedef (object array)
from jax import tree_util

def tree_save_npz(path, pytree):
    leaves, treedef = tree_util.tree_flatten(pytree)
    save_dict = {f'leaf_{i}': np.asarray(leaf) for i, leaf in enumerate(leaves)}
    save_dict['treedef'] = np.array([treedef], dtype=object)
    np.savez(path, **save_dict)

def tree_load_npz(path):
    with np.load(path, allow_pickle=True) as data:
        treedef = data['treedef'][0]
        leaf_keys = sorted([k for k in data.files if k.startswith('leaf_')],
                           key=lambda s: int(s.split('_')[1]))
        leaves = [data[k] for k in leaf_keys]
    return tree_util.tree_unflatten(treedef, leaves)

tree_save_npz('params_demo.npz', params)
params2 = tree_load_npz('params_demo.npz')
print('npz load ok:', jax.tree_util.tree_structure(params) == jax.tree_util.tree_structure(params2))


In [None]:
got it — here’s a self-contained, section-by-section JAX “from zero to expert” you can paste straight into a Jupyter notebook. each section has short notes + runnable code. if anything’s unclear, ping me and we’ll go deeper.

---

# JAX “From Zero to Expert”

> hands-on coverage of: JAX basics (jit/grad/vmap), pytrees, PRNG best practices; autodiff theory & APIs (jvp/vjp, jacfwd/jacrev, custom rules); control flow; LAX primitives; performance (tracing, static args, donation, remat, timing); vectorization & parallelism (vmap, pmap, intro to pjit & sharding); numerics & mixed precision; end-to-end models (MLP, CNN, RNN with scan, compact Transformer block); optimization with Optax; saving/loading; debugging/profiling; interop — plus a few advanced extras.

---

## 0) Setup (run first)

**Tip:** pick the correct JAX wheel for your hardware. CPU is simplest; CUDA wheels depend on your CUDA version.

```python
# If you're on CPU-only:
# %pip install -U "jax[cpu]"

# If you have CUDA (example for CUDA 12):
# %pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Optional but recommended for sections below:
# %pip install -U optax flax orbax-checkpoint matplotlib

import os, sys, math, time, functools, itertools, pickle
from dataclasses import dataclass
from typing import Any, Tuple, Dict

import jax
import jax.numpy as jnp
from jax import jit, grad, value_and_grad, random, lax, vmap, jacrev, jacfwd
import numpy as np

print('JAX version:', jax.__version__)
print('Backend:', jax.default_backend())
print('Devices:', jax.devices())

# Optional helpful toggles:
from jax import config as jax_config
# jax_config.update('jax_enable_x64', True)    # Enable 64-bit if needed (slower on GPU)
# os.environ['JAX_DEBUG_NANS'] = 'true'        # Extra NaN checks inside compiled code
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'  # Avoid preallocating full GPU memory

key = random.PRNGKey(0)
```

---

## 1) JAX basics (arrays, no in-place updates, jit, grad, vmap)

Key ideas: functional transforms (`jit`, `grad`, `vmap`) act on **functions**; no in-place mutation (use `.at[...]`); first `jit` call compiles; timings should block with `.block_until_ready()`.

```python
# Arrays & broadcasting
x = jnp.arange(6).reshape(2, 3)
y = jnp.ones((3,))
print('x:\n', x)
print('x + y:\n', x + y)  # broadcasting

# No in-place mutation -> use .at
z = x.at[:, 1].set(42)
print('z (col 1 set to 42):\n', z)

# JIT compilation: first call compiles, subsequent calls are fast
@jit
def poly(a, b, x):
    return a * x**2 + b * x + 1.0

xv = jnp.linspace(-3, 3, 100_000)
a, b = 2.0, -1.0

t0 = time.time()
out = poly(a, b, xv).block_until_ready()  # compile + run
t1 = time.time()
out2 = poly(a, b, xv).block_until_ready() # run-only
t2 = time.time()
print(f'First call (compile+run): {(t1 - t0)*1e3:.2f} ms')
print(f'Second call (run only):   {(t2 - t1)*1e3:.2f} ms')

# Autodiff with grad
def loss_fn(w, x, y):
    pred = w[0] * x + w[1]
    return jnp.mean((pred - y)**2)

w = jnp.array([2.0, -1.0])
x = jnp.linspace(-1, 1, 1000)
y = 3.0 * x + 0.5  # true slope=3, intercept=0.5

dloss_dw = grad(loss_fn)(w, x, y)
print('Gradients:', dloss_dw)

# vmap to batch without Python loops
def square_plus_one(t):
    return t*t + 1

batched_square = vmap(square_plus_one)  # vectorize over leading axis
print('vmap:', batched_square(jnp.arange(5.)))
```

---

## 2) Pytrees (structuring parameters)

Pytree = nested structure of leaves (arrays/scalars). Many JAX APIs accept pytrees directly.

```python
from jax import tree_util

params = {'W': jnp.ones((3, 3)), 'b': jnp.zeros((3,))}
print('Leaves:', tree_util.tree_leaves(params))

# Map over leaves
scaled = jax.tree_map(lambda x: 0.5 * x, params)
print('Scaled b:', scaled['b'])

# dataclass as pytree
@dataclass
class MLPParams:
    W1: jnp.ndarray
    b1: jnp.ndarray
    W2: jnp.ndarray
    b2: jnp.ndarray

p = MLPParams(jnp.ones((4, 8)), jnp.zeros((8,)), jnp.ones((8, 2)), jnp.zeros((2,)))
print('Num leaves:', len(tree_util.tree_leaves(p)))
```

---

## 3) PRNG best practices (functional randomness)

Split keys; treat keys as inputs/outputs; use `fold_in(step)` for reproducible per-step seeding.

```python
key = random.PRNGKey(42)
key, k1, k2 = random.split(key, 3)
print('Normal:', random.normal(k1, (3,)))
print('Uniform:', random.uniform(k2, (3,)))

# fold_in for deterministic per-step randomness
base = random.PRNGKey(123)
for step in range(3):
    k = random.fold_in(base, step)
    print('step', step, '->', random.normal(k, ()).item())
```

---

## 4) Autodiff theory + APIs (jvp/vjp, jacfwd/jacrev, custom\_vjp/jvp)

* Forward-mode AD (`jvp`, `jacfwd`) is efficient when input dim ≤ output dim.
* Reverse-mode AD (`grad`, `vjp`, `jacrev`) is efficient when output dim ≤ input dim.
* You can nest them for higher-order derivatives.
* Custom rules (`custom_vjp` / `custom_jvp`) = numerical stability or non-standard ops.

```python
# jvp/vjp example
def f(u):
    return jnp.array([u[0]*jnp.sin(u[1]), u[0]*jnp.cos(u[1])])  # R^2 -> R^2

u0 = jnp.array([2.0, 0.3])
tangent = jnp.array([1.0, -0.2])

y, jvp_val = jax.jvp(f, (u0,), (tangent,))
print('f(u0)=', y)
print('JVP(f; u0, tangent)=', jvp_val)

y, pullback = jax.vjp(f, u0)
bar = jnp.array([1.0, 0.0])  # seed cotangent
vjp_val = pullback(bar)[0]
print('VJP(f; u0, bar)=', vjp_val)

# Jacobians
def g(u):
    return jnp.array([u[0]**2 + u[1], jnp.sin(u[0]*u[1])])

J_fwd = jacfwd(g)(u0)
J_rev = jacrev(g)(u0)
print('Jacobian (fwd):\n', J_fwd)
print('Jacobian (rev):\n', J_rev)

# custom_vjp for a stable log-sum-exp
@jax.custom_vjp
def stable_logsumexp(x):
    m = jnp.max(x)
    return m + jnp.log(jnp.sum(jnp.exp(x - m)))

def slse_fwd(x):
    y = stable_logsumexp(x)
    return y, (x, y)

def slse_bwd(res, g):
    x, y = res
    p = jnp.exp(x - y)  # softmax probabilities
    return (g * p,)

stable_logsumexp.defvjp(slse_fwd, slse_bwd)

xx = jnp.array([1000.0, 999.0, 998.0])
print('stable LSE:', stable_logsumexp(xx))
print('grad check:', grad(lambda t: stable_logsumexp(t))(xx))
```

---

## 5) Control flow (lax.cond, lax.scan, while\_loop)

Prefer JAX control-flow primitives **inside jitted code**.

```python
# cond
def piecewise(x):
    return lax.cond(x > 0, lambda t: t*t, lambda t: -t, x)

print(piecewise(2.), piecewise(-3.))

# scan: sum over a sequence (with state "carry")
def scan_body(carry, x):
    s = carry + x
    return s, s  # (new_carry, stacked_output)

xs = jnp.arange(5.0)
carry0 = 0.0
final_carry, s_hist = lax.scan(scan_body, carry0, xs)
print('final sum:', final_carry)
print('s_hist:', s_hist)

# while_loop: simple countdown
def cond_fun(state):
    i, acc = state
    return i > 0

def body_fun(state):
    i, acc = state
    return (i-1, acc + i)

init = (5, 0)
i_end, acc = lax.while_loop(cond_fun, body_fun, init)
print('while result:', i_end, acc)
```

---

## 6) LAX primitives, shape polymorphism, eval\_shape

LAX exposes low-level ops (`lax.dot_general`, `lax.conv_general_dilated`). Use `jax.eval_shape` to inspect shapes/dtypes **without allocating**.

```python
# eval_shape to probe shapes
def my_net(params, x):
    for W, b in params:
        x = jnp.tanh(x @ W + b)
    return x

params = [(random.normal(random.PRNGKey(0), (32, 64)), jnp.zeros((64,))),
          (random.normal(random.PRNGKey(1), (64, 10)), jnp.zeros((10,)))]
sig = jax.eval_shape(my_net, params, jax.ShapeDtypeStruct((8, 32), jnp.float32))
print(sig)

# A tiny LAX convolution (N,H,W,C) * (KH,KW,C,OC)
N,H,W,C = 1, 8, 8, 1
x = random.normal(random.PRNGKey(0), (N,H,W,C))
Wk = random.normal(random.PRNGKey(1), (3,3,C,4))
y = lax.conv_general_dilated(x, Wk, window_strides=(1,1), padding='SAME')
print('conv out shape:', y.shape)
```

---

## 7) Performance: tracing, static args, donation, timing, lower/HLO

* Keep shapes/dtypes steady to avoid retracing.
* Mark Python values as **static** when they change control-flow/structure.
* Use **buffer donation** to reduce copies.
* Always time with `.block_until_ready()`.
* Peek at compiler IR via `lower(...).as_text()`.

```python
# Static args (e.g., a small Python "mode" flag)
@jit
def reg_loss(w, x, y, penalty: str = 'l2', lam: float = 1e-2):
    logits = x @ w
    nll = jnp.mean(jnp.logaddexp(0.0, -y * logits))
    if penalty == 'l2':
        reg = 0.5 * lam * jnp.sum(w * w)
    elif penalty == 'l1':
        reg = lam * jnp.sum(jnp.abs(w))
    else:
        raise ValueError('unknown penalty')
    return nll + reg

w = jnp.zeros((10,))
x = random.normal(random.PRNGKey(0), (128, 10))
y = jnp.sign(random.normal(random.PRNGKey(1), (128,)))
print(reg_loss(w, x, y, 'l2', 1e-2))

# Donation example (donate_argnums=0 donates first arg "arr")
@jit(donate_argnums=(0,))
def in_placey_add(arr, val):
    return arr + val

arr = jnp.ones((1_000_000,))
t0 = time.time(); arr = in_placey_add(arr, 1).block_until_ready(); t1 = time.time()
print('donated add ms:', (t1-t0)*1e3)

# Inspect lowered IR/HLO (API may vary slightly across versions)
def foo_to_lower(a, b):
    return a @ b

print(
    jit(foo_to_lower).lower(jnp.ones((64,64), jnp.float32), jnp.ones((64,64), jnp.float32)).as_text()[:500]
)
```

---

## 8) Vectorization & parallelism (vmap, pmap (legacy), intro to pjit & sharding)

* `vmap`: auto-batch on one device.
* `pmap`: SPMD across devices (still useful, esp. TPUs).
* `pjit` + **explicit sharding** with Mesh/PartitionSpec (recommended for new multi-device work).

```python
# vmap: pairwise dot without Python loops
def dot(a, b):  # [..., d] x [..., d] -> [...]
    return jnp.sum(a * b, axis=-1)

A = random.normal(random.PRNGKey(0), (64, 128))
B = random.normal(random.PRNGKey(1), (64, 128))
batched_dot = vmap(dot)
print('batched dot shape:', batched_dot(A, B).shape)

# pmap demo (will degenerate to single-device behavior if you only have 1 device)
@jax.pmap
def add_one(x):
    return x + 1

xs = jnp.arange(jax.device_count())
print('device_count =', jax.device_count())
print('pmap result:', add_one(xs))

# pjit-style sharding (pattern only; single-device becomes no-op)
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

devices = np.array(jax.devices())
mesh = Mesh(devices[:max(1, min(2, devices.size))], ('data',))
print('Mesh axes:', mesh.axis_names)
```

---

## 9) Numerics & mixed precision (stable softmax, gradient clipping)

Use numerically stable forms; consider mixed precision (fp16/bfloat16 for matmuls/conv, fp32 for accumulators).

```python
def softmax(x, axis=-1):
    z = x - jnp.max(x, axis=axis, keepdims=True)
    num = jnp.exp(z)
    return num / jnp.sum(num, axis=axis, keepdims=True)

x = jnp.array([[1000., 0., -1000.]])
print('stable softmax:', softmax(x))

# Simple global-norm clip util
def clip_by_global_norm(tree, max_norm=1.0):
    gsq = sum([jnp.sum(jnp.square(g)) for g in jax.tree_util.tree_leaves(tree)])
    scale = jnp.minimum(1.0, max_norm / (jnp.sqrt(gsq) + 1e-8))
    return jax.tree_map(lambda g: g * scale, tree)
```

---

## 10) End-to-end MLP (from scratch)

Synthetic 2-class data, MLP, training step with `jit`.

```python
import matplotlib.pyplot as plt

def make_toy_data(k, n_per_class=512, spread=0.7, seed=0):
    key = random.PRNGKey(seed)
    key1, key2 = random.split(key)
    c0 = random.normal(key1, (n_per_class, 2)) * spread + jnp.array([0., -k])
    c1 = random.normal(key2, (n_per_class, 2)) * spread + jnp.array([0., k])
    X = jnp.concatenate([c0, c1], axis=0)
    y = jnp.concatenate([-jnp.ones((n_per_class,)), jnp.ones((n_per_class,))], axis=0)
    return X, y

X, y = make_toy_data(2.5)

def init_mlp(key, sizes):
    keys = random.split(key, len(sizes)-1)
    params = []
    for k, (m, n) in zip(keys, zip(sizes[:-1], sizes[1:])):
        W = random.normal(k, (m, n)) / jnp.sqrt(m)
        b = jnp.zeros((n,))
        params.append((W, b))
    return params

def mlp_apply(params, x):
    for (W, b) in params[:-1]:
        x = jnp.tanh(x @ W + b)
    W, b = params[-1]
    return x @ W + b  # logits

def binary_loss(params, x, y):
    logits = mlp_apply(params, x).squeeze(-1)
    return jnp.mean(jnp.logaddexp(0.0, -y * logits))  # logistic loss

@jit
def train_step(params, x, y, lr=1e-2):
    loss, grads = value_and_grad(binary_loss)(params, x, y)
    # grads = clip_by_global_norm(grads, 1.0)  # optional
    params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return params, loss

key = random.PRNGKey(0)
params = init_mlp(key, [2, 64, 64, 1])

losses = []
for step in range(400):
    params, l = train_step(params, X, y)
    losses.append(float(l))

plt.figure()
plt.plot(losses)
plt.title('MLP Training Loss')
plt.xlabel('step'); plt.ylabel('loss')
plt.show()
```

---

## 11) CNN (from scratch, synthetic images)

```python
# Fake images (N, H, W, C)
N, H, W, C = 64, 28, 28, 1
key = random.PRNGKey(1)
imgs = random.normal(key, (N, H, W, C))
labels = jnp.where(jnp.mean(imgs, axis=(1,2,3)) > 0, 1., -1.)  # separable synthetic target

def init_cnn(key):
    k1, k2, k3 = random.split(key, 3)
    W1 = random.normal(k1, (3, 3, C, 8)) / jnp.sqrt(3*3*C)
    b1 = jnp.zeros((8,))
    W2 = random.normal(k2, (3, 3, 8, 16)) / jnp.sqrt(3*3*8)
    b2 = jnp.zeros((16,))
    W3 = random.normal(k3, (16*7*7, 1)) / jnp.sqrt(16*7*7)
    b3 = jnp.zeros((1,))
    return (W1,b1,W2,b2,W3,b3)

def cnn_apply(params, x):
    W1,b1,W2,b2,W3,b3 = params
    y = lax.conv_general_dilated(x, W1, (1,1), 'SAME')
    y = jnp.tanh(y + b1)
    y = lax.conv_general_dilated(y, W2, (2,2), 'SAME')  # stride 2
    y = jnp.tanh(y + b2)
    y = y.reshape((y.shape[0], -1))
    y = y @ W3 + b3
    return y.squeeze(-1)

@jit
def cnn_loss(params, x, y):
    logits = cnn_apply(params, x)
    return jnp.mean(jnp.logaddexp(0.0, -y * logits))

@jit
def cnn_step(params, x, y, lr=1e-2):
    l, g = value_and_grad(cnn_loss)(params, x, y)
    params = jax.tree_map(lambda p, gg: p - lr * gg, params, g)
    return params, l

params_cnn = init_cnn(random.PRNGKey(2))
for step in range(200):
    params_cnn, l = cnn_step(params_cnn, imgs, labels)
print('Final CNN loss:', float(l))
```

---

## 12) RNN via `lax.scan` (sequence modeling)

```python
T, B, D_in, D_hid = 50, 32, 16, 32
key = random.PRNGKey(3)
xs = random.normal(key, (T, B, D_in))
ys = random.normal(key, (T, B, D_hid))

def init_rnn(key):
    k1,k2,k3 = random.split(key, 3)
    params = dict(
        Wx = random.normal(k1, (D_in, D_hid)) / jnp.sqrt(D_in),
        Wh = random.normal(k2, (D_hid, D_hid)) / jnp.sqrt(D_hid),
        b  = jnp.zeros((D_hid,))
    )
    return params

def rnn_step(params, h, x):
    h_new = jnp.tanh(x @ params['Wx'] + h @ params['Wh'] + params['b'])
    return h_new, h_new

def rnn_apply(params, x_seq, h0):
    hT, hs = lax.scan(lambda h, x: rnn_step(params, h, x), h0, x_seq)
    return hT, hs

params_rnn = init_rnn(key)
h0 = jnp.zeros((B, D_hid))
_, hs = rnn_apply(params_rnn, xs, h0)
print('RNN hidden seq shape:', hs.shape)  # (T, B, D_hid)
```

---

## 13) Compact Transformer encoder block (no Flax)

```python
def layer_norm(x, eps=1e-5):
    m = jnp.mean(x, axis=-1, keepdims=True)
    v = jnp.mean((x - m)**2, axis=-1, keepdims=True)
    return (x - m) / jnp.sqrt(v + eps)

def init_transformer_block(key, d_model=128, n_heads=4, d_ff=256):
    kq, kk, kv, ko, k1, k2 = random.split(key, 6)
    params = dict(
        Wq = random.normal(kq, (d_model, d_model)) / jnp.sqrt(d_model),
        Wk = random.normal(kk, (d_model, d_model)) / jnp.sqrt(d_model),
        Wv = random.normal(kv, (d_model, d_model)) / jnp.sqrt(d_model),
        Wo = random.normal(ko, (d_model, d_model)) / jnp.sqrt(d_model),
        W1 = random.normal(k1, (d_model, d_ff)) / jnp.sqrt(d_model),
        W2 = random.normal(k2, (d_ff, d_model)) / jnp.sqrt(d_ff),
        b1 = jnp.zeros((d_ff,)),
        b2 = jnp.zeros((d_model,)),
    )
    head_dim = d_model // n_heads
    return params, head_dim, n_heads

def split_heads(x, n_heads):
    B,T,D = x.shape
    return x.reshape(B, T, n_heads, D//n_heads).transpose(0,2,1,3)  # (B,H,T,dh)

def combine_heads(x):
    B,H,T,dh = x.shape
    return x.transpose(0,2,1,3).reshape(B, T, H*dh)

def mhsa(params, x, head_dim, n_heads, mask=None):
    Q = x @ params['Wq']; K = x @ params['Wk']; V = x @ params['Wv']
    Qh, Kh, Vh = split_heads(Q, n_heads), split_heads(K, n_heads), split_heads(V, n_heads)
    scale = 1.0 / jnp.sqrt(head_dim)
    attn = jnp.einsum('bhtd,bhTd->bhtT', Qh, Kh) * scale  # (B,H,T,T)
    if mask is not None:
        attn = jnp.where(mask, attn, -1e9)
    probs = softmax(attn, axis=-1)
    out = jnp.einsum('bhtT,bhTd->bhtd', probs, Vh)
    out = combine_heads(out) @ params['Wo']
    return out

def transformer_block(params, x, head_dim, n_heads, mask=None):
    x = layer_norm(x + mhsa(params, x, head_dim, n_heads, mask))
    y = jnp.tanh(x @ params['W1'] + params['b1'])
    y = y @ params['W2'] + params['b2']
    x = layer_norm(x + y)
    return x

# Demo
B, T, D = 8, 32, 128
key = random.PRNGKey(0)
x = random.normal(key, (B, T, D))
params_tx, head_dim, n_heads = init_transformer_block(key, d_model=D, n_heads=4, d_ff=256)
y = transformer_block(params_tx, x, head_dim, n_heads)
print('Transformer out:', y.shape)
```

---

## 14) Optimization with Optax (adamw, schedules, clipping)

```python
try:
    import optax

    def make_optimizer(lr=1e-3, wd=1e-4):
        return optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adamw(learning_rate=lr, weight_decay=wd)
        )

    @dataclass
    class TrainState:
        params: Any
        opt_state: Any

    def init_train_state(params, opt):
        return TrainState(params=params, opt_state=opt.init(params))

    def apply_updates(params, updates):
        return optax.apply_updates(params, updates)

    @jit
    def train_step_optax(state, x, y, opt):
        loss, grads = value_and_grad(binary_loss)(state.params, x, y)
        updates, new_opt_state = opt.update(grads, state.opt_state, state.params)
        new_params = apply_updates(state.params, updates)
        return TrainState(new_params, new_opt_state), loss

    # Demo on MLP data
    key = random.PRNGKey(0)
    params0 = init_mlp(key, [2, 64, 64, 1])
    opt = make_optimizer(1e-2, 1e-4)
    state = init_train_state(params0, opt)
    for step in range(200):
        state, l = train_step_optax(state, X, y, opt)
    print('Optax demo loss:', float(l))

except Exception as e:
    print('Optax not available:', e)
```

---

## 15) Memory: rematerialization (remat/checkpoint), scans vs vmaps, donation

```python
# Rematerialization: recompute activations during backward to save memory
from jax import checkpoint as remat

def big_chain(x, depth=10):
    for _ in range(depth):
        x = jnp.tanh(x @ x.T)
    return x

@jit
def f_no_remat(x): return big_chain(x)

@jit
def f_with_remat(x): return remat(big_chain)(x)

Xbig = random.normal(random.PRNGKey(0), (128, 128))
_ = f_no_remat(Xbig).block_until_ready()
_ = f_with_remat(Xbig).block_until_ready()
print('Remat demo complete (profile on your hardware for real gains).')
```

---

## 16) Saving / loading params (simple & robust)

Two simple approaches:

1. **Pickle** (easy & preserves structure; fine for experiments)
2. **NumPy `.npz` with metadata** (portable arrays; stores treedef in an object field)

```python
# 1) Pickle (simple)
def tree_save_pickle(path, pytree):
    with open(path, 'wb') as f:
        pickle.dump(pytree, f)

def tree_load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

tree_save_pickle('params_demo.pkl', params)
params_loaded = tree_load_pickle('params_demo.pkl')
print('pickle load ok:', jax.tree_util.tree_structure(params) == jax.tree_util.tree_structure(params_loaded))

# 2) NPZ with treedef (object array)
from jax import tree_util

def tree_save_npz(path, pytree):
    leaves, treedef = tree_util.tree_flatten(pytree)
    save_dict = {f'leaf_{i}': np.asarray(leaf) for i, leaf in enumerate(leaves)}
    save_dict['treedef'] = np.array([treedef], dtype=object)
    np.savez(path, **save_dict)

def tree_load_npz(path):
    with np.load(path, allow_pickle=True) as data:
        treedef = data['treedef'][0]
        leaf_keys = sorted([k for k in data.files if k.startswith('leaf_')],
                           key=lambda s: int(s.split('_')[1]))
        leaves = [data[k] for k in leaf_keys]
    return tree_util.tree_unflatten(treedef, leaves)

tree_save_npz('params_demo.npz', params)
params2 = tree_load_npz('params_demo.npz')
print('npz load ok:', jax.tree_util.tree_structure(params) == jax.tree_util.tree_structure(params2))
```

---

## 17) Debugging & profiling

* `jax.debug.print` works inside `jit`.
* `JAX_DEBUG_NANS=1` + `jax_enable_x64=True` for numerics hunts (slower).
* Benchmark with warmup + `.block_until_ready()`; prefer profilers for full runs.
* Keep input shapes/dtypes consistent to avoid retracing.

```python
@jit
def foo(x):
    jax.debug.print('Inside jit, mean(x) = {}', jnp.mean(x))
    return x + 1

_ = foo(jnp.arange(4.))
```

---

## 18) Interop: jax.numpy ↔ numpy, device transfers, jax2tf (note)

```python
arr = jnp.arange(5.)
host_arr = np.array(arr)  # transfers to host, becomes NumPy array
print(type(host_arr), host_arr)

# Explicit device <-> host
on_device = jax.device_put(host_arr)
back_to_host = jax.device_get(on_device)
print(type(back_to_host), back_to_host)
```

---

## 19) (Optional) Quick Flax/Linen taste: a tiny MLP

If you installed `flax`, here’s a minimal pattern.

```python
try:
    import flax.linen as nn
    import optax
    from flax.training import train_state

    class MLP(nn.Module):
        hidden: int = 64
        @nn.compact
        def __call__(self, x):
            x = nn.tanh(nn.Dense(self.hidden)(x))
            x = nn.tanh(nn.Dense(self.hidden)(x))
            x = nn.Dense(1)(x)
            return x.squeeze(-1)

    def loss_fn(params, model, x, y):
        logits = model.apply({'params': params}, x)
        return jnp.mean(jnp.logaddexp(0.0, -y * logits))

    class State(train_state.TrainState): pass

    model = MLP(64)
    key = random.PRNGKey(0)
    params_flax = model.init(key, X)['params']
    tx = optax.adamw(1e-2)
    state = State.create(apply_fn=model.apply, params=params_flax, tx=tx)

    @jit
    def train_step_flax(state, x, y):
        loss, grads = value_and_grad(loss_fn)(state.params, model, x, y)
        state = state.apply_gradients(grads=grads)
        return state, loss

    for step in range(200):
        state, l = train_step_flax(state, X, y)
    print('Flax demo loss:', float(l))
except Exception as e:
    print('Flax not available or import failed:', e)
```

---

## 20) (Optional) Inspect compilation cost & caching

You’ll often see “first call slow, rest fast”. This shows warmup vs. steady-state.

```python
@jit
def heavy_matmul(a, b):
    return a @ b + jnp.tanh(a @ b)

A = random.normal(random.PRNGKey(0), (2048, 2048), dtype=jnp.float32)
B = random.normal(random.PRNGKey(1), (2048, 2048), dtype=jnp.float32)

t0 = time.time(); _ = heavy_matmul(A, B).block_until_ready(); t1 = time.time()
t2 = time.time(); _ = heavy_matmul(A, B).block_until_ready(); t3 = time.time()
print(f'compile+run: {(t1-t0):.3f}s, run-only: {(t3-t2):.3f}s')
```

---

## 21) (Optional) Mixed precision quick pattern

Keep “master” weights in fp32; cast activations/weights on hot ops if desired.

```python
def mixed_precision_mlp_apply(params, x):
    # Example casting: compute in bfloat16 (works well on TPU; on GPUs prefer float16/TF32)
    x16 = x.astype(jnp.bfloat16)
    outs = []
    for (W, b) in params[:-1]:
        y = (x16 @ W.astype(jnp.bfloat16) + b.astype(jnp.bfloat16))
        x16 = jnp.tanh(y)
    W, b = params[-1]
    logits16 = x16 @ W.astype(jnp.bfloat16) + b.astype(jnp.bfloat16)
    return logits16.astype(jnp.float32)  # back to fp32 at boundaries
```

---

## 22) (Optional) Practical batching + `vmap` + `jit`

```python
def dataloader(X, y, batch=128, shuffle=True, seed=0):
    n = X.shape[0]
    idx = np.arange(n)
    rng = np.random.default_rng(seed)
    while True:
        if shuffle:
            rng.shuffle(idx)
        for i in range(0, n, batch):
            j = idx[i:i+batch]
            yield X[j], y[j]

@jit
def step(params, xb, yb, lr=1e-2):
    loss, grads = value_and_grad(binary_loss)(params, xb, yb)
    params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return params, loss

gen = dataloader(X, y, batch=128, shuffle=True, seed=0)
params_b = init_mlp(random.PRNGKey(0), [2, 64, 64, 1])
for step_i in range(50):
    xb, yb = next(gen)
    params_b, l = step(params_b, xb, yb)
print('Minibatch training loss (last):', float(l))
```

---

## 23) (Optional) Common gotchas (quick checklist)

* Py control flow inside `jit` that depends on data → use `lax.cond/scan/while_loop`.
* Inconsistent shapes/dtypes across calls → retracing & recompilation.
* PRNG keys reused accidentally → non-random behavior.
* Mutating globals inside `jit` → doesn’t work; think functional.
* Timing without `.block_until_ready()` → misleading.

---

## 24) Cheat sheet

* **Core transforms**: `jit`, `grad`, `value_and_grad`, `vmap`, `pmap`, `pjit`
* **Autodiff**: `jvp`, `vjp`, `jacfwd`, `jacrev`, `custom_vjp`, `custom_jvp`
* **Control flow**: `lax.cond`, `lax.scan`, `lax.while_loop`
* **Performance**: donation (`donate_argnums`), remat (`jax.checkpoint`), `.block_until_ready()`, `lower(...).as_text()`
* **Random**: `PRNGKey`, `split`, `fold_in`
* **Pytrees**: `tree_map`, `tree_flatten`, `tree_unflatten`
* **Sharding**: `Mesh`, `PartitionSpec`, `NamedSharding`, `pjit`
* **Numerics**: stable softmax/logsumexp, gradient clipping, mixed precision
* **IO**: pickle (simple), npz (portable), Orbax (large jobs; if installed)

---

that’s the full scaffold. paste into a notebook cell-by-cell (top to bottom) and you’ll have a working playground. want me to extend any part (e.g., full pjit example across 2+ GPUs, dataset loaders, LR schedulers, gradient checkpointed Transformer stack, Orbax checkpointing, or TPU tips)?


In [None]:
sweet — here’s a clean, from-scratch **JAX CNN on MNIST** you can paste into your notebook. it loads MNIST (tfds first; falls back to Keras/Torch if needed), builds a small ConvNet with pure JAX/LAX, and trains/evaluates it. everything’s commented so you can tweak fast.

---

# CNN on MNIST (pure JAX/LAX, from scratch)

### 0) Installs (as needed) + imports

> If you don’t have any of these, uncomment the appropriate `%pip` lines and run once.

```python
# --- optional installs (pick one data source you prefer) ---
# %pip install -U tensorflow-datasets
# %pip install -U tensorflow-cpu          # only if you plan to use keras.datasets.mnist
# %pip install -U torch torchvision       # only if you plan to use torchvision

# --- JAX/plotting (if you skipped them earlier) ---
# %pip install -U "jax[cpu]"              # or the right cuda wheel for your GPU
# %pip install -U matplotlib optax        # optax optional (we fallback to SGD)

import time, math, pickle, os
from dataclasses import dataclass
from typing import Any, Tuple, Dict, Iterable, Optional

import numpy as np
import jax
import jax.numpy as jnp
import jax.nn as jnn
from jax import jit, value_and_grad, random, lax

import matplotlib.pyplot as plt

print("JAX:", jax.__version__, "backend:", jax.default_backend(), "devices:", jax.devices())
```

---

### 1) MNIST loader (tfds → keras → torchvision fallback)

```python
def load_mnist() -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Returns (x_train, y_train, x_test, y_test) with shapes:
       - x_*: (N, 28, 28, 1) float32 in [0, 1]
       - y_*: (N,) int64 labels 0..9
    Tries tensorflow_datasets first, then keras, then torchvision.
    """
    # 1) tensorflow_datasets
    try:
        import tensorflow_datasets as tfds
        (x_train, y_train) = tfds.as_numpy(
            tfds.load("mnist", split="train", batch_size=-1, as_supervised=True)
        )
        (x_test, y_test) = tfds.as_numpy(
            tfds.load("mnist", split="test", batch_size=-1, as_supervised=True)
        )
        x_train = x_train.astype(np.float32) / 255.0
        x_test  = x_test.astype(np.float32) / 255.0
        if x_train.ndim == 3:
            x_train = x_train[..., None]
            x_test  = x_test[..., None]
        return x_train, y_train.astype(np.int64), x_test, y_test.astype(np.int64)
    except Exception as e:
        print("[load_mnist] tfds path failed:", e)

    # 2) keras.datasets.mnist
    try:
        from tensorflow.keras.datasets import mnist
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train = (x_train.astype(np.float32) / 255.0)[..., None]
        x_test  = (x_test.astype(np.float32) / 255.0)[..., None]
        return x_train, y_train.astype(np.int64), x_test, y_test.astype(np.int64)
    except Exception as e:
        print("[load_mnist] keras path failed:", e)

    # 3) torchvision
    try:
        from torchvision import datasets, transforms
        import torch
        tfm = transforms.Compose([transforms.ToTensor()])  # returns (C,H,W) in [0,1]
        train_ds = datasets.MNIST(root="./data", train=True, transform=tfm, download=True)
        test_ds  = datasets.MNIST(root="./data", train=False, transform=tfm, download=True)
        def _to_numpy(ds):
            xs = []
            ys = []
            for img, y in ds:
                arr = img.numpy().transpose(1,2,0)  # (H,W,C)
                xs.append(arr)
                ys.append(int(y))
            return np.stack(xs, 0).astype(np.float32), np.array(ys, dtype=np.int64)
        x_train, y_train = _to_numpy(train_ds)
        x_test, y_test   = _to_numpy(test_ds)
        return x_train, y_train, x_test, y_test
    except Exception as e:
        print("[load_mnist] torchvision path failed:", e)
        raise RuntimeError("Could not load MNIST with tfds/keras/torchvision. Install one of them and retry.")

x_train, y_train, x_test, y_test = load_mnist()
print("Train:", x_train.shape, y_train.shape, "| Test:", x_test.shape, y_test.shape)
```

---

### 2) Dataloader (numpy → mini-batches)

```python
def dataloader(X: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool=True, seed: int=0) -> Iterable[Tuple[np.ndarray, np.ndarray]]:
    n = X.shape[0]
    idx = np.arange(n)
    rng = np.random.default_rng(seed)
    while True:
        if shuffle:
            rng.shuffle(idx)
        for i in range(0, n, batch_size):
            j = idx[i:i+batch_size]
            yield X[j], y[j]
```

---

### 3) Model: ConvNet with max-pooling → dense → logits (10 classes)

* Conv(3×3, 32) → ReLU → MaxPool(2×2, s=2)
* Conv(3×3, 64) → ReLU → MaxPool(2×2, s=2)
* Flatten → Dense(128) → ReLU → Dense(10)

Weights are He-initialized (good for ReLU).

```python
def he_init_std(fan_in: int) -> float:
    return math.sqrt(2.0 / fan_in)

def init_cnn_params(key: jax.Array):
    k1, k2, k3, k4 = random.split(key, 4)
    # Conv kernels: (KH, KW, Cin, Cout)
    W1 = random.normal(k1, (3, 3, 1, 32)) * he_init_std(3*3*1)
    b1 = jnp.zeros((32,))
    W2 = random.normal(k2, (3, 3, 32, 64)) * he_init_std(3*3*32)
    b2 = jnp.zeros((64,))
    # After two 2x2 pools: 28 -> 14 -> 7, channels = 64
    flat_dim = 7*7*64
    W3 = random.normal(k3, (flat_dim, 128)) * he_init_std(flat_dim)
    b3 = jnp.zeros((128,))
    W4 = random.normal(k4, (128, 10)) * (1.0 / math.sqrt(128.0))
    b4 = jnp.zeros((10,))
    return (W1,b1,W2,b2,W3,b3,W4,b4)

def max_pool_2x2(x: jax.Array) -> jax.Array:
    # x: (N,H,W,C)
    return lax.reduce_window(
        x,
        -jnp.inf,
        lax.max,
        window_dimensions=(1,2,2,1),
        window_strides=(1,2,2,1),
        padding="VALID"
    )

def cnn_apply(params, x):
    W1,b1,W2,b2,W3,b3,W4,b4 = params
    # Conv1 + ReLU + Pool
    y = lax.conv_general_dilated(x, W1, window_strides=(1,1), padding="SAME")
    y = jnn.relu(y + b1)
    y = max_pool_2x2(y)
    # Conv2 + ReLU + Pool
    y = lax.conv_general_dilated(y, W2, window_strides=(1,1), padding="SAME")
    y = jnn.relu(y + b2)
    y = max_pool_2x2(y)
    # Flatten
    y = y.reshape((y.shape[0], -1))
    # Dense -> ReLU -> Dense
    y = jnn.relu(y @ W3 + b3)
    logits = y @ W4 + b4
    return logits

def loss_and_metrics(params, xb, yb):
    logits = cnn_apply(params, xb)
    logp = jnn.log_softmax(logits, axis=-1)
    nll = -jnp.take_along_axis(logp, yb[:, None], axis=1).mean()
    preds = jnp.argmax(logits, axis=-1)
    acc = (preds == yb).mean()
    return nll, acc
```

---

### 4) Optimizer (Optax AdamW if available; otherwise plain SGD)

```python
try:
    import optax
    use_optax = True
    opt = optax.adamw(learning_rate=1e-3, weight_decay=1e-4)
    @dataclass
    class TrainState:
        params: Any
        opt_state: Any
    def init_state(params):
        return TrainState(params=params, opt_state=opt.init(params))
    @jit
    def train_step(state: "TrainState", xb: jax.Array, yb: jax.Array):
        (loss, acc), grads = value_and_grad(loss_and_metrics, has_aux=True)(state.params, xb, yb)
        updates, new_opt_state = opt.update(grads, state.opt_state, state.params)
        new_params = optax.apply_updates(state.params, updates)
        return TrainState(new_params, new_opt_state), loss, acc
except Exception as e:
    print("[optimizer] Optax not available, falling back to SGD. Reason:", e)
    use_optax = False
    lr = 1e-2
    @dataclass
    class TrainState:
        params: Any
    def init_state(params):
        return TrainState(params=params)
    @jit
    def train_step(state: "TrainState", xb: jax.Array, yb: jax.Array):
        (loss, acc), grads = value_and_grad(loss_and_metrics, has_aux=True)(state.params, xb, yb)
        new_params = jax.tree_map(lambda p, g: p - lr * g, state.params, grads)
        return TrainState(new_params), loss, acc
```

---

### 5) Evaluation helpers

```python
@jit
def eval_batch(params, xb, yb):
    return loss_and_metrics(params, xb, yb)

def evaluate(params, X, y, batch_size=1024):
    n = X.shape[0]
    losses = []
    accs = []
    for i in range(0, n, batch_size):
        xb = jnp.array(X[i:i+batch_size])
        yb = jnp.array(y[i:i+batch_size])
        l, a = eval_batch(params, xb, yb)
        losses.append(float(l))
        accs.append(float(a))
    return np.mean(losses), np.mean(accs)
```

---

### 6) Train loop

```python
# Hyperparams
batch_size = 128
epochs = 5
seed = 0

# Initialize params
key = random.PRNGKey(seed)
params = init_cnn_params(key)
state = init_state(params)

# Build dataloader
train_iter = dataloader(x_train, y_train, batch_size=batch_size, shuffle=True, seed=seed)

# Training
steps_per_epoch = math.ceil(x_train.shape[0] / batch_size)
hist = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}

t0 = time.time()
for epoch in range(1, epochs+1):
    epoch_losses, epoch_accs = [], []
    for _ in range(steps_per_epoch):
        xb_np, yb_np = next(train_iter)
        xb = jnp.array(xb_np)  # moved to device automatically
        yb = jnp.array(yb_np)
        state, loss, acc = train_step(state, xb, yb)
        epoch_losses.append(float(loss))
        epoch_accs.append(float(acc))
    tr_loss, tr_acc = np.mean(epoch_losses), np.mean(epoch_accs)
    te_loss, te_acc = evaluate(state.params if hasattr(state, "params") else state, x_test, y_test)
    hist["train_loss"].append(tr_loss)
    hist["train_acc"].append(tr_acc)
    hist["test_loss"].append(te_loss)
    hist["test_acc"].append(te_acc)
    print(f"[epoch {epoch:02d}] train loss {tr_loss:.4f} acc {tr_acc*100:.2f}% | "
          f"test loss {te_loss:.4f} acc {te_acc*100:.2f}%")
t1 = time.time()
print(f"Total time: {t1 - t0:.2f}s")
```

---

### 7) Quick curves + sample predictions (optional)

```python
plt.figure(); plt.plot(hist["train_loss"], label="train"); plt.plot(hist["test_loss"], label="test"); plt.title("Loss"); plt.legend(); plt.show()
plt.figure(); plt.plot(hist["train_acc"], label="train"); plt.plot(hist["test_acc"], label="test"); plt.title("Accuracy"); plt.legend(); plt.show()

# Show a few predictions
def predict(params, x):
    logits = cnn_apply(params, x)
    return jnp.argmax(logits, axis=-1)

idx = np.random.default_rng(0).integers(0, x_test.shape[0], size=6)
imgs = jnp.array(x_test[idx])
preds = predict(state.params if hasattr(state, "params") else state, imgs)

plt.figure(figsize=(10,3))
for i, k in enumerate(idx):
    plt.subplot(1,6,i+1); plt.imshow(x_test[k, ..., 0], cmap="gray")
    plt.axis("off")
    plt.title(f"y={y_test[k]}\nŷ={int(preds[i])}")
plt.tight_layout(); plt.show()
```

---

### Notes & tips

* This should reach \~99% test acc in a handful of epochs on CPU/GPU (exact numbers vary).
* For faster convergence, bump epochs to 10–15, switch to AdamW (Optax path), or add a learning-rate schedule.
* To go bigger: add dropout, batch norm, or a third conv block; swap ReLU→GELU; try data augmentation (random shifts/crops).
* If you hit GPU OOM, reduce batch size.
* If you want a Flax/Linen version (module definitions, state handling, jit/pjit-ready), say the word and I’ll drop it in.
