[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/multimodal_gemma3_demo.ipynb)

# Gemma3 Multimodal Inference/Training Demo

## Overview

This notebook demonstrates MaxText's multimodal features, using Gemma3-4B as an example:
- Convert an orbax checkpoint from HuggingFace.
- Apply decoding on a single image input.
- Apply SFT to the converted checkpoint on ChartQA dataset.

Given the relative small size of Gemma3-4B, you can run this colab on a v4-8, v5p-8 or v6e-4 TPU VM. However, we recommend using [XPK](https://github.com/AI-Hypercomputer/maxtext/blob/64d6d9b425e78dde94c37a82bb13ba5606e74b1b/docs/guides/run_maxtext_via_xpk.md) to schedule a training workload on a TPU cluster for better performance.

### Get Your Hugging Face Token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:
    *   [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need to paste it in `HF_TOKEN`.

In [None]:
#Install maxtext and dependencies
# 1. Install uv, a fast Python package installer
!pip install uv

# 2. Install MaxText and its dependencies
!uv pip install maxtext --resolution=lowest
!install_maxtext_github_deps

In [None]:
import os
import MaxText

# Get the root directory of the MaxText
MAXTEXT_REPO_ROOT=os.path.dirname(MaxText.__file__)

# Define model name
MODEL_NAME="gemma3-4b"

# Use either a GCS path or a local path for the model checkpoint
MODEL_CHECKPOINT_PATH = f"gs://your-gcs-bucket/{MODEL_NAME}"

# Replace with your actual Hugging Face token
HF_TOKEN = "your_huggingface_token_here"

## Convert Checkpoint from HuggingFace

In [None]:
!python3 -m MaxText.utils.ckpt_conversion.to_maxtext \
    $MAXTEXT_REPO_ROOT/configs/base.yml \
    model_name=$MODEL_NAME \
    hf_access_token=$HF_TOKEN \
    base_output_directory=$MODEL_CHECKPOINT_PATH \
    use_multimodal=true \
    scan_layers=false

## Decode on One Image

In [None]:
!python -m MaxText.decode \
    $MAXTEXT_REPO_ROOT/configs/base.yml \
    model_name=$MODEL_NAME \
    tokenizer_path=assets/tokenizer.gemma3 \
    load_parameters_path=$MODEL_CHECKPOINT_PATH/0/items \
    per_device_batch_size=1 \
    run_name=ht_test max_prefill_predict_length=272 \
    max_target_length=300 \
    steps=1 \
    async_checkpointing=false \
    scan_layers=false \
    use_multimodal=true \
    prompt='Describe image <start_of_image>' \
    image_path=$MAXTEXT_REPO_ROOT/test_assets/test_image.jpg \
    attention='dot_product'

## Supervised Finetuning (SFT)

Running the cell below will trigger a 10-step SFT on your TPU VM (v4-8, v5p-8, or v6e-4). However, we recommend using [XPK](https://github.com/AI-Hypercomputer/maxtext/blob/64d6d9b425e78dde94c37a82bb13ba5606e74b1b/docs/guides/run_maxtext_via_xpk.md) to schedule a training workload on a TPU cluster for better performance. After the SFT, the result checkpoint will be saved to `BASE_OUTPUT_DIRECTORY`.

In [None]:
# Define SFT output directory
BASE_OUTPUT_DIRECTORY=f"gs://your-gcs-bucket/{MODEL_NAME}-sft"
PRE_TRAINED_MODEL_TOKENIZER="google/gemma-3-4b-it"
WORKLOAD_NAME=f"{MODEL_NAME}-chartqa-sft"
STEPS=10
PER_DEVICE_BATCH_SIZE=1

!python -m MaxText.sft_trainer \
    $MAXTEXT_REPO_ROOT/configs/sft-vision-chartqa.yml \
    run_name=$WORKLOAD_NAME \
    model_name=$MODEL_NAME \
    tokenizer_path=$PRE_TRAINED_MODEL_TOKENIZER \
    hf_access_token=$HF_TOKEN \
    load_parameters_path=$MODEL_CHECKPOINT_PATH/0/items \
    base_output_directory=$BASE_OUTPUT_DIRECTORY \
    per_device_batch_size=$PER_DEVICE_BATCH_SIZE \
    steps=$STEPS \
    max_prefill_predict_length=1024 \
    max_target_length=2048 \
    checkpoint_period=1000 \
    scan_layers=False \
    async_checkpointing=True \
    enable_checkpointing=True \
    attention=dot_product \
    max_num_images_per_example=1 \
    dataset_type=hf profiler=xplane