# Tutorial: Exporting StableHLO from JAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][jax-tutorial-colab]
[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][jax-tutorial-kaggle]

JAX is a Python library for high-performance numerical computing. This tutorial shows how to export JAX and Flax (JAX-powered neural network library) models to StableHLO, and directly to TensorFlow SavedModel.

## Tutorial Setup

### Install required dependencies

We use `jax` and `jaxlib` (JAX's support library with compiled binaries), along with `flax` and `transformers` for some models to export.
We also need to install `tensorflow` to work with SavedModel, and recommend using `tensorflow-cpu` or `tf-nightly` for this tutorial.

[jax-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb
[jax-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb

In [1]:
!pip install -U jax jaxlib flax transformers tensorflow-cpu



In [2]:
#@title Define `get_stablehlo_asm` to help with MLIR printing
from jax._src.interpreters import mlir as jax_mlir
from jax._src.lib.mlir import ir

# Returns prettyprint of StableHLO module without large constants
def get_stablehlo_asm(module_str):
  with jax_mlir.make_ir_context():
    stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())
    return stablehlo_module.operation.get_asm(large_elements_limit=20)

# Disable logging for better tutorial rendering
import logging
logging.disable(logging.WARNING)



_Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability._

# Function transformations with Jaxpr

In [6]:
import jax
import jax.numpy as jnp
import numpy as np

def plus(x,y):
  return jnp.add(x,y)

# Create abstract input shapes
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]

jax.make_jaxpr(plus)(*input_shapes)

{ [34;1mlambda [39;22m; a[35m:i32[][39m b[35m:i32[][39m. [34;1mlet[39;22m c[35m:i32[][39m = add a b [34;1min [39;22m(c,) }

## Export JAX model to StableHLO using `jax.export`

In this section we'll export a basic JAX function and a Flax model to StableHLO.

The preferred API for export is [`jax.export`](https://jax.readthedocs.io/en/latest/jax.export.html#module-jax.export). The function to export must be JIT transformed, specifically a result of `jax.jit`, to be exported to StableHLO.

### Export basic JAX model to StableHLO

Let's start by exporting a basic `plus` function to StableHLO, using `np.int32` argument types to trace the function.

Export requires specifying shapes using `jax.ShapeDtypeStruct`, which can be constructed from NumPy values.

In [8]:
import jax
from jax import export
import jax.numpy as jnp
import numpy as np

# Create a JIT-transformed function
@jax.jit
def plus(x,y):
  return jnp.add(x,y)

# Create abstract input shapes
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]

# Export the function to StableHLO
stablehlo_add = export.export(plus)(*input_shapes).mlir_module()
print(get_stablehlo_asm(stablehlo_add))

module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32>
    return %0 : tensor<i32>
  }
}



### Export Hugging Face FlaxResNet18 to StableHLO

Now let's look at a simple model that appears in the wild, `resnet18`.

We'll export a `flax` model from the Hugging Face `transformers` ResNet page, [FlaxResNetModel](https://huggingface.co/docs/transformers/en/model_doc/resnet#transformers.FlaxResNetModel). This steps setup was copied from the Hugging Face documentation.

The documentation also states: _"Finally, this model supports inherent JAX features such as: **Just-In-Time (JIT) compilation** ..."_ which means it is perfect for export.

Similar to our very basic example, our steps for export are:

1. Instantiate a callable (model/function)
2. JIT-transform it with `jax.jit`
3. Specify shapes for export using `jax.ShapeDtypeStruct` on NumPy values
4. Use the JAX `export` API to get a StableHLO module

In [9]:
from transformers import AutoImageProcessor, FlaxResNetModel
import jax
import numpy as np

# Construct jit-transformed flax model with sample inputs
resnet18 = FlaxResNetModel.from_pretrained("microsoft/resnet-18", return_dict=False)
resnet18_jit = jax.jit(resnet18)
sample_input = np.random.randn(1, 3, 224, 224)
input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)

# Export to StableHLO
stablehlo_resnet18_export = export.export(resnet18_jit)(input_shape)
resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())
print(resnet18_stablehlo[:600], "\n...\n", resnet18_stablehlo[-345:])

NotFoundError: /usr/local/lib/python3.12/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so: undefined symbol: _ZN10tensorflow15TensorShapeBaseINS_11TensorShapeEEC2EN4absl12lts_202308024SpanIKlEE

### Export with dynamic batch size

Now let's export that same model with a dynamic batch size!

In the first example, we used an input shape of `tensor<1x3x224x224xf32>`, specifying strict constraints on the input shape. If we want to defer the concrete shapes used in compilation until a later point, we can specify a `symbolic_shape`. In this example, we'll export using `tensor<?x3x224x224xf32>`.

Symbolic shapes are specified using `export.symbolic_shape`, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: `2,a * a,5` to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an `export.SymbolicScope` to avoid unintentional name clashes.

In [None]:
# Construct dynamic sample inputs
dyn_scope = export.SymbolicScope()
dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)

# Export to StableHLO
dyn_resnet18_export = export.export(resnet18_jit)(dyn_input_shape)
dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())
print(dyn_resnet18_stablehlo[:1900], "\n...\n", dyn_resnet18_stablehlo[-1000:])

A few things to note in the exported StableHLO:

1. The exported program now has `tensor<?x3x224x224xf32>`. These input types can be refined in many ways: StableHLO has APIs to [refine shapes](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L69-L99) and [canonicalize dynamic programs](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L18-L28) to static programs. TensorFlow SavedModel execution also takes care of refinement which we'll see in the next example.
2. JAX will generate guards to ensure the values of `a` are valid, in this case `a > 1` is checked. These can be washed away at compile time once refined.

## Export to TensorFlow SavedModel

It is common to export a StableHLO model to SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via [TensorFlow Serving](https://github.com/tensorflow/serving).

JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section, we'll be using our dynamic model from the previous section.

### Export to SavedModel using `jax2tf`

JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in `jax.experimental.jax2tf`. This uses the `export` function under the hood, so the same `jit` requirements apply.

Full details on `jax2tf` can be found in the [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#jax-and-tensorflow-interoperation-jax2tfcall_tf). For this example, we'll only need to know the `polymorphic_shapes` option to specify our dynamic batch dimension.

In [None]:
from jax.experimental import jax2tf
import tensorflow as tf

exported_f = jax2tf.convert(resnet18, polymorphic_shapes=["(a,3,224,224)"])

# Copied from the jax2tf README.md > Usage: saved model
my_model = tf.Module()
my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))
tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

!ls /tmp/resnet18_tf

### Reload and call the SavedModel

Now we can load that SavedModel and compile using our `sample_input` from a previous example.

_Note: The restored model does *not* require JAX to run, just XLA._

In [None]:
restored_model = tf.saved_model.load('/tmp/resnet18_tf')
restored_result = restored_model.f(tf.constant(sample_input, tf.float32))
print("Result shape:", restored_result[0].shape)

## Troubleshooting

### `jax.jit` issues

If the function can be JIT'ed, then it can be exported. Ensure `jax.jit` works first, or look in desired project for uses of JIT already (for example, [AlphaFold's `apply`](https://github.com/google-deepmind/alphafold/blob/dbe2a438ebfc6289f960292f15dbf421a05e563d/alphafold/model/model.py#L89) can be exported easily).

See [JAX's JIT compilation documentation](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`jax.jit` API reference and examples](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) for troubleshooting JIT transformations. The most common issue is control flow, which can often be resolved with `static_argnums` / `static_argnames` as in the linked example.

### Support tickets

You can open an issue on GitHub for further help. Include a reproducible example using one of the above APIs in your issue report, this will help get the issue resolved much quicker!