# Introduction
Wanted a safe space where I could play with JAX.

# Imports

In [2]:
import jax

In [3]:
import jax
import jax.numpy as jnp

# Arrays

In [4]:
x = jnp.arange(5)
isinstance(x, jax.Array)

True

In [5]:
x.devices()
x.sharding

SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

In [6]:
def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))

1.05


In [7]:
@jax.jit
def f(x):
    print(x)
    return x + 1

x = jnp.arange(4)
result = f(x)

Traced<ShapedArray(int32[4])>with<DynamicJaxprTrace(level=1/0)>


In [8]:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)

{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) }

In [9]:
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]


In [10]:
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]


In [11]:
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]


# Quickstart

## `jax.jit()`

In [17]:
from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

5.76 ms ± 1.07 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x) # compiles on first call
%timeit selu_jit(x).block_until_ready()

835 µs ± 63 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## `jax.grad()`

In [15]:
from jax import grad

def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [14]:
def first_finite_differences(f, x, eps=1E-3):
    return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                      for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


Thoughts:
- interesting that grad wraps around *functions*
- would be interesting to have a better idea of what is going on under the hood

In [16]:
from jax import jacobian
print(jacobian(jnp.exp)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]


Thoughts: seems like magic that this can adapt to the tensor dimension

In [24]:
from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085776 -0.        ]
 [-0.         -0.         -0.07996249]]


In [28]:
# Lennie testing: this does also work with 'standard container'
def list_fn(myl: list):
    return jnp.exp(myl[0]) + jnp.exp(myl[1])

print(jacobian(list_fn)([jnp.arange(4.0), jnp.arange(2,6.0)]))

[Array([[ 1.       ,  0.       ,  0.       ,  0.       ],
       [ 0.       ,  2.7182817,  0.       ,  0.       ],
       [ 0.       ,  0.       ,  7.389056 ,  0.       ],
       [ 0.       ,  0.       ,  0.       , 20.085537 ]], dtype=float32), Array([[  7.389056,   0.      ,   0.      ,   0.      ],
       [  0.      ,  20.085537,   0.      ,   0.      ],
       [  0.      ,   0.      ,  54.598152,   0.      ],
       [  0.      ,   0.      ,   0.      , 148.41316 ]], dtype=float32)]


In [29]:
# working out what vjp does...
def f(x, y):
    return jnp.sin(x), jnp.cos(y)

primals_out, f_vjp = jax.vjp(f, 0.5, 1.0)
# evalute the function f at *primals and return value and also a vjp function
# this sends cotangent vector with shape of primals_out to tuple of cotangent vectors with shape of primals
# i.e. think of deriv of f: R^n -> R^m not as matrix of shape (m,n) but as a function R^m to a linear map R^n -> R

In [30]:
# to check previous thinking
def f(x, y, z):
    return jnp.sin(x), y + z

primals_out, f_vjp = jax.vjp(f, 0.5, 1.0, 2.0)


Notes:
- when print `f_vjp` it returns a Partial object
- subsections of this printed string include jaxprs (see above)

In [34]:
print(primals_out)
print(f_vjp)

(Array(0.47942555, dtype=float32, weak_type=True), Array(3., dtype=float32, weak_type=True))
Partial(_HashableCallableShim(functools.partial(<function _vjp_pullback_wrapper at 0x7f5d5de67eb0>, 'f', [ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], (PyTreeDef((*, *)), PyTreeDef((*, *, *))))), Partial(_HashableCallableShim(functools.partial(<function vjp.<locals>.unbound_vjp at 0x7f5d5cb293f0>, [(ShapedArray(float32[], weak_type=True), None), (ShapedArray(float32[], weak_type=True), None)], { lambda a:f32[]; b:f32[] c:f32[] d:f32[]. let
    e:f32[] = pjit[
      name=sin
      jaxpr={ lambda ; f:f32[] g:f32[]. let h:f32[] = mul f g in (h,) }
    ] b a
    i:f32[] = pjit[
      name=_add
      jaxpr={ lambda ; j:f32[] k:f32[]. let l:f32[] = add j k in (l,) }
    ] c d
  in (e, i) })), (Array(0.87758255, dtype=float32, weak_type=True),)))


