# Deploy UL2 model on GPUs with NVIDIA Triton using NVIDIA's FasterTransformers backend

## Define variables

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}
import os

# Get your Google Cloud project ID using google.auth
import google.auth

_, PROJECT_ID = google.auth.default()
print("Project ID: ", PROJECT_ID)

# validate PROJECT_ID
if PROJECT_ID == "" or PROJECT_ID is None or PROJECT_ID == "[your-project-id]":
    print(
        f"Please set your project id before proceeding to next step. Currently it's set as {PROJECT_ID}"
    )

In [None]:
REGION = "us-central1"  # @param {type: "string"}

In [None]:
from datetime import datetime

TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")

In [None]:
BUCKET_NAME = "gs://[your-bucket-name]"  # @param {type:"string"}
BUCKET_NAME = "gs://cloud-ai-platform-2f444b6a-a742-444b-b91a-c7519f51bd77"  # @param {type:"string"}

## Setting up Artifact Repository

- Enable the Artifact Registry API service for your project.

In [None]:
! gcloud services enable artifactregistry.googleapis.com

- Create a private Docker repository to push the container images

In [None]:
DOCKER_ARTIFACT_REPO = "llms-on-vertex-ai"

In [None]:
# create a new Docker repository with your region with the description
! gcloud artifacts repositories create {DOCKER_ARTIFACT_REPO} \
    --repository-format=docker \
    --location={REGION} \
    --description="Triton Docker repository"

# verify that your repository was created.
! gcloud artifacts repositories list \
    --location={REGION} \
    --filter="name~"{DOCKER_ARTIFACT_REPO}

In [None]:
! gcloud auth configure-docker {REGION}-docker.pkg.dev --quiet

- Configure authentication to the private repo

Before you push or pull container images, configure Docker to use the gcloud command-line tool to authenticate requests to Artifact Registry for your region.

## Convert JAX Checkpoint to FasterTransformers 

- [ ] Conversion steps
    - [ ] Download JAX checkpoint from GCS
    - [ ] Run conversion script to convert from JAX to FT
    - [ ] Validate conversion is running fine
    - [ ] Organize model repository as Triton's spec
    - [ ] Upload FT checkpoint to GCS
- [ ] Build container image to run conversion
    - [ ] Prepare docker
- [ ] Configure infra to run conversion
    - [ ] Pick Compute choice: GCE, Vertex AI Training Custom Job, GKE, Cloud Batch
    - [ ] Configure compute spec
- [ ] Run conversion

### JAX --> FT Conversion Script

In [None]:
%%writefile src/run-converter-jax-to-fastertransformer.sh

#!/bin/bash

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# gcs_jax_checkpoint_location
# gcs_ft_checkpoint_location
# tensor-parallelism (1,2,4,8)

# Set up a global error handler
err_handler() {
    echo "Error on line: $1"
    echo "Caused by: $2"
    echo "That returned exit status: $3"
    echo "Aborting..."
    exit $3
}

trap 'err_handler "$LINENO" "$BASH_COMMAND" "$?"' ERR

TIMESTAMP=`date "+%Y-%m-%d %H:%M:%S"`

echo "NVIDIA Driver version"
nvidia-smi

# Set variables
GCS_JAX_CHECKPOINT=$1
echo "GCS_JAX_CHECKPOINT = ${GCS_JAX_CHECKPOINT}"

GCS_FT_CHECKPOINT=$2
echo "GCS_FT_CHECKPOINT = ${GCS_FT_CHECKPOINT}"

if [[ -z $3 ]];
then 
    TENSOR_PARALLELISM=1
else
    TENSOR_PARALLELISM=$3
fi
echo "TENSOR_PARALLELISM = ${TENSOR_PARALLELISM}"

