In [1]:
!which python

/Users/tristantorchet/Desktop/Code/VSCode/LearningJAX/.venv/bin/python


# 1. Asynchronous Dispatch

In [2]:
import numpy as np
import jax.numpy as jnp
from jax import random
x = random.uniform(random.key(0), (1000, 1000))
# Printing the result (i.e. evaluating `repr(result)` or `str(result)`)
# will block until the value is ready.
jnp.dot(x, x) + 3.  

Array([[258.0197 , 249.64864, 257.13373, ..., 236.67952, 250.68939,
        241.36853],
       [265.65982, 256.28915, 262.18246, ..., 242.03183, 256.16754,
        252.44124],
       [262.38916, 255.72754, 261.2306 , ..., 240.83559, 255.41093,
        249.62468],
       ...,
       [259.1581 , 253.092  , 257.72174, ..., 242.23874, 250.72684,
        247.16638],
       [271.22662, 261.91205, 265.33398, ..., 248.26648, 262.0539 ,
        261.337  ],
       [257.16138, 254.7543 , 259.083  , ..., 241.59848, 248.62595,
        243.22354]], dtype=float32)

If you run the following cells on GPU you will see a really short Wall time (100's of us) if you don't use the blocking while a long one (10's of ms) if you use the blocking.

On CPU it doesn't matter.

In [3]:
%time jnp.dot(x, x) 

CPU times: user 47.3 ms, sys: 8.07 ms, total: 55.4 ms
Wall time: 9.76 ms


Array([[255.01973, 246.64864, 254.13373, ..., 233.67952, 247.68939,
        238.36853],
       [262.65982, 253.28915, 259.18246, ..., 239.03183, 253.16756,
        249.44124],
       [259.38916, 252.72754, 258.2306 , ..., 237.83559, 252.41093,
        246.62468],
       ...,
       [256.1581 , 250.092  , 254.72174, ..., 239.23874, 247.72684,
        244.16638],
       [268.22662, 258.91205, 262.33398, ..., 245.26648, 259.0539 ,
        258.337  ],
       [254.16138, 251.7543 , 256.083  , ..., 238.59848, 245.62595,
        240.22354]], dtype=float32)

In [4]:
%time np.asarray(jnp.dot(x, x))  

CPU times: user 46.1 ms, sys: 2.5 ms, total: 48.6 ms
Wall time: 7.14 ms


array([[255.01973, 246.64864, 254.13373, ..., 233.67952, 247.68939,
        238.36853],
       [262.65982, 253.28915, 259.18246, ..., 239.03183, 253.16756,
        249.44124],
       [259.38916, 252.72754, 258.2306 , ..., 237.83559, 252.41093,
        246.62468],
       ...,
       [256.1581 , 250.092  , 254.72174, ..., 239.23874, 247.72684,
        244.16638],
       [268.22662, 258.91205, 262.33398, ..., 245.26648, 259.0539 ,
        258.337  ],
       [254.16138, 251.7543 , 256.083  , ..., 238.59848, 245.62595,
        240.22354]], dtype=float32)

In [5]:
%time jnp.dot(x, x).block_until_ready()

CPU times: user 48.6 ms, sys: 6.96 ms, total: 55.5 ms
Wall time: 8.31 ms


Array([[255.01973, 246.64864, 254.13373, ..., 233.67952, 247.68939,
        238.36853],
       [262.65982, 253.28915, 259.18246, ..., 239.03183, 253.16756,
        249.44124],
       [259.38916, 252.72754, 258.2306 , ..., 237.83559, 252.41093,
        246.62468],
       ...,
       [256.1581 , 250.092  , 254.72174, ..., 239.23874, 247.72684,
        244.16638],
       [268.22662, 258.91205, 262.33398, ..., 245.26648, 259.0539 ,
        258.337  ],
       [254.16138, 251.7543 , 256.083  , ..., 238.59848, 245.62595,
        240.22354]], dtype=float32)

# 2. Tracing

In [6]:
from jax import jit
@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([11.263997 ,  4.3654366,  2.8149018], dtype=float32)

In [7]:
from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)

{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) }

## 2.1 Static

### 2.1.1 Static parameters

In [None]:
@jit
def f(x, neg):
    return -x if neg else x

f(1, True)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /var/folders/th/_3y3_jmx0mj98ffdxdh9fjt80000gn/T/ipykernel_6837/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [None]:
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
    return -x if neg else x

f(1, True)

Array(-1, dtype=int32, weak_type=True)

### 2. 2. Static operations

Notice that although `x` is traced, `x.shape` is a static value. 

However, when we use `jnp.array` and `jnp.prod` on this static value, it becomes a traced value, at which point it cannot be used in a function like `reshape()` that requires a static input <font color='red'>**(recall: array shapes must be static).**</font>.


A useful pattern is to use numpy for operations that should be static (i.e. done at compile-time), and use jax.numpy for operations that should be traced (i.e. compiled and executed at run-time). 

In [None]:
import jax.numpy as jnp
from jax import jit

@jit
def f(x):
    return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /var/folders/th/_3y3_jmx0mj98ffdxdh9fjt80000gn/T/ipykernel_6837/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /var/folders/th/_3y3_jmx0mj98ffdxdh9fjt80000gn/T/ipykernel_6837/1983583872.py:6:19 (f)

In [None]:
@jit
def f(x):
    print(f"x = {x}")
    print(f"x.shape = {x.shape}")
    print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
    # comment this out to avoid the error:
    # return x.reshape(jnp.array(x.shape).prod())

f(x)

x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>


In [12]:
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
    return x.reshape((np.prod(x.shape),))

f(x)

Array([1., 1., 1., 1., 1., 1.], dtype=float32)