# Run SFT on Llama3.1-8B-Instruct model

This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with Tunix integration for efficient training.

## Dataset Overview
https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k

**Dataset Information:**
- **Name**: HuggingFaceH4/ultrachat_200k
- **Type**: Supervised Fine-Tuning dataset
- **Size**: ~200k conversations
- **Format**: Chat conversations with human-AI pairs
- **Splits**: train_sft, test_sft
- **Data columns**: ['messages']

**Dataset Structure:**
Each example contains a 'messages' field with:
- role: 'user' or 'assistant'
- content: The actual message text

**Example data format:**
```json
{
  "messages": [
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."}
  ]
}
```

## Key Features
- **MaxText Llama3.1-8B-Instruct model** 
- **Tunix integration** for optimized training
- **UltraChat-200k dataset** from HuggingFace
- Tokenizes with meta-llama/Llama-3.1-8B-Instruct


## Prerequisites
- MaxText environment with all dependencies
- Tunix installation
- HuggingFace access token for dataset download
- Sufficient compute resources (TPU/GPU)


In [None]:
### (Optional) Run this if you just have this file and nothing else

# 1. Clone the MaxText repository (from AI‑Hypercomputer)
!git clone https://github.com/AI-Hypercomputer/maxtext.git

# 2. Navigate into the cloned directory
%cd maxtext

In [None]:
### (Optional) Do not run this if you already installed the dependencies

# 3. Ensure setup.sh is executable
!chmod +x setup.sh

# 4. Execute the setup script
!./setup.sh

# force numpy version
!pip install --force-reinstall numpy==2.1.2
#install nest_asyncio
!pip install nest_asyncio

import nest_asyncio
nest_asyncio.apply()
# To fix "This event loop is already running" error in Colab


In [None]:

import os
import sys
#  Set  home directory. Change this to your home directory where maxtext is cloned
MAXTEXT_HOME = os.path.expanduser("~") + "/maxtext"
print(f"Home directory (from Python): {MAXTEXT_HOME}")
#set the path to the Llama3.1-8B-Instruct checkpoint you want to load, gs://<bucket> supported 
MODEL_CHECKPOINT_PATH = "path/to/scanned/checkpoint"

In [None]:
from pathlib import Path
from typing import Optional, Dict, Any

# Find MaxText directory and change working directory to it
current_dir = Path.cwd()
if current_dir.name == 'examples':
    # We're in the examples folder, go up one level
    maxtext_path = current_dir.parent.parent
else:
    # We're in the root, MaxText is a subfolder
    maxtext_path = Path(f'{MAXTEXT_HOME}') / 'src' / 'MaxText'

# Change working directory to MaxText project root
os.chdir(maxtext_path)
sys.path.insert(0, str(maxtext_path))

print(f"✓ Changed working directory to: {os.getcwd()}")
print(f"✓ MaxText project root: {maxtext_path}")
print(f"✓ Added to Python path: {maxtext_path}")
import jax
if not jax.distributed.is_initialized():
    jax.distributed.initialize()    
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")


## Hugging Face Authentication Setup

If you encounter 401 unauthorized errors when loading datasets, you need to authenticate with Hugging Face. Set your token below:


In [None]:
# Hugging Face Authentication Setup
from huggingface_hub import login

# Set your Hugging Face token here
HF_TOKEN = "hf_your_token_here"  # Replace with your actual token
login(token=HF_TOKEN)
  

In [None]:
# MaxText imports 
try:
    from MaxText import pyconfig
    from MaxText.sft.sft_trainer import train as sft_train

    MAXTEXT_AVAILABLE = True
    print("✓ MaxText imports successful")
except ImportError as e:
    print(f"⚠️ MaxText not available: {e}")
    MAXTEXT_AVAILABLE = False


##  Configuration Setup

## Notes
- Trains on completion only (sft_train_on_completion_only=True)
- Please set sft_train_on_completion_only=False to train both on prompts and completions. By default SFT will train only on completions.

In [None]:
# Fixed configuration setup
if MAXTEXT_AVAILABLE:
    # Proper config setup using MaxText's config system
    config_argv = [
        "",  
        f"{MAXTEXT_HOME}/src/MaxText/configs/sft.yml",   # SFT config
        f"load_parameters_path={MODEL_CHECKPOINT_PATH}",
        "model_name=llama3.1-8b",
        "steps=100",
        "per_device_batch_size=1",
        "max_target_length=1024",
        "learning_rate=2.0e-5",
        "eval_steps=5",
        "weight_dtype=bfloat16",
        "dtype=bfloat16",
        "hf_path=HuggingFaceH4/ultrachat_200k",
        f"hf_access_token={HF_TOKEN}",
        "base_output_directory=/tmp/maxtext_output",
        "run_name=sft_llama3_demo",
        "tokenizer_path=meta-llama/Llama-3.1-8B-Instruct",
        "eval_interval=10",
        "profiler=xplane",
    ]
    
    # Initialize configuration using MaxText's pyconfig
    config = pyconfig.initialize(config_argv)
    
    print("✓ Fixed configuration loaded:")
    print(f"  - Model: {config.model_name}")
    print(f"  - Dataset: {config.hf_path}")
    print(f"  - Steps: {config.steps}")
    print(f"  - Use SFT: {config.use_sft}")
    print(f"  - Learning Rate: {config.learning_rate}")
else:
    print("MaxText not available - cannot load configuration")


##  Execute Actual Training

Let's actually run the training using the MaxText SFT trainer's `train()` function.


In [None]:
#  Execute the training using MaxText SFT trainer's train() function
if MAXTEXT_AVAILABLE:
    print("="*60)
    print("EXECUTING ACTUAL TRAINING")
    print("="*60)
    
    sft_train(config)           
    
    print("\n✅ Training completed successfully!")
               
else:
    print("MaxText not available - cannot execute training")


##  Summary

This notebook demonstrated the complete MaxText & Tunix integration for SFT training.


The integration provides the best of both worlds: MaxText's high-performance LLM training and Tunix's optimized training infrastructure, making it ideal for production SFT training on large datasets like UltraChat-200k.
