In [None]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train JAX/Flax model on Vertex AI custom container and use `jax2tf` to convert to SavedModel

In [1]:
import os
import time

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from google.cloud import aiplatform
from google.cloud import aiplatform_v1
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist

In [2]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = PROJECT_ID
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

USE_GPU = True

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_NAME = "jax_model_customcontainer"
SAVEDMODEL_BASEDIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}/output"
os.environ['SAVEDMODEL_BASEDIR'] = SAVEDMODEL_BASEDIR

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

In [3]:
%%bash 
cat $TRAINING_APP_FOLDER/trainer/task.py

import argparse
import logging
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import (
    load_mnist, FlaxMNIST
)
from jax.experimental.jax2tf.examples.saved_model_lib import (
    convert_and_save_model
)

TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 16
NUM_EPOCHS = 2

# Block TF from the GPU to let JAX use it all
tf.config.set_visible_devices([], 'GPU')

logger = logging.getLogger()

# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

flax_mnist = FlaxMNIST()

train_ds = load_mnist(tfds.Split.TRAIN, TRAIN_BATCH_SIZE)
test_ds = load_mnist(tfds.Split.TEST, TEST_BATCH_SIZE)

image, _ = next(iter(train_ds))
input_signature = tf.TensorSpec.from_tensor(
    tf.expand_dims(image[0], axis=0)
)


