# Train JAX/Flax model on Vertex AI custom container and use `jax2tf` to convert to SavedModel

runs on TF2.5 [Vertex Notebook](https://cloud.google.com/vertex-ai/docs/general/notebooks)

In [2]:
import os
import time

from google.cloud import aiplatform
from google.cloud import aiplatform_v1

In [3]:
PROJECT_ID = !(gcloud config get-value core/project)
PROJECT_ID = PROJECT_ID[0]

REGION = "us-central1"
os.environ['REGION'] = REGION

BUCKET_NAME = "dsparing-sandbox-bucket"
os.environ['BUCKET_NAME'] = BUCKET_NAME
# Use a regional bucket in the above region you have rights to.
# Create if needed:
# !gsutil mb -l ${REGION} gs://${BUCKET_NAME}

TRAINING_APP_FOLDER = 'training_app'
os.environ['TRAINING_APP_FOLDER'] = TRAINING_APP_FOLDER

MODEL_NAME = "jax_model"
MODEL_DIR = f"gs://{BUCKET_NAME}/models/{MODEL_NAME}"
os.environ['MODEL_DIR'] = MODEL_DIR

In [4]:
%%bash
mkdir -p $TRAINING_APP_FOLDER/trainer

In [5]:
%%writefile {TRAINING_APP_FOLDER}/trainer/task.py
import argparse
import logging
import os
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist, FlaxMNIST
from jax.experimental.jax2tf.examples.saved_model_lib import convert_and_save_model

logger = logging.getLogger()
logger.setLevel(logging.INFO)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output_dir",
        help="GCS location to export SavedModel",
        default = os.getenv("AIP_MODEL_DIR")
    )
    args = parser.parse_args().__dict__

    
    flags.FLAGS(['e']) # need to initialize flags somehow to avoid errors in load_mnist

    train_batch_size = 128
    test_batch_size = 16


    flax_mnist = FlaxMNIST()

    train_ds = load_mnist(tfds.Split.TRAIN, train_batch_size)
    test_ds = load_mnist(tfds.Split.TEST, test_batch_size)

    predict_fn, params = flax_mnist.train(
        train_ds=train_ds,
        test_ds=test_ds,
        num_epochs=2
    )

    image, _ = next(iter(train_ds))
    input_signature = tf.TensorSpec.from_tensor(tf.expand_dims(image[0], axis=0))

    convert_and_save_model(
        jax_fn=predict_fn,
        params=params,
        model_dir=args["output_dir"],
        input_signatures=[input_signature],
    )

Overwriting training_app/trainer/task.py


