Skip to content

Add SFT LLaMA3 GPU demo notebook#3285

Merged
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
katjasrz:sft_llama3_gpu
Mar 6, 2026
Merged

Add SFT LLaMA3 GPU demo notebook#3285
copybara-service[bot] merged 1 commit intoAI-Hypercomputer:mainfrom
katjasrz:sft_llama3_gpu

Conversation

@katjasrz
Copy link
Copy Markdown
Collaborator

@katjasrz katjasrz commented Mar 2, 2026

Description

Supersedes #3146 (closed due to branch rename/history rewrite).

  • Adds src/MaxText/examples/sft_llama3_gpu.ipynb: a GPU-focused, end-to-end notebook for SFT of Llama 3.1-8B on NVIDIA GPUs (HF auth → gated access note → HF → MaxText checkpoint conversion (CPU) → SFT run → TensorBoard → inference sanity check).
  • Motivation: complements existing notebook/docs that emphasize TPU SFT flows (e.g. sft_llama3_demo.ipynb) with a clear NVIDIA GPU path.

Tests

Executed the notebook in the NGC container nvcr.io/nvidia/jax:26.01-maxtext-py3 on a cluster node with 8 H100 NVIDIA GPUs, CUDA 13.1, driver 580.105.08, JAX 0.8.1.dev20260217.

Verified:

  • Checkpoint conversion completes.
  • SFT runs for 100 steps and writes checkpoints/logs.
  • TensorBoard logs are created.
  • Inference sanity check runs and produces output.

To reproduce, execute as follows:

cd <YOUR_ROOT_DIRECTORY>

# clone this repo
git clone https://github.com/katjasrz/maxtext.git

export ROOT_DIR=$(pwd)
export PROJECT_DIR=$ROOT_DIR/maxtext/src/maxtext/examples
export HF_CACHE_DIR=$ROOT_DIR/huggingface

docker run -it --rm --ipc=host \
  --gpus=all \
  -p 8889:8889 \
  -p 6006:6006 \
  --shm-size=16g \
  --ulimit memlock=-1 \
  -v "$PROJECT_DIR":/workspace \
  -v "$HF_CACHE_DIR":/hf_cache \
  -e HF_HOME=/hf_cache \
  -e LOCAL_UID=$(id -u) \
  -e LOCAL_GID=$(id -g) \
  nvcr.io/nvidia/jax:26.01-maxtext-py3 \
  bash -lc 'set -e
    groupadd -g $LOCAL_GID hostgrp 2>/dev/null || true
    useradd -u $LOCAL_UID -g $LOCAL_GID -M -d /workspace hostusr 2>/dev/null || true
    
    python3 -m pip install --upgrade pip
    pip install jupyterlab ipywidgets
    pip install -U git+https://github.com/google/tunix
		pip install torch --index-url https://download.pytorch.org/whl/cpu
    
    su hostusr -c "cd /workspace && HOME=/workspace HF_HOME=/hf_cache \
      jupyter lab --ip=0.0.0.0 --port=8889 --no-browser"'

Then follow the instructions in the jupyter notebook src/maxtext/examples/sft_llama3_gpu.ipynb

Two container-specific fixes corresponding to an older maxtext version excluded from the notebook:

Fix 1. The code below includes a workaround for a known container issue where create_nnx_model defaults model_mode to None instead of "train". This is patched at runtime.

# Fix for container bug: model_creation_utils.create_nnx_model defaults model_mode=None
# but it should default to "train". Set the correct default.
from MaxText import model_creation_utils
model_creation_utils.create_nnx_model.__defaults__ = (None, None, "train", None)

Fix 2. The code below is a workaround for a known container issue where empty-string defaults for hf_train_files/hf_eval_files/hf_data_dir cause datasets.load_dataset to fail. These are patched to None at runtime.

# Fix for container bug: empty string defaults for hf_train_files/hf_eval_files/hf_data_dir
# cause datasets.load_dataset to fail. Monkey-patch to convert empty strings to None.
# Guard against multiple applications to avoid recursion.
import datasets
if not hasattr(datasets, '_original_load_dataset'):
    datasets._original_load_dataset = datasets.load_dataset

    def _patched_load_dataset(*args, **kwargs):
        for key in ['data_files', 'data_dir']:
            if key in kwargs and kwargs[key] == '':
                kwargs[key] = None
        return datasets._original_load_dataset(*args, **kwargs)

    datasets.load_dataset = _patched_load_dataset

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Copy Markdown
Collaborator

@A9isha A9isha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much!

There are a few failing tests which would need to be fixed of course

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@A9isha
Copy link
Copy Markdown
Collaborator

A9isha commented Mar 5, 2026

Ignoring the jupyter notebook test failure w.r.t getpass

Gemini's explanation: b
Because the PR is likely submitted from a fork, GitHub Actions automatically restricts access to repository secrets to prevent malicious code from exfiltrating them. As a result, the HF_TOKEN environment variable is empty. The code then falls back to getpass() to ask for interactive input. However, the CI job runs the notebook headlessly using papermill ( source), which has no interactive input stream, so it crashes.Because the PR is likely submitted from a fork, GitHub Actions automatically restricts access to repository secrets to prevent malicious code from exfiltrating them. As a result, the HF_TOKEN environment variable is empty. The code then falls back to getpass() to ask for interactive input. However, the CI job runs the notebook headlessly using papermill ( source), which has no interactive input stream, so it crashes.

@igorts-git igorts-git mentioned this pull request Mar 5, 2026
4 tasks
@copybara-service copybara-service Bot merged commit 7656eb8 into AI-Hypercomputer:main Mar 6, 2026
51 of 61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants