In [1]:
!pip install jax jaxlib



In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, make_jaxpr
import timeit

In [3]:
def f(x1, x2):
    return jnp.log(x1) + x1 * x2 - jnp.sin(x2)

dy_dx1 = grad(f, argnums=0)
dy_dx2 = grad(f, argnums=1)

In [4]:
x1, x2 = 2.0, 5.0
print("f(x1, x2)  =", f(x1, x2))
print("∂f/∂x1     =", dy_dx1(x1, x2))
print("∂f/∂x2     =", dy_dx2(x1, x2))

f(x1, x2)  = 11.652072
∂f/∂x1     = 5.5
∂f/∂x2     = 1.7163378


In [5]:

print("JAXPR for ∂f/∂x1:")
print(make_jaxpr(dy_dx1)(x1, x2))

print("\nJAXPR for ∂f/∂x2:")
print(make_jaxpr(dy_dx2)(x1, x2))

JAXPR for ∂f/∂x1:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = log a
    d:f32[] = mul a b
    e:f32[] = add c d
    f:f32[] = sin b
    _:f32[] = sub e f
    g:f32[] = mul 1.0 b
    h:f32[] = div 1.0 a
    i:f32[] = add_any g h
  in (i,) }

JAXPR for ∂f/∂x2:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = log a
    d:f32[] = mul a b
    e:f32[] = add c d
    f:f32[] = sin b
    g:f32[] = cos b
    _:f32[] = sub e f
    h:f32[] = neg 1.0
    i:f32[] = mul h g
    j:f32[] = mul a 1.0
    k:f32[] = add_any i j
  in (k,) }


In [12]:

g1 = lambda x1, x2: (jit(f)(x1, x2), jit(dy_dx1)(x1, x2), jit(dy_dx2)(x1, x2))
g2 = jit(lambda x1, x2: (f(x1, x2), dy_dx1(x1, x2), dy_dx2(x1, x2)))

print("g1 timing (1000 runs):")
print(timeit.timeit(lambda: g1(2.0, 5.0), number=1000), "sec")

print("g2 timing (1000 runs):")
print(timeit.timeit(lambda: g2(2.0, 5.0), number=1000), "sec")

g1 timing (1000 runs):
0.42857839599992076 sec
g2 timing (1000 runs):
0.07183710800006793 sec


In [13]:
x1s = jnp.linspace(1.0, 10.0, 1000)
x2s = x1s + 1

# a) Batch both
batch1 = vmap(g2, in_axes=(0, 0))
out1 = batch1(x1s, x2s)

# b) Batch x1 only, fix x2
batch2 = vmap(g2, in_axes=(0, None))
out2 = batch2(x1s, 0.5)

print("First 5 results (batch1):", out1[0][:5])
print("First 5 results (batch2):", out2[0][:5])

First 5 results (batch1): [1.0907025 1.1305652 1.1705842 1.2107604 1.2510945]
First 5 results (batch2): [0.02057445 0.0340476  0.04744107 0.06075621 0.07399434]
