# Train JAX/Flax model on Vertex AI custom container and use `jax2tf` to convert to SavedModel

runs on TF2.5 [Vertex Notebook](https://cloud.google.com/vertex-ai/docs/general/notebooks)

In [1]:
import os
import time

from google.cloud import aiplatform
from google.cloud import aiplatform_v1

In [2]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = "dsparing-sandbox-bucket"
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_NAME = "jax_model"
MODEL_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}"
os.environ['MODEL_DIR'] = MODEL_DIR

In [3]:
%%bash
mkdir -p $TRAINING_APP_FOLDER/trainer

In [4]:
%%writefile {TRAINING_APP_FOLDER}/trainer/task.py
import argparse
import logging
import os
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import (
    load_mnist, FlaxMNIST
)
from jax.experimental.jax2tf.examples.saved_model_lib import (
    convert_and_save_model
)

logger = logging.getLogger()
logger.setLevel(logging.INFO)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output_dir",
        help="GCS location to export SavedModel",
        default=os.getenv("AIP_MODEL_DIR")
    )
    args = parser.parse_args().__dict__

    # need to initialize flags somehow to avoid errors in load_mnist
    flags.FLAGS(['e'])

    train_batch_size = 128
    test_batch_size = 16

    flax_mnist = FlaxMNIST()

    train_ds = load_mnist(tfds.Split.TRAIN, train_batch_size)
    test_ds = load_mnist(tfds.Split.TEST, test_batch_size)

    predict_fn, params = flax_mnist.train(
        train_ds=train_ds,
        test_ds=test_ds,
        num_epochs=2
    )

    image, _ = next(iter(train_ds))
    input_signature = tf.TensorSpec.from_tensor(
        tf.expand_dims(image[0], axis=0)
    )

    convert_and_save_model(
        jax_fn=predict_fn,
        params=params,
        model_dir=args["output_dir"],
        input_signatures=[input_signature],
    )

Overwriting training_app/trainer/task.py


We should be able to use [CustomContainerTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomContainerTrainingJob), but it gives an error, see the similar [CustomTrainingJob.run](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob.run) currently giving an error even when using the [official notebook](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom/sdk-custom-image-classification-online.ipynb).

In [5]:
MACHINE_TYPE = 'n1-standard-4'
REPLICA_COUNT = 1
JOB_NAME = 'jax_customcontainer_training'

USE_GPU = True
if USE_GPU:
    TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-gpu.2-5'
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
    IMAGE_NAME = 'jax_vertex_image_gpu'
else:
    TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-cpu.2-5'
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None
    IMAGE_NAME = 'jax_vertex_image_cpu'

os.environ['TRAINING_IMAGE'] = TRAINING_IMAGE

In [6]:
%%writefile {TRAINING_APP_FOLDER}/requirements.txt
flax
jax[cuda111]  # needs pip to run with `-f https://storage.googleapis.com/jax-releases/jax_releases.html`

Writing training_app/requirements.txt


In [7]:
%%bash
echo > $TRAINING_APP_FOLDER/Dockerfile "FROM $TRAINING_IMAGE
RUN python3 -m pip install --no-cache-dir -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_releases.html 
WORKDIR /app
COPY trainer/task.py .

ENTRYPOINT [\"python\", \"task.py\"]"

In [8]:
IMAGE_TAG = 'latest'
IMAGE_URI = 'gcr.io/{}/{}:{}'.format(PROJECT_ID, IMAGE_NAME, IMAGE_TAG)
os.environ['IMAGE_URI'] = IMAGE_URI

In [None]:
%%bash
docker build -f $TRAINING_APP_FOLDER/Dockerfile \
--tag $IMAGE_URI $TRAINING_APP_FOLDER

### Optional local test:

In [None]:
%%bash
docker run --name training_jax $IMAGE_URI --output_dir=$MODEL_DIR # --xla_cpu_compilation_enabled=true --runtime=nvidia

once the above container run finished:

In [None]:
%%bash
docker rm -f training_jax

### Push image to registry

In [None]:
%%bash
docker push $IMAGE_URI

### Submit training job

In [None]:
aiplatform.init(project=PROJECT_ID, location=REGION)

In [None]:
api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"

# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple
# requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [
            {
                "machine_spec": {
                    "machine_type": MACHINE_TYPE,
                    "accelerator_type": ACCELERATOR_TYPE,
                    "accelerator_count": ACCELERATOR_COUNT,
                },
                "replica_count": REPLICA_COUNT,
                "container_spec": {
                    "image_uri": IMAGE_URI,
                    "args": [
                        "--output_dir=" + MODEL_DIR,
                    ]
                },
            }
        ],
    },
}
parent = f"projects/{PROJECT_ID}/locations/{REGION}"
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)


In [None]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (
        aiplatform_v1.JobState.JOB_STATE_QUEUED,
        aiplatform_v1.JobState.JOB_STATE_PENDING,
        aiplatform_v1.JobState.JOB_STATE_RUNNING
    ):
        break
    time.sleep(30)

### Make sure we can actually predict with savedmodel (Optional)

In [None]:
%%bash
gsutil ls -l $MODEL_DIR

In [None]:
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags

# need to initialize flags somehow to avoid errors in load_mnist
flags.FLAGS(['e'])

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))

In [None]:
loaded_model = tf.saved_model.load(MODEL_DIR)
loaded_model.signatures["serving_default"](image_to_predict)