In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - TIMM

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_pytorch_timm.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_timm.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates running local inference using the [timm](https://github.com/rwightman/pytorch-image-models) library, finetuning the PyTorch [timm models](https://github.com/huggingface/pytorch-image-models#models), and deploying the models on [Vertex AI](https://cloud.google.com/vertex-ai).

### Objective

- Setup environment.
- Run inference locally using the timm library.
- Create a custom training job on Vertex AI to train or finetune a model.
- Deploy the model on Vertex AI for online prediction.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Before you begin

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. For finetuning, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Frestricted_image_training_nvidia_a100_80gb_gpus)** to check if your project already has the required 8 Nvidia A100 80 GB GPUs in the us-central1 region. If yes, then run this notebook in the us-central1 region. If you do not have 8 Nvidia A100 80 GPUs or have more GPU requirements than this, then schedule your job with Nvidia H100 GPUs via Dynamic Workload Scheduler using [these instructions](https://cloud.google.com/vertex-ai/docs/training/schedule-jobs-dws). For Dynamic Workload Scheduler, check the [us-central1](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) or [europe-west4](https://console.cloud.google.com/iam-admin/quotas?location=europe-west4&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) quota for Nvidia H100 GPUs. If you do not have enough GPUs, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request quota.

# @markdown 3. For serving, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_l4_gpus)** to check if your project already has the required 1 L4 GPU in the us-central1 region.  If yes, then run this notebook in the us-central1 region. If you need more L4 GPUs for your project, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request more. Alternatively, if you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus).

# @markdown > | Machine Type | Accelerator Type | Recommended Regions |
# @markdown | ----------- | ----------- | ----------- |
# @markdown | a2-ultragpu-1g | 1 NVIDIA_A100_80GB | us-central1, us-east4, europe-west4, asia-southeast1, us-east4 |
# @markdown | a3-highgpu-2g | 2 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-4g | 4 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | us-central1, europe-west4, us-west1, asia-southeast1 |

# @markdown 4. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

BUCKET_URI = "gs://"  # @param {type:"string"}

# @markdown 5. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = ""  # @param {type:"string"}

! pip3 install timm

import base64
import datetime
import importlib
import os
import urllib
import uuid

import timm
import torch
from google.cloud import aiplatform
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git

common_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util"
)

models, endpoints = {}, {}

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    if not os.environ.get("GOOGLE_CLOUD_REGION"):
        raise ValueError(
            "REGION must be set. See"
            " https://cloud.google.com/vertex-ai/docs/general/locations for"
            " available cloud locations."
        )
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}"
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
MODEL_BUCKET = os.path.join(BUCKET_URI, "timm")


# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Gets the default SERVICE_ACCOUNT.
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)


# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME

! gcloud config set project $PROJECT_ID
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/storage.admin"
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/aiplatform.user"

## Run local inference

This section runs local inference on an image using a selected pre-trained model.

In [None]:
# @title Load a pretrained model

# @markdown `MODEL_NAME`: The model you want to train and serve.

# We use a ViT model as the example.
MODEL_NAME = "vit_tiny_patch16_224"  # @param ["vit_tiny_patch16_224", "beit_base_patch16_224", "deit3_small_patch16_224", "efficientnet_b2", "mobilenetv2_100", "resnet50", "resnest50d", "convnext_base", "cspdarknet53", "inception_v4"]

model = timm.create_model(MODEL_NAME, pretrained=True)
model.eval()

In [None]:
# @title Preprocess the image

config = resolve_data_config({}, model=model)
transform = create_transform(**config)

# @markdown This example downloads a test image from GitHub Pytorch sample images folder.

SOURCE = "https://github.com/pytorch/hub/raw/master/images/dog.jpg"  # @param {type:"string"}

# @markdown You can upload and use your own images by changing `IMAGE_FILENAME`.

