[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_qwen3_demo.ipynb)

# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo


## Overview

This notebook performs SFT training and evaluation workflow on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k).
The primary goal is to demonstrate the end-to-end process of:
1. Pre-SFT Evaluation: Calcuating baseline accuracy for the model before training.
2. SFT Training: Fine-tune the model using MaxText & Tunix SFT trainer.
3. Post-SFT Evaluation: Re-running the evaluation loop after training to measure the performance gain achieved by SFT.

This notebook can run on the **public TPU v5e-1**.

## Prerequisites

### Change Runtime Type

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Change runtime type** from the dropdown menu.
4.  Select **v5e-1 TPU** as the **Hardware accelerator**.
5. Click on **Save**.


### Get Your Hugging Face Token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:
    *   [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need to paste it in the next step.

**Follow these steps to store your token:**

1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).

2. Click **"+ Add new secret"**.

3. Set the Name as **HF_TOKEN**.

4. Paste your token into the Value field.

5. Ensure the Notebook access toggle is turned On.

## Installation: MaxText & Other Dependencies

In [None]:
!git clone https://github.com/AI-Hypercomputer/maxtext.git
%cd /content/maxtext

# Install uv, a fast Python package installer
!pip install uv

# Install MaxText and its dependencies
!uv pip install -e .[tpu] --resolution=lowest

# Install vLLM
!VLLM_TARGET_DEVICE="tpu" pip install --no-cache-dir --pre \
    --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
    --extra-index-url https://pypi.org/simple/ \
    --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
    --extra-index-url https://download.pytorch.org/whl/nightly/cpu \
    --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html \
    --find-links https://storage.googleapis.com/libtpu-releases/index.html \
    --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
    --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
    vllm==0.11.1rc1.dev292+g1b86bd8e1.tpu
!pip install --no-cache-dir --pre \
    --index-url https://us-python.pkg.dev/cloud-tpu-images/maxtext-rl/simple/ \
    --extra-index-url https://pypi.org/simple/ \
    --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ \
    --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html \
    tpu-commons==0.1.2

### Restart Session
To apply certain changes, you need to restart the session.

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Restart session** from the dropdown menu.

You will be asked to confirm the action in a pop-up dialog. Click on **Yes**.

## Imports

In [None]:
import os
import transformers

import MaxText
from MaxText import pyconfig
from MaxText.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from MaxText.sft import sft_trainer

from tunix.rl.rollout.vllm_rollout import VllmRollout

from datetime import datetime
from flax import nnx
from huggingface_hub import login
try:
    from google.colab import userdata
except Exception:
    pass

MAXTEXT_REPO_ROOT=os.path.dirname(MaxText.__file__)
print(f"MaxText installation path: {MAXTEXT_REPO_ROOT}")

In [None]:
try:
    HF_TOKEN = userdata.get("HF_TOKEN")
