[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/maxtext/examples/sft_llama3_demo.ipynb)

# Llama3.1-8B-Instruct Supervised Fine-Tuning (SFT) Demo


## Overview

This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with MaxText and Tunix integration for efficient training.

This notebook can run on **TPU v6e-8** or **v5p-8**.

## Prerequisites

### Change Runtime Type (only if running on Google Colab)

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Change runtime type** from the dropdown menu.
4.  Select **v6e-8** or **v5p-8 TPU** as the **Hardware accelerator**.
5. Click on **Save**.

### Get Your Hugging Face Token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:
    *   [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need this in the later steps.

**Follow these steps to store your token (only if running on Google Colab):**

1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).

2. Click **"+ Add new secret"**.

3. Set the Name as **HF_TOKEN**.

4. Paste your token into the Value field.

5. Ensure the Notebook access toggle is turned On.

In [None]:
try:
  from google.colab import userdata
  print("Running the notebook on Google Colab")
  IN_COLAB = True
except ImportError:
    print("Running the notebook on Visual Studio or JupyterLab")
    IN_COLAB = False

### Installation: MaxText & Other Dependencies

**‚ö†Ô∏è Note:** The installation process in following cell may take a few minutes to complete. Please be patient.

In [None]:
if IN_COLAB:
    !git clone https://github.com/AI-Hypercomputer/maxtext.git
    %cd /content/maxtext

    # Install uv, a fast Python package installer
    !pip install uv

    # Install MaxText and its dependencies
    !uv pip install -e .[tpu] --resolution=lowest
    !python3 -m MaxText.install_maxtext_extra_deps

### Restart Session (only if running on Google Colab)
To apply certain changes, you need to restart the session.

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Restart session** from the dropdown menu.

You will be asked to confirm the action in a pop-up dialog. Click on **Yes**.

## Environment Setup

In [None]:
import datetime
import os
from maxtext.configs import pyconfig
from maxtext.utils.globals import MAXTEXT_PKG_DIR
from maxtext.trainers.post_train.sft import train_sft
import jax
from huggingface_hub import login


print(f"MaxText installation path: {MAXTEXT_PKG_DIR}")

In [None]:
if not jax.distributed.is_initialized():
  jax.distributed.initialize()
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

In [None]:
if IN_COLAB:
    HF_TOKEN = userdata.get("HF_TOKEN")
else:
    HF_TOKEN = os.environ.get("HF_TOKEN", "")

# If not found in the environment, prompt the user for input securely
# getpass function ensures the token is hidden while you type
if not HF_TOKEN:
    from getpass import getpass
    HF_TOKEN = getpass("Hugging Face token not found in environment. Please enter it here: ")

if HF_TOKEN:
    login(token=HF_TOKEN)
    print("Authenticated with Hugging Face successfully!")
else:
    print("Authentication failed: Hugging Face token is not set.")

## Model Configurations

In [None]:
MODEL_NAME = "llama3.1-8b"
TOKENIZER_PATH = "meta-llama/Llama-3.1-8B-Instruct"

# set the path to the model checkpoint or leave empty to download from HuggingFace
MODEL_CHECKPOINT_PATH = ""
if not MODEL_CHECKPOINT_PATH:
   MODEL_CHECKPOINT_PATH = f"{MAXTEXT_PKG_DIR}/llama_checkpoint"
   print("Model checkpoint will be downloaded from HuggingFace at: ",  MODEL_CHECKPOINT_PATH)
   print("Set MODEL_CHECKPOINT_PATH if you do not wish to download the checkpoint.")

BASE_OUTPUT_DIRECTORY = f"{MAXTEXT_PKG_DIR}/sft_llama3_output"
RUN_NAME = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

## Download Llama3.1-8B Model Checkpoint from Hugging Face

In [None]:
if not os.path.exists(MODEL_CHECKPOINT_PATH):
    import subprocess
    import sys

    # Install torch for the conversion script
    print("Installing torch...")
    subprocess.run(
        [
            sys.executable, "-m", "pip", "install",
            "torch", "--index-url", "https://download.pytorch.org/whl/cpu"
        ],
        check=True
    )

    # Run checkpoint conversion with environment variables
    print("Converting checkpoint from HuggingFace...")
    env = os.environ.copy()
    env["JAX_PLATFORMS"] = "cpu"
    # env["PYTHONPATH"] = MAXTEXT_PKG_DIR

    subprocess.run(
        [
            sys.executable,
            "-m", "maxtext.checkpoint_conversion.to_maxtext",
            f"{MAXTEXT_PKG_DIR}/configs/base.yml",
            f"model_name={MODEL_NAME}",
            f"base_output_directory={MODEL_CHECKPOINT_PATH}",
            f"hf_access_token={HF_TOKEN}",
            "use_multimodal=false",
            "scan_layers=true",
            "skip_jax_distributed_system=True",
        ],
        check=True,
        env=env
    )

if not os.path.exists(MODEL_CHECKPOINT_PATH):
    raise ValueError("Model checkpoint conversion failed. Check the logs above.")

## MaxText Configurations

In [None]:
# Load configuration for SFT training
config_argv = [
    "",
    f"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml",
    f"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items",
    f"model_name={MODEL_NAME}",
    "steps=100",
    "per_device_batch_size=1",
    "max_target_length=1024",
    "learning_rate=2.0e-5",
    "weight_dtype=bfloat16",
    "dtype=bfloat16",
    "hf_path=HuggingFaceH4/ultrachat_200k",
    f"hf_access_token={HF_TOKEN}",
    f"base_output_directory={BASE_OUTPUT_DIRECTORY}",
    f"run_name={RUN_NAME}",
    f"tokenizer_path={TOKENIZER_PATH}",
    "profiler=xplane",
]

config = pyconfig.initialize(config_argv)

print("‚úì SFT configuration loaded:")
print(f" Model: {config.model_name}")
print(f" Training Steps: {config.steps}")
print(f" Output Directory: {config.base_output_directory}")

## SFT Training

In [None]:
print("=" * 60)
print("üöÄ Starting SFT Training...")
print("=" * 60)

try:
    trainer, mesh = train_sft.train(config)

    print("\n" + "=" * 60)
    print("‚úÖ Training Completed Successfully!")
    print("=" * 60)
    print(f"üìÅ Checkpoints saved to: {config.checkpoint_dir}")
except Exception as e:
    print("\n" + "=" * 60)
    print("‚ùåTraining Failed!")
    print("=" * 60)
    print(f"Error: {str(e)}")

## üìö Learn More

- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html
- **Configuration**: See `src/maxtext/configs/post_train/sft.yml` for all available options
- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation