# Train JAX/Flax model on Vertex AI custom container and use `jax2tf` to convert to SavedModel

In [1]:
import os
import time

from google.cloud import aiplatform
from google.cloud import aiplatform_v1

In [2]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = "dsparing-sandbox-bucket"
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_NAME = "jax_model_customcontainer"
SAVEDMODEL_BASEDIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}/output"
os.environ['SAVEDMODEL_BASEDIR'] = SAVEDMODEL_BASEDIR

In [3]:
%%bash
mkdir -p $TRAINING_APP_FOLDER/trainer

In [4]:
%%writefile {TRAINING_APP_FOLDER}/trainer/task.py
import argparse
import logging
import os
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import (
    load_mnist, FlaxMNIST
)
from jax.experimental.jax2tf.examples.saved_model_lib import (
    convert_and_save_model
)

logger = logging.getLogger()
logger.setLevel(logging.INFO)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output_dir",
        help="GCS location to export SavedModel",
        default=os.getenv("AIP_MODEL_DIR")
    )
    args = parser.parse_args().__dict__

    # need to initialize flags somehow to avoid errors in load_mnist
    flags.FLAGS(['e'])

    train_batch_size = 128
    test_batch_size = 16

    flax_mnist = FlaxMNIST()

    train_ds = load_mnist(tfds.Split.TRAIN, train_batch_size)
    test_ds = load_mnist(tfds.Split.TEST, test_batch_size)

    predict_fn, params = flax_mnist.train(
        train_ds=train_ds,
        test_ds=test_ds,
        num_epochs=2,
    )

    image, _ = next(iter(train_ds))
    input_signature = tf.TensorSpec.from_tensor(
        tf.expand_dims(image[0], axis=0)
    )

    convert_and_save_model(
        jax_fn=predict_fn,
        params=params,
        model_dir=args["output_dir"],
        input_signatures=[input_signature],
    )

Overwriting training_app/trainer/task.py


We should be able to use [CustomContainerTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomContainerTrainingJob), but it gives an error, see the similar [CustomTrainingJob.run](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob.run) currently giving an error even when using the [official notebook](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom/sdk-custom-image-classification-online.ipynb).

In [5]:
MACHINE_TYPE = 'n1-standard-4'
REPLICA_COUNT = 1
JOB_NAME = 'jax_customcontainer_training'

USE_GPU = True
if USE_GPU:
    TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-gpu.2-5'
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
    IMAGE_NAME = 'jax_vertex_image_gpu'
else:
    TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-cpu.2-5'
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None
    IMAGE_NAME = 'jax_vertex_image_cpu'

os.environ['TRAINING_IMAGE'] = TRAINING_IMAGE

In [6]:
%%writefile {TRAINING_APP_FOLDER}/requirements.txt
flax
jax[cuda111]  # needs pip to run with `-f https://storage.googleapis.com/jax-releases/jax_releases.html`

Overwriting training_app/requirements.txt


In [7]:
%%bash
echo > $TRAINING_APP_FOLDER/Dockerfile "FROM $TRAINING_IMAGE

COPY requirements.txt ./
RUN python3 -m pip install --no-cache-dir -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_releases.html 

WORKDIR /app
COPY trainer/task.py .

ENTRYPOINT [\"python\", \"task.py\"]"

In [8]:
IMAGE_TAG = 'latest'
IMAGE_URI = 'gcr.io/{}/{}:{}'.format(PROJECT_ID, IMAGE_NAME, IMAGE_TAG)
os.environ['IMAGE_URI'] = IMAGE_URI

In [9]:
%%bash
docker build -f $TRAINING_APP_FOLDER/Dockerfile \
--tag $IMAGE_URI $TRAINING_APP_FOLDER

Sending build context to Docker daemon  37.38kB
Step 1/6 : FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-5
 ---> 0f998c784cd6
Step 2/6 : COPY requirements.txt ./
 ---> Using cache
 ---> ffa2eb0f16a3
Step 3/6 : RUN python3 -m pip install --no-cache-dir -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_releases.html
 ---> Using cache
 ---> 0403d96c9d9d
Step 4/6 : WORKDIR /app
 ---> Using cache
 ---> d734275a2c36
Step 5/6 : COPY trainer/task.py .
 ---> Using cache
 ---> 679ee46226c9
Step 6/6 : ENTRYPOINT ["python", "task.py"]
 ---> Using cache
 ---> 5d45e431f3c4
Successfully built 5d45e431f3c4
Successfully tagged gcr.io/dsparing-sandbox/jax_vertex_image_gpu:latest


### Optional local test:

In [10]:
!docker run --name training_jax --runtime nvidia $IMAGE_URI --output_dir=$SAVEDMODEL_BASEDIR/localmodel

2021-06-26 18:37:27.582794: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: mnist/3.0.1
INFO:absl:Load dataset info from /tmp/tmpyl1i15qytfds
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Generating dataset mnist (/root/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.

Dl Completed...: 100%|██████████| 4

once the above container run finished:

In [11]:
%%bash
docker rm -f training_jax

training_jax


### Push image to registry

In [12]:
%%bash
docker push $IMAGE_URI

The push refers to repository [gcr.io/dsparing-sandbox/jax_vertex_image_gpu]
a88fb86cb67d: Preparing
9d471c683180: Preparing
2a77eb46b89c: Preparing
2d3dc1ac256c: Preparing
03dc40ebdbd6: Preparing
f5d954d2bd94: Preparing
db34a056d495: Preparing
f9b1cb8c2687: Preparing
e68443de6bca: Preparing
88bb87a4088d: Preparing
29bf522d97b4: Preparing
d96c519f0898: Preparing
ff0a6aeeabc0: Preparing
8a7cebfdebb3: Preparing
61546b863e43: Preparing
d2b843fb2f7a: Preparing
36c9a9d68143: Preparing
730e84ac5c5d: Preparing
a25ae1798c0c: Preparing
37a19de06c8b: Preparing
27c459f353b4: Preparing
25d03c11e857: Preparing
d5585264beff: Preparing
e39414beba01: Preparing
ff6af85bc8aa: Preparing
98cedd6c9734: Preparing
574aa732d388: Preparing
686978e3bf48: Preparing
80088b120579: Preparing
79a187e0621d: Preparing
72657ad6008c: Preparing
22d4dd8ed907: Preparing
d7e7872b888e: Preparing
5f08512fd434: Preparing
c7bb31fc0e08: Preparing
50858308da3d: Preparing
f5d954d2bd94: Waiting
db34a056d495: Waiting
f9b1cb8c2687: W

### Submit training job

In [13]:
aiplatform.init(project=PROJECT_ID, location=REGION)

In [14]:
api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"

# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple
# requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [
            {
                "machine_spec": {
                    "machine_type": MACHINE_TYPE,
                    "accelerator_type": ACCELERATOR_TYPE,
                    "accelerator_count": ACCELERATOR_COUNT,
                },
                "replica_count": REPLICA_COUNT,
                "container_spec": {
                    "image_uri": IMAGE_URI,
                },
            }
        ],
        "base_output_directory": {
            "output_uri_prefix": SAVEDMODEL_BASEDIR
        },
    },
}
parent = f"projects/{PROJECT_ID}/locations/{REGION}"
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)


response: name: "projects/654544512569/locations/us-central1/customJobs/4949618718201085952"
display_name: "jax_customcontainer_training"
job_spec {
  worker_pool_specs {
    machine_spec {
      machine_type: "n1-standard-4"
      accelerator_type: NVIDIA_TESLA_T4
      accelerator_count: 1
    }
    replica_count: 1
    disk_spec {
      boot_disk_type: "pd-ssd"
      boot_disk_size_gb: 100
    }
    container_spec {
      image_uri: "gcr.io/dsparing-sandbox/jax_vertex_image_gpu:latest"
    }
  }
  base_output_directory {
    output_uri_prefix: "gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output"
  }
}
state: JOB_STATE_PENDING
create_time {
  seconds: 1624732707
  nanos: 800796000
}
update_time {
  seconds: 1624732707
  nanos: 800796000
}



In [15]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (
        aiplatform_v1.JobState.JOB_STATE_QUEUED,
        aiplatform_v1.JobState.JOB_STATE_PENDING,
        aiplatform_v1.JobState.JOB_STATE_RUNNING
    ):
        break
    time.sleep(30)

JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_SUCCEEDED


### Make sure we can actually predict with savedmodel (Optional)

In [16]:
%%bash
gsutil ls -l $SAVEDMODEL_BASEDIR/model

         0  2021-06-26T17:50:35Z  gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output/model/
     54003  2021-06-26T18:52:26Z  gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output/model/saved_model.pb
                                 gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output/model/assets/
                                 gs://dsparing-sandbox-bucket/models/jax_model_customcontainer/output/model/variables/
TOTAL: 2 objects, 54003 bytes (52.74 KiB)


In [17]:
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags

# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [18]:
loaded_model = tf.saved_model.load(f"{SAVEDMODEL_BASEDIR}/model")
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[-9.033175 , -4.895049 , -2.108976 , -0.5636912, -3.901619 ,
         -2.4996982, -2.2915943, -6.8160973, -2.3445058, -6.4918838]],
       dtype=float32)>}