From 3eba64cedd94bc0c564bb30723762365f5c87f82 Mon Sep 17 00:00:00 2001 From: SurbhiJainUSC Date: Mon, 13 Oct 2025 21:23:29 +0000 Subject: [PATCH] Improve SFT documentation --- docs/index.md | 8 +- docs/tutorials.md | 1 + .../tutorials/how_to_run_colabs.md | 8 +- docs/tutorials/sft.md | 122 ++++++++++-------- docs/tutorials/sft_on_multi_host.md | 79 ++++++++++-- 5 files changed, 146 insertions(+), 72 deletions(-) rename src/MaxText/examples/README_how_to_run_examples.md => docs/tutorials/how_to_run_colabs.md (94%) diff --git a/docs/index.md b/docs/index.md index a77c8def3..74c170290 100644 --- a/docs/index.md +++ b/docs/index.md @@ -54,8 +54,12 @@ Our goal is to provide a variety of models (dimension “a”) and techniques (d Check out these getting started guides: -* [SFT](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/sft.md) (Supervised Fine Tuning) -* [GRPO](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/tutorials/grpo.md) (Group Relative Policy Optimization) +* Supervised Fine Tuning (SFT) + * [SFT on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft.html) + * [SFT on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/sft_on_multi_host.html) +* Group Relative Policy Optimization (GRPO) + * [GRPO on Single-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo.html) + * [GRPO on Multi-Host TPUs](https://maxtext.readthedocs.io/en/latest/tutorials/grpo_with_pathways.html) ### Model library diff --git a/docs/tutorials.md b/docs/tutorials.md index 8b88913bd..8987382cb 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -41,6 +41,7 @@ You can also find other examples in the [MaxText repository](https://github.com/ tutorials/first_run.md tutorials/pretraining.md tutorials/full_finetuning.md +tutorials/how_to_run_colabs.md tutorials/grpo.md tutorials/sft.md tutorials/sft_on_multi_host.md diff --git a/src/MaxText/examples/README_how_to_run_examples.md b/docs/tutorials/how_to_run_colabs.md similarity index 94% rename from src/MaxText/examples/README_how_to_run_examples.md rename to docs/tutorials/how_to_run_colabs.md index 7dc5adb5f..4b12adda5 100644 --- a/src/MaxText/examples/README_how_to_run_examples.md +++ b/docs/tutorials/how_to_run_colabs.md @@ -10,7 +10,7 @@ This guide provides comprehensive instructions for setting up Jupyter Lab on TPU - [Method 3: Colab + Local Jupyter Lab Hybrid](#method-3-colab--local-jupyter-lab-hybrid) - [Available Examples](#available-examples) - [Common Pitfalls & Debugging](#common-pitfalls--debugging) -- [Support & Resources](#support--resources) +- [Support & Resources](#support-and-resources) - [Contributing](#contributing) ## Prerequisites @@ -48,7 +48,7 @@ Upload notebooks or mount your GitHub repo > **Note:** In Colab, the repo root will usually be `/content/maxtext` **Example:** -```python +```bash !git clone https://github.com/AI-Hypercomputer/maxtext.git %cd maxtext ``` @@ -122,8 +122,8 @@ Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Colla ### Supervised Fine-Tuning (SFT) -- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B with Hugging Face ultrachat_200k dataset -- **`sft_llama3_demo.ipynb`** → Llama3.1-8B with Hugging Face ultrachat_200k dataset +- **`sft_qwen3_demo.ipynb`** → Qwen3-0.6B SFT training and evaluation on [OpenAI's GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k) +- **`sft_llama3_demo.ipynb`** → Llama3.1-8B SFT training on [Hugging Face ultrachat_200k dataset](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) ### GRPO Training diff --git a/docs/tutorials/sft.md b/docs/tutorials/sft.md index e15941126..be2ca13b1 100644 --- a/docs/tutorials/sft.md +++ b/docs/tutorials/sft.md @@ -14,81 +14,99 @@ limitations under the License. --> -# Try SFT +# Supervised Fine-Tuning (SFT) on Single-Host TPUs Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks. -This tutorial demonstrates step-by-step instructions for setting up the environment and then training the Llama3.1 8B model on the [HuggingFaceH4/ultrachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset using SFT. If you wish to use a different dataset, you can [update the dataset configurations](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/sft.yml). +This tutorial demonstrates step-by-step instructions for setting up the environment and then training the model on a Hugging Face dataset using SFT. We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT. In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started! -## Setup virtual environment - -### Create a Python3.12 virtual environment +## Install dependencies ```sh -bash tools/setup/setup.sh +# 1. Clone the repository +git clone https://github.com/AI-Hypercomputer/maxtext.git +cd maxtext + +# 2. Create virtual environment +export VENV_NAME= # e.g., maxtext_venv +pip install uv +uv venv --python 3.12 --seed $VENV_NAME +source $VENV_NAME/bin/activate + +# 3. Install dependencies in editable mode +uv pip install -e .[tpu] --resolution=lowest +install_maxtext_github_deps ``` -### Activate virtual environment -``` -# Replace with your virtual environment name if not using this default name -venv_name="maxtext_venv" -source ~/$venv_name/bin/activate -``` - -### Install MaxText dependencies -``` -bash tools/setup/setup.sh +## Setup environment variables +Set the following environment variables before running SFT. +```sh +# -- Model configuration -- +export PRE_TRAINED_MODEL= # e.g., 'llama3.1-8b' +export PRE_TRAINED_MODEL_TOKENIZER= # e.g., 'meta-llama/Llama-3.1-8B-Instruct' +export HF_TOKEN= + +# -- MaxText configuration -- +export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory +export RUN_NAME= # e.g., $(date +%Y-%m-%d-%H-%M-%S) +export STEPS= # e.g., 1000 +export PER_DEVICE_BATCH_SIZE= # e.g., 1 + +# -- Dataset configuration -- +export DATASET_NAME= # e.g., HuggingFaceH4/ultrachat_200k +export TRAIN_SPLIT= # e.g., train_sft +export TRAIN_DATA_COLUMNS= # e.g., ['messages'] ``` -## Run SFT -There are two scenarios supported for running SFT: -1. **Run SFT on Hugging Face checkpoint** - Download the checkpoint directly from Hugging Face and fine-tune it using SFT. - -2. **Run SFT on MaxText checkpoint** - Use a checkpoint generated by MaxText and fine-tune it using SFT. - -Choose the scenario that matches your workflow and follow the corresponding instructions below. +## Get your model checkpoint +This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint. -### Run SFT on Hugging Face checkpoint -* The script will first convert a Hugging Face checkpoint to a MaxText checkpoint. -* It then runs SFT on this converted checkpoint. +### Option 1: Using an existing MaxText checkpoint +If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. -#### Setup environment variables +```sh +export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items ``` - export HF_TOKEN= - export BASE_OUTPUT_DIRECTORY= +### Option 2: Converting a Hugging Face checkpoint +If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible. - export STEPS= +1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved. - export PER_DEVICE_BATCH_SIZE=1 +```sh +export PRE_TRAINED_MODEL_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint/0/items ``` -Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh): -``` -bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh -``` +2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py). -### Run SFT on MaxText checkpoint -* The script directly runs SFT on MaxText checkpoint. +```sh +pip install torch # Ensure torch is installed for the conversion script -#### Setup environment variables +python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \ + model_name=${PRE_TRAINED_MODEL} \ + hf_access_token=${HF_TOKEN} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/maxtext-checkpoint \ + scan_layers=True ``` - export HF_TOKEN= - - export BASE_OUTPUT_DIRECTORY= - - export STEPS= - export PER_DEVICE_BATCH_SIZE=1 +## Run SFT on Hugging Face Dataset +Now you are ready to run SFT using the following command: - export PRE_TRAINED_MODEL_CKPT_PATH= -``` - -Finally, run the [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/end_to_end/tpu/llama3.1/8b/run_sft.sh): -``` -bash ~/maxtext/end_to_end/tpu/llama3.1/8b/run_sft.sh +```sh +python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \ + run_name=${RUN_NAME} \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + model_name=${PRE_TRAINED_MODEL} \ + load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \ + hf_access_token=${HF_TOKEN} \ + tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + steps=${STEPS} \ + hf_path=${DATASET_NAME} \ + train_split=${TRAIN_SPLIT} \ + train_data_columns=${TRAIN_DATA_COLUMNS} \ + profiler=xplane ``` +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. diff --git a/docs/tutorials/sft_on_multi_host.md b/docs/tutorials/sft_on_multi_host.md index dbff94365..746ca1471 100644 --- a/docs/tutorials/sft_on_multi_host.md +++ b/docs/tutorials/sft_on_multi_host.md @@ -14,8 +14,14 @@ # limitations under the License. --> -# Supervised Fine-Tuning (SFT) with Deepseek-V3 model -This guide provides step by step instructions to run SFT with Deepseek-V3 model on TPU v6e-256. Deepseek-V3 is a Mixture-of-Experts (MoE) language model with 671B parameters. +# Supervised Fine-Tuning (SFT) on Multi-Host TPUs +Supervised fine-tuning (SFT) is a process where a pre-trained large language model is fine-tuned on a labeled dataset to adapt the model to perform better on specific tasks. + +This tutorial demonstrates step-by-step instructions for setting up the multi-host TPU environment and then training the model on the Hugging Face dataset using SFT. In this tutorial we use a multi-host TPU such as `v6e-256`. + +We use [Tunix](https://github.com/google/tunix), a JAX-based library designed for post-training tasks, to perform SFT. + +Let's get started! ## 1. Build and upload MaxText Docker image This section guides you through cloning the MaxText repository, building MaxText Docker image with dependencies, and uploading the docker image to your project's Artifact Registry. @@ -28,7 +34,7 @@ cd maxtext ### 1.2. Build MaxText Docker image ```bash -bash dependencies/scripts/docker_build_dependency_image.sh MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest +bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training ``` This creates a local Docker image named `maxtext_base_image`. @@ -44,7 +50,7 @@ The `docker_upload_runner.sh` script uploads your Docker image to Artifact Regis Install XPK by following the instructions in the [official documentation](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#installation-via-pip). ## 3. Create GKE cluster -If you don't already have a GKE cluster with a `v6e-256` TPU slice available, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create). +If you don't already have a GKE cluster, create one by following the [XPK cluster creation guide](https://github.com/AI-Hypercomputer/xpk?tab=readme-ov-file#cluster-create). ## 4. Environment configuration ```bash @@ -54,20 +60,63 @@ export CLUSTER_NAME= export ZONE= # -- Workload Configuration -- -export WORKLOAD_NAME="sft-$(date +%Y-%m-%d-%H-%M-%S)" # Or your desired workload name -export TPU_TYPE=v6e-256 +export WORKLOAD_NAME= # e.g., $(date +%Y-%m-%d-%H-%M-%S) +export TPU_TYPE= # e.g., v6e-256 export TPU_SLICE=1 export DOCKER_IMAGE="gcr.io/${PROJECT}/${DOCKER_IMAGE_NAME}" # -- MaxText Configuration -- -export OUTPUT_PATH= -export STEPS=100 # Number of fine-tuning steps to run -export HF_TOKEN= -export MODEL_CHECKPOINT_PATH= +export OUTPUT_PATH= # e.g., gs://my-bucket/my-output-directory +export STEPS= # e.g., 1000 +export HF_TOKEN= + +# -- Model Configuration -- +export MODEL_NAME= # e.g., deepseek3-671b +export TOKENIZER_PATH= # e.g., deepseek-ai/DeepSeek-V3 + +# -- Dataset configuration -- +export DATASET_NAME= # e.g., HuggingFaceH4/ultrachat_200k +export TRAIN_SPLIT= # e.g., train_sft +export TRAIN_DATA_COLUMNS= # e.g., ['messages'] +``` + +## 5. Get MaxText model checkpoint +This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint. + +### Option 1: Using an existing MaxText checkpoint +If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. + +```bash +export MODEL_CHECKPOINT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items ``` -## 5. Submit workload on GKE cluster -This section provides the command to run SFT with Deepseek-v3 model on a v6e-256 GKE cluster. +### Option 2: Converting a Hugging Face checkpoint +If your model checkpoint is from Hugging Face, you need to run a conversion script to make it MaxText-compatible. + +1. **Set the Output Path:** First, define where the new MaxText checkpoint will be saved. + +```bash +export MODEL_CHECKPOINT_PATH=${OUTPUT_PATH}/${WORKLOAD_NAME}/maxtext-checkpoint/0/items +``` + +2. **Run the Conversion Script:** Execute the following command that downloads the specified Hugging Face model and converts its weights into the MaxText format. The conversion script only supports official versions of models from Hugging Face. To see the specific models and versions currently supported for conversion, please refer to the `HF_IDS` dictionary in the MaxText utility file [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py). + +```bash +xpk workload create \ +--cluster=${CLUSTER_NAME} \ +--project=${PROJECT} \ +--zone=${ZONE} \ +--docker-image=${DOCKER_IMAGE} \ +--workload=ckpt-${WORKLOAD_NAME} \ +--tpu-type=${TPU_TYPE} \ +--num-slices=${TPU_SLICE} \ +--command "python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml model_name=$MODEL_NAME hf_access_token=$HF_TOKEN base_output_directory=$OUTPUT_PATH/$WORKLOAD_NAME/maxtext-checkpoint scan_layers=True" +``` + +## 6. Submit workload on GKE cluster +This section provides the command to run SFT on a GKE cluster. + +### 6.1. SFT with Multi-Controller JAX (McJAX) ```bash xpk workload create \ --cluster=${CLUSTER_NAME} \ @@ -77,7 +126,9 @@ xpk workload create \ --workload=${WORKLOAD_NAME} \ --tpu-type=${TPU_TYPE} \ --num-slices=${TPU_SLICE} \ ---command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=deepseek3-671b load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=deepseek-ai/DeepSeek-V3 per_device_batch_size=1 steps=$STEPS profiler=xplane megablox=False sparse_matmul=False ici_expert_parallelism=16 ici_fsdp_parallelism=16 weight_dtype=bfloat16 dtype=bfloat16 remat_policy=full decoder_layer_input=offload sa_block_q=2048 sa_block_q_dkv=2048 sa_block_q_dq=2048 opt_type=sgd attention=flash capacity_factor=1.0 max_target_length=2048" +--command "python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane hf_path=$DATASET_NAME train_split=$TRAIN_SPLIT train_data_columns=$TRAIN_DATA_COLUMNS" ``` -Once the fine-tuning is completed, you can access your model checkpoint at `${OUTPUT_PATH}/${WORKLOAD_NAME}/checkpoints/${STEPS}/model_params`. +Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. +### 6.2. SFT with Pathways +Pathways support is coming soon.