We should be able to use [CustomContainerTrainingJob](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomContainerTrainingJob), but it gives an error, see the similar [CustomTrainingJob.run](https://googleapis.dev/python/aiplatform/latest/aiplatform.html#google.cloud.aiplatform.CustomTrainingJob.run) currently giving an error even when using the [official notebook](https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom/sdk-custom-image-classification-online.ipynb).

In [6]:
MACHINE_TYPE = 'n1-standard-4'
REPLICA_COUNT = 1
JOB_NAME = 'jax_customcontainer_training'

USE_GPU = True
if USE_GPU:
    TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-gpu.2-5'
    ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"
    ACCELERATOR_COUNT = 1
    IMAGE_NAME='jax_vertex_image_gpu'
else:
    TRAINING_IMAGE = 'gcr.io/deeplearning-platform-release/tf2-cpu.2-5'
    ACCELERATOR_TYPE = None
    ACCELERATOR_COUNT = None
    IMAGE_NAME='jax_vertex_image_cpu'

os.environ['TRAINING_IMAGE'] = TRAINING_IMAGE

In [7]:
%%bash
echo > $TRAINING_APP_FOLDER/Dockerfile "FROM $TRAINING_IMAGE
RUN python -m pip install -U jax jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html 
RUN python -m pip install -U flax 
WORKDIR /app
COPY trainer/task.py .

ENTRYPOINT [\"python\", \"task.py\"]"

In [8]:
IMAGE_TAG = 'latest'
IMAGE_URI = 'gcr.io/{}/{}:{}'.format(PROJECT_ID, IMAGE_NAME, IMAGE_TAG)
os.environ['IMAGE_URI'] = IMAGE_URI

In [9]:
%%bash
docker build -f $TRAINING_APP_FOLDER/Dockerfile --tag $IMAGE_URI $TRAINING_APP_FOLDER

Sending build context to Docker daemon  35.84kB
Step 1/6 : FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-5
 ---> 950969e5619c
Step 2/6 : RUN python -m pip install -U jax jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
 ---> Using cache
 ---> 62b0407dec24
Step 3/6 : RUN python -m pip install -U flax
 ---> Using cache
 ---> 8424950476a1
Step 4/6 : WORKDIR /app
 ---> Using cache
 ---> fc9e144a502a
Step 5/6 : COPY trainer/task.py .
 ---> Using cache
 ---> 5a8d67bded7d
Step 6/6 : ENTRYPOINT ["python", "task.py"]
 ---> Using cache
 ---> 51dc3109ad96
Successfully built 51dc3109ad96
Successfully tagged gcr.io/dsparing-sandbox/jax_vertex_image_gpu:latest


In [10]:
IMAGE_URI

'gcr.io/dsparing-sandbox/jax_vertex_image_gpu:latest'

### Optional local test:

once the above container run finished:

### Push image to registry

In [None]:
%%bash
docker push $IMAGE_URI

### Submit training job

In [14]:
aiplatform.init(project=PROJECT_ID, location=REGION)

In [15]:
api_endpoint: str = f"{REGION}-aiplatform.googleapis.com"

# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
custom_job = {
    "display_name": JOB_NAME,
    "job_spec": {
        "worker_pool_specs": [
            {
                "machine_spec": {
                    "machine_type": MACHINE_TYPE,
                    "accelerator_type": ACCELERATOR_TYPE,
                    "accelerator_count": ACCELERATOR_COUNT,
                },
                "replica_count": REPLICA_COUNT,
                "container_spec": {
                    "image_uri": IMAGE_URI,
                    "args": [
                        "--output_dir=" + MODEL_DIR,
                    ]
                },
            }
        ],
    },
}
parent = f"projects/{PROJECT_ID}/locations/{REGION}"
response = client.create_custom_job(parent=parent, custom_job=custom_job)
print("response:", response)


response: name: "projects/654544512569/locations/us-central1/customJobs/6638415801906888704"
display_name: "jax_customcontainer_training"
job_spec {
  worker_pool_specs {
    machine_spec {
      machine_type: "n1-standard-4"
      accelerator_type: NVIDIA_TESLA_T4
      accelerator_count: 1
    }
    replica_count: 1
    disk_spec {
      boot_disk_type: "pd-ssd"
      boot_disk_size_gb: 100
    }
    container_spec {
      image_uri: "gcr.io/dsparing-sandbox/jax_vertex_image_gpu:latest"
      args: "--output_dir=gs://dsparing-sandbox-bucket/models/jax_model"
    }
  }
}
state: JOB_STATE_PENDING
create_time {
  seconds: 1624287896
  nanos: 263811000
}
update_time {
  seconds: 1624287896
  nanos: 263811000
}



In [16]:
while True:
    job_state = client.get_custom_job(name=response.name).state
    print(job_state)
    if job_state not in (aiplatform_v1.JobState.JOB_STATE_QUEUED, aiplatform_v1.JobState.JOB_STATE_PENDING, aiplatform_v1.JobState.JOB_STATE_RUNNING):
        break
    time.sleep(30)

JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_PENDING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_SUCCEEDED


### Make sure we can actually predict with savedmodel (Optional)

In [None]:
%%bash
pip3 install --user --upgrade jax jaxlib==0.1.67+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [18]:
%%bash
gsutil ls -l $MODEL_DIR

         0  2021-06-18T01:17:21Z  gs://dsparing-sandbox-bucket/models/jax_model/
     54003  2021-06-21T15:15:10Z  gs://dsparing-sandbox-bucket/models/jax_model/saved_model.pb
                                 gs://dsparing-sandbox-bucket/models/jax_model/assets/
                                 gs://dsparing-sandbox-bucket/models/jax_model/variables/
TOTAL: 2 objects, 54003 bytes (52.74 KiB)


In [19]:
from jax.experimental.jax2tf.examples.mnist_lib import load_mnist
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags

flags.FLAGS(['e']) # need to initialize flags somehow to avoid errors in load_mnist

image_to_predict, _ = next(iter(load_mnist(tfds.Split.TEST, 1)))

INFO:absl:Load dataset info from /home/jupyter/tensorflow_datasets/mnist/3.0.1
INFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.splits from disk and from code do not match. Keeping the one from code.
INFO:absl:Field info.module_name from disk and from code do not match. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/jupyter/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Constructing tf.data.Dataset mnist for split test, from /home/jupyter/tensorflow_datasets/mnist/3.0.1


In [20]:
loaded_model = tf.saved_model.load(MODEL_DIR)
loaded_model.signatures["serving_default"](image_to_predict)

{'output_0': <tf.Tensor: shape=(1, 10), dtype=float32, numpy=
 array([[ -6.717919  , -14.699879  ,  -1.9437121 ,  -3.6636014 ,
          -7.9636226 ,  -9.026531  , -10.630737  ,  -0.21502149,
          -5.097106  ,  -4.083507  ]], dtype=float32)>}