# Copy JAX checkpoint to local directory
LOCAL_JAX_CHECKPOINT="/models/"$(basename $GCS_JAX_CHECKPOINT)
mkdir -p $LOCAL_JAX_CHECKPOINT
echo "[INFO] ${TIMESTAMP} Copying JAX checkpoint from ${GCS_JAX_CHECKPOINT} to local ${LOCAL_JAX_CHECKPOINT}"
SECONDS=0
gcloud storage cp --quiet --recursive $GCS_JAX_CHECKPOINT /models/
echo "[INFO] Completed copying JAX checkpoint locally in ${SECONDS}s"

# Creating local directories for saving FasterTransformer model
LOCAL_FT_CHECKPOINT="/models/ul2-ft"
mkdir -p $LOCAL_FT_CHECKPOINT

# Run JAX to FasterTransformer 
echo "[INFO] ${TIMESTAMP} Converting JAX checkpoint to FasterTransformer"
SECONDS=0
cd /FasterTransformer/build && \
   python3 ../examples/pytorch/t5/utils/jax_t5_ckpt_convert.py \
   $LOCAL_JAX_CHECKPOINT \
   $LOCAL_FT_CHECKPOINT \
   --tensor-parallelism $TENSOR_PARALLELISM
echo "[INFO] ${TIMESTAMP} Completed converting JAX checkpoint to FasterTransformer in ${SECONDS}s"

# Organize model repository for Triton serving
echo "[INFO] ${TIMESTAMP} Organizing model repository for serving"
cd $LOCAL_FT_CHECKPOINT
mkdir -p $LOCAL_FT_CHECKPOINT/ul2/1
mv $LOCAL_FT_CHECKPOINT/1-gpu $LOCAL_FT_CHECKPOINT/ul2/1/

# Format Triton config for UL2
cp /triton/config.pbtxt $LOCAL_FT_CHECKPOINT/ul2/config.pbtxt
sed -i -e 's!@@MODEL_CHECKPOINT_PATH@@!'$GCS_FT_CHECKPOINT'ul2/1/1-gpu!g' $LOCAL_FT_CHECKPOINT/ul2/config.pbtxt 
sed -i -e 's!@@TENSOR_PARA_SIZE@@!'$TENSOR_PARALLELISM'!g' $LOCAL_FT_CHECKPOINT/ul2/config.pbtxt 
sed -i -e 's!@@PIPELINE_PARA_SIZE@@!'$TENSOR_PARALLELISM'!g;' $LOCAL_FT_CHECKPOINT/ul2/config.pbtxt 

# Uploaded FasterTransformer checkpoint to Cloud Storage bucket
echo "[INFO] ${TIMESTAMP} Copying FasterTransformer model from local ${LOCAL_FT_CHECKPOINT} to ${GCS_FT_CHECKPOINT}"
SECONDS=0
gcloud storage cp --recursive $LOCAL_FT_CHECKPOINT $GCS_FT_CHECKPOINT -q
echo "[INFO] Completed copying FasterTransformer model to Cloud Storage in ${SECONDS}s"

### Build container image to run conversion process

In [None]:
%%writefile src/Dockerfile.jax-to-fastertransformer

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

FROM nvcr.io/nvidia/pytorch:22.07-py3

# Install gcloud SDK
RUN apt-get install apt-transport-https ca-certificates gnupg
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg  add - && apt-get update -y && apt-get install google-cloud-cli -y

# Clone FasterTransformer repo
WORKDIR /
RUN git clone --branch=main https://github.com/NVIDIA/FasterTransformer.git

# Build FasterTransformer
# Specify SM version as 80 for A100 GPUs; and 70 for V100
WORKDIR /FasterTransformer
RUN mkdir build && \
    cd build && \
    cmake -DSM=80 -DCMAKE_BUILD_TYPE=Release -DBUILD_PYT=ON -DBUILD_MULTI_GPU=ON .. && \
    make -j12

