# Train JAX/Flax model on Vertex AI pre-built container and use `jax2tf` to convert to SavedModel

runs on TF2.5 [Vertex Notebook](https://cloud.google.com/vertex-ai/docs/general/notebooks)

In [2]:
import os

In [3]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = "dsparing-sandbox-bucket"
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_NAME = "jax_model"
MODEL_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}"
os.environ['MODEL_DIR'] = MODEL_DIR

In [4]:
%%bash
mkdir -p $TRAINING_APP_FOLDER/trainer
touch $TRAINING_APP_FOLDER/trainer/__init__.py

There is no pre-built Vertex TF2.5 container yet, and `jax.experimental.jax2tf.examples.saved_model_lib.convert_and_save_model` uses the `jit_compile` argument for `tf.function`, which is the TF2.5 new name for `experimental_compile`. So we define a copy of that model with the only difference that jit_compile is still called `experimental_compile` for TF2.4.

In [5]:
%%writefile {TRAINING_APP_FOLDER}/trainer/saved_model_lib.py

from typing import Any, Callable, Sequence, Optional, Union

from jax.experimental import jax2tf  # type: ignore[import]
import tensorflow as tf  # type: ignore[import]


def convert_and_save_model(
    jax_fn: Callable[[Any, Any], Any],
    params,
    model_dir: str,
    *,
    input_signatures: Sequence[tf.TensorSpec],
    polymorphic_shapes: Optional[Union[str, jax2tf.PolyShape]] = None,
    with_gradient: bool = False,
    enable_xla: bool = True,
    compile_model: bool = True,
    saved_model_options: Optional[tf.saved_model.SaveOptions] = None):
  """Convert a JAX function and saves a SavedModel.
  This is an example, for serious uses you will likely want to copy and
  expand it as needed (see note at the top of the model).
  Use this function if you have a trained ML model that has both a prediction
  function and trained parameters, which you want to save separately from the
  function graph as variables (e.g., to avoid limits on the size of the
  GraphDef, or to enable fine-tuning.) If you don't have such parameters,
  you can still use this library function but probably don't need it
  (see jax2tf/README.md for some simple examples).
  In order to use this wrapper you must first convert your model to a function
  with two arguments: the parameters and the input on which you want to do
  inference. Both arguments may be np.ndarray or (nested)
  tuples/lists/dictionaries thereof.
  See the README.md for a discussion of how to prepare Flax and Haiku models.
  Args:
    jax_fn: a JAX function taking two arguments, the parameters and the inputs.
      Both arguments may be (nested) tuples/lists/dictionaries of np.ndarray.
    params: the parameters, to be used as first argument for `jax_fn`. These
      must be (nested) tuples/lists/dictionaries of np.ndarray, and will be
      saved as the variables of the SavedModel.
    model_dir: the directory where the model should be saved.
    input_signatures: the input signatures for the second argument of `jax_fn`
      (the input). A signature must be a `tensorflow.TensorSpec` instance, or a
      (nested) tuple/list/dictionary thereof with a structure matching the
      second argument of `jax_fn`. The first input_signature will be saved as
      the default serving signature. The additional signatures will be used
      only to ensure that the `jax_fn` is traced and converted to TF for the
      corresponding input shapes.
    with_gradient: whether the SavedModel should support gradients. If True,
      then a custom gradient is saved. If False, then a
      tf.raw_ops.PreventGradient is saved to error if a gradient is attempted.
      (At the moment due to a bug in SavedModel, custom gradients are not
      supported.)
    enable_xla: whether the jax2tf converter is allowed to use TFXLA ops. If
      False, the conversion tries harder to use purely TF ops and raises an
      exception if it is not possible. (default: True)
    compile_model: use TensorFlow jit_compiler on the SavedModel. This
      is needed if the SavedModel will be used for TensorFlow serving.
    polymorphic_shapes: if given then it will be used as the
      `polymorphic_shapes` argument to jax2tf.convert for the second parameter of
      `jax_fn`. In this case, a single `input_signatures` is supported, and
      should have `None` in the polymorphic dimensions.
    saved_model_options: options to pass to savedmodel.save.
  """
  if not input_signatures:
    raise ValueError("At least one input_signature must be given")
  if polymorphic_shapes is not None:
    if len(input_signatures) > 1:
      raise ValueError("For shape-polymorphic conversion a single "
                       "input_signature is supported.")
  tf_fn = jax2tf.convert(
    jax_fn,
    with_gradient=with_gradient,
    polymorphic_shapes=[None, polymorphic_shapes],
    enable_xla=enable_xla)

  # Create tf.Variables for the parameters. If you want more useful variable
  # names, you can use `tree.map_structure_with_path` from the `dm-tree` package
  param_vars = tf.nest.map_structure(
    # Due to a bug in SavedModel it is not possible to use tf.GradientTape on
    # a function converted with jax2tf and loaded from SavedModel. Thus, we
    # mark the variables as non-trainable to ensure that users of the
    # SavedModel will not try to fine tune them.
    lambda param: tf.Variable(param, trainable=with_gradient),
    params)
  tf_graph = tf.function(lambda inputs: tf_fn(param_vars, inputs),
                         autograph=False,
                         experimental_compile=compile_model)

  signatures = {}
  # This signature is needed for TensorFlow Serving use.
  signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
    tf_graph.get_concrete_function(input_signatures[0])
  for input_signature in input_signatures[1:]:
    # If there are more signatures, trace and cache a TF function for each one
    tf_graph.get_concrete_function(input_signature)
  wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)
  if with_gradient:
    if not saved_model_options:
      saved_model_options = tf.saved_model.SaveOptions(experimental_custom_gradients=True)
    else:
      saved_model_options.experimental_custom_gradients = True
  tf.saved_model.save(wrapper, model_dir, signatures=signatures,
                      options=saved_model_options)


class _ReusableSavedModelWrapper(tf.train.Checkpoint):
  """Wraps a function and its parameters for saving to a SavedModel.
  Implements the interface described at
  https://www.tensorflow.org/hub/reusable_saved_models.
  """

  def __init__(self, tf_graph, param_vars):
    """Args:
      tf_graph: a tf.function taking one argument (the inputs), which can be
         be tuples/lists/dictionaries of np.ndarray or tensors. The function
         may have references to the tf.Variables in `param_vars`.
      param_vars: the parameters, as tuples/lists/dictionaries of tf.Variable,
         to be saved as the variables of the SavedModel.
    """
    super().__init__()
    # Implement the interface from https://www.tensorflow.org/hub/reusable_saved_models
    self.variables = tf.nest.flatten(param_vars)
    self.trainable_variables = [v for v in self.variables if v.trainable]
    # If you intend to prescribe regularization terms for users of the model,
    # add them as @tf.functions with no inputs to this list. Else drop this.
    self.regularization_losses = []
    self.__call__ = tf_graph

Overwriting training_app/trainer/saved_model_lib.py


There is no pre-built Vertex TF2.5 container yet, and jax2tf in jax==0.2.14 uses `tensorflow.compiler.tf2xla.conv(... preferred_element_type)` which is not there in TF2.4, only in TF2.5. We can avoid this though by using `convert_and_save_model( ... enable_xla=False)`.

In [6]:
%%writefile {TRAINING_APP_FOLDER}/trainer/task.py
import argparse
import logging
import os
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist, FlaxMNIST
# from jax.experimental.jax2tf.examples.saved_model_lib import convert_and_save_model
from trainer.saved_model_lib import convert_and_save_model

logger = logging.getLogger()
logger.setLevel(logging.INFO)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output_dir",
        help="GCS location to export SavedModel",
        default = os.getenv("AIP_MODEL_DIR")
    )
    args = parser.parse_args().__dict__

    
    flags.FLAGS(['e']) # need to initialize flags somehow to avoid errors in load_mnist

    train_batch_size = 128
    test_batch_size = 16


    flax_mnist = FlaxMNIST()

    train_ds = load_mnist(tfds.Split.TRAIN, train_batch_size)
    test_ds = load_mnist(tfds.Split.TEST, test_batch_size)

    predict_fn, params = flax_mnist.train(
        train_ds=train_ds,
        test_ds=test_ds,
        num_epochs=2
    )

    image, _ = next(iter(train_ds))
    input_signature = tf.TensorSpec.from_tensor(tf.expand_dims(image[0], axis=0))

    convert_and_save_model(
        jax_fn=predict_fn,
        params=params,
        model_dir=args["output_dir"],
        input_signatures=[input_signature],
        enable_xla=False,
    )

Overwriting training_app/trainer/task.py


Optional local test:

We should be able to use [CustomTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob) or [CustomPythonPackageTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomPythonPackageTrainingJob), but it gives an error, see the similar [CustomTrainingJob.run](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob.run) currently giving an error even when using the [official notebook](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom/sdk-custom-image-classification-online.ipynb).

In [7]:
%%writefile {TRAINING_APP_FOLDER}/setup.py
from setuptools import find_packages
from setuptools import setup

REQUIRED_PACKAGES = ['flax', 'jax']

setup(
    name='jax_flax_trainer',
    version='0.1',
    install_requires=REQUIRED_PACKAGES,
    packages=find_packages(),
    include_package_data=True,
    description='JAX/FLAX model training application.'
)


Overwriting training_app/setup.py


In [8]:
%%bash
cd $TRAINING_APP_FOLDER
python ./setup.py sdist --formats=gztar
cd ..

running sdist
running egg_info
writing jax_flax_trainer.egg-info/PKG-INFO
writing dependency_links to jax_flax_trainer.egg-info/dependency_links.txt
writing requirements to jax_flax_trainer.egg-info/requires.txt
writing top-level names to jax_flax_trainer.egg-info/top_level.txt
reading manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'
writing manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'
running check
creating jax_flax_trainer-0.1
creating jax_flax_trainer-0.1/jax_flax_trainer.egg-info
creating jax_flax_trainer-0.1/trainer
copying files to jax_flax_trainer-0.1...
copying setup.py -> jax_flax_trainer-0.1
copying jax_flax_trainer.egg-info/PKG-INFO -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/SOURCES.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/dependency_links.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/requires.txt -> jax_flax_trainer-0.1/jax_flax_tra






In [9]:
%%bash
gsutil -q cp $TRAINING_APP_FOLDER/dist/jax_flax_trainer-0.1.tar.gz gs://$BUCKET_NAME/jax/

In [10]:
import time

from google.cloud import aiplatform
from google.cloud import aiplatform_v1

JOB_NAME="jax_job"

# Vertex AI machines to use for training
JAXLIB_URI = "gs://jax-releases/cuda110/jaxlib-0.1.67+cuda110-cp37-none-manylinux2010_x86_64.whl"
PYTHON_PACKAGE_URIS = [f"gs://{BUCKET_NAME}/jax/jax_flax_trainer-0.1.tar.gz", JAXLIB_URI]
MACHINE_TYPE="n1-standard-4"
REPLICA_COUNT=1
PYTHON_MODULE="trainer.task"

USE_GPU = True
if USE_GPU:
    TRAINING_IMAGE = 'us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest'
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    TRAINING_IMAGE = 'us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest'
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None
    
api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"

# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [
            {
                "machine_spec": {
                    "machine_type": MACHINE_TYPE,
                    "accelerator_type": ACCELERATOR_TYPE,
                    "accelerator_count": ACCELERATOR_COUNT,
                },
                "replica_count": REPLICA_COUNT,
                "python_package_spec": {
                    "executor_image_uri": TRAINING_IMAGE,
                    "package_uris": PYTHON_PACKAGE_URIS,
                    "python_module": PYTHON_MODULE,
                },
            }
        ],
        "base_output_directory": {
            "output_uri_prefix" : f"gs://{BUCKET_NAME}/jax/output"
        },
    },
}
parent = f"projects/{PROJECT_ID}/locations/{REGION}"
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)


response: name: "projects/654544512569/locations/us-central1/customJobs/587618596302094336"
display_name: "jax_job"
job_spec {
  worker_pool_specs {
    machine_spec {
      machine_type: "n1-standard-4"
      accelerator_type: NVIDIA_TESLA_T4
      accelerator_count: 1
    }
    replica_count: 1
    disk_spec {
      boot_disk_type: "pd-ssd"
      boot_disk_size_gb: 100
    }
    python_package_spec {
      executor_image_uri: "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest"
      package_uris: "gs://dsparing-sandbox-bucket/jax/jax_flax_trainer-0.1.tar.gz"
      package_uris: "gs://jax-releases/cuda110/jaxlib-0.1.67+cuda110-cp37-none-manylinux2010_x86_64.whl"
      python_module: "trainer.task"
    }
  }
  base_output_directory {
    output_uri_prefix: "gs://dsparing-sandbox-bucket/jax/output"
  }
}
state: JOB_STATE_PENDING
create_time {
  seconds: 1624266105
  nanos: 296884000
}
update_time {
  seconds: 1624266105
  nanos: 296884000
}



In [11]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (aiplatform_v1.JobState.JOB_STATE_QUEUED, aiplatform_v1.JobState.JOB_STATE_PENDING, aiplatform_v1.JobState.JOB_STATE_RUNNING):
        break
    time.sleep(30)

JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_SUCCEEDED


### Make sure we can actually predict with savedmodel (Optional)

In [None]:
%%bash
pip3 install --user --upgrade jax jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [13]:
%%bash
gsutil ls -l gs://$BUCKET_NAME/jax/output/model

         0  2021-06-21T06:06:58Z  gs://dsparing-sandbox-bucket/jax/output/model/
     44207  2021-06-21T09:13:07Z  gs://dsparing-sandbox-bucket/jax/output/model/saved_model.pb
                                 gs://dsparing-sandbox-bucket/jax/output/model/assets/
                                 gs://dsparing-sandbox-bucket/jax/output/model/variables/
TOTAL: 2 objects, 44207 bytes (43.17 KiB)


In [14]:
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags

flags.FLAGS(['e']) # need to initialize flags somehow to avoid errors in load_mnist

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [15]:
loaded_model = tf.saved_model.load(f"gs://{BUCKET_NAME}/jax/output/model")
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[ -5.494674  , -12.381934  , -11.991706  ,  -8.824324  ,
          -7.54799   ,  -0.01416086,  -9.895304  ,  -8.375378  ,
          -4.7775846 ,  -7.468926  ]], dtype=float32)>}