In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train JAX/Flax model on Vertex AI pre-built container and use `jax2tf` to convert to SavedModel

In [2]:
import os
import time

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from google.cloud import aiplatform_v1
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

In [3]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = PROJECT_ID
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

USE_GPU = True

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_NAME = "jax_model_prebuilt"
MODELPACKAGE_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}/package"
MODELPACKAGE_NAME = "jax_flax_trainer-0.1.tar.gz"
SAVEDMODEL_BASEDIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}/output"
os.environ['MODELPACKAGE_DIR'] = MODELPACKAGE_DIR
os.environ['MODELPACKAGE_NAME'] = MODELPACKAGE_NAME
os.environ['SAVEDMODEL_BASEDIR'] = SAVEDMODEL_BASEDIR

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

There is no pre-built Vertex TF2.5 container yet, and `jax.experimental.jax2tf.examples.saved_model_lib.convert_and_save_model` uses the `jit_compile` argument for `tf.function`, which is the TF2.5 new name for `experimental_compile`. So we define a copy `./saved_model_lib_tf24.py` of that module with the only difference that jit_compile is still called `experimental_compile` for TF2.4.

There is no pre-built Vertex TF2.5 container yet, and jax2tf in jax==0.2.14 uses `tensorflow.compiler.tf2xla.conv(... preferred_element_type)` which is not there in TF2.4, only in TF2.5. We can avoid this though by using `convert_and_save_model( ... enable_xla=False)`.

In [4]:
!cat $TRAINING_APP_FOLDER/trainer/task_tf2_4.py

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import (
    load_mnist, FlaxMNIST
)
# from jax.experimental.jax2tf.examples.saved_model_lib import (
#     convert_and_save_model
# )

from trainer.saved_model_lib_tf2_4 import convert_and_save_model

TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 16
NUM_EPO

## Test training Python package locally

In [5]:
!export PYTHONPATH=${PYTHONPATH}:${PWD}/$TRAINING_APP_FOLDER && \
    python3 -m trainer.task_tf2_4 --output_dir=$SAVEDMODEL_BASEDIR/localmodel

2021-06-27 06:47:03.972535: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2021-06-27 06:47:03.972740: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker:

We should be able to use [CustomTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob) (which would read a local Python package) or [CustomPythonPackageTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomPythonPackageTrainingJob) from the high-level API, which would not just train the model on Vertex AI, but would also upload it as a Vertex AI model; however, as we need to pass `jaxlib` as a second Python package (and the high-level API expects only one), we use [create_custom_job](https://googleapis.dev/python/aiplatform/latest/aiplatform_v1/job_service.html?#google.cloud.aiplatform_v1.services.job_service.JobServiceClient.create_custom_job) from the low-level API. For this, first we need to create a source distribution and upload it to Cloud Storage.

## Create a source distribution and upload to Cloud Storage

In [6]:
!cat $TRAINING_APP_FOLDER/setup.py

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import find_packages
from setuptools import setup

REQUIRED_PACKAGES = ['flax', 'jax']

setup(
    name='jax_flax_trainer',
    version='0.1',
    install_requires=REQUIRED_PACKAGES,
    packages=find_packages(),
    include_package_data=True,
    description='JAX/FLAX model training application.'
)


In [7]:
!cd $TRAINING_APP_FOLDER && python ./setup.py sdist --formats=gztar

running sdist
running egg_info
writing jax_flax_trainer.egg-info/PKG-INFO
writing dependency_links to jax_flax_trainer.egg-info/dependency_links.txt
writing requirements to jax_flax_trainer.egg-info/requires.txt
writing top-level names to jax_flax_trainer.egg-info/top_level.txt
reading manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'
writing manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'

running check


creating jax_flax_trainer-0.1
creating jax_flax_trainer-0.1/jax_flax_trainer.egg-info
creating jax_flax_trainer-0.1/trainer
copying files to jax_flax_trainer-0.1...
copying setup.py -> jax_flax_trainer-0.1
copying jax_flax_trainer.egg-info/PKG-INFO -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/SOURCES.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/dependency_links.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/requires.txt -> jax_flax_trainer-0.1/jax_flax_

In [8]:
!gsutil -q cp $TRAINING_APP_FOLDER/dist/$MODELPACKAGE_NAME $MODELPACKAGE_DIR/

## Run custom training job with Python package on Vertex AI 

In [9]:
JAXLIB_URI = ("gs://jax-releases/cuda110/jaxlib-0.1.67+"
              "cuda110-cp37-none-manylinux2010_x86_64.whl")
PYTHON_PACKAGE_URIS = [f"{MODELPACKAGE_DIR}/{MODELPACKAGE_NAME}", JAXLIB_URI]
PYTHON_MODULE = "trainer.task_tf2_4"

if USE_GPU:
    TRAINING_IMAGE = 'us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest'
else:
    TRAINING_IMAGE = 'us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest'
    
python_package_spec = {
    "executor_image_uri": TRAINING_IMAGE,
    "package_uris": PYTHON_PACKAGE_URIS,
    "python_module": PYTHON_MODULE,
}

In [10]:
# Vertex AI machines to use for training
MACHINE_TYPE = "n1-standard-4"
REPLICA_COUNT = 1

if USE_GPU:
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

worker_pool_spec = {
    "machine_spec": {
        "machine_type": MACHINE_TYPE,
        "accelerator_type": ACCELERATOR_TYPE,
        "accelerator_count": ACCELERATOR_COUNT,
    },
    "replica_count": REPLICA_COUNT,
    "python_package_spec": python_package_spec,
}

In [11]:
JOB_NAME = "jax_prebuilt_training"

custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [worker_pool_spec],
        "base_output_directory": {
            "output_uri_prefix": SAVEDMODEL_BASEDIR
        },
    },
}

