In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - JAX Vision Transformer

<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_jax_vision_transformer.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>

  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_jax_vision_transformer.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/notebooks/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/community/model_garden/model_garden_jax_vision_transformer.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
Open in Vertex AI Workbench
    </a>
  </td>
</table>

**_NOTE_**: This notebook has been tested in the following environment:

* Python version = 3.9

## Overview

This notebook demonstrates finetuning a [JAX ViT-B16 model](https://github.com/google-research/vision_transformer#available-vit-models) for image classification task on GPU and deploying them on Vertex AI for online prediction.

Learn more about [Generative AI Support in Vertex AI](https://cloud.google.com/blog/products/ai-machine-learning/vertex-ai-model-garden-and-generative-ai-studio).

### Objective

In this tutorial, you learn how fine-tune, deploy and predict with a Vertex AI pretrained JAX Vision Transformer based model.

This tutorial uses the following Google Cloud ML services and resources:

- Vertex AI Model Garden
- Vertex AI Training
- Vertex AI Model Registry
- Vertex AI Online Prediction

The steps performed are:

- Finetune a JAX Vision Transformer based model.
- Upload the model to [Model Registry](https://cloud.google.com/vertex-ai/docs/model-registry/introduction).
- Deploy the model on [Endpoint](https://cloud.google.com/vertex-ai/docs/predictions/using-private-endpoints).
- Run online predictions for image classification.


### Dataset

This notebook uses the [tf_flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers) and has a section which shows how to download and prepare it. You can follow similar process to use your own custom dataset too.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Installation

Install the following packages required to execute this notebook.

In [None]:
# Install the packages.
! pip3 install --upgrade google-cloud-aiplatform

### Colab only

In [None]:
# Automatically restart kernel after installs so that your environment can access the new packages.
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

## Before you begin

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

1. [Enable the Vertex AI API and Compute Engine API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,compute_component).

1. If you are running this notebook locally, you need to install the [Cloud SDK](https://cloud.google.com/sdk).


#### Set your project ID

**If you don't know your project ID**, try the following:
* Run `gcloud config list`.
* Run `gcloud projects list`.
* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

# Set the project id
! gcloud config set project {PROJECT_ID}

#### Region

You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations).

In [None]:
REGION = "us-central1"  # @param {type: "string"}

### Authenticate your Google Cloud account

Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below.

**1. Vertex AI Workbench**
* Do nothing as you are already authenticated.

**2. Local JupyterLab instance, uncomment and run:**

In [None]:
# ! gcloud auth login

**3. Colab, uncomment and run:**

In [None]:
# from google.colab import auth
# auth.authenticate_user()

**4. Service account or other**
* See how to grant Cloud Storage permissions to your service account at https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples.

### Create a Cloud Storage bucket

Create a storage bucket to store intermediate artifacts such as datasets.

In [None]:
BUCKET_URI = f"gs://your-bucket-name-{PROJECT_ID}-unique"  # @param {type:"string"}

**Only if your bucket doesn't already exist**: Run the following cell to create your Cloud Storage bucket.

In [None]:
! gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}

### Import libraries

In [None]:
import base64
import glob
import os
import random
import shutil
from datetime import datetime
from io import BytesIO

import numpy as np
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from PIL import Image

### Initialize Vertex AI SDK for Python

Initialize the Vertex AI SDK for Python for your project.

In [None]:
staging_bucket = os.path.join(BUCKET_URI, "jax_vit_staging")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=staging_bucket)

### Define constants

In [None]:
# The pre-built training docker image.
TRAIN_DOCKER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/jax-vit-train-gpu"
# The pre-built TF SavedModel conversion docker image.
MODEL_CONVERSION_DOCKER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/jax-vit-model-conversion"
# The pre-built prediction docker image.
OPTIMIZED_TF_RUNTIME_IMAGE_URI = (
    "us-docker.pkg.dev/vertex-ai-restricted/prediction/tf_opt-gpu.nightly:latest"
)

### Define common functions

This section defines functions for:

- Splitting the [tf_flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers) images into `train` and `test` folders.
- Converting a Cloud Storage path such as `gs://bucket-name` to GCSFuse path format such as `/gcsfuse/bucket-name`.
- Encoding a local image file to a string for prediction input.

In [None]:
def split(base_dir, test_ratio=0.1):
    """Splits images and moves them to train and test folders."""
    paths = glob.glob(f"{base_dir}/*/*.jpg")
    random.shuffle(paths)
    counts = dict(test=0, train=0)
    for i, path in enumerate(paths):
        split = "test" if i < test_ratio * len(paths) else "train"
        *_, class_name, basename = path.split("/")
        dst = f"{base_dir}/{split}/{class_name}/{basename}"
        if not os.path.isdir(os.path.dirname(dst)):
            os.makedirs(os.path.dirname(dst))
        shutil.move(path, dst)
        counts[split] += 1
    print(f'Moved {counts["train"]:,} train and {counts["test"]:,} test images.')


def gcs_fuse_path(path: str) -> str:
    """Try to convert path to gcsfuse path if it starts with gs:// else do not modify it."""
    path = path.strip()
    if path.startswith("gs://"):
        return "/gcs/" + path[5:]
    return path


def load_bytes_from_local_image(local_image_path, new_width=-1):
    """Returns encoded image string for prediction input."""
    image = Image.open(local_image_path)
    if new_width <= 0:
        new_image = image
    else:
        width, height = image.size
        print("original input image size: ", width, " , ", height)
        new_height = int(height * new_width / width)
        print("new input image size: ", new_width, " , ", new_height)
        new_image = image.resize((new_width, new_height))
    buffered = BytesIO()
    new_image.save(buffered, format="JPEG")
    encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return encoded_string

### Prepare dataset

If you are not using [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview#all_datasets), then you need to prepare your dataset and store it on Cloud Storage. The following example shows
how to do this for the [tf_flowers dataset](https://www.tensorflow.org/datasets/catalog/tf_flowers). If using TensorFlow Datasets, you pass
the dataset name such as `tf_flowers` to the `--config.dataset` flag and bypass this section.

In [None]:
local_flower_data_directory = "./flower_photos"  # @param {type:"string"}
FLOWER_DATA_GCS_PATH = os.path.join(BUCKET_URI, "flower_dataset")
# The flower dataset has 5 classes.
NUM_CLASSES = 5
# NOTE: For custom dataset, the training code picks the class names
# from the folder structure and then sorts them to create a mapping
# from class-index to class-name. This is why the mapping below
# looks different from default `tf_flowers` documentation.
LABEL_IDX_TO_STR = {
    0: "daisy",
    1: "dandelion",
    2: "roses",
    3: "sunflowers",
    4: "tulips",
}

In [None]:
# Download flower data to a local directory.
! rm -rf $local_flower_data_directory;
! (cd "./" && curl https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz | tar xz)

In [None]:
# Since the default file format of above "tf_flowers" dataset is
# flower_photos/{class_name}/{filename}.jpg
# we first need to split it into a "train" (90%) and a "test" (10%) set:
# flower_photos/train/{class_name}/{filename}.jpg
# flower_photos/test/{class_name}/{filename}.jpg

split(local_flower_data_directory)

In [None]:
# Move Flower data from local directory to Cloud Storage.
# This step takes around 2 mins to finish.
! gsutil -m cp -R $local_flower_data_directory/train/* $FLOWER_DATA_GCS_PATH/train/
! gsutil -m cp -R $local_flower_data_directory/test/* $FLOWER_DATA_GCS_PATH/test/

## Finetune with JAX Vision Transformer

Create and run the training job with the model-garden JAX vision transformer training docker using the Vertex AI SDK. The training uses one V100 GPU and runs for around 10 mins once the training job begins.

In [None]:
# Set up training docker arguments.

TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
JOB_NAME = "jax_vision_transformer" + TIMESTAMP

finetuning_workdir = os.path.join(BUCKET_URI, JOB_NAME)
pre_trained_dir = "gs://vit_models/imagenet21k"
docker_args_list = [
    "--config",
    "vit_jax/configs/vit.py:b16",
    "--config.dataset",
    f"{gcs_fuse_path(FLOWER_DATA_GCS_PATH)}",
    "--config.pp.train",
    "train",
    "--config.pp.test",
    "test",
    "--config.pretrained_dir",
    f"{gcs_fuse_path(pre_trained_dir)}",
    "--config.batch",
    "128",
    "--config.batch_eval",
    "128",
    "--config.base_lr",
    "0.01",
    "--config.shuffle_buffer",
    "1000",
    "--config.total_steps",
    "100",
    "--config.warmup_steps",
    "10",
    "--config.pp.crop",
    "224",
    "--workdir",
    f"{gcs_fuse_path(finetuning_workdir)}",
]
print(docker_args_list)

In [None]:
# Create and run the training job.
# Click on the generated link in the output under "View backing custom job:" to see your run in the Cloud Console.
NUM_GPU = 1
container_uri = TRAIN_DOCKER_URI
job = aiplatform.CustomContainerTrainingJob(
    display_name=JOB_NAME,
    container_uri=container_uri,
)
model = job.run(
    args=docker_args_list,
    base_output_dir=f"{finetuning_workdir}",
    replica_count=1,
    machine_type="n1-standard-4",
    accelerator_type="NVIDIA_TESLA_V100",
    accelerator_count=NUM_GPU,
)

## Convert JAX Vision Transformer model to TF SavedModel

Convert the previously fine-tuned JAX model to a TF SavedModel for online prediction.

In [None]:
# Set up model conversion docker arguments.
# Note: Many of the arguments below are similar to the training job
# such as the model name and train and test data related parameters.

jax_checkpoint_dir = finetuning_workdir

TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")
JOB_NAME = "jax_model_conversion" + TIMESTAMP
saved_model_dir = os.path.join(BUCKET_URI, "jax2tf_" + TIMESTAMP)

docker_args_list = [
    "--config",
    "vit_jax/configs/vit.py:b16",
    "--num_classes",
    f"{NUM_CLASSES}",
    "--saved_model_dir",
    f"{saved_model_dir}",
    "--jax_checkpoint_dir",
    f"{jax_checkpoint_dir}",
    "--config.pretrained_dir",
    f"{pre_trained_dir}",
    "--config.dataset",
    f"{gcs_fuse_path(FLOWER_DATA_GCS_PATH)}",
    "--config.pp.train",
    "train",
    "--config.pp.test",
    "test",
    "--config.pp.crop",
    "224",
]
print(docker_args_list)

In [None]:
# Create and run the model conversion job.
# Click on the generated link in the output under "View backing custom job:" to see your run in the Cloud Console.
container_uri = MODEL_CONVERSION_DOCKER_URI
job = aiplatform.CustomContainerTrainingJob(
    display_name=JOB_NAME,
    container_uri=container_uri,
)
model_conversion_workdir = os.path.join(BUCKET_URI, JOB_NAME)
model = job.run(
    args=docker_args_list,
    base_output_dir=f"{model_conversion_workdir}",
    replica_count=1,
    machine_type="n1-standard-4",
)

## Run online prediction

Run online prediction with the converted TF SavedModel.

Upload TF SavedModel and deploy it to an endpoint for prediction. This step takes around 15 minutes to finish.

In [None]:
jax_vit_model = aiplatform.Model.upload(
    display_name="jax_vit",
    artifact_uri=saved_model_dir,
    serving_container_image_uri=OPTIMIZED_TF_RUNTIME_IMAGE_URI,
    serving_container_args=[],
    location=REGION,
)

jax_vit_endpoint = jax_vit_model.deploy(
    deployed_model_display_name="jax_vit_deployed",
    traffic_split={"0": 100},
    machine_type="n1-standard-4",
    accelerator_type="NVIDIA_TESLA_V100",
    accelerator_count=1,
    min_replica_count=1,
    max_replica_count=1,
)

Load a local test image file, encode it into a string, send it to the endpoint for prediction, and then generate the final class label from the predicted class probabilities.

In [None]:
test_directory = os.path.join(local_flower_data_directory, "test/tulips")
local_test_image_path = os.path.join(test_directory, os.listdir(test_directory)[0])
print(local_test_image_path)
instances_list = [
    {
        "bytes_inputs": {
            "b64": load_bytes_from_local_image(local_test_image_path, new_width=240)
        }
    }
]
instances = [json_format.ParseDict(s, Value()) for s in instances_list]
results = jax_vit_endpoint.predict(instances=instances)
logits = results.predictions[0]
predicted_label = LABEL_IDX_TO_STR[int(np.argmax(logits))]
print("predicted_label: ", predicted_label)

## Cleaning up

To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud
project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.

Otherwise, you can delete the individual resources you created in this tutorial:

In [None]:
# Delete endpoint resource.
jax_vit_endpoint.delete(force=True)

# Delete model resource.
jax_vit_model.delete()

# Delete Cloud Storage objects that were created.
delete_bucket = True
if delete_bucket or os.getenv("IS_TESTING"):
    ! gsutil -m rm -r $BUCKET_URI