except Exception:
    HF_TOKEN=os.environ.get("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("Authenticated with Hugging Face successfully!")
else:
    print("Authentication failed: HF_TOKEN is not set.")

## Model Configurations

In [None]:
MODEL_NAME = "qwen3-0.6b"
TOKENIZER_PATH = "Qwen/Qwen3-0.6B"
tokenizer = transformers.AutoTokenizer.from_pretrained(
  TOKENIZER_PATH,
  token=HF_TOKEN,
)

# This colab will download the checkpoint from HF and store at `MODEL_CHECKPOINT_PATH`
MODEL_CHECKPOINT_PATH = f"{MAXTEXT_REPO_ROOT}/qwen_checkpoint"

RUN_NAME = datetime.now().strftime("%Y-%m-%d-%H-%m-%S")

# This is the directory where the fine-tuned model checkpoint will be saved
BASE_OUTPUT_DIRECTORY = f"/tmp/maxtext_qwen06"

## Download Qwen3-0.6B Model Checkpoint from Hugging Face

In [None]:
%%capture
!python3 -m MaxText.utils.ckpt_conversion.to_maxtext \
    $MAXTEXT_REPO_ROOT/configs/base.yml \
    model_name=$MODEL_NAME \
    base_output_directory=$MODEL_CHECKPOINT_PATH \
    hf_access_token=$HF_TOKEN \
    use_multimodal=false \
    scan_layers=true

In [None]:
print(f"Model checkpoint can be found at: {MODEL_CHECKPOINT_PATH}/0/items")

## Dataset Configurations

In [None]:
DATASET_NAME = "openai/gsm8k"
TRAIN_DATA_SPLIT = "train"
TEST_DATA_SPLIT = "test"
HF_DATA_DIR = "main"
TRAIN_DATA_COLUMNS = ["question", "answer"]
CHAT_TEMPLATE_PATH = f"{MAXTEXT_REPO_ROOT}/examples/chat_templates/math_qa.json"
NUM_TEST_SAMPLES = 20 # Total number of samples to test
BATCH_SIZE = 1 # Number of test samples to process in a batch

## MaxText Configurations

In [None]:
%%capture
config = pyconfig.initialize([
    "",
    f"{MAXTEXT_REPO_ROOT}/configs/sft.yml",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items",
    f"model_name={MODEL_NAME}",
    f"hf_access_token={HF_TOKEN}",
    f"base_output_directory={BASE_OUTPUT_DIRECTORY}",
    f"run_name={RUN_NAME}",
    f"tokenizer_path={TOKENIZER_PATH}",
    f"hf_path={DATASET_NAME}",
    f"train_split={TRAIN_DATA_SPLIT}",
    f"hf_data_dir={HF_DATA_DIR}",
    f"train_data_columns={TRAIN_DATA_COLUMNS}",
    "steps=500",
    "per_device_batch_size=1",
    "max_target_length=1024",
    "learning_rate=3e-6",
    "weight_dtype=bfloat16",
    "dtype=bfloat16",
    f"chat_template_path={CHAT_TEMPLATE_PATH}",
])

## Initial Setup & Data Preparation

### Create Test Dataset

In [None]:
test_dataset = get_test_dataset(config, tokenizer)
test_dataset = test_dataset[:NUM_TEST_SAMPLES]
test_dataset = test_dataset.to_iter_dataset().batch(BATCH_SIZE, drop_remainder=True)
TOTAL_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE
print(f"Processing {NUM_TEST_SAMPLES} examples with a batch size of {BATCH_SIZE}. This will result in {TOTAL_BATCHES} total batches for the test run.")

### Create SFT Trainer State

In [None]:
trainer, mesh = sft_trainer.setup_trainer_state(config)

### Create vLLM Rollout

In [None]:
tunix_model = TunixMaxTextAdapter(trainer.model)
vllm_rollout = VllmRollout(
    model=tunix_model,
    tokenizer=tokenizer,
    cache_config_or_size=1280,
    mesh=mesh,
    model_version=TOKENIZER_PATH,
    hbm_utilization=0.8,
    init_with_random_weights=True,
    tpu_backend_type="jax",
)

## Evaluation before SFT Training

In [None]:
print("Running Pre-SFT Evaluation...")
score = evaluate_model(test_dataset, vllm_rollout, debug=False)

In [None]:
print("========================= Score for PRE-SFT Evaluation =========================")
print(f"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%")
print(f"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%")
print(f"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%")

## SFT Training

In [None]:
print("Starting SFT Training...")
trainer = sft_trainer.train_model(config, trainer, mesh)
print("SFT Training Complete!")

## Evaluation after SFT Training

In [None]:
print("Running Post-SFT Evaluation...")
model = TunixMaxTextAdapter(trainer.model)
state = nnx.state(model)
vllm_rollout.update_params(state)
score = evaluate_model(test_dataset, vllm_rollout, debug=False)

In [None]:
print("========================= Score for POST-SFT Evaluation =========================")
print(f"Percentage of test samples where the model produced the correct numerical answer: {score['correct']}%")
print(f"Percentage of test samples where the model produced the numerical answer within 10%: {score['partially_correct']}%")
print(f"Percentage of test samples where the model's output adheres to the expected structure: {score['correct_format']}%")