You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Supersedes #3146 (closed due to branch rename/history rewrite).
Adds src/MaxText/examples/sft_llama3_gpu.ipynb: a GPU-focused, end-to-end notebook for SFT of Llama 3.1-8B on NVIDIA GPUs (HF auth → gated access note → HF → MaxText checkpoint conversion (CPU) → SFT run → TensorBoard → inference sanity check).
Motivation: complements existing notebook/docs that emphasize TPU SFT flows (e.g. sft_llama3_demo.ipynb) with a clear NVIDIA GPU path.
Tests
Executed the notebook in the NGC container nvcr.io/nvidia/jax:26.01-maxtext-py3 on a cluster node with 8 H100 NVIDIA GPUs, CUDA 13.1, driver 580.105.08, JAX 0.8.1.dev20260217.
Verified:
Checkpoint conversion completes.
SFT runs for 100 steps and writes checkpoints/logs.
Then follow the instructions in the jupyter notebook src/maxtext/examples/sft_llama3_gpu.ipynb
Two container-specific fixes corresponding to an older maxtext version excluded from the notebook:
Fix 1. The code below includes a workaround for a known container issue where create_nnx_model defaults model_mode to None instead of "train". This is patched at runtime.
# Fix for container bug: model_creation_utils.create_nnx_model defaults model_mode=None# but it should default to "train". Set the correct default.
from MaxText import model_creation_utils
model_creation_utils.create_nnx_model.__defaults__ = (None, None, "train", None)
Fix 2. The code below is a workaround for a known container issue where empty-string defaults for hf_train_files/hf_eval_files/hf_data_dir cause datasets.load_dataset to fail. These are patched to None at runtime.
# Fix for container bug: empty string defaults for hf_train_files/hf_eval_files/hf_data_dir# cause datasets.load_dataset to fail. Monkey-patch to convert empty strings to None.# Guard against multiple applications to avoid recursion.
import datasets
if not hasattr(datasets, '_original_load_dataset'):
datasets._original_load_dataset = datasets.load_dataset
def _patched_load_dataset(*args, **kwargs):
forkeyin ['data_files', 'data_dir']:
if key in kwargs and kwargs[key] == '':
kwargs[key] = None
return datasets._original_load_dataset(*args, **kwargs)
datasets.load_dataset = _patched_load_dataset
Checklist
Before submitting this PR, please make sure (put X in square brackets):
I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
I have necessary comments in my code, particularly in hard-to-understand areas.
I have run end-to-end tests tests and provided workload links above if applicable.
I will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.
Ignoring the jupyter notebook test failure w.r.t getpass
Gemini's explanation: b
Because the PR is likely submitted from a fork, GitHub Actions automatically restricts access to repository secrets to prevent malicious code from exfiltrating them. As a result, the HF_TOKEN environment variable is empty. The code then falls back to getpass() to ask for interactive input. However, the CI job runs the notebook headlessly using papermill ( source), which has no interactive input stream, so it crashes.Because the PR is likely submitted from a fork, GitHub Actions automatically restricts access to repository secrets to prevent malicious code from exfiltrating them. As a result, the HF_TOKEN environment variable is empty. The code then falls back to getpass() to ask for interactive input. However, the CI job runs the notebook headlessly using papermill ( source), which has no interactive input stream, so it crashes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Supersedes #3146 (closed due to branch rename/history rewrite).
src/MaxText/examples/sft_llama3_gpu.ipynb: a GPU-focused, end-to-end notebook for SFT of Llama 3.1-8B on NVIDIA GPUs (HF auth → gated access note → HF → MaxText checkpoint conversion (CPU) → SFT run → TensorBoard → inference sanity check).sft_llama3_demo.ipynb) with a clear NVIDIA GPU path.Tests
Executed the notebook in the NGC container
nvcr.io/nvidia/jax:26.01-maxtext-py3on a cluster node with 8 H100 NVIDIA GPUs, CUDA 13.1, driver 580.105.08, JAX 0.8.1.dev20260217.Verified:
To reproduce, execute as follows:
Then follow the instructions in the jupyter notebook src/maxtext/examples/sft_llama3_gpu.ipynb
Two container-specific fixes corresponding to an older maxtext version excluded from the notebook:
Fix 1. The code below includes a workaround for a known container issue where create_nnx_model defaults model_mode to None instead of "train". This is patched at runtime.
Fix 2. The code below is a workaround for a known container issue where empty-string defaults for hf_train_files/hf_eval_files/hf_data_dir cause datasets.load_dataset to fail. These are patched to None at runtime.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.