def main(output_dir):
    logger.setLevel(logging.INFO)
    predict_fn, params = flax_mnist.train(
        train_ds=train_ds,
        test_ds=test_ds,
        num_epochs=NUM_EPO

We should be able to use [CustomContainerTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomContainerTrainingJob), but it gives an error, see the similar [CustomTrainingJob.run](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob.run) currently giving an error even when using the [official notebook](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom/sdk-custom-image-classification-online.ipynb).

In [4]:
if USE_GPU:
    TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-gpu.2-5'
    IMAGE_NAME = 'jax_vertex_image_gpu'
else:
    TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-cpu.2-5'    
    IMAGE_NAME = 'jax_vertex_image_cpu'

os.environ['TRAINING_IMAGE'] = TRAINING_IMAGE

We write a `requirements.txt` and a `Dockerfile` that defines our custom container based on a [Deep Learning Container image](https://cloud.google.com/deep-learning-containers/docs/choosing-container#container_images), including the `pip install` of the required packages, copy of the model training code, and the Entrypoint launching our training.

In [5]:
%%bash
cat $TRAINING_APP_FOLDER/requirements.txt

flax
jax[cuda111]  # needs pip to run with `-f https://storage.googleapis.com/jax-releases/jax_releases.html`


In [6]:
%%bash
cat > $TRAINING_APP_FOLDER/Dockerfile << EOF
FROM $TRAINING_IMAGE

RUN python -m pip install -U jax jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html 
RUN python -m pip install -U flax 

WORKDIR /app
COPY trainer/task.py .

ENTRYPOINT ["python", "task.py"]
EOF

In [7]:
IMAGE_TAG = 'latest'
IMAGE_URI = 'gcr.io/{}/{}:{}'.format(PROJECT_ID, IMAGE_NAME, IMAGE_TAG)
os.environ['IMAGE_URI'] = IMAGE_URI

In [8]:
!docker build -f $TRAINING_APP_FOLDER/Dockerfile \
--tag $IMAGE_URI $TRAINING_APP_FOLDER

Sending build context to Docker daemon  43.01kB
Step 1/6 : FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-5
latest: Pulling from deeplearning-platform-release/tf2-gpu.2-5

[1Bd2c87b75: Pulling fs layer 
[1B10be24e1: Pulling fs layer 
[1B7173dcfe: Pulling fs layer 
[1B8de7822d: Pulling fs layer 
[1B4ac0274d: Pulling fs layer 
[1Bb86d08de: Pulling fs layer 
[1B019dd5e8: Pulling fs layer 
[1B73e465ef: Pulling fs layer 
[1B630baacd: Pulling fs layer 
[2B630baacd: Waiting fs layer 
[1B6fce16a1: Pulling fs layer 
[1Bc64e20d2: Pulling fs layer 
[1B12f3cce5: Pulling fs layer 
[1B6a369ea4: Pulling fs layer 
[1B2ea143ea: Pulling fs layer 
[1B5fa6733c: Pulling fs layer 
[1B4adad992: Pulling fs layer 
[1Bb56a4779: Pulling fs layer 
[1B7e5e0af5: Pulling fs layer 
[1Bd9bf08cb: Pulling fs layer 
[12B6c72f57: Waiting fs layer 
[1Bfb29e345: Pulling fs layer 
[1Bec7e36f6: Pulling fs layer 
[1Bf0ba3fb3: Pulling fs layer 
[1B12e657e4: Pulling fs layer 
[1Bfad557e1: Pulling f

## Test training container locally

In [9]:
!docker run --name training_jax --runtime nvidia $IMAGE_URI --output_dir=$SAVEDMODEL_BASEDIR/localmodel

2021-06-26 23:07:26.373098: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-06-26 23:07:29.869490: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-06-26 23:07:29.877583: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-06-26 23:07:29.878335: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:00:04.0 name: Tesla T4 computeCapability: 7.5
coreClock: 1.59GHz coreCount: 40 deviceMemorySize: 14.75GiB deviceMemoryBandwidth: 298.08GiB/s
2021-06-26 23:07:29.878394: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
2021-06-26 23:07:30.124762: I tensorflow/stream_executor/platform/default

once the above container run finished:

In [10]:
%%bash
docker rm -f training_jax

training_jax


## Push image to Container Registry

In [11]:
%%bash
docker push $IMAGE_URI

The push refers to repository [gcr.io/dsparing-sandbox/jax_vertex_image_gpu]
4812223ad77a: Preparing
a4d92d39c789: Preparing
86c9e14fe193: Preparing
fda9b92a8d8e: Preparing
03dc40ebdbd6: Preparing
f5d954d2bd94: Preparing
db34a056d495: Preparing
f9b1cb8c2687: Preparing
e68443de6bca: Preparing
f5d954d2bd94: Waiting
db34a056d495: Waiting
f9b1cb8c2687: Waiting
88bb87a4088d: Preparing
29bf522d97b4: Preparing
d96c519f0898: Preparing
ff0a6aeeabc0: Preparing
29bf522d97b4: Waiting
e68443de6bca: Waiting
88bb87a4088d: Waiting
d96c519f0898: Waiting
ff0a6aeeabc0: Waiting
8a7cebfdebb3: Preparing
61546b863e43: Preparing
d2b843fb2f7a: Preparing
36c9a9d68143: Preparing
730e84ac5c5d: Preparing
a25ae1798c0c: Preparing
37a19de06c8b: Preparing
27c459f353b4: Preparing
25d03c11e857: Preparing
d5585264beff: Preparing
8a7cebfdebb3: Waiting
e39414beba01: Preparing
36c9a9d68143: Waiting
ff6af85bc8aa: Preparing
61546b863e43: Waiting
98cedd6c9734: Preparing
730e84ac5c5d: Waiting
d2b843fb2f7a: Waiting
574aa732d388:

## Run custom training job with custom container on Vertex AI

In [None]:
container_spec = {
    "image_uri": IMAGE_URI,
}

In [None]:
MACHINE_TYPE = 'n1-standard-4'
REPLICA_COUNT = 1

if USE_GPU:
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
else:
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None
    
worker_pool_spec = {
    "machine_spec": {
        "machine_type": MACHINE_TYPE,
        "accelerator_type": ACCELERATOR_TYPE,
        "accelerator_count": ACCELERATOR_COUNT,
    },
    "replica_count": REPLICA_COUNT,
    "container_spec": container_spec,
}        

In [None]:
JOB_NAME = 'jax_customcontainer_training'

custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [worker_pool_spec],
        "base_output_directory": {
            "output_uri_prefix": SAVEDMODEL_BASEDIR
        },
    },
}

In [None]:
api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"
client_options = {"api_endpoint": api_endpoint}
client = aiplatform.gapic.JobServiceClient(client_options=client_options)

parent = f"projects/{PROJECT_ID}/locations/{REGION}"

In [12]:
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)

response: name: "projects/654544512569/locations/us-central1/customJobs/3364351649366671360"
display_name: "jax_customcontainer_training"
job_spec {
  worker_pool_specs {
    machine_spec {
      machine_type: "n1-standard-4"
      accelerator_type: NVIDIA_TESLA_T4
      accelerator_count: 1
    }
    replica_count: 1
    disk_spec {
      boot_disk_type: "pd-ssd"
      boot_disk_size_gb: 100
    }
    container_spec {
      image_uri: "gcr.io/dsparing-sandbox/jax_vertex_image_gpu:latest"
    }
  }
  base_output_directory {
    output_uri_prefix: "gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output"
  }
}
state: JOB_STATE_PENDING
create_time {
  seconds: 1624748937
  nanos: 833874000
}
update_time {
  seconds: 1624748937
  nanos: 833874000
}



In [13]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (
        aiplatform_v1.JobState.JOB_STATE_QUEUED,
        aiplatform_v1.JobState.JOB_STATE_PENDING,
        aiplatform_v1.JobState.JOB_STATE_RUNNING
    ):
        break
    time.sleep(30)

JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_SUCCEEDED


## Local prediction with SavedModel

In [14]:
%%bash
gsutil ls -l $SAVEDMODEL_BASEDIR/model

         0  2021-06-26T17:50:35Z  gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output/model/
     53991  2021-06-26T23:19:02Z  gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output/model/saved_model.pb
                                 gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output/model/assets/
                                 gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output/model/variables/
TOTAL: 2 objects, 53991 bytes (52.73 KiB)


In [15]:
# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, batch_size=1)))

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [16]:
loaded_model = tf.saved_model.load(f"{SAVEDMODEL_BASEDIR}/model")
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-10.123418  ,  -6.362619  ,  -6.8052564 ,  -3.753019  ,
          -0.33722907,  -4.726612  ,  -6.3873215 ,  -4.1726475 ,
          -3.9679353 ,  -1.536783  ]], dtype=float32)>}