In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train JAX/Flax model on Vertex AI custom container and use `jax2tf` to convert to SavedModel

In [2]:
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples import mnist_lib

In [3]:
PROJECT_ID = !(gcloud config get-value project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"

BUCKET_NAME = PROJECT_ID
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l $REGION gs://$BUCKET_NAME

USE_GPU = True

TRAINING_APP_FOLDER = 'training_app'

BASE_OUTPUT_DIR = f"gs://{BUCKET_NAME}"
MODEL_NAME = "jax_model_customcontainer"
MODEL_VERSION = 1

SERVING_BATCH_SIZE = 3

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

In [4]:
!cat $TRAINING_APP_FOLDER/trainer/task.py

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib, saved_model_lib

TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 16
NUM_EPOCHS = 2

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], "GPU")

logger = logging.getLogger()

# need to initialize flags somehow to avoid error

In [5]:
if USE_GPU:
    BASE_TRAINING_IMAGE = "gcr.io/deeplearning-platform-release/tf2-gpu.2-5"
    TRAINING_IMAGE_NAME = "jax_vertex_training_gpu"
else:
    BASE_TRAINING_IMAGE = "gcr.io/deeplearning-platform-release/tf2-cpu.2-5"
    TRAINING_IMAGE_NAME = "jax_vertex_training_cpu"

We write a `requirements.txt` and a `Dockerfile` that defines our custom container based on a [Deep Learning Container image](https://cloud.google.com/deep-learning-containers/docs/choosing-container#container_images), including the `pip install` of the required packages, copy of the model training code, and the Entrypoint launching our training.

In [6]:
!cat $TRAINING_APP_FOLDER/requirements.txt

flax
jax[cuda111]  # needs pip to run with `-f https://storage.googleapis.com/jax-releases/jax_releases.html`


In [7]:
os.environ["TRAINING_APP_FOLDER"] = TRAINING_APP_FOLDER
os.environ["BASE_TRAINING_IMAGE"] = BASE_TRAINING_IMAGE

In [8]:
%%bash
cat > $TRAINING_APP_FOLDER/Dockerfile << EOF
FROM $BASE_TRAINING_IMAGE

COPY requirements.txt .
RUN python3 -m pip install -r requirements.txt \
    -f https://storage.googleapis.com/jax-releases/jax_releases.html 

WORKDIR /app
COPY trainer/task.py .

ENTRYPOINT ["python", "task.py"]
EOF

In [9]:
TRAINING_IMAGE_URI = f"gcr.io/{PROJECT_ID}/{TRAINING_IMAGE_NAME}"

In [10]:
!cd $TRAINING_APP_FOLDER && \
    gcloud builds submit --tag $TRAINING_IMAGE_URI --timeout="30m"

Creating temporary tarball archive of 22 file(s) totalling 41.8 KiB before compression.
Uploading tarball of [.] to [gs://dsparing-sandbox_cloudbuild/source/1625151978.12105-0ef148ea8fa14c8487dfa9d05bf08444.tgz]
Created [https://cloudbuild.googleapis.com/v1/projects/dsparing-sandbox/locations/global/builds/2aacee55-6e6a-4ace-8137-ccc434dbb314].
Logs are available at [https://console.cloud.google.com/cloud-build/builds/2aacee55-6e6a-4ace-8137-ccc434dbb314?project=654544512569].
----------------------------- REMOTE BUILD OUTPUT ------------------------------
starting build "2aacee55-6e6a-4ace-8137-ccc434dbb314"

FETCHSOURCE
Fetching storage object: gs://dsparing-sandbox_cloudbuild/source/1625151978.12105-0ef148ea8fa14c8487dfa9d05bf08444.tgz#1625151978338636
Copying gs://dsparing-sandbox_cloudbuild/source/1625151978.12105-0ef148ea8fa14c8487dfa9d05bf08444.tgz#1625151978338636...
/ [1 files][ 13.4 KiB/ 13.4 KiB]                                                
Operation completed over 1 obje

## Test training container locally

In [11]:
!docker pull $TRAINING_IMAGE_URI 
!docker run \
    --name training_jax \
    --runtime nvidia \
    $TRAINING_IMAGE_URI \
    --output_dir=$BASE_OUTPUT_DIR/model \
    --model_name="$MODEL_NAME"_local \
    --model_version=$MODEL_VERSION

Using default tag: latest
latest: Pulling from dsparing-sandbox/jax_vertex_training_gpu

[1Bd2c87b75: Already exists 
[1B10be24e1: Already exists 
[1B7173dcfe: Already exists 
[1B8de7822d: Already exists 
[1B4ac0274d: Already exists 
[1Bb86d08de: Already exists 
[1B019dd5e8: Already exists 
[1B73e465ef: Already exists 
[1B630baacd: Already exists 
[1B86c72f57: Already exists 
[1B6fce16a1: Already exists 
[1Bc64e20d2: Already exists 
[1B12f3cce5: Already exists 
[1B6a369ea4: Already exists 
[1B2ea143ea: Already exists 
[1B5fa6733c: Already exists 
[1B4adad992: Already exists 
[1Bb56a4779: Already exists 
[1B7e5e0af5: Already exists 
[1Bd9bf08cb: Already exists 
[1B0834967b: Already exists 
[1Bfb29e345: Already exists 
[1Bec7e36f6: Already exists 
[1Bf0ba3fb3: Already exists 
[1B12e657e4: Already exists 
[1Bfad557e1: Already exists 
[1B293fd93e: Already exists 
[1B8ef0086e: Already exists 
[1Be8557bb1: Already exists 
[1Bdf11e45e: Already exists 
[1Bae24303

once the above container run finished:

In [12]:
!docker rm -f training_jax

training_jax


## Create custom container for prediction

When we below use `CustomContainerTrainingJob.run` to submit the training job, we can specify a prediction container environment as well. If we do so, the model artifact will not only be saved to cloud storage, but will also be uploaded to Vertex AI ready for batch prediction requests or ready to be deployed to an endpoint for online prediction.

Therefore we specify a custom prediction container now. (If we didn't, we could still call `CustomContainerTrainingJob.run`, but without the `model_serving_container_*` arguments, and the model training job would finish at storing the artifact in Cloud Storage.)

We will simply use the default TensorFlow Serving container image. We still need to build this container, because a Container Registry or Artifact Registry container is expected, so in effect we copy this from Docker Hub.

In [13]:
SERVING_FOLDER = "serving"
SERVING_IMAGE_NAME = "tensorflow-serving"
SERVING_IMAGE_URI = f"gcr.io/{PROJECT_ID}/{SERVING_IMAGE_NAME}"

USE_GPU_SERVING = False
if USE_GPU_SERVING:
    TFSERVING_TAG = "latest-gpu"
else:
    TFSERVING_TAG = "latest"

In [14]:
os.environ["SERVING_FOLDER"] = SERVING_FOLDER
os.environ["TFSERVING_TAG"] = TFSERVING_TAG

In [15]:
%%bash
mkdir -p $SERVING_FOLDER
cat > $SERVING_FOLDER/Dockerfile << EOF
FROM tensorflow/serving:$TFSERVING_TAG
EOF

We only build the container if it is not already available.

In [16]:
!cd $SERVING_FOLDER && docker pull $SERVING_IMAGE_URI || gcloud builds submit --tag $SERVING_IMAGE_URI

Using default tag: latest
latest: Pulling from dsparing-sandbox/tensorflow-serving
Digest: sha256:6651f4839e1124dbde75ee531825112af0a6b8ef082c88ab14ca53eb69a2e4bb
Status: Image is up to date for gcr.io/dsparing-sandbox/tensorflow-serving:latest
gcr.io/dsparing-sandbox/tensorflow-serving:latest


## Run custom training job with custom container on Vertex AI

The below CustomContainerTrainingJob.run in theory also uploads the Model to Vertex AI. We will ignore this and only use this method to store the SavedModel on GCS, because we cannot use a pre-built container

("Vertex AI Prediction supports pre-built containers, with no additional customization such as args (""Do not specify any other subfields of containerSpec"" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers.)

and because we don't have a custom prediction container on gcr.io yet. (if we had one, we could use that. but we can't directly use docker hub, where tensorflow/serving is.)

In [17]:
JOB_NAME = "jax_customcontainer_training"

job = aiplatform.CustomContainerTrainingJob(
    display_name=JOB_NAME,
    container_uri=TRAINING_IMAGE_URI,
    model_serving_container_image_uri=SERVING_IMAGE_URI,
    model_serving_container_predict_route=f"/v1/models/{MODEL_NAME}:predict",
    model_serving_container_health_route=f"/v1/models/{MODEL_NAME}",
    model_serving_container_args=[
        "--xla_cpu_compilation_enabled=true",
        f"--model_name={MODEL_NAME}",
        f"--model_base_path=$(AIP_STORAGE_URI)/{MODEL_NAME}",
    ],
    model_serving_container_ports=[8501],
    staging_bucket=f"gs://{BUCKET_NAME}",
)

In [18]:
REPLICA_COUNT = 1

if USE_GPU:
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
    ACCELERATOR_COUNT = None

model = job.run(
    model_display_name=MODEL_NAME,
    base_output_dir=BASE_OUTPUT_DIR,
    args=[
        f"--model_name={MODEL_NAME}",
        f"--model_version={MODEL_VERSION}",
    ],
    replica_count=REPLICA_COUNT,
    accelerator_count=ACCELERATOR_COUNT,
    accelerator_type=ACCELERATOR_TYPE,
)

print(model.display_name, model.resource_name)

INFO:google.cloud.aiplatform.training_jobs:Training Output directory:
gs://dsparing-sandbox 
INFO:google.cloud.aiplatform.training_jobs:View Training:
https://console.cloud.google.com/ai/platform/locations/us-central1/training/1929603324328280064?project=654544512569
INFO:google.cloud.aiplatform.training_jobs:CustomContainerTrainingJob projects/654544512569/locations/us-central1/trainingPipelines/1929603324328280064 current state:
PipelineState.PIPELINE_STATE_RUNNING
INFO:google.cloud.aiplatform.training_jobs:View backing custom job:
https://console.cloud.google.com/ai/platform/locations/us-central1/training/5952514060221153280?project=654544512569
INFO:google.cloud.aiplatform.training_jobs:CustomContainerTrainingJob projects/654544512569/locations/us-central1/trainingPipelines/1929603324328280064 current state:
PipelineState.PIPELINE_STATE_RUNNING
INFO:google.cloud.aiplatform.training_jobs:CustomContainerTrainingJob projects/654544512569/locations/us-central1/trainingPipelines/1929603

## Local prediction with SavedModel

In [19]:
!gsutil ls -l $BASE_OUTPUT_DIR/model/$MODEL_NAME/$MODEL_VERSION

         0  2021-06-30T04:57:38Z  gs://dsparing-sandbox/model/jax_model_customcontainer/1/
     59519  2021-07-01T15:30:58Z  gs://dsparing-sandbox/model/jax_model_customcontainer/1/saved_model.pb
                                 gs://dsparing-sandbox/model/jax_model_customcontainer/1/assets/
                                 gs://dsparing-sandbox/model/jax_model_customcontainer/1/variables/
TOTAL: 2 objects, 59519 bytes (58.12 KiB)


In [20]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS([""])

images_to_predict, _ = next(
    iter(mnist_lib.load_mnist(tfds.Split.TEST, batch_size=SERVING_BATCH_SIZE))
)

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [21]:
loaded_model = tf.saved_model.load(
    os.path.join(BASE_OUTPUT_DIR, "model", MODEL_NAME, str(MODEL_VERSION))
)
loaded_model.signatures["serving_default"](images_to_predict)

{'output_0': <tf.Tensor: shape=(3, 10), dtype=float32, numpy=
 array([[-12.268999  , -22.472895  , -14.996583  , -13.254833  ,
          -0.04260941,  -7.159297  ,  -7.789949  ,  -7.701139  ,
          -6.578989  ,  -3.2525737 ],
        [ -7.5400867 , -20.041262  , -15.078344  ,  -9.657801  ,
          -7.344303  ,  -6.194578  , -16.317587  ,  -0.31555295,
          -7.76667   ,  -1.3208492 ],
        [-11.988442  ,  -6.592482  ,  -0.46542993,  -1.0133778 ,
         -13.975661  ,  -9.580245  ,  -5.0168834 , -11.966478  ,
          -6.8450675 , -13.723193  ]], dtype=float32)>}