# Install other required packages
WORKDIR /FasterTransformer
RUN pip install -r /FasterTransformer/examples/pytorch/t5/requirement.txt
RUN pip install transformers==4.20.1 zarr
RUN pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Copy conversion script
COPY ul2-inference/converter/jax_t5_ckpt_convert.py ul2-inference/converter/ul2_config.template /FasterTransformer/examples/pytorch/t5/utils/
COPY run-converter-jax-to-fastertransformer.sh /run-converter-jax-to-fastertransformer.sh

# Copy Triton related config
RUN mkdir -p /triton
COPY ul2-inference/triton/config.pbtxt /triton/config.pbtxt

- Build the image and tag the Artifact Registry path that the image will be pushed to

In [None]:
# JAX to FasterTransformers container image name
JAX_TO_FT_IMAGE_NAME = "jax-to-fastertransformer"
# JAX_TO_FT_IMAGE_URI = f"{REGION}-docker.pkg.dev/{PROJECT_ID}/{DOCKER_ARTIFACT_REPO}/{JAX_TO_FT_IMAGE_NAME}"
JAX_TO_FT_IMAGE_URI = f"gcr.io/{PROJECT_ID}/{DOCKER_ARTIFACT_REPO}/{JAX_TO_FT_IMAGE_NAME}"

In [None]:
JAX_TO_FT_IMAGE_URI

- Create Cloud Build configuration file

In [None]:
%%writefile src/cloudbuild.yaml

# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

steps:
- name: 'gcr.io/cloud-builders/docker'
  args: ['build', '-t', '$_IMAGE_URI', '$_FILE_LOCATION', '-f', '$_FILE_LOCATION/Dockerfile.$_DOCKERNAME']
images:
- '$_IMAGE_URI'

In [None]:
! ls -ltr ./src/

In [None]:
FILE_LOCATION = './src'
! gcloud builds submit \
      --region $REGION \
      --config src/cloudbuild.yaml \
      --substitutions _DOCKERNAME=$JAX_TO_FT_IMAGE_NAME,_IMAGE_URI=$JAX_TO_FT_IMAGE_URI,_FILE_LOCATION=$FILE_LOCATION \
      --timeout "2h" \
      --machine-type=e2-highcpu-32 \
      --quiet

- Use local docker to build and push the container image to Artifact Registry. The Artifact Registry image URI will be used when creating the Vertex AI model resource.

In [None]:
# ! docker build -t $JAX_TO_FT_IMAGE_URI . -f Dockerfile.jax-to-fastertransformer
# ! docker push $JAX_TO_FT_IMAGE_URI

In [None]:
# run script from docker on a GCE machine
# docker run -ti \
#   --gpus all \
#   --shm-size 5g \
#   -p 9999:9999 \
#   -v $PWD:$PWD \
#   -v /home/jupyter/models:/models \
#   -w $PWD \
#   --name ft-converter \
#   us-central1-docker.pkg.dev/rthallam-demo-project/llms-on-vertex-ai/jax-to-fastertransformer \
#   bash
#  /run-converter-jax-to-fastertransformer.sh \
#    "gs://se-checkpoints/ul2-xsum" \
#    "gs://cloud-ai-platform-2f444b6a-a742-444b-b91a-c7519f51bd77/llm/models/ul2/ft/"

### Configure Vertex AI Training CustomJob  to run JAX --> FT conversion

In [None]:
%%bash -s $BUCKET_NAME $JAX_TO_FT_IMAGE_URI

BUCKET_NAME=$1
JAX_TO_FT_IMAGE_URI=$2

cat << EOF > ./src/config.jax-to-fastertransformer.yaml

baseOutputDirectory:
    outputUriPrefix: ${BUCKET_NAME}/llm/jobs/ul2-jax-to-f/$(date "+%Y%m%d-%H%M%S")/
