In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden MediaPipe with text classification

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_mediapipe_text_classification.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_mediapipe_text_classification.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates how to use [MediaPipe Model Maker](https://developers.google.com/mediapipe/solutions/model_maker) to train an on-device text classification model in Vertex AI Model Garden.

### Objective

* Train new models
  * Convert input data to training formats
  * Create [custom jobs](https://cloud.google.com/vertex-ai/docs/training/create-custom-job) to train new models
  * Export models

* Cleanup resources

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Before you begin

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. For finetuning, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Frestricted_image_training_nvidia_a100_80gb_gpus)** to check if your project already has the required 8 Nvidia A100 80 GB GPUs in the us-central1 region. If yes, then run this notebook in the us-central1 region. If you do not have 8 Nvidia A100 80 GPUs or have more GPU requirements than this, then schedule your job with Nvidia H100 GPUs via Dynamic Workload Scheduler using [these instructions](https://cloud.google.com/vertex-ai/docs/training/schedule-jobs-dws). For Dynamic Workload Scheduler, check the [us-central1](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) or [europe-west4](https://console.cloud.google.com/iam-admin/quotas?location=europe-west4&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) quota for Nvidia H100 GPUs. If you do not have enough GPUs, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request quota.

# @markdown 3. For serving, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_l4_gpus)** to check if your project already has the required 1 L4 GPU in the us-central1 region.  If yes, then run this notebook in the us-central1 region. If you need more L4 GPUs for your project, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request more. Alternatively, if you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus).

# @markdown > | Machine Type | Accelerator Type | Recommended Regions |
# @markdown | ----------- | ----------- | ----------- |
# @markdown | a2-ultragpu-1g | 1 NVIDIA_A100_80GB | us-central1, us-east4, europe-west4, asia-southeast1, us-east4 |
# @markdown | a3-highgpu-2g | 2 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-4g | 4 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |
# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | us-central1, europe-west4, us-west1, asia-southeast1 |

# @markdown 4. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

BUCKET_URI = "gs://"  # @param {type:"string"}

# @markdown 5. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.

REGION = ""  # @param {type:"string"}

! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git

import datetime
import importlib
import json
import os
import uuid

from google.cloud import aiplatform

common_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util"
)


# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    if not os.environ.get("GOOGLE_CLOUD_REGION"):
        raise ValueError(
            "REGION must be set. See"
            " https://cloud.google.com/vertex-ai/docs/general/locations for"
            " available cloud locations."
        )
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}"
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
MODEL_BUCKET = os.path.join(BUCKET_URI, "mediapipe_text_classification")


# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Gets the default SERVICE_ACCOUNT.
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)


# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME

! gcloud config set project $PROJECT_ID
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/storage.admin"
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/aiplatform.user"

REGION_PREFIX = REGION.split("-")[0]
assert REGION_PREFIX in (
    "us",
    "europe",
    "asia",
), f'{REGION} is not supported. It must be prefixed by "us", "asia", or "europe".'

## Train your customized models

In [None]:
# @title Set the dataset

# @markdown The following code block uses the [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) dataset which contains 67,349 movie reviews for training and 872 movie reviews for testing. The dataset has two classes: positive and negative movie reviews. Positive reviews are labeled with 1 and negative reviews with 0.

# @markdown The SST-2 dataset is stored as a TSV file. The only difference between the TSV and CSV formats is that TSV uses a tab `\t` character as its delimiter and CSV uses a comma `,`.

training_data_path = "gs://mediapipe-tasks/text_classifier/SST-2/train.tsv"  # @param {type:"string"}

validation_data_path = "gs://mediapipe-tasks/text_classifier/SST-2/dev.tsv"  # @param {type:"string"}

# The delimiter used in the dataset.
delimiter = "\t"  # @param {type:"string"}

# Character used to quote fields that contain special characters
# like the `delimiter`.
quotechar = "\t"  # @param {type:"string"}

# Sequence of keys for the CSV columns (represented as a comma
# separated list). If empty, the first row of the CSV file is used
# as the keys
fieldnames = ""  # @param {type:"string"}

# Column name for the input text.
text_column = "sentence"  # @param {type:"string"}

# Column name for the labels.
label_column = "label"  # @param {type:"string"}

In [None]:
# @title Set fine-tuning options

# @markdown You can pick between different model architectures to further customize your training:
# @markdown *   Average Word Embedding Model
# @markdown *   BERT-classifier

# @markdown To set the model architecture and other training parameters, adjust the below values:

model_architecture = "average_word_embedding"  # @param ["average_word_embedding", "mobilebert"]

# The learning rate to use for gradient descent-based
# optimizers. Defaults to 3e-5 for the BERT-based classifier
# and 0 for the average word-embedding classifier because
# it does not need such an optimizer.
learning_rate: float = 0.0  # @param {type:"number"}

# Batch size for training. Defaults to 32 for the average
# word-embedding classifier and 48 for the BERT-based
# classifier.
batch_size: int = 48  # @param {type:"number"}

# Number of training iterations over the dataset. Defaults
# to 10 for the average word-embedding classifier and 3
# for the BERT-based classifier.
epochs: int = 10  # @param {type:"slider", min:0, max:100, step:1}

# An integer that indicates the number of training steps per
# epoch. If set to 0, the training pipeline calculates the
# default steps per epoch as the training dataset size
# divided by batch size.
steps_per_epoch: int = 0  # @param {type:"number"}

# Controls whether the dataset is shuffled before training.
shuffle: bool = False  # @param {type:"boolean"}