In [36]:
f_vjp((1.0, 2.0))

(Array(0.87758255, dtype=float32, weak_type=True),
 Array(2., dtype=float32, weak_type=True),
 Array(2., dtype=float32, weak_type=True))

Saying: 'what is the derivative in this direction?'  

Question to self: why might that be useful?  
Think this is because it is what you need for backprop, and saves you having to compute a full Jacobian matrix (which would be big)  
I guess this helps reduce the amount of information that needs to be passed around

Resource: really nice explanation in https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

Optional extension: try out more of those ways to understand what is going on.

## `jax.vmap()`
Auto-vectorisation

In [37]:
random

<module 'jax.random' from '/store/DPMMS/ww347/ML_reproductions/jax_ecosystem/local_conda_env_jax/lib/python3.10/site-packages/jax/random.py'>

In [38]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)

In [39]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
1.08 ms ± 188 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [41]:
import numpy as np

@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, mat.T)

np.testing.assert_allclose(
    naively_batched_apply_matrix(batched_x),
    batched_apply_matrix(batched_x), 
    atol=1E-6, rtol=1E-3,
)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
23.3 µs ± 1.01 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [45]:
from jax import vmap

@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

np.testing.assert_allclose(
    naively_batched_apply_matrix(batched_x),
    vmap_batched_apply_matrix(batched_x), 
    atol=1E-6, rtol=1E-3,
)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
32 µs ± 1.13 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Note: for some reason really did need to wind down the rtol to avoid an assertion error

Comments: this does seem nice and naively brings down execution times by a factor of 20.

# Just-in-time compilation
https://jax.readthedocs.io/en/latest/jit-compilation.html


