In [1]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train JAX/Flax model on Vertex AI custom container and use `jax2tf` to convert to SavedModel

In [2]:
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from jax.experimental.jax2tf.examples import mnist_lib

In [3]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = PROJECT_ID
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

USE_GPU = True

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_BASE_PATH = f"gs://{BUCKET_NAME}/savedmodels"
MODEL_NAME = "jax_model_customcontainer"
MODEL_VERSION = 1
os.environ['MODEL_NAME'] = MODEL_NAME

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

In [4]:
!cat $TRAINING_APP_FOLDER/trainer/task.py

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples import mnist_lib
from jax.experimental.jax2tf.examples import saved_model_lib 

TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 16
NUM_EPOCHS = 2

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

logger = logging.getLogger()

# nee

We should be able to use [CustomContainerTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomContainerTrainingJob), but it gives an error, see the similar [CustomTrainingJob.run](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob.run) currently giving an error even when using the [official notebook](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom/sdk-custom-image-classification-online.ipynb).

In [5]:
if USE_GPU:
    BASE_TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-gpu.2-5'
    TRAINING_IMAGE_NAME = 'jax_vertex_training_gpu'
else:
    BASE_TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-cpu.2-5'    
    TRAINING_IMAGE_NAME = 'jax_vertex_training_cpu'

SERVING_IMAGE_NAME = 'jax_vertex_prediction'

os.environ['BASE_TRAINING_IMAGE'] = BASE_TRAINING_IMAGE

