In [6]:
import jax
import jax.numpy as jnp

def f(x, y): return 2 * x + y
x, y = 3, 4

lowered = jax.jit(f).lower(x, y)

# Print lowered HLO
print(lowered.as_text())
# module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
#   func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
#     %c = stablehlo.constant dense<2> : tensor<i32>
#     %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
#     %1 = stablehlo.add %0, %arg1 : tensor<i32>
#     return %1 : tensor<i32>
#   }
# }

compiled = lowered.compile()

# Query for cost analysis, print FLOP estimate
compiled.cost_analysis()[0]['flops']
2.0

# Execute the compiled function!
print(compiled(x, y))
# Array(10, dtype=int32, weak_type=True)


module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<i32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<2> : tensor<i32>
    %1 = stablehlo.multiply %0, %arg0 : tensor<i32>
    %2 = stablehlo.add %1, %arg1 : tensor<i32>
    return %2 : tensor<i32>
  }
}

10


In [7]:
i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)

Array(10, dtype=int32)

In [8]:
x_1d = y_1d = jnp.arange(3)
jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d)  

TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

In [9]:
x_f = y_f = jnp.float32(72.)
jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f)  

TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with float32[]
Argument 'y' compiled with int32[] and called with float32[]

In [12]:
lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)
# Lowered HLO, specialized to the *value* of the first argument (7)
print(lowered_with_x.as_text())
lowered_with_x.compile()(5)


module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<14> : tensor<i32>
    %1 = stablehlo.add %0, %arg0 : tensor<i32>
    return %1 : tensor<i32>
  }
}



Array(19, dtype=int32, weak_type=True)

In [13]:
jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)

TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

In [14]:
jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)

Array(25, dtype=int32)