# Supervised Fine-Tuning of Llama 3.1-8B on NVIDIA GPUs with JAX and MaxText

This tutorial walks you through supervised fine-tuning (SFT) of Llama 3.1-8B on NVIDIA GPUs using JAX and MaxText. You'll learn how to take a pretrained Llama checkpoint, convert it into MaxText's native format, configure an SFT training run, and verify the result with a quick inference test.

**What you'll do:**
1. Set up the environment and authenticate with Hugging Face
2. Download and convert the Llama 3.1-8B checkpoint to MaxText format
3. Configure and launch supervised fine-tuning on the UltraChat 200k dataset
4. Visualize training metrics with TensorBoard
5. Run a quick inference sanity check

## Preliminaries

### Make sure you have supported hardware

**Hardware requirements.** Full-parameter SFT of Llama 3.1-8B is memory-intensive due to optimizer state, activations, and sharded model parameters. We recommend a system with **8 NVIDIA GPUs with at least 80 GB of memory each** (e.g., A100-80GB, H100-80GB, or H200). This allows the model, optimizer state, and activations to be cleanly sharded across devices without aggressive memory tuning.

When running `nvidia-smi`, you should see eight or more visible GPUs, each reporting at least 80 GB of total memory, with recent drivers, CUDA 12.x+ support, and minimal memory usage before training starts.

In [None]:
!nvidia-smi

### Get your Hugging Face token

To access model checkpoint from the Hugging Face Hub, you need to authenticate with a personal access token.

**Follow these steps to get your token:**

1.  **Navigate to the Access Tokens page** in your Hugging Face account settings. You can go there directly by visiting this URL: [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)

2.  **Create a new token** by clicking the **"+ Create new token"** button.

3.  **Give your token a name** and assign it a **`read` role**. The `read` role is sufficient for downloading models.

4.  **Copy the generated token**. You will need this in the later steps.

**Follow these steps to store your token (only if running on Google Colab):**

1. On the left sidebar of your Colab window, click the key icon (the Secrets tab).

2. Click **"+ Add new secret"**.

3. Set the Name as **HF_TOKEN**.

4. Paste your token into the Value field.

5. Ensure the Notebook access toggle is turned On.

In [None]:
import os

from ipywidgets import Password, Button, HBox, Output
from IPython.display import display

try:
    from huggingface_hub import whoami
except Exception:
    from huggingface_hub import HfApi

def _verify_token(token: str) -> str:
    try:
        return whoami(token=token).get("name", "unknown")
    except TypeError:
        return HfApi(token=token).whoami().get("name", "unknown")

token_box = Password(description="HF Token:", placeholder="paste your token here", layout={"width": "400px"})
save_btn = Button(description="Save", button_style="success")
out = Output()

def save_token(_):
    out.clear_output()
    with out:
        existing = os.environ.get("HF_TOKEN")
        entered = token_box.value.strip()
        if existing and not entered:
            user = _verify_token(existing)
            print(f"Using existing HF_TOKEN. Logged in as: {user}")
            return
        if not entered:
            print("No token entered.")
            return
        os.environ["HF_TOKEN"] = entered
        user = _verify_token(entered)
        print(f"Token saved. Logged in as: {user}")

save_btn.on_click(save_token)
display(HBox([token_box, save_btn]), out)

### Authenticate with Hugging Face

Verify that your Hugging Face token is set and valid by calling the Hub's `whoami` endpoint. 

In [None]:
# Prefer environment variable if already set
HF_TOKEN = os.environ.get("HF_TOKEN")

if HF_TOKEN:
    try:
        user = whoami()["name"]
        print(f"Authenticated with Hugging Face as: {user} (via HF_TOKEN env)")
    except Exception as e:
        print("HF_TOKEN is set but authentication failed:", e)
else:
    raise RuntimeError(
        "HF_TOKEN is not set. Please create a Hugging Face access token "
        "and export it as an environment variable."
    )

### Acquire permission to use the gated model

Llama 3.1-8B is a gated model, so you must explicitly request access before it can be downloaded. Visit the [model page](https://huggingface.co/meta-llama/Llama-3.1-8B) on Hugging Face, log in with the same account linked to your access token, and click **Request access**. You'll need to agree to Meta's license terms; approval is usually granted quickly but is not automatic. Once approved, your Hugging Face token will authorize downloads transparently. If you skip this step, model downloads will fail even with a valid token.

## Get the model and convert it into MaxText format

### Import dependencies

Import the core libraries needed for this tutorial:
- **JAX**: High-performance ML framework with automatic differentiation and XLA compilation
- **MaxText**: Google's production-grade training stack for JAX, providing model architectures, checkpoint management, and the SFT training loop

The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with JAX, CUDA, and MaxText preinstalled. To install the dependencies manually:

```bash
pip install 'jax[cuda13]' maxtext
```

On top of it, for the model conversion step you will also need **Torch**, the CPU version would be enough:

```bash
pip install torch --index-url https://download.pytorch.org/whl/cpu
```

In [None]:
# Imports
from datetime import datetime
from pathlib import Path
import sys
import subprocess
import logging

import transformers

import jax
import jax.numpy as jnp
import MaxText
from MaxText import pyconfig
from MaxText.sft.sft_trainer import train as sft_train, setup_trainer_state

MAXTEXT_REPO_ROOT = os.path.dirname(MaxText.__file__)

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"Number of available devices: {jax.local_device_count()}")
print(f"MaxText installation path: {MAXTEXT_REPO_ROOT}")

