In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train JAX/Flax model on Vertex AI pre-built container and use `jax2tf` to convert to SavedModel

In [2]:
import os
import time

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from google.cloud import aiplatform_v1
from jax.experimental.jax2tf.examples import mnist_lib

In [3]:
PROJECT_ID = !(gcloud config get-value project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"

BUCKET_NAME = PROJECT_ID
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l $REGION gs://$BUCKET_NAME

USE_GPU = True

TRAINING_APP_FOLDER = 'training_app'

MODEL_NAME = "jax_model_prebuilt"

MODELPACKAGE_DIR = f"gs://{BUCKET_NAME}/trainers/{MODEL_NAME}"
MODELPACKAGE_NAME = "jax_flax_trainer-0.1.tar.gz"

BASE_OUTPUT_DIR = f"gs://{BUCKET_NAME}"
MODEL_VERSION = 1

SERVING_BATCH_SIZE = 3

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

There is no pre-built Vertex TF2.5 container yet, and
- `jax.experimental.jax2tf.examples.saved_model_lib.convert_and_save_model` uses the `jit_compile` argument for `tf.function`, which is the TF2.5 new name for `experimental_compile`. So we define a copy `./saved_model_lib_tf24.py` of that module with the only difference that jit_compile is still called `experimental_compile` for TF2.4.
- jax2tf in jax==0.2.14 uses `tensorflow.compiler.tf2xla.conv(... preferred_element_type)` which is not there in TF2.4, only in TF2.5. We can avoid this though by using `convert_and_save_model( ... enable_xla=False)`.

To work around the above limitations, we'll use the the slightly modified `task_tf2_4.py` version of `task.py`.

In [4]:
!cat $TRAINING_APP_FOLDER/trainer/task_tf2_4.py

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib

from trainer import saved_model_lib_tf2_4

TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 16
NUM_EPOCHS = 2

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], "GPU")

logger = logging.getLogger()

# need to initialize fl

## Test training Python package locally

In [5]:
os.environ["TRAINING_APP_FOLDER"] = TRAINING_APP_FOLDER
os.environ["BASE_OUTPUT_DIR"] = BASE_OUTPUT_DIR
os.environ["MODEL_NAME"] = MODEL_NAME
os.environ["MODEL_VERSION"] = str(MODEL_VERSION)

!export PYTHONPATH=${PYTHONPATH}:${PWD}/$TRAINING_APP_FOLDER && \
    python3 -m trainer.task_tf2_4 \
        --output_dir=$BASE_OUTPUT_DIR/model \
        --model_name="$MODEL_NAME"_local \
        --model_version=$MODEL_VERSION

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
INFO:root:mnist_flax: Epoch 0 in 5.81 sec
INFO:root:mnist_flax: Training set accuracy 88.60%
INFO:root:mnist_flax: Test set accuracy 89.03%
INFO:root:mnist_flax: Epoch 1 in 1.11 sec
INFO:root:mnist_flax: Training set accuracy 90.70%
INFO:root:mnist_flax: Test set accuracy 91.05%


In [6]:
!gsutil ls -l $BASE_OUTPUT_DIR/model/"$MODEL_NAME"_local/$MODEL_VERSION

         0  2021-06-29T23:39:16Z  gs://dsparing-sandbox/model/jax_model_prebuilt_local/1/
     50304  2021-07-01T14:34:20Z  gs://dsparing-sandbox/model/jax_model_prebuilt_local/1/saved_model.pb
                                 gs://dsparing-sandbox/model/jax_model_prebuilt_local/1/assets/
                                 gs://dsparing-sandbox/model/jax_model_prebuilt_local/1/variables/
TOTAL: 2 objects, 50304 bytes (49.12 KiB)


We should be able to use [CustomTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob) (which would read a local Python package) or [CustomPythonPackageTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomPythonPackageTrainingJob) from the high-level API, which would not just train the model on Vertex AI, but would also upload it as a Vertex AI model; however, as we need to pass `jaxlib` as a second Python package (and the high-level API expects only one), we use [create_custom_job](https://googleapis.dev/python/aiplatform/latest/aiplatform_v1/job_service.html?#google.cloud.aiplatform_v1.services.job_service.JobServiceClient.create_custom_job) from the low-level API. For this, first we need to create a source distribution and upload it to Cloud Storage.

## Create a source distribution and upload to Cloud Storage

In [7]:
!cat $TRAINING_APP_FOLDER/setup.py

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from setuptools import find_packages, setup

REQUIRED_PACKAGES = ["flax", "jax"]

setup(
    name="jax_flax_trainer",
    version="0.1",
    install_requires=REQUIRED_PACKAGES,
    packages=find_packages(),
    include_package_data=True,
    description="JAX/FLAX model training application.",
)


In [8]:
!cd $TRAINING_APP_FOLDER && python ./setup.py sdist --formats=gztar

running sdist
running egg_info
writing jax_flax_trainer.egg-info/PKG-INFO
writing dependency_links to jax_flax_trainer.egg-info/dependency_links.txt
writing requirements to jax_flax_trainer.egg-info/requires.txt
writing top-level names to jax_flax_trainer.egg-info/top_level.txt
reading manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'
writing manifest file 'jax_flax_trainer.egg-info/SOURCES.txt'

running check


creating jax_flax_trainer-0.1
creating jax_flax_trainer-0.1/jax_flax_trainer.egg-info
creating jax_flax_trainer-0.1/trainer
copying files to jax_flax_trainer-0.1...
copying setup.py -> jax_flax_trainer-0.1
copying jax_flax_trainer.egg-info/PKG-INFO -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/SOURCES.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/dependency_links.txt -> jax_flax_trainer-0.1/jax_flax_trainer.egg-info
copying jax_flax_trainer.egg-info/requires.txt -> jax_flax_trainer-0.1/jax_flax_

In [9]:
!gsutil -q cp $TRAINING_APP_FOLDER/dist/$MODELPACKAGE_NAME $MODELPACKAGE_DIR/

## Run custom training job with Python package on Vertex AI 

In [10]:
JAXLIB_URI = (
    "gs://jax-releases/cuda110/jaxlib-0.1.67+"
    "cuda110-cp37-none-manylinux2010_x86_64.whl"
)
PYTHON_PACKAGE_URIS = [f"{MODELPACKAGE_DIR}/{MODELPACKAGE_NAME}", JAXLIB_URI]
PYTHON_MODULE = "trainer.task_tf2_4"

if USE_GPU:
    TRAINING_IMAGE = "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest"
else:
    TRAINING_IMAGE = "us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest"

python_package_spec = {
    "executor_image_uri": TRAINING_IMAGE,
    "package_uris": PYTHON_PACKAGE_URIS,
    "python_module": PYTHON_MODULE,
    "args": [
        f"--model_name={MODEL_NAME}",
        f"--model_version={MODEL_VERSION}",
    ],
}

In [11]:
# Vertex AI machines to use for training
MACHINE_TYPE = "n1-standard-4"
REPLICA_COUNT = 1

if USE_GPU:
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None

worker_pool_spec = {
    "machine_spec": {
        "machine_type": MACHINE_TYPE,
        "accelerator_type": ACCELERATOR_TYPE,
        "accelerator_count": ACCELERATOR_COUNT,
    },
    "replica_count": REPLICA_COUNT,
    "python_package_spec": python_package_spec,
}

In [12]:
JOB_NAME = "jax_prebuilt_training"

custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [worker_pool_spec],
        "base_output_directory": {
            "output_uri_prefix": BASE_OUTPUT_DIR,
        },
    },
}