In [46]:
global_list = []
def log2(x):
    global_list.append(x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


Note: nothing about side-effect.  
Feature not bug!

'When tracing, JAX wraps each argument by a tracer object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.'

In [56]:
def log2_with_print(x):
    print("printed x:", x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2

jaxpr = jax.make_jaxpr(log2_with_print)(3.0)
print('\njaxpr:', jaxpr)


printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>

jaxpr: { lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


so printing happens when the jaxpr is evaluated; but doesn't show up in the jaxpr

'A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it. For example, if we have a Python conditional, the jaxpr will only know about the branch we take:'

In [57]:
def log2_if_rank_2(x):
    if x.ndim == 2:
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2
    else:
        return x

print('Rank 1 example')
print(jax.make_jaxpr(log2_if_rank_2)(jnp.array([3.0, 4.0])))

print('Rank 2 example')
print(jax.make_jaxpr(log2_if_rank_2)(jnp.array([[3.0, 4.0], [5.0, 6.0]])))

Rank 1 example
{ lambda ; a:f32[2]. let  in (a,) }
Rank 2 example
{ lambda ; a:f32[2,2]. let
    b:f32[2,2] = log a
    c:f32[] = log 2.0
    d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
    e:f32[2,2] = div b d
  in (e,) }


## JIT compiling a function
Just repeated from jit section of quickstart

## Why can't we just JIT everything?

In [58]:
# Condition on value of x.

def f(x):
    if x > 0:
        return x
    else:
        return 2 * x
    
jax.jit(f)(10)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_2458417/2494647410.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [59]:
# While loop conditioned on x and n
def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

jax.jit(g)(10, 20)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /tmp/ipykernel_2458417/3570161443.py:2 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Options:
- Can rewrite code to avoid conditionals on value
- Use [Control flow operators](https://jax.readthedocs.io/en/latest/jax.lax.html#lax-control-flow) like jax.lax.cond(). 

In [75]:
# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
    return prev_i + 1

def g_inner_jitted(x, n):
    i = 0
    while i < n:
        i = loop_body(i)
    return x + i

g_inner_jitted(10, 20)

Array(30, dtype=int32, weak_type=True)

In [76]:
# Exercise: implement same function but using jax.lax.cond
from jax.lax import while_loop

# my version
def g_lax(x, n):
    i = 0
    while_loop(cond_fun=lambda i: i < n,
               body_fun=lambda i: i + 1,
               init_val=i)
    return x + i

jax.jit(g_lax)(10, 20)

Array(10, dtype=int32, weak_type=True)

In [77]:
def g_lax_gpt4(x, n):
    def body_fun(i):
        return i + 1
    def cond_fun(i):
        return i < n
    i = jax.lax.while_loop(cond_fun, body_fun, 0)
    return x + i

Before adding in the block until ready I got
```
Jitted gpt4 version
10.9 µs ± 508 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Jitted version
7.28 µs ± 700 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Original version
652 ns ± 72.7 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
Inner jitted version
632 µs ± 9.04 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```

In [79]:
def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

In [81]:
gitted = jax.jit(g_lax)
gitted(10, 20)

gitted_gpt4 = jax.jit(g_lax_gpt4)
gitted_gpt4(10, 20)

print('Jitted gpt4 version')
%timeit gitted_gpt4(10, 20).block_until_ready()

print('Jitted version')
%timeit gitted(10, 20).block_until_ready()

print('Original version')
%timeit g(10, 20)

print('Inner jitted version')
%timeit g_inner_jitted(10, 20).block_until_ready()

Jitted gpt4 version
12.1 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Jitted version
7.86 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Original version
575 ns ± 8.62 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
Inner jitted version
642 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Note: this is somewhat surprising: the original version is far faster than the other options!

Explanation from GPT4: the jit compilation happens

In [None]:
# While loop conditioned on x and n
def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i

# While loop conditioned on x and n with a jitted body.
@jax.jit
def loop_body(prev_i):
    return prev_i + 1

def g_inner_jitted(x, n):
    i = 0
    while i < n:
        i = loop_body(i)
    return x + i

# My idea: use jax.lax
from jax.lax import while_loop

def g_lax(x, n):
    i = 0
    while_loop(cond_fun=lambda i: i < n,
               body_fun=lambda i: i + 1,
               init_val=i)
    return x + i

gitted = jax.jit(g_lax)
gitted(10, 20)

print('Jitted version')
%timeit gitted(10, 20)

print('Original version')
%timeit g(10, 20)

print('Inner jitted version')
%timeit g_inner_jitted(10, 20)

In [73]:
import jax
import jax.numpy as jnp
from jax import jit
from jax.lax import while_loop

# Initial matrix A and power n
A = jnp.array(jax.random.normal(jax.random.PRNGKey(0), (20, 20)))
power = 20

# Original version using a simple loop
def matrix_power(A, n):
    result = A
    for _ in range(1, n):
        result = result @ A
    return result

# Inner jitted version with jitted multiplication
@jit
def multiply_matrices(B, A):
    return B @ A

def matrix_power_inner_jitted(A, n):
    result = A
    for _ in range(1, n):
        result = multiply_matrices(result, A)
    return result

# JAX while_loop version
def matrix_power_lax(A, n):
    def body_fun(carry):
        result, A, i = carry
        return (result @ A, A, i + 1)
    
    def cond_fun(carry):
        _, _, i = carry
        return i < n

    final_result, _, _ = while_loop(cond_fun, body_fun, (A, A, 1))
    return final_result

# Jitting the while_loop version
matrix_power_lax_jitted = jit(matrix_power_lax)

# Timing and comparing the versions
import timeit

print('Original version')
print(timeit.timeit('matrix_power(A, power)', globals=globals(), number=10000))

print('Inner jitted version')
print(timeit.timeit('matrix_power_inner_jitted(A, power)', globals=globals(), number=10000))

print('JAX while_loop jitted version')
print(timeit.timeit('matrix_power_lax_jitted(A, power)', globals=globals(), number=10000))


Original version
2.3959092549048364
Inner jitted version
2.0549116819165647
JAX while_loop jitted version
0.410714193014428


Interestingly enough there does not seem to be a massive difference here.  
But at least the ordering is now as one might expect / hope (the more complicated version is in fact better).

## Making arguments static

Can tell JAX to re-compile the function for each new value of the specified static input.  
(Good if limited set of static values). 

To do this when using jit as a decorator can use functools.partial

## JIT and caching

In [74]:
from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()

jit called in a loop with partials:
311 ms ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
290 ms ± 2.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
2.56 ms ± 24.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Note that this uses the block_until_ready  
Let's see if that tweaks the answer in prev code!