### Define model paths and run configuration

This block defines the core paths and identifiers used throughout the tutorial: the model name, tokenizer source, checkpoint location, and output directory. You can override `MODEL_CHECKPOINT_PATH` via an environment variable to point to an existing converted checkpoint and skip the conversion step.

In [None]:
MODEL_NAME = "llama3.1-8b"
TOKENIZER_PATH = "meta-llama/Llama-3.1-8B-Instruct"

WORKSPACE_DIR = Path(os.environ.get("WORKSPACE_DIR", "/workspace"))

# If set, use it; otherwise default to /workspace/llama_checkpoint
MODEL_CHECKPOINT_PATH = os.environ.get("MODEL_CHECKPOINT_PATH")
MODEL_CHECKPOINT_PATH = Path(MODEL_CHECKPOINT_PATH) if MODEL_CHECKPOINT_PATH else (WORKSPACE_DIR / "llama_checkpoint")

print(f"Model checkpoint directory: {MODEL_CHECKPOINT_PATH}")
print("Tip: set MODEL_CHECKPOINT_PATH to a local directory to reuse an existing converted checkpoint.")

BASE_OUTPUT_DIRECTORY = Path(os.environ.get("BASE_OUTPUT_DIRECTORY", str(WORKSPACE_DIR / "sft_llama3_output")))

### Download and convert the Llama 3.1 checkpoint

This block downloads the pretrained Llama 3.1-8B weights from Hugging Face and converts them into MaxText's native checkpoint format. If a converted checkpoint already exists at the target path, this step is skipped entirely.

The conversion runs in a CPU-only JAX context (`JAX_PLATFORMS=cpu`) to avoid unnecessary GPU memory allocation. 

In [None]:
ckpt_dir = Path(MODEL_CHECKPOINT_PATH)