We write a `requirements.txt` and a `Dockerfile` that defines our custom container based on a [Deep Learning Container image](https://cloud.google.com/deep-learning-containers/docs/choosing-container#container_images), including the `pip install` of the required packages, copy of the model training code, and the Entrypoint launching our training.

In [6]:
!cat $TRAINING_APP_FOLDER/requirements.txt

flax
jax[cuda111]  # needs pip to run with `-f https://storage.googleapis.com/jax-releases/jax_releases.html`


In [7]:
%%bash
cat > $TRAINING_APP_FOLDER/Dockerfile << EOF
FROM $BASE_TRAINING_IMAGE

COPY requirements.txt .
RUN python3 -m pip install -r requirements.txt \
    -f https://storage.googleapis.com/jax-releases/jax_releases.html 

WORKDIR /app
COPY trainer/task.py .

ENTRYPOINT ["python", "task.py"]
EOF

In [8]:
IMAGE_TAG = 'latest'
TRAINING_IMAGE_URI = 'gcr.io/{}/{}:{}'.format(PROJECT_ID, TRAINING_IMAGE_NAME, IMAGE_TAG)
os.environ['TRAINING_IMAGE_URI'] = TRAINING_IMAGE_URI

In [9]:
!cd $TRAINING_APP_FOLDER && \
    gcloud builds submit --tag $TRAINING_IMAGE_URI --timeout="30m"

## Test training container locally

In [10]:
!docker pull $TRAINING_IMAGE_URI 
!docker run \
    --name training_jax \
    --runtime nvidia \
    $TRAINING_IMAGE_URI \
    --output_dir=$MODEL_BASE_PATH/"$MODEL_NAME"_local \
    --model_version=$MODEL_VERSION

latest: Pulling from dsparing-sandbox/jax_vertex_training_gpu
Digest: sha256:66ee236080795e5f5a856feda1dbcf541e955d5d8ee8a5368478b05e81cf8ac6
Status: Image is up to date for gcr.io/dsparing-sandbox/jax_vertex_training_gpu:latest
gcr.io/dsparing-sandbox/jax_vertex_training_gpu:latest
2021-06-28 07:55:27.198710: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-06-28 07:55:31.007997: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-06-28 07:55:31.015923: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-06-28 07:55:31.016621: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:04.0 name: Tesla T4 computeCapability: 7.5
coreClock: 1.59GHz coreCount: 

once the above container run finished:

In [11]:
!docker rm -f training_jax

training_jax


## Create custom container for prediction

When we below use `CustomContainerTrainingJob.run` to submit the training job, we can specify a prediction container environment as well. If we do so, the model artifact will not only be saved to cloud storage, but will also be uploaded to Vertex AI ready for batch prediction requests or ready to be deployed to an endpoint for online prediction.

Therefore we specify a custom prediction container now. (If we didn't, we could still call `CustomContainerTrainingJob.run`, but without the `model_serving_container_*` arguments, and the model training job would finish at storing the artifact in Cloud Storage.)

We will simply use the default TensorFlow Serving container image. We still need to build this container, because a Container Registry or Artifact Registry container is expected, so in effect we copy this from Docker Hub.

In [12]:
SERVING_FOLDER = 'serving'

USE_GPU_SERVING = False
if USE_GPU_SERVING:
    TFSERVING_TAG = "latest-gpu"
else:
    TFSERVING_TAG = "latest"

os.environ['SERVING_FOLDER'] = SERVING_FOLDER
os.environ['TFSERVING_TAG'] = TFSERVING_TAG

In [13]:
%%bash
mkdir -p $SERVING_FOLDER
cat > $SERVING_FOLDER/Dockerfile << EOF
FROM tensorflow/serving:$TFSERVING_TAG
EOF

In [14]:
IMAGE_TAG = 'latest'
SERVING_IMAGE_URI = 'gcr.io/{}/{}:{}'.format(PROJECT_ID, SERVING_IMAGE_NAME, IMAGE_TAG)
os.environ['SERVING_IMAGE_URI'] = SERVING_IMAGE_URI

In [15]:
!cd serving && gcloud builds submit --tag $SERVING_IMAGE_URI

Creating temporary tarball archive of 1 file(s) totalling 31 bytes before compression.
Uploading tarball of [.] to [gs://dsparing-sandbox_cloudbuild/source/1624866962.12636-9a9c9c5ba19b421dbdcdd55bd1e58da9.tgz]
Created [https://cloudbuild.googleapis.com/v1/projects/dsparing-sandbox/locations/global/builds/77fedb4a-bb0e-48dc-b110-791a9a17277e].
Logs are available at [https://console.cloud.google.com/cloud-build/builds/77fedb4a-bb0e-48dc-b110-791a9a17277e?project=654544512569].
----------------------------- REMOTE BUILD OUTPUT ------------------------------
starting build "77fedb4a-bb0e-48dc-b110-791a9a17277e"

FETCHSOURCE
Fetching storage object: gs://dsparing-sandbox_cloudbuild/source/1624866962.12636-9a9c9c5ba19b421dbdcdd55bd1e58da9.tgz#1624866962353729
Copying gs://dsparing-sandbox_cloudbuild/source/1624866962.12636-9a9c9c5ba19b421dbdcdd55bd1e58da9.tgz#1624866962353729...
/ [1 files][  198.0 B/  198.0 B]                                                
Operation completed over 1 objec

## Run custom training job with custom container on Vertex AI

The below CustomContainerTrainingJob.run in theory also uploads the Model to Vertex AI. We will ignore this and only use this method to store the SavedModel on GCS, because we cannot use a pre-built container

("Vertex AI Prediction supports pre-built containers, with no additional customization such as args (""Do not specify any other subfields of containerSpec"" [source](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#using_a_pre-built_container)), and custom containers. As we need the `--xla_cpu_compilation_enabled` arg, we can only use custom containers.)

and because we don't have a custom prediction container on gcr.io yet. (if we had one, we could use that. but we can't directly use docker hub, where tensorflow/serving is.)

In [16]:
JOB_NAME = 'jax_customcontainer_training'

job = aiplatform.CustomContainerTrainingJob(
    display_name=JOB_NAME,
    container_uri=TRAINING_IMAGE_URI,
    model_serving_container_image_uri=SERVING_IMAGE_URI,
    model_serving_container_predict_route=f"/v1/models/model:predict",
    model_serving_container_health_route=f"/v1/models/model",
    model_serving_container_args=[
        '--xla_cpu_compilation_enabled=true',
        '--model_base_path=$(AIP_STORAGE_URI)',
    ],
    staging_bucket=f"gs://{BUCKET_NAME}"
)

In [17]:
REPLICA_COUNT = 1

if USE_GPU:
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
    ACCELERATOR_COUNT = None   

model = job.run(
    model_display_name=MODEL_NAME,
    base_output_dir=os.path.join(MODEL_BASE_PATH, MODEL_NAME),
    replica_count=REPLICA_COUNT,
    accelerator_count=ACCELERATOR_COUNT,
    accelerator_type=ACCELERATOR_TYPE,
)


INFO:google.cloud.aiplatform.training_jobs:Training Output directory:
gs://dsparing-sandbox/savedmodels/jax_model_customcontainer 
INFO:google.cloud.aiplatform.training_jobs:View Training:
https://console.cloud.google.com/ai/platform/locations/us-central1/training/5787569723868708864?project=654544512569
INFO:google.cloud.aiplatform.training_jobs:View backing custom job:
https://console.cloud.google.com/ai/platform/locations/us-central1/training/66168609759559680?project=654544512569
INFO:google.cloud.aiplatform.training_jobs:CustomContainerTrainingJob projects/654544512569/locations/us-central1/trainingPipelines/5787569723868708864 current state:
PipelineState.PIPELINE_STATE_RUNNING
INFO:google.cloud.aiplatform.training_jobs:CustomContainerTrainingJob projects/654544512569/locations/us-central1/trainingPipelines/5787569723868708864 current state:
PipelineState.PIPELINE_STATE_RUNNING
INFO:google.cloud.aiplatform.training_jobs:CustomContainerTrainingJob projects/654544512569/locations/u

## Local prediction with SavedModel

In [18]:
!gsutil ls -l $MODEL_BASE_PATH/$MODEL_NAME/model/$MODEL_VERSION

         0  2021-06-28T08:10:22Z  gs://dsparing-sandbox/savedmodels/jax_model_customcontainer/model/1/
     53991  2021-06-28T08:10:24Z  gs://dsparing-sandbox/savedmodels/jax_model_customcontainer/model/1/saved_model.pb
                                 gs://dsparing-sandbox/savedmodels/jax_model_customcontainer/model/1/assets/
                                 gs://dsparing-sandbox/savedmodels/jax_model_customcontainer/model/1/variables/
TOTAL: 2 objects, 53991 bytes (52.73 KiB)


In [19]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS([''])

image_to_predict, _ = next(iter(mnist_lib.load_mnist(tfds.Split.TEST, batch_size=1)))

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [20]:
loaded_model = tf.saved_model.load(f"{MODEL_BASE_PATH}/{MODEL_NAME}/model/{MODEL_VERSION}")
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-1.2814110e+01, -8.4742651e+00, -4.9168525e+00, -1.0399423e+01,
         -8.9546528e+00, -9.5392723e+00, -8.5732741e-03, -1.7909704e+01,
         -7.1681709e+00, -1.4106701e+01]], dtype=float32)>}