# Train JAX/Flax model on Vertex AI pre-built container and use `jax2tf` to convert to SavedModel

In [1]:
import os
import time

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from google.cloud import aiplatform_v1
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

In [2]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = "dsparing-sandbox-bucket"
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_NAME = "jax_model_prebuilt"
MODELPACKAGE_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}/package"
MODELPACKAGE_NAME = "jax_flax_trainer-0.1.tar.gz"
SAVEDMODEL_BASEDIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}/output"
os.environ['MODELPACKAGE_DIR'] = MODELPACKAGE_DIR
os.environ['MODELPACKAGE_NAME'] = MODELPACKAGE_NAME
os.environ['SAVEDMODEL_BASEDIR'] = SAVEDMODEL_BASEDIR

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

There is no pre-built Vertex TF2.5 container yet, and `jax.experimental.jax2tf.examples.saved_model_lib.convert_and_save_model` uses the `jit_compile` argument for `tf.function`, which is the TF2.5 new name for `experimental_compile`. So we define a copy `./saved_model_lib_tf24.py` of that module with the only difference that jit_compile is still called `experimental_compile` for TF2.4.

There is no pre-built Vertex TF2.5 container yet, and jax2tf in jax==0.2.14 uses `tensorflow.compiler.tf2xla.conv(... preferred_element_type)` which is not there in TF2.4, only in TF2.5. We can avoid this though by using `convert_and_save_model( ... enable_xla=False)`.

In [3]:
%%bash
cat $TRAINING_APP_FOLDER/trainer/task_tf2_4.py

import argparse
import logging
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import (
    load_mnist, FlaxMNIST
)
# from jax.experimental.jax2tf.examples.saved_model_lib import (
#     convert_and_save_model
# )

from trainer.saved_model_lib_tf2_4 import convert_and_save_model

TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 16
NUM_EPOCHS = 2

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

logger = logging.getLogger()

# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

flax_mnist = FlaxMNIST()

train_ds = load_mnist(tfds.Split.TRAIN, TRAIN_BATCH_SIZE)
test_ds = load_mnist(tfds.Split.TEST, TEST_BATCH_SIZE)

image, _ = next(iter(train_ds))
input_signature = tf.TensorSpec.from_tensor(
    tf.expand_dims(image[0], axis=0)
)


