Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ Our goal is to provide a variety of models (dimension “a”) and techniques (d

Check out these getting started guides:

* [SFT](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/sft.md) (Supervised Fine Tuning)
* [GRPO](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/grpo.md) (Group Relative Policy Optimization)
* Supervised Fine Tuning (SFT)
* [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft.html)
* [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft_on_multi_host.html)
* Group Relative Policy Optimization (GRPO)
* [GRPO on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html)
* [GRPO on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo_with_pathways.html)

### Model library

Expand Down
1 change: 1 addition & 0 deletions docs/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ You can also find other examples in the [MaxText repository](https://github.com/
tutorials/first_run.md
tutorials/pretraining.md
tutorials/full_finetuning.md
tutorials/how_to_run_colabs.md
tutorials/grpo.md
tutorials/sft.md
tutorials/sft_on_multi_host.md
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ This guide provides comprehensive instructions for setting up Jupyter Lab on TPU
- [Method 3: Colab + Local Jupyter Lab Hybrid](#method-3-colab--local-jupyter-lab-hybrid)
- [Available Examples](#available-examples)
- [Common Pitfalls & Debugging](#common-pitfalls--debugging)
- [Support & Resources](#support--resources)
- [Support & Resources](#support-and-resources)
- [Contributing](#contributing)

## Prerequisites
Expand Down Expand Up @@ -48,7 +48,7 @@ Upload notebooks or mount your GitHub repo
> **Note:** In Colab, the repo root will usually be `/content/maxtext`

**Example:**
```python
```bash
!git clone https://github.com/AI-Hypercomputer/maxtext.git
%cd maxtext
```
Expand Down Expand Up @@ -122,8 +122,8 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla

### Supervised Fine-Tuning (SFT)

- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B with Hugging Face ultrachat_200k dataset
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B with Hugging Face ultrachat_200k dataset
- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k)
- **`sft_llama3_demo.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)

### GRPO Training

Expand Down
122 changes: 70 additions & 52 deletions docs/tutorials/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,81 +14,99 @@
limitations under the License.
-->

# Try SFT
# Supervised Fine-Tuning (SFT) on Single-Host TPUs
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.

This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B model on the [HuggingFaceH4/ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset using SFT. If you wish to use a different dataset, you can [update the dataset configurations](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/sft.yml).
This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT.

We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT.

In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started!

## Setup virtual environment

### Create a Python3.12 virtual environment
## Install dependencies
```sh
bash tools/setup/setup.sh
# 1. Clone the repository
git clone https://github.com/AI-Hypercomputer/maxtext.git
cd maxtext

# 2. Create virtual environment
export VENV_NAME=<your virtual env name> # e.g., maxtext_venv
pip install uv
uv venv --python 3.12 --seed $VENV_NAME
source $VENV_NAME/bin/activate

# 3. Install dependencies in editable mode
uv pip install -e .[tpu] --resolution=lowest
install_maxtext_github_deps
```

### Activate virtual environment
```
# Replace with your virtual environment name if not using this default name
venv_name="maxtext_venv"
source ~/$venv_name/bin/activate
```

### Install MaxText dependencies
```
bash tools/setup/setup.sh
## Setup environment variables
Set the following environment variables before running SFT.
```sh
# -- Model configuration --
export PRE_TRAINED_MODEL=<model name> # e.g., 'llama3.1-8b'
export PRE_TRAINED_MODEL_TOKENIZER=<tokenizer path> # e.g., 'meta-llama/Llama-3.1-8B-Instruct'
export HF_TOKEN=<Hugging Face access token>

# -- MaxText configuration --
export BASE_OUTPUT_DIRECTORY=<output directory to store run logs> # e.g., gs://my-bucket/my-output-directory
export RUN_NAME=<name for this run> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
export STEPS=<number of fine-tuning steps to run> # e.g., 1000
export PER_DEVICE_BATCH_SIZE=<batch size per device> # e.g., 1

# -- Dataset configuration --
export DATASET_NAME=<Hugging Face dataset name> # e.g., HuggingFaceH4/ultrachat_200k
export TRAIN_SPLIT=<data split for train> # e.g., train_sft
export TRAIN_DATA_COLUMNS=<data columns to train on> # e.g., ['messages']
```

## Run SFT
There are two scenarios supported for running SFT:
1. **Run SFT on Hugging Face checkpoint**
Download the checkpoint directly from Hugging Face and fine-tune it using SFT.

2. **Run SFT on MaxText checkpoint**
Use a checkpoint generated by MaxText and fine-tune it using SFT.

Choose the scenario that matches your workflow and follow the corresponding instructions below.
## Get your model checkpoint
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.

### Run SFT on Hugging Face checkpoint
* The script will first convert a Hugging Face checkpoint to a MaxText checkpoint.
* It then runs SFT on this converted checkpoint.
### Option 1: Using an existing MaxText checkpoint
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

#### Setup environment variables
```sh
export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
```
export HF_TOKEN=<Hugging Face access token>

export BASE_OUTPUT_DIRECTORY=<output directory to store run logs>
### Option 2: Converting a Hugging Face checkpoint
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.

export STEPS=<number of fine-tuning steps to run>
1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.

export PER_DEVICE_BATCH_SIZE=1
```sh
export PRE_TRAINED_MODEL_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint/0/items
```

Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh):
```
bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh
```
2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).

### Run SFT on MaxText checkpoint
* The script directly runs SFT on MaxText checkpoint.
```sh
pip install torch # Ensure torch is installed for the conversion script

#### Setup environment variables
python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
model_name=${PRE_TRAINED_MODEL} \
hf_access_token=${HF_TOKEN} \
base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint \
scan_layers=True
```
export HF_TOKEN=<Hugging Face access token>

export BASE_OUTPUT_DIRECTORY=<output directory to store run logs>

export STEPS=<number of fine-tuning steps to run>

export PER_DEVICE_BATCH_SIZE=1
## Run SFT on Hugging Face Dataset
Now you are ready to run SFT using the following command:

export PRE_TRAINED_MODEL_CKPT_PATH=<gcs path for MaxText checkpoint>
```

Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh):
```
bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh
```sh
python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \
run_name=${RUN_NAME} \
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
model_name=${PRE_TRAINED_MODEL} \
load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \
hf_access_token=${HF_TOKEN} \
tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
steps=${STEPS} \
hf_path=${DATASET_NAME} \
train_split=${TRAIN_SPLIT} \
train_data_columns=${TRAIN_DATA_COLUMNS} \
profiler=xplane
```
Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`.
79 changes: 65 additions & 14 deletions docs/tutorials/sft_on_multi_host.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
# limitations under the License.
-->

# Supervised Fine-Tuning (SFT) with Deepseek-V3 model
This guide provides step by step instructions to run SFT with Deepseek-V3 model on TPU v6e-256. Deepseek-V3 is a Mixture-of-Experts (MoE) language model with 671B parameters.
# Supervised Fine-Tuning (SFT) on Multi-Host TPUs
Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks.

This tutorial demonstrates step-by-step instructions for setting up the multi-host TPU environment and then training the model on the Hugging Face dataset using SFT. In this tutorial we use a multi-host TPU such as `v6e-256`.

We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT.

Let's get started!

## 1. Build and upload MaxText Docker image
This section guides you through cloning the MaxText repository, building MaxText Docker image with dependencies, and uploading the docker image to your project's Artifact Registry.
Expand All @@ -28,7 +34,7 @@ cd maxtext

### 1.2. Build MaxText Docker image
```bash
bash dependencies/scripts/docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training
```
This creates a local Docker image named `maxtext_base_image`.

Expand All @@ -44,7 +50,7 @@ The `docker_upload_runner.sh` script uploads your Docker image to Artifact Regis
Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation-via-pip).

## 3. Create GKE cluster
If you don't already have a GKE cluster with a `v6e-256` TPU slice available, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create).
If you don't already have a GKE cluster, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create).

## 4. Environment configuration
```bash
Expand All @@ -54,20 +60,63 @@ export CLUSTER_NAME=<Name of GKE Cluster>
export ZONE=<GKE Cluster Zone>

# -- Workload Configuration --
export WORKLOAD_NAME="sft-$(date +%Y-%m-%d-%H-%M-%S)" # Or your desired workload name
export TPU_TYPE=v6e-256
export WORKLOAD_NAME=<Name of Workload> # e.g., $(date +%Y-%m-%d-%H-%M-%S)
export TPU_TYPE=<TPU Type> # e.g., v6e-256
export TPU_SLICE=1
export DOCKER_IMAGE="gcr.io/${PROJECT}/${DOCKER_IMAGE_NAME}"

# -- MaxText Configuration --
export OUTPUT_PATH=<GCS Bucket Path for output/logs>
export STEPS=100 # Number of fine-tuning steps to run
export HF_TOKEN=<Hugging Face access token>
export MODEL_CHECKPOINT_PATH=<GCS path to model checkpoint>
export OUTPUT_PATH=<GCS Path for Output/Logs> # e.g., gs://my-bucket/my-output-directory
export STEPS=<Fine-Tuning Steps> # e.g., 1000
export HF_TOKEN=<Hugging Face Access Token>

# -- Model Configuration --
export MODEL_NAME=<Model Name> # e.g., deepseek3-671b
export TOKENIZER_PATH=<Model Tokenizer> # e.g., deepseek-ai/DeepSeek-V3

# -- Dataset configuration --
export DATASET_NAME=<Hugging Face Dataset Name> # e.g., HuggingFaceH4/ultrachat_200k
export TRAIN_SPLIT=<Data Split for Train> # e.g., train_sft
export TRAIN_DATA_COLUMNS=<Data Columns to Train on> # e.g., ['messages']
```

## 5. Get MaxText model checkpoint
This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint.

### Option 1: Using an existing MaxText checkpoint
If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section.

```bash
export MODEL_CHECKPOINT_PATH=<gcs path for MaxText checkpoint> # e.g., gs://my-bucket/my-model-checkpoint/0/items
```

## 5. Submit workload on GKE cluster
This section provides the command to run SFT with Deepseek-v3 model on a v6e-256 GKE cluster.
### Option 2: Converting a Hugging Face checkpoint
If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible.

1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved.

```bash
export MODEL_CHECKPOINT_PATH=${OUTPUT_PATH}/${WORKLOAD_NAME}/maxtext-checkpoint/0/items
```

2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py).

```bash
xpk workload create \
--cluster=${CLUSTER_NAME} \
--project=${PROJECT} \
--zone=${ZONE} \
--docker-image=${DOCKER_IMAGE} \
--workload=ckpt-${WORKLOAD_NAME} \
--tpu-type=${TPU_TYPE} \
--num-slices=${TPU_SLICE} \
--command "python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml model_name=$MODEL_NAME hf_access_token=$HF_TOKEN base_output_directory=$OUTPUT_PATH/$WORKLOAD_NAME/maxtext-checkpoint scan_layers=True"
```

## 6. Submit workload on GKE cluster
This section provides the command to run SFT on a GKE cluster.

### 6.1. SFT with Multi-Controller JAX (McJAX)
```bash
xpk workload create \
--cluster=${CLUSTER_NAME} \
Expand All @@ -77,7 +126,9 @@ xpk workload create \
--workload=${WORKLOAD_NAME} \
--tpu-type=${TPU_TYPE} \
--num-slices=${TPU_SLICE} \
--command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=deepseek3-671b load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=deepseek-ai/DeepSeek-V3 per_device_batch_size=1 steps=$STEPS profiler=xplane megablox=False sparse_matmul=False ici_expert_parallelism=16 ici_fsdp_parallelism=16 weight_dtype=bfloat16 dtype=bfloat16 remat_policy=full decoder_layer_input=offload sa_block_q=2048 sa_block_q_dkv=2048 sa_block_q_dq=2048 opt_type=sgd attention=flash capacity_factor=1.0 max_target_length=2048"
--command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane hf_path=$DATASET_NAME train_split=$TRAIN_SPLIT train_data_columns=$TRAIN_DATA_COLUMNS"
```
Once the fine-tuning is completed, you can access your model checkpoint at `${OUTPUT_PATH}/${WORKLOAD_NAME}/checkpoints/${STEPS}/model_params`.
Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`.

### 6.2. SFT with Pathways
Pathways support is coming soon.
Loading