In [13]:
api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"
client_options = {"api_endpoint": api_endpoint}
client = aiplatform.gapic.JobServiceClient(client_options=client_options)

parent = f"projects/{PROJECT_ID}/locations/{REGION}"

In [14]:
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)

response: name: "projects/654544512569/locations/us-central1/customJobs/8507110580102889472"
display_name: "jax_prebuilt_training"
job_spec {
  worker_pool_specs {
    machine_spec {
      machine_type: "n1-standard-4"
      accelerator_type: NVIDIA_TESLA_T4
      accelerator_count: 1
    }
    replica_count: 1
    disk_spec {
      boot_disk_type: "pd-ssd"
      boot_disk_size_gb: 100
    }
    python_package_spec {
      executor_image_uri: "us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-4:latest"
      package_uris: "gs://dsparing-sandbox/trainers/jax_model_prebuilt/jax_flax_trainer-0.1.tar.gz"
      package_uris: "gs://jax-releases/cuda110/jaxlib-0.1.67+cuda110-cp37-none-manylinux2010_x86_64.whl"
      python_module: "trainer.task_tf2_4"
      args: "--model_name=jax_model_prebuilt"
      args: "--model_version=1"
    }
  }
  base_output_directory {
    output_uri_prefix: "gs://dsparing-sandbox"
  }
}
state: JOB_STATE_PENDING
create_time {
  seconds: 1625150066
  nanos: 223148000
}


In [15]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (
        aiplatform_v1.JobState.JOB_STATE_QUEUED,
        aiplatform_v1.JobState.JOB_STATE_PENDING,
        aiplatform_v1.JobState.JOB_STATE_RUNNING,
    ):
        break
    time.sleep(30)

JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_SUCCEEDED


## Local prediction with SavedModel

Verify the timestamps below that the model is newly written:

In [16]:
!gsutil ls -l $BASE_OUTPUT_DIR/model/$MODEL_NAME/$MODEL_VERSION

         0  2021-06-30T04:49:45Z  gs://dsparing-sandbox/model/jax_model_prebuilt/1/
     49753  2021-07-01T14:50:32Z  gs://dsparing-sandbox/model/jax_model_prebuilt/1/saved_model.pb
                                 gs://dsparing-sandbox/model/jax_model_prebuilt/1/assets/
                                 gs://dsparing-sandbox/model/jax_model_prebuilt/1/variables/
TOTAL: 2 objects, 49753 bytes (48.59 KiB)


In [17]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS([""])

images_to_predict, _ = next(
    iter(mnist_lib.load_mnist(tfds.Split.TEST, batch_size=SERVING_BATCH_SIZE))
)

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [18]:
loaded_model = tf.saved_model.load(
    f"{BASE_OUTPUT_DIR}/model/{MODEL_NAME}/{MODEL_VERSION}"
)
loaded_model.signatures["serving_default"](images_to_predict)

{'output_0': <tf.Tensor: shape=(3, 10), dtype=float32, numpy=
 array([[-1.0069501e+01, -2.3111656e+01, -1.1957015e+01, -9.8489275e+00,
         -1.8814577e+01, -1.1980792e+01, -2.1702503e+01, -4.6480820e-04,
         -1.0392562e+01, -8.0278521e+00],
        [-5.2159538e+00, -8.7340212e+00, -4.7790408e+00, -4.5998144e+00,
         -4.3836346e+00, -3.2001929e+00, -6.3972487e+00, -7.1097474e+00,
         -1.1168523e-01, -3.6531754e+00],
        [-1.7394794e+01, -1.6356308e+01, -1.0075261e-03, -7.8558688e+00,
         -2.2520752e+01, -1.2708941e+01, -1.6958504e+01, -2.1539352e+01,
         -7.3916960e+00, -2.0681179e+01]], dtype=float32)>}