workerPoolSpecs:
  -
    machineSpec:
      machineType: a2-highgpu-1g 
      acceleratorType: NVIDIA_TESLA_A100
      acceleratorCount: 1
    replicaCount: 1
    diskSpec:
      bootDiskType: pd-ssd
      bootDiskSizeGb: 500
    containerSpec:
      imageUri: ${JAX_TO_FT_IMAGE_URI}
      command:
      - /run-converter-jax-to-fastertransformer.sh
      args:
      - "gs://se-checkpoints/ul2-xsum"
      - "gs://cloud-ai-platform-2f444b6a-a742-444b-b91a-c7519f51bd77/llm/models/ul2/ft/"
EOF

In [None]:
%%bash -s $BUCKET_NAME $JAX_TO_FT_IMAGE_URI

BUCKET_NAME=$1
JAX_TO_FT_IMAGE_URI=$2

cat << EOF > ./src/config.jax-to-fastertransformer.yaml

baseOutputDirectory:
    outputUriPrefix: ${BUCKET_NAME}/llm/jobs/ul2-jax-to-f/$(date "+%Y%m%d-%H%M%S")/
workerPoolSpecs:
  -
    machineSpec:
      machineType: a2-highgpu-1g 
      acceleratorType: NVIDIA_TESLA_A100
      acceleratorCount: 1
    replicaCount: 1
    diskSpec:
      bootDiskType: pd-ssd
      bootDiskSizeGb: 500
    containerSpec:
      imageUri: ${JAX_TO_FT_IMAGE_URI}
      command:
      - /bin/bash
      - /run-converter-jax-to-fastertransformer.sh
      - "gs://se-checkpoints/ul2-xsum"
      - "gs://cloud-ai-platform-2f444b6a-a742-444b-b91a-c7519f51bd77/llm/models/ul2/ft/"

EOF

In [None]:
! cat ./src/config.jax-to-fastertransformer.yaml

### Run conversion on Vertex AI Training CustomJOb

In [None]:
! gcloud beta ai custom-jobs create \
  --display-name=llm-ul2-jax-to-ft-conversion \
  --region=$REGION \
  --project=$PROJECT_ID \
  --config=./src/config.jax-to-fastertransformer.yaml

## Deploying FasterTransformer Checkpoint on Vertex AI Prediction Endpoints

In [None]:
from google.cloud import aiplatform as aip

In [None]:
aip.init(project=PROJECT_ID, staging_bucket=BUCKET_NAME)

In [None]:
MODEL_ARTIFACTS_REPOSITORY = f"{BUCKET_NAME}/llm/models/ul2/ft/ul2-ft"

MODEL_NAME = "llms-ul2-xsum-inference"
MODEL_DISPLAY_NAME = f"triton-{MODEL_NAME}"
ENDPOINT_DISPLAY_NAME = f"endpoint-{MODEL_NAME}"

# requires allow listing
NGC_TRITON_IMAGE_URI = "nvcr.io/ea-bignlp/bignlp-inference:22.08-py3"

# prediction container image name
IMAGE_NAME = "nemo-bignlp-triton-inference"
IMAGE_URI = f"gcr.io/{PROJECT_ID}/llms-on-vertex-ai/{IMAGE_NAME}"

In [None]:
print(f"MODEL_DISPLAY_NAME = {MODEL_DISPLAY_NAME}")
print(f"IMAGE_URI = {IMAGE_URI}")
print(f"MODEL_ARTIFACTS_REPOSITORY = {MODEL_ARTIFACTS_REPOSITORY}")

- Upload model FT checkpoint to Model Registry

In [None]:
PARENT_MODEL = 'projects/560224572293/locations/us-central1/models/1999664205250166784@1'

In [None]:
model = aip.Model.upload(
    display_name=MODEL_DISPLAY_NAME,
    serving_container_image_uri=IMAGE_URI,
    # artifact_uri=MODEL_ARTIFACTS_REPOSITORY,
    parent_model=PARENT_MODEL,
    sync=True,
    serving_container_args=[
        f'--model-repository={MODEL_ARTIFACTS_REPOSITORY}',
        '--strict-model-config=true',
        '--log-verbose=99',
        '--log-error=1']
)

model.resource_name

In [None]:
endpoint = aip.Endpoint.create(display_name=ENDPOINT_DISPLAY_NAME)

In [None]:
model = aip.Model('projects/560224572293/locations/us-central1/models/1999664205250166784@2')
endpoint = aip.Endpoint('projects/560224572293/locations/us-central1/endpoints/5911401121435353088')

- Create [custom service account](https://cloud.google.com/vertex-ai/docs/general/custom-service-account) to access model repository

In [None]:
! gcloud iam service-accounts list

In [None]:
# create role
! gcloud iam roles create storage_buckets_viewer \
    --project=$PROJECT_ID \
    --title="storage.buckets.get" \
    --description="Storage object reader" \
    --permissions="storage.buckets.get"

In [None]:
CUSTOM_SA_NAME="vertex-ai-llm-sa"

! gcloud iam service-accounts create $CUSTOM_SA_NAME \
    --description="Custom service account to attach to Vertex AI resources used for LLMs" \
    --display-name=$CUSTOM_SA_NAME 

In [None]:
! gcloud projects add-iam-policy-binding $PROJECT_ID \
    --member="serviceAccount:"$CUSTOM_SA_NAME"@"$PROJECT_ID".iam.gserviceaccount.com" \
    --role="projects/"$PROJECT_ID"/roles/storage_buckets_viewer" \
    --condition=None

In [None]:
! set -x && PROJECT_NUMBER=$(gcloud projects list \
--filter="$(gcloud config get-value project)" \
--format="value(PROJECT_NUMBER)") && echo $PROJECT_NUMBER

In [None]:
! set -x && \
  CUSTOM_SA_NAME="vertex-ai-llm-sa@rthallam-demo-project.iam.gserviceaccount.com" && \
  PROJECT_NUMBER=$(gcloud projects list \
    --filter="$(gcloud config get-value project)" \
    --format="value(PROJECT_NUMBER)") && \
  AI_PLATFORM_SERVICE_AGENT="service-"$PROJECT_NUMBER"@gcp-sa-aiplatform.iam.gserviceaccount.com" && \
  gcloud iam service-accounts add-iam-policy-binding $CUSTOM_SA_NAME \
    --role=roles/iam.serviceAccountAdmin \
    --member="serviceAccount:"$AI_PLATFORM_SERVICE_AGENT

- Deploy model to endpoint

In [None]:
%%time

! export REGION=us-central1 && \
  export MODEL_DISPLAY_NAME="triton-llms-ul2-xsum-inference" && \
  export ENDPOINT_ID=5911401121435353088 && \
  export MODEL_ID=1999664205250166784 && \
  CUSTOM_SA_NAME="vertex-ai-llm-sa@rthallam-demo-project.iam.gserviceaccount.com" && \
  set -x && \
  gcloud alpha ai endpoints deploy-model $ENDPOINT_ID \
    --region=$REGION \
    --model=$MODEL_ID \
    --display-name=$MODEL_DISPLAY_NAME \
    --machine-type=a2-highgpu-1g \
    --accelerator=count=1,type=nvidia-tesla-a100 \
    --enable-access-logging \
    --enable-container-logging \
    --service-account=$CUSTOM_SA_NAME

In [None]:
traffic_percentage = 100
machine_type = "a2-highgpu-1g"
accelerator_type = "NVIDIA_TESLA_A100"
accelerator_count = 1
min_replica_count = 1
max_replica_count = 2

model.deploy(
    endpoint=endpoint,
    deployed_model_display_name=MODEL_DISPLAY_NAME,
    machine_type=machine_type,
    min_replica_count=min_replica_count,
    max_replica_count=max_replica_count,
    traffic_percentage=traffic_percentage,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count,
    sync=True,
)

endpoint.name