# Length of the sequence to feed into the model.
seq_len: int = 256  # @param {type:"number"}

# Whether to convert all uppercase characters to lowercase
# during preprocessing.
do_lower_case: bool = True  # @param {type:"boolean"}

# The rate for dropout.
dropout_rate: float = 0.2  # @param {type:"number"}

# Dimension of the word embedding. Only used for the Average Word
# Embedding Model.
wordvec_dim: int = 16  # @param {type:"number"}

# Number of words to generate the vocabulary from data.
# Only used for the Average Word Embedding Model.
vocab_size: int = 10000  # @param {type:"number"}

In [None]:
# @title Run finetuning job

# @markdown With your training dataset and fine-tuning options prepared, you are ready to start the fine-tuning process. This process is resource intensive and can take a few minutes to a few hours depending on the model archtiecture and your available compute resources. On Vertex AI with GPU processing, the example fine-tuning below takes between 2-3 minutes to train an Average Word Embedding Model on the SST-2 dataset.

# @markdown To begin the fine-tuning process, use the following code:

EVALUATION_RESULT_OUTPUT_DIRECTORY = os.path.join(STAGING_BUCKET, "evaluation")
EVALUATION_RESULT_OUTPUT_FILE = os.path.join(
    EVALUATION_RESULT_OUTPUT_DIRECTORY, "evaluation.json"
)

EXPORTED_MODEL_OUTPUT_DIRECTORY = os.path.join(STAGING_BUCKET, "model")
EXPORTED_MODEL_OUTPUT_FILE = os.path.join(
    EXPORTED_MODEL_OUTPUT_DIRECTORY, "model.tflite"
)

model_export_path = EXPORTED_MODEL_OUTPUT_DIRECTORY
evaluation_result_path = EVALUATION_RESULT_OUTPUT_DIRECTORY

preprocessing_params = {
    "text_column": text_column,
    "label_column": label_column,
    "delimiter": delimiter,
    "quotechar": quotechar,
}
if fieldnames:
    preprocessing_params["fieldnames"] = [
        fieldname.strip() for fieldname in fieldnames.split(",")
    ]

hparams = {
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": epochs,
    "shuffle": shuffle,
}
if steps_per_epoch:
    hparams["steps_per_epoch"] = steps_per_epoch

model_options = {
    "dropout_rate": dropout_rate,
    "wordvec_dim": wordvec_dim,
    "do_lower_case": do_lower_case,
    "vocab_size": vocab_size,
    "dropout_rate": dropout_rate,
}

TRAINING_JOB_DISPLAY_NAME = "mediapipe_text_classifier_%s" % now
TRAINING_CONTAINER = f"{REGION_PREFIX}-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/mediapipe-train"
TRAINING_MACHINE_TYPE = "n1-highmem-16"
TRAINING_ACCELERATOR_TYPE = "NVIDIA_TESLA_V100"
TRAINING_ACCELERATOR_COUNT = 2


worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": TRAINING_MACHINE_TYPE,
            "accelerator_type": TRAINING_ACCELERATOR_TYPE,
            "accelerator_count": TRAINING_ACCELERATOR_COUNT,
        },
        "replica_count": 1,
        "container_spec": {
            "image_uri": TRAINING_CONTAINER,
            "command": [],
            "args": [
                "--task_name=text_classifier",
                "--training_data_path=%s" % training_data_path,
                "--validation_data_path=%s" % validation_data_path,
                "--evaluation_result_path=%s" % evaluation_result_path,
                "--model_export_path=%s" % model_export_path,
                "--model_architecture=%s" % model_architecture,
                "--preprocessing_params=%s" % json.dumps(preprocessing_params),
                "--hparams=%s" % json.dumps(hparams),
                "--model_options=%s" % json.dumps(model_options),
            ],
        },
    }
]

# Check quota.
common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=TRAINING_ACCELERATOR_TYPE,
    accelerator_count=2,
    is_for_training=True,
)


# Add labels for the finetuning job.
labels = {
    "mg-source": "notebook",
    "mg-notebook-name": "model_garden_mediapipe_text_classification.ipynb".split(".")[
        0
    ],
}

labels["mg-tune"] = "publishers-google-models-mediapipe"
versioned_model_id = model_architecture.lower().replace("_", "-")
labels["versioned-mg-tune"] = f"{labels['mg-tune']}-{versioned_model_id}"

training_job = aiplatform.CustomJob(
    display_name=TRAINING_JOB_DISPLAY_NAME,
    project=PROJECT_ID,
    worker_pool_specs=worker_pool_specs,
    staging_bucket=STAGING_BUCKET,
    labels=labels,
)

training_job.run()

In [None]:
# @title Export model

# @markdown After finetuning, you can save the Tensorflow Lite model, try it out in the [Text Classification](https://mediapipe-studio.webapps.google.com/demo/text_classifier) demo in MediaPipe Studio or integrate it with your on-device application by following the [Text classification task guide](https://developers.google.com/mediapipe/solutions/text/text_classifier). The exported model contains the generates required model metadata, as well as a classification label file.

! gsutil cp $EXPORTED_MODEL_OUTPUT_FILE text_classification_model.tflite

## Clean up

In [None]:
# @title Clean up training jobs and buckets
# @markdown Delete temporary GCS buckets.

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_NAME

# Delete training data and jobs.
if training_job.list(filter=f'display_name="{TRAINING_JOB_DISPLAY_NAME}"'):
    training_job.delete()