[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/main/src/MaxText/examples/sft_llama3_demo.ipynb)

# Llama3.1-8B-Instruct Supervised Fine-Tuning (SFT) Demo


## Overview

This notebook can run on **TPU v5e-8** or **v5p-8**

This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with MaxText and Tunix integration for efficient training.

## Dataset Overview

**Dataset Link:** https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k

### Dataset Information:
- **Name:** HuggingFaceH4/ultrachat_200k
- **Type:** Supervised Fine-Tuning dataset
- **Size:** ~200k conversations
- **Format:** Chat conversations with human-AI pairs
- **Splits:** train_sft, test_sft
- **Data columns:** ['messages']

### Dataset Structure:
Each example contains a 'messages' field with:
- **role:** 'user' or 'assistant'
- **content:** The actual message text

### Example data format:
```json
{
  "messages": [
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."}
  ]
}
```


## Prerequisites

### Change Runtime Type

**Instructions:**
1.  Navigate to the menu at the top of the screen.
2.  Click on **Runtime**.
3.  Select **Change runtime type** from the dropdown menu.
4.  Select **v5e-8** or **v5p-8 TPU** as the **Hardware accelerator**.
5. Click on **Save**.


### Get Your Hugging Face Token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL:
    *   [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need to paste it in the next step.

**Follow these steps to store your token:**

Just put your token in the line below

In [None]:
HF_TOKEN=""

Install maxtext and dependencies

https://maxtext.readthedocs.io/en/latest/tutorials/sft.html#install-dependencies

## Set up the maxtext environment

In [None]:
!cd ~/maxtext/src/ #  This is the path to the maxtext folder

In [None]:
import MaxText
from MaxText import pyconfig
from MaxText.sft.sft_trainer import train as sft_train
import jax
import os
# Hugging Face Authentication Setup
from huggingface_hub import login


MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)
print(f"MaxText installation path: {MAXTEXT_REPO_ROOT}")

In [None]:
if HF_TOKEN:
    login(token=HF_TOKEN)
    print("Authenticated with Hugging Face")
else:
    print("Authentication failed: Hugging Face token not set")

In [None]:
## Set the model, checkpoint path and output directory
MODEL_NAME = "llama3.1-8b"
# Case 1: Set `MODEL_CHECKPOINT_PATH` to the path (local or gs://) that already has `Llama3.1-8B-Instruct` model checkpoint
# Case 2: If you do not have the checkpoint, then do not update `MODEL_CHECKPOINT_PATH`
# and this colab will download the checkpoint from HF and store at `"{MAXTEXT_REPO_ROOT}/llama_checkpoint"`
MODEL_CHECKPOINT_PATH = f"{MAXTEXT_REPO_ROOT}/llama_checkpoint"

# This is the directory where the fine-tuned model will be saved
# You can change it to any path you want (local or gs://)
BASE_OUTPUT_DIRECTORY = "/tmp/out/maxtext_llama3_8b"


In [None]:
# This is the command to convert the HF model to the MaxText format 
if not os.path.exists(MODEL_CHECKPOINT_PATH):
    !python3 -m MaxText.utils.ckpt_conversion.to_maxtext \
        $MAXTEXT_REPO_ROOT/configs/base.yml \
        model_name=$MODEL_NAME \
        base_output_directory=$MODEL_CHECKPOINT_PATH \
        hf_access_token=$HF_TOKEN \
        use_multimodal=false \
        scan_layers=false

In [None]:
# this is the code to initialize jax if it's not initialized in the cell above
if not jax.distributed.is_initialized():
  jax.distributed.initialize()
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

In [None]:
# MaxText imports
try:
  MAXTEXT_AVAILABLE = True
  print("✓ MaxText imports successful")
except ImportError as e:
  print(f"⚠️ MaxText not available: {e}")
  MAXTEXT_AVAILABLE = False

In [None]:
# Fixed configuration setup for Llama3.1-8B on TPU
if MAXTEXT_AVAILABLE:
  config_argv = [
      "",
      f"{MAXTEXT_REPO_ROOT}/configs/sft.yml",  # base SFT config
      f"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items/",  # Load pre-trained weights!, replace with your checkpoint path
      f"model_name={MODEL_NAME}",
      "steps=100",  # adjust for your training needs
      "per_device_batch_size=1",  # minimal to avoid OOM
      "max_target_length=1024",
      "learning_rate=2.0e-5",  # safe small LR
      "eval_steps=5",
      "weight_dtype=bfloat16",
      "dtype=bfloat16",
      "hf_path=HuggingFaceH4/ultrachat_200k",  # HuggingFace dataset
      f"hf_access_token={HF_TOKEN}",
      f"base_output_directory={BASE_OUTPUT_DIRECTORY}",
      "run_name=sft_llama3_8b_test",
      "tokenizer_path=meta-llama/Llama-3.1-8B-Instruct",  # Llama tokenizer
      "eval_interval=10",
      "profiler=xplane",
  ]

  # Initialize configuration using MaxText's pyconfig
  config = pyconfig.initialize(config_argv)

  print("✓ Fixed configuration loaded:")
  print(f"  - Model: {config.model_name}")
  print(f"  - Dataset: {config.hf_path}")
  print(f"  - Steps: {config.steps}")
  print(f"  - Use SFT: {config.use_sft}")
  print(f"  - Learning Rate: {config.learning_rate}")
else:
  print("MaxText not available - cannot load configuration")

In [None]:
#  Execute the training using MaxText SFT trainer's train() function
if MAXTEXT_AVAILABLE:
  print("=" * 60)
  print("EXECUTING ACTUAL TRAINING")
  print("=" * 60)

  trainer, mesh = sft_train(config)

print("Training complete!")
print("Model saved at: ", BASE_OUTPUT_DIRECTORY)