# Train JAX/Flax model on Vertex AI pre-built container and use `jax2tf` to convert to SavedModel

runs on TF2.5 [Vertex Notebook](https://cloud.google.com/vertex-ai/docs/general/notebooks)

In [1]:
import os
import time
from absl import flags

import tensorflow as tf
import tensorflow_datasets as tfds
from google.cloud import aiplatform
from google.cloud import aiplatform_v1
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

In [2]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = "dsparing-sandbox-bucket"
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_NAME = "jax_model"
MODEL_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}"
os.environ['MODEL_DIR'] = MODEL_DIR

In [3]:
%%bash
mkdir -p $TRAINING_APP_FOLDER/trainer
touch $TRAINING_APP_FOLDER/trainer/__init__.py

There is no pre-built Vertex TF2.5 container yet, and `jax.experimental.jax2tf.examples.saved_model_lib.convert_and_save_model` uses the `jit_compile` argument for `tf.function`, which is the TF2.5 new name for `experimental_compile`. So we define a copy `./saved_model_lib_tf24.py` of that module with the only difference that jit_compile is still called `experimental_compile` for TF2.4.

In [4]:
%%bash
cp saved_model_lib_tf24.py $TRAINING_APP_FOLDER/trainer/saved_model_lib.py

There is no pre-built Vertex TF2.5 container yet, and jax2tf in jax==0.2.14 uses `tensorflow.compiler.tf2xla.conv(... preferred_element_type)` which is not there in TF2.4, only in TF2.5. We can avoid this though by using `convert_and_save_model( ... enable_xla=False)`.

In [5]:
%%writefile {TRAINING_APP_FOLDER}/trainer/task.py
import argparse
import logging
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import (
    load_mnist, FlaxMNIST
)
# from jax.experimental.jax2tf.examples.saved_model_lib import (
#     convert_and_save_model
# )

from trainer.saved_model_lib import convert_and_save_model

logger = logging.getLogger()
logger.setLevel(logging.INFO)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output_dir",
        help="GCS location to export SavedModel",
        default=os.getenv("AIP_MODEL_DIR")
    )
    args = parser.parse_args().__dict__

    # need to initialize flags somehow to avoid errors in load_mnist
    flags.FLAGS(['e'])

    train_batch_size = 128
    test_batch_size = 16

    flax_mnist = FlaxMNIST()

    train_ds = load_mnist(tfds.Split.TRAIN, train_batch_size)
    test_ds = load_mnist(tfds.Split.TEST, test_batch_size)

    predict_fn, params = flax_mnist.train(
        train_ds=train_ds,
        test_ds=test_ds,
        num_epochs=2
    )

    image, _ = next(iter(train_ds))
    input_signature = tf.TensorSpec.from_tensor(
        tf.expand_dims(image[0], axis=0)
    )

    convert_and_save_model(
        jax_fn=predict_fn,
        params=params,
        model_dir=args["output_dir"],
        input_signatures=[input_signature],
        enable_xla=False,
    )

Writing training_app/trainer/task.py


Optional local test:

We should be able to use [CustomTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob) or [CustomPythonPackageTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomPythonPackageTrainingJob), but it gives an error, see the similar [CustomTrainingJob.run](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob.run) currently giving an error even when using the [official notebook](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom/sdk-custom-image-classification-online.ipynb).

In [6]:
%%writefile {TRAINING_APP_FOLDER}/setup.py
from setuptools import find_packages
from setuptools import setup

REQUIRED_PACKAGES = ['flax', 'jax']

setup(
    name='jax_flax_trainer',
    version='0.1',
    install_requires=REQUIRED_PACKAGES,
    packages=find_packages(),
    include_package_data=True,
    description='JAX/FLAX model training application.'
)


Writing training_app/setup.py


In [10]:
%%bash
cd $TRAINING_APP_FOLDER
python ./setup.py sdist --formats=gztar
cd ..

running sdist
running egg_info
writing requirements to jax_flax_trainer.egg-info/requires.txt
writing jax_flax_trainer.egg-info/PKG-INFO
writing top-level names to jax_flax_trainer.egg-info/top_level.txt
writing dependency_links to jax_flax_trainer.egg-info/dependency_links.txt
reading manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'
writing manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'
running check
creating jax_flax_trainer-0.1
creating jax_flax_trainer-0.1/jax_flax_trainer.egg-info
creating jax_flax_trainer-0.1/trainer
copying files to jax_flax_trainer-0.1...
copying setup.py -> jax_flax_trainer-0.1
copying jax_flax_trainer.egg-info/PKG-INFO -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/SOURCES.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/dependency_links.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/requires.txt -> jax_flax_trainer-0.1/jax_flax_tra






In [11]:
%%bash
gsutil -q cp $TRAINING_APP_FOLDER/dist/jax_flax_trainer-0.1.tar.gz gs://$BUCKET_NAME/jax/

In [12]:
JOB_NAME = "jax_job"

# Vertex AI machines to use for training
JAXLIB_URI = "gs://jax-releases/cuda110/jaxlib-0.1.67+cuda110-cp37-none-manylinux2010_x86_64.whl"
PYTHON_PACKAGE_URIS = [f"gs://{BUCKET_NAME}/jax/jax_flax_trainer-0.1.tar.gz", JAXLIB_URI]
MACHINE_TYPE = "n1-standard-4"
REPLICA_COUNT = 1
PYTHON_MODULE = "trainer.task"

USE_GPU = True
if USE_GPU:
    TRAINING_IMAGE = 'us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest'
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    TRAINING_IMAGE = 'us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest'
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"

# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple
# requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [
            {
                "machine_spec": {
                    "machine_type": MACHINE_TYPE,
                    "accelerator_type": ACCELERATOR_TYPE,
                    "accelerator_count": ACCELERATOR_COUNT,
                },
                "replica_count": REPLICA_COUNT,
                "python_package_spec": {
                    "executor_image_uri": TRAINING_IMAGE,
                    "package_uris": PYTHON_PACKAGE_URIS,
                    "python_module": PYTHON_MODULE,
                },
            }
        ],
        "base_output_directory": {
            "output_uri_prefix": f"gs://{BUCKET_NAME}/jax/output"
        },
    },
}
parent = f"projects/{PROJECT_ID}/locations/{REGION}"
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)


response: name: "projects/654544512569/locations/us-central1/customJobs/3457942079122964480"
display_name: "jax_job"
job_spec {
  worker_pool_specs {
    machine_spec {
      machine_type: "n1-standard-4"
      accelerator_type: NVIDIA_TESLA_T4
      accelerator_count: 1
    }
    replica_count: 1
    disk_spec {
      boot_disk_type: "pd-ssd"
      boot_disk_size_gb: 100
    }
    python_package_spec {
      executor_image_uri: "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest"
      package_uris: "gs://dsparing-sandbox-bucket/jax/jax_flax_trainer-0.1.tar.gz"
      package_uris: "gs://jax-releases/cuda110/jaxlib-0.1.67+cuda110-cp37-none-manylinux2010_x86_64.whl"
      python_module: "trainer.task"
    }
  }
  base_output_directory {
    output_uri_prefix: "gs://dsparing-sandbox-bucket/jax/output"
  }
}
state: JOB_STATE_PENDING
create_time {
  seconds: 1624694068
  nanos: 708517000
}
update_time {
  seconds: 1624694068
  nanos: 708517000
}



In [13]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (
        aiplatform_v1.JobState.JOB_STATE_QUEUED,
        aiplatform_v1.JobState.JOB_STATE_PENDING,
        aiplatform_v1.JobState.JOB_STATE_RUNNING
    ):
        break
    time.sleep(30)

JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_SUCCEEDED


### Make sure we can actually predict with savedmodel (Optional)

In [14]:
%%bash
gsutil ls -l gs://$BUCKET_NAME/jax/output/model

         0  2021-06-21T06:06:58Z  gs://dsparing-sandbox-bucket/jax/output/model/
     44207  2021-06-26T08:09:13Z  gs://dsparing-sandbox-bucket/jax/output/model/saved_model.pb
                                 gs://dsparing-sandbox-bucket/jax/output/model/assets/
                                 gs://dsparing-sandbox-bucket/jax/output/model/variables/
TOTAL: 2 objects, 44207 bytes (43.17 KiB)


In [15]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [16]:
loaded_model = tf.saved_model.load(f"gs://{BUCKET_NAME}/jax/output/model")
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[ -2.8519683, -22.318636 ,  -6.8104887, -13.863109 ,  -0.7156592,
          -3.3984666,  -1.8780065,  -7.833604 ,  -1.581222 ,  -2.815784 ]],
       dtype=float32)>}