! wget $SOURCE -O test.jpg
IMAGE_FILENAME = "test.jpg"  # @param {type:"string"}
# @markdown  You can also copy over images stored in a GCS bucket by running the command - `! gsutil cp "gs://path/to/image" "test.jpg"`

img = Image.open(IMAGE_FILENAME).convert("RGB")
tensor = transform(img).unsqueeze(0)  # transform and add batch dimension
display(img)

In [None]:
# @title Get the model predictions

# @markdown This section gives the probability of the top 5 predictions.

with torch.no_grad():
    out = model(tensor)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
print(probabilities.shape)

url, filename = (
    "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt",
    "imagenet_classes.txt",
)
urllib.request.urlretrieve(url, filename)
with open("imagenet_classes.txt") as f:
    categories = [s.strip() for s in f.readlines()]

# Print top categories per image.
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

## Run training job

In [None]:
# @title Training job

# @markdown This section runs a regular training job on Vertex AI.

# @markdown Before creating a training job, you need to prepare the dataset for training and evaluation.
# @markdown For example, you can use [ImageNet-1K](https://huggingface.co/datasets/imagenet-1k).

# @markdown If you want to create a hyperparameter tuning job instead, you can skip to the next section.

# The prebuilt training docker uri.
TRAIN_DOCKER_URI = (
    "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-timm-train"
)

# The path to data directory on Cloud Storage without gs:// prefix.
# In the form of: <bucket-name>/path-to-data

# @markdown - `GCS_DATA_DIR` - The GCS path of the directory which contains the training data and the evaluation data.
GCS_DATA_DIR = ""  # @param {type:"string"}

# Input and output path.
data_dir = f"/gcs/{GCS_DATA_DIR}"
output_dir = os.path.join(MODEL_BUCKET, "output")

# Worker pool spec.
# Single node with multiple GPUs.
TRAINING_MACHINE_TYPE = "n1-highmem-32"
NUM_NODES = 1
TRAINING_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100"  # @param {type:"string"}
TRAINING_ACCELERATOR_COUNT = 4

# Model specific config.
job_name = f"pytorch-{MODEL_NAME}"
batch_size = 32
epochs = 2

train_job = aiplatform.CustomContainerTrainingJob(
    display_name=job_name,
    container_uri=TRAIN_DOCKER_URI,
)

train_job.run(
    args=[
        "--standalone",
        f"--nnodes={NUM_NODES}",
        f"--nproc_per_node={TRAINING_ACCELERATOR_COUNT}",
        "train.py",
        data_dir,
        f"--model={MODEL_NAME}",
        "--pretrained",
        f"--output={output_dir}",
        f"--batch-size={batch_size}",
        f"--epochs={epochs}",
    ],
    replica_count=num_nodes,
    machine_type=TRAINING_MACHINE_TYPE,
    accelerator_type=TRAINING_ACCELERATOR_TYPE,
    accelerator_count=TRAINING_ACCELERATOR_COUNT,
)

## Run hyperparameter tuning job

In [None]:
# @title Hyperparameter tuning job

# @markdown You can use a [hyperparameter tuning](https://cloud.google.com/vertex-ai/docs/training/hyperparameter-tuning-overview) job to find the best configuration of your hyperparameters.

# @markdown You can skip this section if you already trained a model in the previous section and do not want to tune the hyperparameters.

from google.cloud import aiplatform
from google.cloud.aiplatform import hyperparameter_tuning as hpt

# Input and output path.
data_dir = f"/gcs/{GCS_DATA_DIR}"
output_dir = common_util.gcs_fuse_path(MODEL_BUCKET)

# Model specific configurations.
job_name = f"pytorch-hp-{MODEL_NAME}"
batch_size = 32
epochs = 2

# Machine specs.
HPT_MACHINE_TYPE = "n1-highmem-16"
num_nodes = 1
HPT_ACCELERATOR_TYPE = "NVIDIA_TESLA_V100"  # @param {type:"string"}
HPT_ACCELERATOR_COUNT = 2

