# Example 1: Converting a simple JAX function

This Colab notebook accompanies the [JAX on the Web with TensorFlow.js](https://blog.tensorflow.org/2022/08/JAX-on-the-Web-with-TensorFlow.js.html) blog post.

## Setup code

In [None]:
!pip install tensorflowjs -q

In [None]:
# General imports
import json
import numpy as np
import os
import glob
import string

from IPython.core.display import display, HTML, Javascript
import google.colab.html
import google.colab.output
import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflowjs as tfjs

In [None]:
# This is a helper function for running inference on a TensorFlow.js model
# in Colab directly.

_TFJS_SRC_URL = 'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@3.11.0'

# We have to make HTML resources persistent even if they are not explicitly passed
# to JavaScript. Otherwise, they will get garbage collected.
global_refs = {}

def get_tfjs_predict_fn(model_dir):
  """Load a TF SavedModel from `model_dir` and return a prediction function.
  
  Caling the prediction function will run inference on the TFjs model in the
  browser.

  Arg:
    model_dir: Location of the TF SavedModel.
  """
  model_path = os.path.join(model_dir, 'model.json')
  ref = google.colab.html.create_resource(filepath=model_path, route=model_path)
  weight_path = os.path.join(model_dir, 'group1-shard1of1.bin')

  def add_resource(path):
    with open(path, 'rb') as f:
      return google.colab.html.create_resource(content=f.read(), route=path)

  global global_refs
  global_refs.update({
      p: add_resource(p) for p in glob.glob(os.path.join(model_dir, '*.bin'))})

  def call_tfjs(x):
    """Calls the TFjs model in the browser and returns the output."""
    print(f'NOTE: Running TFJs inference for model in {model_dir}...')
    input_json = json.dumps(jax.tree_map(lambda x: x.tolist(), x))
    display(HTML(f'<script src="{_TFJS_SRC_URL}"></script>'))
    display(Javascript(string.Template('''
      async function getOutput() {
        const model = await tf.loadGraphModel('$model_url');
        const x = tf.tensor(JSON.parse('$inputs'));
        let result = model.predict(x);
        console.log(result.shape);
        return [await result.data(), result.shape];
      }
      window.modelOutput = getOutput();
    ''').substitute(dict(model_url=ref.url, inputs=input_json))))

    output_dict, shape = google.colab.output.eval_js('modelOutput')
    return np.array([*output_dict.values()]).reshape(shape).astype(np.float32)
  return call_tfjs

## Convert JAX --> TFjs

First, you’ll convert a few simple JAX functions using `converters.convert_jax()`.

The following example uses a single parameter `weight` and implements a function `prod`, which multiplies the input with the parameter (in a real example, `params` will contain the all weights of the modules used in the neural network):

In [None]:
def prod(params, xs):
  return params['weight'] * xs

In [None]:
params = {'weight': jnp.array([0.5, 1])}
xs = np.arange(6).reshape((3, 2))
jax_result = prod(params, xs)
print(jax_result)

[[0. 1.]
 [1. 3.]
 [2. 5.]]


In [None]:
model_dir = 'example1'
tfjs.converters.convert_jax(
    prod,
    params,
    input_signatures=[tf.TensorSpec((3, 2), tf.float32)],
    model_dir=model_dir)

# Verify the outputs have been written.
!ls -l $model_dir



Writing weight file example1/model.json...
total 8
-rw-r--r-- 1 root root    8 Aug 31 09:07 group1-shard1of1.bin
-rw-r--r-- 1 root root 1308 Aug 31 09:07 model.json


In [None]:
tfjs.converters.convert_jax(
    prod,
    params,
    input_signatures=[tf.TensorSpec((3, 2), tf.float32)],
    model_dir=model_dir)
tfjs_predict_fn = get_tfjs_predict_fn(model_dir)
print(tfjs_predict_fn(xs))  # Same output as JAX.




Writing weight file example1/model.json...
NOTE: Running TFJs inference for model in example1...


<IPython.core.display.Javascript object>

[[0. 1.]
 [1. 3.]
 [2. 5.]]


Run inference in the browser and verify the results match those of JAX.

In [None]:
tfjs_predict_fn = get_tfjs_predict_fn(model_dir)
tfjs_result = tfjs_predict_fn(xs)
assert (jax_result == tfjs_result).all()
print('TFjs result:', tfjs_result)

NOTE: Running TFJs inference for model in example1...


<IPython.core.display.Javascript object>

TFjs result: [[0. 1.]
 [1. 3.]
 [2. 5.]]


## Supporting Dynamic Shapes

Dynamic shapes do not work in the model we just converted, which can be seen from the error when we try to run inference on the model with a different input shape than `(3, 2)`.

In [None]:
try:
  tfjs_result = tfjs_predict_fn(np.ones((5, 2)))
except Exception as e:
  print('\nCAUGHT EXCEPTION:\n', e)

NOTE: Running TFJs inference for model in example1...


<IPython.core.display.Javascript object>


CAUGHT EXCEPTION:
 Error: The shape of dict['xs_0'] provided in model.execute(dict) must be [3,2], but was [5,2]


Dynamic shapes are supported as usual in Tensorflow by passing the value `None` for the dynamic dimensions in `input_signature`.

Additionally, one should pass the argument `polymorphic_shapes` specifying names for dynamic dimensions.

Note that [polymorphism](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#shape-polymorphic-conversion) is a term coming from type theory, but here we use it to imply that the function works for multiple related shapes, such as multiple batch sizes. This is necessary for shape checking in the JAX function (see [here](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more documentation on this notation).

In [None]:
model_dir = 'example2'

tfjs.converters.convert_jax(
    prod,
    params,
    input_signatures=[tf.TensorSpec((None, 2), tf.float32)],
    polymorphic_shapes=['(b, 2)'],
    model_dir=model_dir)

tfjs_result = get_tfjs_predict_fn(model_dir)(np.arange(10).reshape((5, 2)))
print('TFjs result:', tfjs_result)



Writing weight file example2/model.json...
NOTE: Running TFJs inference for model in example2...


<IPython.core.display.Javascript object>

TFjs result: [[0. 1.]
 [1. 3.]
 [2. 5.]
 [3. 7.]
 [4. 9.]]


In [None]:
get_tfjs_predict_fn(model_dir)(np.array([[1., 2.]]))

NOTE: Running TFJs inference for model in example2...


<IPython.core.display.Javascript object>

array([[0.5, 2. ]], dtype=float32)

## Multiple arguments and Shape Polymorphism

Below we demonstrate with a simple example how to provide multiply arguments with polymorphic shapes. If one now call the model below with different values for the first dimensions, JAX will return a shape error.

In [None]:
def prod_of_sum(params, x, y):
  return params['weight'] * (x + y)

jax_result = prod_of_sum(params, xs, xs)
model_dir = 'example3'

tfjs.converters.convert_jax(
    prod_of_sum,
    params,
    input_signatures=[tf.TensorSpec((None, 2), tf.float32),
                      tf.TensorSpec((None, 2), tf.float32)],
    polymorphic_shapes=['(b, 2)', '(b, 2)'],
    model_dir=model_dir)



Writing weight file example3/model.json...


The summation `x + y` inside the function `weighted_sum_of_sum` requires the dimensions of both arrays to be equal. So if we pass different variables to the first dimensions in `input_signatures`, we get a shape error. This is very helpful since it allows us to catch these errors before converting.

In [None]:
model_dir = './example4'

try:
  tfjs.converters.convert_jax(
      prod_of_sum,
      params,
      input_signatures=[tf.TensorSpec((None, 2), tf.float32),
                        tf.TensorSpec((None, 2), tf.float32)],
      polymorphic_shapes=['(b, 2)', '(d, 2)'],
      model_dir=model_dir)
except Exception as e:
  print('CAUGHT EXCEPTION:\n', e)

CAUGHT EXCEPTION:
 add got incompatible shapes for broadcasting: (b, 2), (d, 2).
