In [1]:
import jax
import jax.numpy as jnp

In [2]:
jax.__version__

'0.9.0'

In [3]:
jax.devices()

[CudaDevice(id=0)]

## Using jax.jit as transformation or annotation

In [4]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [5]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,))

In [6]:
selu_jit = jax.jit(selu)

In [7]:
%timeit -n100 selu(x).block_until_ready()

The slowest run took 16.10 times longer than the fastest. This could mean that an intermediate result is being cached.
754 μs ± 991 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit -n100 selu_jit(x).block_until_ready()

The slowest run took 9.46 times longer than the fastest. This could mean that an intermediate result is being cached.
101 μs ± 127 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
@jax.jit
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [10]:
z = selu(x) # warmup

In [11]:
%timeit -n100 selu(x).block_until_ready()

49 μs ± 3.99 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Compiling and running on specific hardware

In [12]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [13]:
selu_jit_cpu = jax.jit(selu, backend='cpu')
selu_jit_gpu = jax.jit(selu, backend='gpu')

In [14]:
%timeit -n100 selu(x).block_until_ready()

188 μs ± 5.25 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%timeit -n100 selu_jit_cpu(x).block_until_ready()

913 μs ± 194 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
%timeit -n100 selu_jit_gpu(x).block_until_ready()

The slowest run took 7.09 times longer than the fastest. This could mean that an intermediate result is being cached.
91.7 μs ± 100 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
x_cpu = jax.device_put(x, jax.devices('cpu')[0])
x_gpu = jax.device_put(x, jax.devices('gpu')[0])

In [18]:
%timeit -n100 selu(x_cpu).block_until_ready()

1.83 ms ± 361 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
%timeit -n100 selu(x_gpu).block_until_ready()

The slowest run took 5.43 times longer than the fastest. This could mean that an intermediate result is being cached.
461 μs ± 403 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
%timeit -n100 selu_jit_cpu(x_cpu).block_until_ready()

143 μs ± 65 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%timeit -n100 selu_jit_gpu(x_gpu).block_until_ready()

51.6 μs ± 4.52 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
%timeit -n100 selu_jit_cpu(x_gpu).block_until_ready()

744 μs ± 163 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
%timeit -n100 selu_jit_gpu(x_cpu).block_until_ready()

1.36 ms ± 369 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
x_cpu.device

CpuDevice(id=0)

In [25]:
x_gpu.device

CudaDevice(id=0)

In [26]:
selu_jit_gpu

<PjitFunction of <function selu at 0x7f81a0bc2d40>>

In [27]:
selu_jit_cpu

<PjitFunction of <function selu at 0x7f81a0bc2d40>>

## Working with function arguments

In [28]:
def dist(order, x, y):
  print("Compiling")
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

In [29]:
dist_jit = jax.jit(dist, static_argnums=0)

In [30]:
dist_jit(1, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Compiling


Array(4., dtype=float32)

In [31]:
dist_jit(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Compiling


Array(2.828427, dtype=float32)

In [32]:
dist_jit(1, jnp.array([10.0, 10.0]), jnp.array([2.0, 2.0]))

Array(16., dtype=float32)

In [33]:
from functools import partial

@partial(jax.jit, static_argnums=0)
def dist(order, x, y):
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)


In [34]:
dist

<PjitFunction of <function dist at 0x7f8141742ac0>>

In [35]:
dist(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Array(2.828427, dtype=float32)

In [36]:
def dense_layer(x, w, b, activation_func):
    return activation_func(x*w+b)

In [37]:
x = jnp.array([1.0, 2.0, 3.0])
w = jnp.ones((3,3))
b = jnp.ones(3)

In [38]:
dense_layer_jit = jax.jit(dense_layer)

dense_layer_jit(x, w, b, selu)

TypeError: Error interpreting argument to <function dense_layer at 0x7f81417434c0> as an abstract array. The problematic value is of type <class 'function'> and was passed to the function at path activation_func.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

In [39]:
dense_layer_jit = jax.jit(dense_layer, static_argnums=3)

In [40]:
dense_layer_jit(x, w, b, selu)

Array([[2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804]], dtype=float32)

## Pure functions again

In [41]:
global_state = 1

def impure_function(x):
  print(f'Side-effect: printing x={x}')
  y = x*global_state
  return y

In [42]:
impure_function_jit = jax.jit(impure_function)

In [43]:
impure_function_jit(10)

Side-effect: printing x=JitTracer(~int32[])


Array(10, dtype=int32, weak_type=True)

In [44]:
impure_function_jit(10)

Array(10, dtype=int32, weak_type=True)

In [45]:
impure_function_jit(11)

Array(11, dtype=int32, weak_type=True)

In [46]:
global_state = 2

In [47]:
impure_function_jit(10)

Array(10, dtype=int32, weak_type=True)

In [48]:
impure_function(10)

Side-effect: printing x=10


20

## Jaxpr

Getting jaxpr

In [49]:
def f1(x, y, z):
  return jnp.sum(x + y * z)

In [50]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 1.0, 0.0]).T