In [12]:
api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"
client_options = {"api_endpoint": api_endpoint}
client = aiplatform.gapic.JobServiceClient(client_options=client_options)

parent = f"projects/{PROJECT_ID}/locations/{REGION}"

In [13]:
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)

response: name: "projects/654544512569/locations/us-central1/customJobs/5056579209351135232"
display_name: "jax_prebuilt_training"
job_spec {
  worker_pool_specs {
    machine_spec {
      machine_type: "n1-standard-4"
      accelerator_type: NVIDIA_TESLA_T4
      accelerator_count: 1
    }
    replica_count: 1
    disk_spec {
      boot_disk_type: "pd-ssd"
      boot_disk_size_gb: 100
    }
    python_package_spec {
      executor_image_uri: "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest"
      package_uris: "gs://dsparing-sandbox/models/jax_model_prebuilt/package/jax_flax_trainer-0.1.tar.gz"
      package_uris: "gs://jax-releases/cuda110/jaxlib-0.1.67+cuda110-cp37-none-manylinux2010_x86_64.whl"
      python_module: "trainer.task_tf2_4"
    }
  }
  base_output_directory {
    output_uri_prefix: "gs://dsparing-sandbox/models/jax_model_prebuilt/output"
  }
}
state: JOB_STATE_PENDING
create_time {
  seconds: 1624776457
  nanos: 306704000
}
update_time {
  seconds: 1624776457
  n

In [14]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (
        aiplatform_v1.JobState.JOB_STATE_QUEUED,
        aiplatform_v1.JobState.JOB_STATE_PENDING,
        aiplatform_v1.JobState.JOB_STATE_RUNNING
    ):
        break
    time.sleep(30)

JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_SUCCEEDED


## Local prediction with SavedModel

In [15]:
!gsutil ls -l $SAVEDMODEL_BASEDIR/model

         0  2021-06-27T06:32:36Z  gs://dsparing-sandbox/models/jax_model_prebuilt/output/model/
     44195  2021-06-27T06:51:17Z  gs://dsparing-sandbox/models/jax_model_prebuilt/output/model/saved_model.pb
                                 gs://dsparing-sandbox/models/jax_model_prebuilt/output/model/assets/
                                 gs://dsparing-sandbox/models/jax_model_prebuilt/output/model/variables/
TOTAL: 2 objects, 44195 bytes (43.16 KiB)


In [16]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, batch_size=1)))

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [17]:
loaded_model = tf.saved_model.load(f"{SAVEDMODEL_BASEDIR}/model")
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-6.8632755, -4.290162 , -2.6556697, -2.2902765, -2.8904533,
         -2.83985  , -5.353467 , -4.57983  , -0.7665547, -1.5133249]],
       dtype=float32)>}