[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_qwen3_demo.ipynb)

# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo


## Overview

This notebook can run on the **public TPU v5e-1**

This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Qwen3-0.6B using the Hugging Face ultrachat_200k dataset with Tunix integration for efficient training.

## Dataset Overview

**Dataset Link:** https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k

### Dataset Information:
- **Name:** HuggingFaceH4/ultrachat_200k
- **Type:** Supervised Fine-Tuning dataset
- **Size:** ~200k conversations
- **Format:** Chat conversations with human-AI pairs
- **Splits:** train_sft, test_sft
- **Data columns:** ['messages']

### Dataset Structure:
Each example contains a 'messages' field with:
- **role:** 'user' or 'assistant'
- **content:** The actual message text

### Example data format:
```json
{
  "messages": [
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."}
  ]
}
```


## Prerequisites

### Change Runtime Type

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Change runtime type** from the dropdown menu.
4.  Select **v5e-1 TPU** as the **Hardware accelerator**.
5. Click on **Save**.


### Get Your Hugging Face Token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:
    *   [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need to paste it in the next step.

**Follow these steps to store your token:**

1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).

2. Click **"+ Add new secret"**.

3. Set the Name as **HF_TOKEN**.

4. Paste your token into the Value field.

5. Ensure the Notebook access toggle is turned On.

In [None]:
# Install maxtext and dependencies
# 1. Install uv, a fast Python package installer
!pip install uv

# 2. Install MaxText and its dependencies
!uv pip install maxtext --resolution=lowest
!install_maxtext_github_deps

### Restart Session
To apply certain changes, you need to restart the session.

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Restart session** from the dropdown menu.

You will be asked to confirm the action in a pop-up dialog. Click on **Yes**.

## Set up the maxtext environment

In [None]:
# MaxText
try:
  import MaxText
  from MaxText import pyconfig
  from MaxText.sft.sft_trainer import train as sft_train
  import os

  MAXTEXT_AVAILABLE = True
  print("✓ MaxText imports successful")
except ImportError as e:
  print(f"⚠️ MaxText not available: {e}")
  MAXTEXT_AVAILABLE = False

# Hugging Face Authentication Setup
from huggingface_hub import login

# use google colab userdata to get the HF token
from google.colab import userdata


MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)
print(f"MaxText installation path: {MAXTEXT_REPO_ROOT}")

In [None]:
## Set the model, checkpoint path and output directory
MODEL_NAME = "qwen3-0.6b"
# If you do not have the checkpoint this colab will download the checkpoint from HF and store at `"{MAXTEXT_REPO_ROOT}/qwen_checkpoint\"`
MODEL_CHECKPOINT_PATH = f"{MAXTEXT_REPO_ROOT}/qwen_checkpoint"

# This is the directory where the fine-tuned model will be saved
# You can change it to any path you want
BASE_OUTPUT_DIRECTORY = "/tmp/out/maxtext_qwen06"

# Set your Hugging Face token as a secret in the Google Colab
HF_TOKEN = userdata.get("HF_TOKEN")
# HF_TOKEN = "your_actual_token_here" - use this for a private jupyter lab
login(token=HF_TOKEN)

## Get model checkpoint

In [None]:
# This is the command to convert the HF model to the MaxText format
if not os.path.exists(MODEL_CHECKPOINT_PATH):
  !python3 -m MaxText.utils.ckpt_conversion.to_maxtext \
        $MAXTEXT_REPO_ROOT/configs/base.yml \
        model_name=$MODEL_NAME \
        base_output_directory=$MODEL_CHECKPOINT_PATH \
        hf_access_token=$HF_TOKEN \
        use_multimodal=false \
        scan_layers=false

## Initialize jax and set model config

In [None]:
# this is the code to initialize jax if it's not initialized in the cell above
import jax

if not jax.distributed.is_initialized():
  jax.distributed.initialize()
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

In [None]:
# Fixed configuration setup for Qwen-0.6B on small TPU
if MAXTEXT_AVAILABLE:
  config_argv = [
    "",
    f"{MAXTEXT_REPO_ROOT}/configs/sft.yml",  # base SFT config
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items/",  # Load pre-trained weights!, replace with your checkpoint path
    f"model_name={MODEL_NAME}",
    "steps=20",  # very short run for testing
    "per_device_batch_size=1",  # minimal to avoid OOM
    "max_target_length=1024",
    "learning_rate=2.0e-5",  # safe small LR
    "eval_steps=5",
    "weight_dtype=bfloat16",
    "dtype=bfloat16",
    "hf_path=HuggingFaceH4/ultrachat_200k",  # HuggingFace dataset/model if needed
    f"hf_access_token={HF_TOKEN}",
    f"base_output_directory={BASE_OUTPUT_DIRECTORY}",
    "run_name=sft_qwen0.6b_test",
    "tokenizer_path=Qwen/Qwen3-0.6B",  # Qwen tokenizer
    "eval_interval=10",
    "profiler=xplane",
  ]

  # Initialize configuration using MaxText's pyconfig
  config = pyconfig.initialize(config_argv)

  print("✓ Fixed configuration loaded:")
  print(f"  - Model: {config.model_name}")
  print(f"  - Dataset: {config.hf_path}")
  print(f"  - Steps: {config.steps}")
  print(f"  - Use SFT: {config.use_sft}")
  print(f"  - Learning Rate: {config.learning_rate}")
else:
  print("MaxText not available - cannot load configuration")

## Train the model, save the tuned model 

In [None]:
#  Execute the training using MaxText SFT trainer's train() function
if MAXTEXT_AVAILABLE:
  print("=" * 60)
  print("EXECUTING ACTUAL TRAINING")
  print("=" * 60)

  sft_train(config)

print("Training complete!")
print("Model saved at: ", BASE_OUTPUT_DIRECTORY)