·
61 commits
to main
since this release
Changes
- Upgraded JAX to version 0.10.0 for pre-training and 0.10.1 for post-training.
- New vLLM-Powered Evaluation Framework: Introduced an eval framework for running lm-eval, evalchemy, and custom benchmarking against MaxText checkpoints. See the evaluation guide for details.
- Added support for pre-training new models:
- Direct Preference Optimization (DPO/ORPO) Support: Full support for DPO and ORPO alignment pipelines. See the DPO tutorial for details.
- Reinforcement Learning (RL) Recipe: Added a pre-configured RL recipe for Qwen3-30b-a3b.
- Iterative Quality Monitoring (RL): Added intermediate evaluation hooks to automatically run quality benchmarks during RL training (every
eval_intervalsteps), optimized with a neweval_batch_sizeconfiguration knob. - Developer Extensibility: Added
dataset_processor_pathCLI knob for custom dataset integration, and refactored shared post-training hooks to simplify custom SFT, DPO, and RL workflow development. - Generalized Learn-to-Init (LTI) for Distillation: Enhanced post-training distillation capabilities with generalized LTI support.
- Added support for recording elastic goodput events during training to track efficiency (PR #3901).
- Installation Updates: Updated the
[tpu-post-train]installation command to requireUV_TORCH_BACKEND=cpu(see Installation Guide). - Zero1 AOT Compilation: Added zero1 support to Ahead-Of-Time (AOT) compilation in train compile, improving compilation capabilities for zero1 config.
- MoE Performance Optimization: Integrated ragged gather reduce into Mixture of Experts (MoE) layers to optimize memory and performance by replacing ragged scatter and supporting backward pass.
- Added E2E scripts to run checkpoint conversion, pre-training and post-training (SFT, RL) with Gemma3-4B model.
- Bug Fixes and Usability Enhancements:
- Attention Masking Fix in RL: Fixed an issue in
TunixMaxTextAdapterwhere queries at non-pad positions could attend to pad-position keys during training, which was corrupting log-probabilities and affecting GRPO training reward trajectories (PR #4016). - JAX/NNX Gradient Mutation Fix: Refactored post-training loops (
train_distill,train_sft,train_rl) to usejax.value_and_gradwith explicit NNX state split/merge instead of nestingnnx.value_and_gradinsidennx.jit(PR #3652). - Qwen3-MoE Checkpoint Conversion: Fixed checkpoint conversion issues for Qwen3-MoE models (PR #3868).
- Duplicate Configuration Failures Fix: Allowed identical config overrides and handled configuration exceptions cleanly (PR #3933).
- Attention Masking Fix in RL: Fixed an issue in
- Documentation Improvements: Updated Getting started guide, including new guides for the evaluation framework and the DPO tutorial.
Deprecations
- Deleted legacy DPO implementation in favor of the integrated DPO trainer.
- Removed stack trace collection feature.