def main(output_dir):
    logger.setLevel(logging.INFO)
    predict_fn, params = flax_mnist.train(
      

## Test training Python package locally

In [4]:
%%bash
export PYTHONPATH=${PYTHONPATH}:${PWD}/$TRAINING_APP_FOLDER
python3 -m trainer.task_tf2_4 --output_dir=$SAVEDMODEL_BASEDIR/localmodel

2021-06-26 23:10:36.695574: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2021-06-26 23:10:36.695954: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker:

## Create a source distribution and upload to Cloud Storage

In [5]:
%%bash
cat $TRAINING_APP_FOLDER/setup.py

from setuptools import find_packages
from setuptools import setup

REQUIRED_PACKAGES = ['flax', 'jax']

setup(
    name='jax_flax_trainer',
    version='0.1',
    install_requires=REQUIRED_PACKAGES,
    packages=find_packages(),
    include_package_data=True,
    description='JAX/FLAX model training application.'
)


In [6]:
%%bash
cd $TRAINING_APP_FOLDER
python ./setup.py sdist --formats=gztar
cd ..

running sdist
running egg_info
writing jax_flax_trainer.egg-info/PKG-INFO
writing dependency_links to jax_flax_trainer.egg-info/dependency_links.txt
writing requirements to jax_flax_trainer.egg-info/requires.txt
writing top-level names to jax_flax_trainer.egg-info/top_level.txt
reading manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'
writing manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'
running check
creating jax_flax_trainer-0.1
creating jax_flax_trainer-0.1/jax_flax_trainer.egg-info
creating jax_flax_trainer-0.1/trainer
copying files to jax_flax_trainer-0.1...
copying setup.py -> jax_flax_trainer-0.1
copying jax_flax_trainer.egg-info/PKG-INFO -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/SOURCES.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/dependency_links.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/requires.txt -> jax_flax_trainer-0.1/jax_flax_tra






In [7]:
%%bash
gsutil -q cp $TRAINING_APP_FOLDER/dist/$MODELPACKAGE_NAME $MODELPACKAGE_DIR/

## Run custom training job with Python package on Vertex AI 

We should be able to use [CustomTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob) or [CustomPythonPackageTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomPythonPackageTrainingJob) from the high-level API, but it gives an error (see the similar [CustomTrainingJob.run](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob.run) currently giving an error even when using the [official notebook](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom/sdk-custom-image-classification-online.ipynb)), so we use [create_custom_job](https://googleapis.dev/python/aiplatform/latest/aiplatform_v1/job_service.html?#google.cloud.aiplatform_v1.services.job_service.JobServiceClient.create_custom_job) from the low-level API.

In [8]:
JOB_NAME = "jax_prebuilt_training"

# Vertex AI machines to use for training
JAXLIB_URI = ("gs://jax-releases/cuda110/jaxlib-0.1.67+"
              "cuda110-cp37-none-manylinux2010_x86_64.whl")
PYTHON_PACKAGE_URIS = [f"{MODELPACKAGE_DIR}/{MODELPACKAGE_NAME}", JAXLIB_URI]
MACHINE_TYPE = "n1-standard-4"
REPLICA_COUNT = 1
PYTHON_MODULE = "trainer.task_tf2_4"

USE_GPU = True
if USE_GPU:
    TRAINING_IMAGE = 'us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest'
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    TRAINING_IMAGE = 'us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest'
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"

# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple
# requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [
            {
                "machine_spec": {
                    "machine_type": MACHINE_TYPE,
                    "accelerator_type": ACCELERATOR_TYPE,
                    "accelerator_count": ACCELERATOR_COUNT,
                },
                "replica_count": REPLICA_COUNT,
                "python_package_spec": {
                    "executor_image_uri": TRAINING_IMAGE,
                    "package_uris": PYTHON_PACKAGE_URIS,
                    "python_module": PYTHON_MODULE,
                },
            }
        ],
        "base_output_directory": {
            "output_uri_prefix": SAVEDMODEL_BASEDIR
        },
    },
}
parent = f"projects/{PROJECT_ID}/locations/{REGION}"
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)


response: name: "projects/654544512569/locations/us-central1/customJobs/3380114248062468096"
display_name: "jax_prebuilt_training"
job_spec {
  worker_pool_specs {
    machine_spec {
      machine_type: "n1-standard-4"
      accelerator_type: NVIDIA_TESLA_T4
      accelerator_count: 1
    }
    replica_count: 1
    disk_spec {
      boot_disk_type: "pd-ssd"
      boot_disk_size_gb: 100
    }
    python_package_spec {
      executor_image_uri: "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest"
      package_uris: "gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/package/jax_flax_trainer-0.1.tar.gz"
      package_uris: "gs://jax-releases/cuda110/jaxlib-0.1.67+cuda110-cp37-none-manylinux2010_x86_64.whl"
      python_module: "trainer.task_tf2_4"
    }
  }
  base_output_directory {
    output_uri_prefix: "gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output"
  }
}
state: JOB_STATE_PENDING
create_time {
  seconds: 1624749073
  nanos: 578971000
}
update_time {
  seconds: 

In [9]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (
        aiplatform_v1.JobState.JOB_STATE_QUEUED,
        aiplatform_v1.JobState.JOB_STATE_PENDING,
        aiplatform_v1.JobState.JOB_STATE_RUNNING
    ):
        break
    time.sleep(30)

JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_SUCCEEDED


## Local prediction with SavedModel

In [10]:
%%bash
gsutil ls -l $SAVEDMODEL_BASEDIR/model

         0  2021-06-26T17:43:46Z  gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/
     44195  2021-06-26T23:21:05Z  gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/saved_model.pb
                                 gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/assets/
                                 gs://dsparing-sandbox-bucket/models/jax_model_prebuilt/output/model/variables/
TOTAL: 2 objects, 44195 bytes (43.16 KiB)


In [11]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, batch_size=1)))

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [12]:
loaded_model = tf.saved_model.load(f"{SAVEDMODEL_BASEDIR}/model")
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-11.065625  , -21.81225   , -18.527025  , -11.501289  ,
          -7.305455  ,  -8.047893  , -17.828339  ,  -0.03009149,
          -9.536808  ,  -3.5559657 ]], dtype=float32)>}