In [51]:
jax.make_jaxpr(f1)(x,y,z)

{ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[3,3][39m c[35m:f32[3][39m. [34;1mlet
    [39;22md[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] c
    e[35m:f32[3,3][39m = mul b d
    f[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] a
    g[35m:f32[3,3][39m = add f e
    h[35m:f32[][39m = reduce_sum[axes=(0, 1) out_sharding=None] g
  [34;1min [39;22m(h,) }

In [52]:
f1_jaxpr = jax.make_jaxpr(f1)(x,y,z)

In [53]:
type(f1_jaxpr)

jax._src.core.ClosedJaxpr

In [54]:
f1_jaxpr.jaxpr

{ [34;1mlambda [39;22m; a[35m:f32[3][39m b[35m:f32[3,3][39m c[35m:f32[3][39m. [34;1mlet
    [39;22md[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] c
    e[35m:f32[3,3][39m = mul b d
    f[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] a
    g[35m:f32[3,3][39m = add f e
    h[35m:f32[][39m = reduce_sum[axes=(0, 1) out_sharding=None] g
  [34;1min [39;22m(h,) }

In [55]:
f1_jaxpr.consts

[]

Jaxpr for a function with side-effect

In [56]:
def f2(x, y):
  print(f'x={x}, y={y}, z={z}')
  return jnp.sum(x + y * z)

In [57]:
f2_jaxpr = jax.make_jaxpr(f2)(x,y)

x=JitTracer(float32[3]), y=JitTracer(float32[3,3]), z=[2. 1. 0.]


In [58]:
type(f2_jaxpr.jaxpr)

jax._src.core.Jaxpr

In [59]:
f2_jaxpr.jaxpr

{ [34;1mlambda [39;22ma[35m:f32[3][39m; b[35m:f32[3][39m c[35m:f32[3,3][39m. [34;1mlet
    [39;22md[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] a
    e[35m:f32[3,3][39m = mul c d
    f[35m:f32[1,3][39m = broadcast_in_dim[
      broadcast_dimensions=(1,)
      shape=(1, 3)
      sharding=None
    ] b
    g[35m:f32[3,3][39m = add f e
    h[35m:f32[][39m = reduce_sum[axes=(0, 1) out_sharding=None] g
  [34;1min [39;22m(h,) }

In [60]:
f2_jaxpr.jaxpr.constvars

[Var(id=140190085133056):float32[3]]

In [61]:
f2_jaxpr.jaxpr.invars

[Var(id=140190085288512):float32[3], Var(id=140190085291904):float32[3,3]]

In [62]:
f2_jaxpr.jaxpr.outvars

[Var(id=140190085293440):float32[]]

In [63]:
f2_jaxpr.jaxpr.eqns

[a[35m:f32[1,3][39m = broadcast_in_dim[
   broadcast_dimensions=(1,)
   shape=(1, 3)
   sharding=None
 ] b,
 a[35m:f32[3,3][39m = mul b c,
 a[35m:f32[1,3][39m = broadcast_in_dim[
   broadcast_dimensions=(1,)
   shape=(1, 3)
   sharding=None
 ] b,
 a[35m:f32[3,3][39m = add b c,
 a[35m:f32[][39m = reduce_sum[axes=(0, 1) out_sharding=None] b]

In [64]:
f2_jaxpr.jaxpr.effects

set()

In [65]:
type(f2_jaxpr.consts)

list

In [66]:
f2_jaxpr.consts

[Array([2., 1., 0.], dtype=float32)]

In [67]:
type(jax.make_jaxpr(f2))

function

In [68]:
jax.grad(f2)(x,y)

x=GradTracer(primal=[1. 1. 1.], typeof(tangent)=f32[3]), y=[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]], z=[2. 1. 0.]


Array([3., 3., 3.], dtype=float32)

Tracing with control structures

In [69]:
def f3(x):
  y = x
  for i in range(5):
    y += i
  return y

In [70]:
jax.make_jaxpr(f3)(0)

{ [34;1mlambda [39;22m; a[35m:i32[][39m. [34;1mlet
    [39;22mb[35m:i32[][39m = add a 0:i32[]
    c[35m:i32[][39m = add b 1:i32[]
    d[35m:i32[][39m = add c 2:i32[]
    e[35m:i32[][39m = add d 3:i32[]
    f[35m:i32[][39m = add e 4:i32[]
  [34;1min [39;22m(f,) }

In [71]:
jax.jit(f3)(0)

Array(10, dtype=int32, weak_type=True)

In [72]:
def f4(x):
  y = 0
  for i in range(x.shape[0]):
    y += x[i]
  return y

In [73]:
jax.make_jaxpr(f4)(jnp.array([1.0, 2.0, 3.0]))

{ [34;1mlambda [39;22m; a[35m:f32[3][39m. [34;1mlet
    [39;22mb[35m:f32[1][39m = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
    c[35m:f32[][39m = squeeze[dimensions=(0,)] b
    d[35m:f32[][39m = add 0.0:f32[] c
    e[35m:f32[1][39m = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
    f[35m:f32[][39m = squeeze[dimensions=(0,)] e
    g[35m:f32[][39m = add d f
    h[35m:f32[1][39m = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
    i[35m:f32[][39m = squeeze[dimensions=(0,)] h
    j[35m:f32[][39m = add g i
  [34;1min [39;22m(j,) }

In [74]:
jax.jit(f4)(jnp.array([1.0, 2.0, 3.0]))

Array(6., dtype=float32)

Dependance on a parameter value

In [75]:
def f5(x):
  y = 0
  for i in range(x):
    y += i
  return y

In [76]:
f5(5)

10

In [77]:
jax.make_jaxpr(f5)(5)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function f5 at /tmp/ipykernel_2378932/4135095833.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [78]:
jax.jit(f5)(5)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function f5 at /tmp/ipykernel_2378932/4135095833.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [79]:
def relu(x):
  #print(x)
  if x > 0:
    return x
  return 0.0

In [80]:
relu(10.0)

10.0

In [81]:
jax.make_jaxpr(relu)(10.0)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function relu at /tmp/ipykernel_2378932/1355355841.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

Using static parameters to overcome this

In [82]:
jax.make_jaxpr(f5, static_argnums=0)(5)

{ [34;1mlambda [39;22m; . [34;1mlet[39;22m  [34;1min [39;22m(10:i32[],) }

In [83]:
jax.jit(f5, static_argnums=0)(5)

Array(10, dtype=int32, weak_type=True)

In [84]:
jax.make_jaxpr(relu, static_argnums=0)(12.3)

{ [34;1mlambda [39;22m; . [34;1mlet[39;22m  [34;1min [39;22m(12.3:f32[],) }

In [85]:
jax.jit(relu, static_argnums=0)(12.3)

Array(12.3, dtype=float32, weak_type=True)

Rewriting with structured control flow primitives

In [86]:
def f5(x):
  return jax.lax.fori_loop(0, x, lambda i,v: v+i, 0)

In [87]:
f5(5)

Array(10, dtype=int32, weak_type=True)

In [88]:
jax.make_jaxpr(f5)(5)

{ [34;1mlambda [39;22m; a[35m:i32[][39m. [34;1mlet
    [39;22m_[35m:i32[][39m _[35m:i32[][39m b[35m:i32[][39m = while[
      body_jaxpr={ [34;1mlambda [39;22m; c[35m:i32[][39m d[35m:i32[][39m e[35m:i32[][39m. [34;1mlet
          [39;22mf[35m:i32[][39m = add c 1:i32[]
          g[35m:i32[][39m = add e c
        [34;1min [39;22m(f, d, g) }
      body_nconsts=0
      cond_jaxpr={ [34;1mlambda [39;22m; h[35m:i32[][39m i[35m:i32[][39m j[35m:i32[][39m. [34;1mlet
          [39;22mk[35m:bool[][39m = lt h i
        [34;1min [39;22m(k,) }
      cond_nconsts=0
    ] 0:i32[] a 0:i32[]
  [34;1min [39;22m(b,) }

In [89]:
jax.jit(f5)(5)

Array(10, dtype=int32, weak_type=True)

In [90]:
def relu(x):
  return jax.lax.cond(x>0, lambda x: x, lambda x: 0.0, x)

In [91]:
relu(12.3)

Array(12.3, dtype=float32, weak_type=True)

In [92]:
relu(-12.3)

Array(0., dtype=float32, weak_type=True)

In [93]:
jax.make_jaxpr(relu)(12.3)

{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:bool[][39m = gt a 0.0:f32[]
    c[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    d[35m:f32[][39m = cond[
      branches=(
        { [34;1mlambda [39;22m; e[35m:f32[][39m. [34;1mlet[39;22m  [34;1min [39;22m(0.0:f32[],) }
        { [34;1mlambda [39;22m; f[35m:f32[][39m. [34;1mlet[39;22m  [34;1min [39;22m(f,) }
      )
    ] c a
  [34;1min [39;22m(d,) }

In [94]:
jax.jit(relu)(12.3)

Array(12.3, dtype=float32, weak_type=True)

## XLA

In [95]:
def f(x, y, z):
  return jnp.sum(x + y * z)

In [96]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 1.0, 0.0]).T

In [97]:
x

Array([1., 1., 1.], dtype=float32)

In [98]:
y

Array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]], dtype=float32)

In [99]:
z

Array([2., 1., 0.], dtype=float32)

In [100]:
f(x, y, z)

Array(27., dtype=float32)

In [101]:
f_jitted = jax.jit(f)

In [102]:
f_lowered = f_jitted.lower(x,y,z)
print(f_lowered.as_text())

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3xf32>) -> (tensor<f32> {jax.result_info = "result"}) {
    %0 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [0, 1] : (tensor<1x3xf32>) -> tensor<3x3xf32>
    %2 = stablehlo.multiply %arg1, %1 : tensor<3x3xf32>
    %3 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x3xf32>) -> tensor<3x3xf32>
    %5 = stablehlo.add %4, %2 : tensor<3x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %6 = stablehlo.reduce(%5 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x3xf32>, tensor<f32>) -> tensor<f32>
    return %6 : tensor<f32>
  }
}



In [103]:
f_compiled = f_lowered.compile()

In [104]:
f_compiled

<jax._src.stages.Compiled at 0x7f7fa0348d10>

In [105]:
print(f_compiled.as_text())

HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[3]{0}, f32[3,3]{1,0}, f32[3]{0})->f32[]}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="dce56488409c87cbe64cac12e477699d"}

FileNames
1 "/home/icute/repos/personal_website/.venv/lib/python3.13/site-packages/ipykernel/ipkernel.py"
2 "/home/icute/repos/personal_website/.venv/lib/python3.13/site-packages/ipykernel/zmqshell.py"
3 "/home/icute/repos/personal_website/.venv/lib/python3.13/site-packages/IPython/core/interactiveshell.py"
4 "/home/icute/repos/personal_website/.venv/lib/python3.13/site-packages/IPython/core/async_helpers.py"
5 "/tmp/ipykernel_2378932/1458392737.py"
6 "/tmp/ipykernel_2378932/3103132191.py"

FunctionNames
1 "IPythonKernel.do_execute"
2 "ZMQInteractiveShell.run_cell"
3 "InteractiveShell.run_cell"
4 "InteractiveShell._run_cell"
5 "_pseudo_sync_runner"
6 "InteractiveShell.run_cell_async"
7 "

MLIR

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/mlir

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/mlir_hlo



## JIT Limitations

Long representation

In [106]:
def cumulative_sum(x):
  acc = 0.0
  y = []
  for i in range(x.shape[0]):
    acc += x[i]
    y.append(acc)
  return y

In [107]:
cumulative_sum(jnp.array([1.0, 1.0, 5.0, 2.0]))

[Array(1., dtype=float32),
 Array(2., dtype=float32),
 Array(7., dtype=float32),
 Array(9., dtype=float32)]

In [108]:
j = jax.make_jaxpr(cumulative_sum)(jnp.ones(10000))

In [109]:
len(j.jaxpr.eqns)

30000

In [110]:
j

{ [34;1mlambda [39;22m; a[35m:f32[10000][39m. [34;1mlet
    [39;22mb[35m:f32[1][39m = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
    c[35m:f32[][39m = squeeze[dimensions=(0,)] b
    d[35m:f32[][39m = add 0.0:f32[] c
    e[35m:f32[1][39m = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
    f[35m:f32[][39m = squeeze[dimensions=(0,)] e
    g[35m:f32[][39m = add d f
    h[35m:f32[1][39m = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
    i[35m:f32[][39m = squeeze[dimensions=(0,)] h
    j[35m:f32[][39m = add g i
    k[35m:f32[1][39m = slice[limit_indices=(4,) start_indices=(3,) strides=None] a
    l[35m:f32[][39m = squeeze[dimensions=(0,)] k
    m[35m:f32[][39m = add j l
    n[35m:f32[1][39m = slice[limit_indices=(5,) start_indices=(4,) strides=None] a
    o[35m:f32[][39m = squeeze[dimensions=(0,)] n
    p[35m:f32[][39m = add m o
    q[35m:f32[1][39m = slice[limit_indices=(6,) start_indices=(5,) strides=None] 

In [111]:
%time cs = jax.jit(cumulative_sum)(jnp.ones(10000))

CPU times: user 30 s, sys: 386 ms, total: 30.4 s
Wall time: 30.6 s


In [112]:
def cumulative_sum_fast(x):
  result, array = jax.lax.scan(lambda carry, elem: (carry+elem, carry+elem), 0.0, x)
  return array

In [113]:
cumulative_sum_fast(jnp.array([1.0, 1.0, 5.0, 2.0]))

Array([1., 2., 7., 9.], dtype=float32)

In [114]:
j = jax.make_jaxpr(cumulative_sum_fast)(jnp.ones(10000))

In [115]:
len(j.jaxpr.eqns)

1

In [116]:
j

{ [34;1mlambda [39;22m; a[35m:f32[10000][39m. [34;1mlet
    [39;22m_[35m:f32[][39m b[35m:f32[10000][39m = scan[
      _split_transpose=False
      jaxpr={ [34;1mlambda [39;22m; c[35m:f32[][39m d[35m:f32[][39m. [34;1mlet
          [39;22me[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] c
          f[35m:f32[][39m = add e d
          g[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] c
          h[35m:f32[][39m = add g d
        [34;1min [39;22m(f, h) }
      length=10000
      linear=(False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] 0.0:f32[] a
  [34;1min [39;22m(b,) }

In [117]:
%time cs = jax.jit(cumulative_sum_fast)(jnp.ones(10000))

CPU times: user 238 ms, sys: 131 ms, total: 369 ms
Wall time: 378 ms


Class methods

In [118]:
class ScaleClass:
  def __init__(self, scale: jnp.array):
    self.scale = scale

  @jax.jit
  def apply(self, x: jnp.array):
    return self.scale * x

In [119]:
scale_double = ScaleClass(2)

In [120]:
scale_double.apply(10)

TypeError: Error interpreting argument to <function ScaleClass.apply at 0x7f7fa036aca0> as an abstract array. The problematic value is of type <class '__main__.ScaleClass'> and was passed to the function at path self.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

In [121]:
from functools import partial

class ScaleClass:
  def __init__(self, scale: jnp.array):
    self.scale = scale

  def apply(self, x: jnp.array):
    return _apply_helper(self.scale, x)

@partial(jax.jit, static_argnums=0)
def _apply_helper(scale, x):
  return scale*x

In [122]:
scale_double = ScaleClass(2)


In [123]:
scale_double.apply(10)

Array(20, dtype=int32, weak_type=True)

## AOT compilation

In [124]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  print('Function run')
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [125]:
selu_jit = jax.jit(selu)

In [126]:
selu_aot = jax.jit(selu).lower(1.0).compile()

Function run


In [127]:
selu_jit(17.8)

Array(18.702477, dtype=float32, weak_type=True)

In [128]:
selu_aot(17.8)

Array(18.702477, dtype=float32, weak_type=True)

In [129]:
selu_jit(17)

Function run


Array(17.861917, dtype=float32, weak_type=True)

In [130]:
selu_aot(17)

TypeError: Argument types differ from the types for which this computation was compiled. Perhaps you are calling the compiled executable with a different enable_x64 mode than when it was AOT compiled? The mismatches are:
Argument 'x' compiled with float32[] and called with int32[]

In [131]:
selu_jit_batched = jax.vmap(selu_jit)

In [132]:
selu_aot_batched = jax.vmap(selu_aot)

In [133]:
selu_jit_batched(jnp.array([42.0, 78.0, -12.3]))

Function run


Array([44.129444 , 81.95468  , -1.7580913], dtype=float32)

In [134]:
selu_aot_batched(jnp.array([42.0, 78.0, -12.3]))

TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>.