# Worker pool specs.
worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": HPT_MACHINE_TYPE,
            "accelerator_type": HPT_ACCELERATOR_TYPE,
            "accelerator_count": HPT_ACCELERATOR_COUNT,
        },
        "replica_count": num_nodes,
        "container_spec": {
            "image_uri": TRAIN_DOCKER_URI,
            "args": [
                "--standalone",
                f"--nnodes={num_nodes}",
                f"--nproc_per_node={HPT_ACCELERATOR_COUNT}",
                "train.py",
                data_dir,
                f"--model={MODEL_NAME}",
                "--pretrained",
                f"--output={output_dir}",
                f"--batch-size={batch_size}",
                f"--epochs={epochs}",
            ],
        },
    }
]

# Hyperparameter job specs.
metric_spec = {"top1_accuracy": "maximize"}
parameter_spec = {
    "lr": hpt.DoubleParameterSpec(min=0.001, max=0.05, scale="log"),
}
max_trial_count = 2
parallel_trial_count = 2

# Check quota.
common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=HPT_ACCELERATOR_TYPE,
    accelerator_count=HPT_ACCELERATOR_COUNT,
    is_for_training=True,
)


# Add labels for the finetuning job.
labels = {
    "mg-source": "notebook",
    "mg-notebook-name": "model_garden_pytorch_timm.ipynb".split(".")[0],
}

labels["mg-tune"] = "publishers-timm-models-vit"
versioned_model_id = MODEL_NAME.lower().replace("_", "-")
labels["versioned-mg-tune"] = f"{labels['mg-tune']}-{versioned_model_id}"

# Launch jobs.
training_job = aiplatform.CustomJob(
    display_name=job_name,
    worker_pool_specs=worker_pool_specs,
    labels=labels,
)
hp_job = aiplatform.HyperparameterTuningJob(
    display_name=job_name,
    custom_job=training_job,
    metric_spec=metric_spec,
    parameter_spec=parameter_spec,
    max_trial_count=max_trial_count,
    parallel_trial_count=parallel_trial_count,
)

hp_job.run()

## Deploy model

In [None]:
# @title Deploy

# @markdown This section uploads the model to Model Registry and deploys it on an Endpoint resource.
# @markdown This step will take ~15 minutes to complete.

# @markdown The uploaded models and the endpoints can be managed in the [Model Registry](https://console.cloud.google.com/vertex-ai/models) and the [Endpoints](https://console.cloud.google.com/vertex-ai/endpoints) respectively.

# The prebuilt serving docker uri.
SERVE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/pytorch-timm-serve"
# The port number used by torchserve traffic.
SERVE_PORT = 7080


SERVING_MACHINE_TYPE = "n1-standard-8"
SERVING_ACCELERATOR_TYPE = "NVIDIA_TESLA_T4"  # @param {type:"string"}


# @markdown - `MODEL_PT_PATH : ` The Cloud Storage path which contains the model checkpoint file(.pth extension).
# @markdown - for e.g. - `gs://path_to_model_best.pth.tar`.
MODEL_PT_PATH = ""  # @param {type:"string"}

# @markdown - `INDEX_TO_NAME_FILE : ` [Optional] The Cloud Storage path to index_to_name.json, including gs:// prefix.
# @markdown - for e.g. - `gs://path_to_index_to_name.json`
INDEX_TO_NAME_FILE = ""  # @param {type:"string"}

# Upload model.
if INDEX_TO_NAME_FILE:
    serving_env = {
        "MODEL_ID": "timm-mobilenetv2-100",
        "MODEL_NAME": MODEL_NAME,
        "MODEL_PT_PATH": MODEL_PT_PATH,
        "INDEX_TO_NAME_FILE": INDEX_TO_NAME_FILE,
        "DEPLOY_SOURCE": "notebook",
    }