def run_ckpt_conversion(
    *,
    maxtext_repo_root: str,
    model_name: str,
    output_dir: Path,
    hf_token: str,
    quiet: bool = True,
) -> None:
    env = os.environ.copy()

    # Conversion should not touch GPUs
    env["JAX_PLATFORMS"] = "cpu"

    # Reduce verbosity (JAX/XLA/TensorFlow C++ logging)
    env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")  # 0=all, 1=INFO off, 2=INFO+WARNING off, 3=only FATAL

    cmd = [
        sys.executable, "-m", "MaxText.utils.ckpt_conversion.to_maxtext",
        f"{maxtext_repo_root}/configs/base.yml",
        f"model_name={model_name}",
        f"base_output_directory={str(output_dir)}",
        f"hf_access_token={hf_token}",
        "use_multimodal=false",
        "scan_layers=true",
        "skip_jax_distributed_system=True",
    ]

    output_dir.parent.mkdir(parents=True, exist_ok=True)

    if quiet:
        # Capture logs; show only if something goes wrong
        result = subprocess.run(
            cmd,
            env=env,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        if result.returncode != 0:
            print("Checkpoint conversion failed. Logs:\n")
            if result.stdout:
                print("----- stdout -----")
                print(result.stdout)
            if result.stderr:
                print("----- stderr -----")
                print(result.stderr)
            raise RuntimeError("Checkpoint conversion failed. See logs above.")
    else:
        # Verbose mode (streams logs)
        subprocess.run(cmd, env=env, check=True)

    print(f"Checkpoint successfully converted to MaxText format at: {output_dir}")

if ckpt_dir.exists():
    print(f"Converted checkpoint already exists at: {ckpt_dir}")
else:
    print(f"Converting checkpoint to MaxText format â†’ {ckpt_dir}")
    run_ckpt_conversion(
        maxtext_repo_root=MAXTEXT_REPO_ROOT,
        model_name=MODEL_NAME,
        output_dir=ckpt_dir,
        hf_token=HF_TOKEN,
        quiet=True, 
    )

if not ckpt_dir.exists():
    raise RuntimeError("Model checkpoint conversion failed. See logs above.")

## Provide the training configuration

This block builds the full MaxText SFT training configuration by loading the base `sft.yml` config and applying runtime overrides for the model, dataset, hyperparameters, and output paths. Each run is tagged with a timestamp-based name to keep outputs isolated across experiments. Key settings:

- **Dataset:** [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), a large instruction-style conversational dataset commonly used for SFT of chat models.
- **Training:** 100 steps, learning rate 2e-5, sequence length 1024, bfloat16 precision.
- **Checkpoint source:** The converted MaxText checkpoint from the previous step.

To use your own dataset, ensure it follows a compatible schema and is accessible via the Hugging Face Hub or a local path. MaxText handles dataset loading, sharding, and batching automatically.

In [None]:
ckpt_items_path = Path(MODEL_CHECKPOINT_PATH) / "0" / "items"

if not os.environ.get("HF_TOKEN"):
    raise RuntimeError("HF_TOKEN is not set. Export it before loading the SFT config.")

RUN_NAME = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

# Load configuration for SFT training
config_argv = [
    "",
    f"{MAXTEXT_REPO_ROOT}/configs/sft.yml",
    f"load_parameters_path={ckpt_items_path}",
    f"model_name={MODEL_NAME}",
    "steps=100",
    "per_device_batch_size=1",
    "max_target_length=1024",
    "learning_rate=2.0e-5",
    "weight_dtype=bfloat16",
    "dtype=bfloat16",
    "hf_path=HuggingFaceH4/ultrachat_200k",
    f"hf_access_token={HF_TOKEN}",
    f"base_output_directory={BASE_OUTPUT_DIRECTORY}",
    f"run_name={RUN_NAME}",
    f"tokenizer_path={TOKENIZER_PATH}",
    "hardware=gpu",
]

# Suppress the verbose per-parameter config dump (hundreds of INFO lines)
_pyconfig_logger = logging.getLogger("MaxText.pyconfig")
_prev_level = _pyconfig_logger.level
_pyconfig_logger.setLevel(logging.WARNING)

config = pyconfig.initialize(config_argv)

_pyconfig_logger.setLevel(_prev_level)

print("SFT configuration loaded:")
print(f"  Model: {config.model_name}")
print(f"  Training Steps: {config.steps}")
print(f"  Max sequence length: {config.max_target_length}")
print(f"  Output Directory: {config.base_output_directory}")

## Run the SFT training

This section launches the SFT training loop. It runs MaxText's `sft_train` with the configuration defined above, reports progress, and saves checkpoints to the output directory. On completion, it prints the checkpoint and TensorBoard log paths.

In [None]:
os.environ.setdefault("LIBTPU_INIT_ARGS", "")

print("=" * 60)
print(f"Starting SFT training (run_name={RUN_NAME})")
print("=" * 60)

try:
    result = sft_train(config)

    print("\n" + "=" * 60)
    print("Training completed successfully")
    print("=" * 60)
    print(f"Checkpoints written to: {config.checkpoint_dir}")
    if hasattr(config, "tensorboard_dir"):
        print(f"TensorBoard logs: {config.tensorboard_dir}")

    if isinstance(result, tuple) and len(result) == 2:
        trainer, mesh = result
except Exception as e:
    print("\n" + "=" * 60)
    print("Training failed")
    print("=" * 60)
    print(f"Error details: {e}")
    raise

## Visualize training metrics with TensorBoard

To monitor training loss and other metrics, launch TensorBoard in a separate terminal:

```bash
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
tensorboard --logdir=/workspace/sft_llama3_output --host 0.0.0.0 --port 6006 --load_fast=false
```

Then open [http://127.0.0.1:6006/](http://127.0.0.1:6006/) in your browser. 

## Test inference

A quick sanity check to verify the fine-tuned model produces coherent output. The code below tokenizes a prompt using the Llama 3.1 chat template, then runs greedy autoregressive generation for up to 10 tokens, stopping early if the model produces an EOS token. This confirms the model loaded correctly and produces reasonable predictions.

Note: this is naive autoregressive generation without KV-caching, so each step recomputes attention over the full sequence. For production use, consider a dedicated serving framework with KV-cache support.

In [None]:
# Load tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH, token=HF_TOKEN)

# Get model from trainer
model = trainer.model

# Format prompt using Llama chat template
prompt = "What is the capital of France?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Tokenize
tokens = jnp.array(tokenizer(text)["input_ids"])[None, :]

# Greedy autoregressive generation
max_new_tokens = 10
generated_ids = []
eos_token_id = tokenizer.eos_token_id

for _ in range(max_new_tokens):
    seq_len = tokens.shape[1]
    positions = jnp.arange(seq_len)[None, :]
    attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]

    with mesh:
        output = model(tokens, positions, None, attention_mask)
        logits = output[0] if isinstance(output, tuple) else output

    next_token_id = int(jnp.argmax(logits[0, -1]))
    generated_ids.append(next_token_id)

    if next_token_id == eos_token_id:
        break

    tokens = jnp.concatenate([tokens, jnp.array([[next_token_id]])], axis=1)

# Decode all generated tokens
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

print(f"Prompt: {prompt}")
print(f"Generated ({len(generated_ids)} tokens): '{generated_text}'")