else:
    serving_env = {
        "MODEL_ID": "timm-mobilenetv2-100",
        "MODEL_NAME": MODEL_NAME,
        "MODEL_PT_PATH": MODEL_PT_PATH,
        "DEPLOY_SOURCE": "notebook",
    }

match MODEL_NAME:
    case "vit_tiny_patch16_224":
        publisher_model_id = "vit-base-patch16-224"
    case "beit_base_patch16_224":
        publisher_model_id = "beit-base-patch16-224"
    case "deit3_small_patch16_224":
        publihser_model_id = "deit-base-patch16-224"
    case "efficientnet_b2":
        publisher_model_id = "efficientnetv2-rw-s"
    case "mobilenetv2_100":
        publisher_model_id = "mobilenetv2-100"
    case "resnet50":
        publisher_model_id = "resnet-50"
    case "resnest50d":
        publisher_model_id = "resnest50d"
    case "convnext_base":
        publisher_model_id = "convnext-base"
    case "cspdarknet53":
        publisher_model_id = "cspdarknet53"
    case "inception_v4":
        publisher_model_id = "inception-v4"
    case _:
        raise ValueError(f"Unknown model: {MODEL_NAME}")

models["timm-model"] = aiplatform.Model.upload(
    display_name=MODEL_NAME,
    serving_container_image_uri=SERVE_DOCKER_URI,
    serving_container_ports=[SERVE_PORT],
    serving_container_predict_route="/predictions/timm_serving",
    serving_container_health_route="/ping",
    serving_container_environment_variables=serving_env,
    model_garden_source_model_name=(
        f"publishers/timm/models/{publisher_model_id}",
    ),
)
# Or reuse a pre-uploaded model.
# models["timm-model"] = aiplatform.Model('projects/123456789/locations/us-central1/models/123456789@1')

# Create an endpoint.
endpoints["timm-endpoint"] = aiplatform.Endpoint.create(display_name="pytorch-timm-endpoint")
# Or reuse a pre-created endpoint.
# endpoints["timm-endpoint"] = aiplatform.Endpoint('projects/123456789/locations/us-central1/endpoints/123456789')

# Check quota.
common_util.check_quota(
                  project_id=PROJECT_ID,
                  region=REGION,
                  accelerator_type=SERVING_ACCELERATOR_TYPE,
                  accelerator_count=1,
                  is_for_training=False,)

# Deploy model to endpoint.
models["timm-model"].deploy(
    endpoint=endpoints["timm-endpoint"],
    machine_type=SERVING_MACHINE_TYPE,
    accelerator_type=SERVING_ACCELERATOR_TYPE,
    accelerator_count=1,
    traffic_percentage=100,
    service_account=SERVICE_ACCOUNT,
    system_labels={
        "NOTEBOOK_NAME": "model_garden_pytorch_timm.ipynb"
    },
)

In [None]:
# @title Predict

# You can get the deployed endpoint object by its resource name returned by Endpoint.create(). For example:
# endpoints["timm-endpoint"] = aiplatform.Endpoint('projects/816369962409/locations/us-central1/endpoints/8809168414485512192')

# @markdown Upload an image along with its filename below.
IMAGE_FILENAME = "test.jpg"  # @param {type:"string"}

# Alternatively, uncomment the following line to download a cat image for demonstration.
# ! wget http://images.cocodataset.org/val2017/000000039769.jpg -O test.jpg

with open(IMAGE_FILENAME, "rb") as f:
    image_b64 = base64.b64encode(f.read()).decode("utf-8")
instances = [{"data": {"b64": image_b64}}]

prediction = endpoints["timm-endpoint"].predict(instances=instances)
print(prediction)

## Clean Up Resources

In [None]:
# @title Delete the job, model and endpoint

# Delete the training job and hyperparameter tuning job.
train_job.delete()
hp_job.delete()

# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continuous charges that may incur.

# Undeploy model and delete endpoint.
for endpoint in endpoints.values():
    endpoint.delete(force=True)

# Delete models.
for model in models.values():
    model.delete()

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_NAME