Skip to content

maxtext-v0.2.3

Latest

Choose a tag to compare

@SurbhiJainUSC SurbhiJainUSC released this 12 Jun 22:17
· 61 commits to main since this release

Changes

  • Upgraded JAX to version 0.10.0 for pre-training and 0.10.1 for post-training.
  • New vLLM-Powered Evaluation Framework: Introduced an eval framework for running lm-eval, evalchemy, and custom benchmarking against MaxText checkpoints. See the evaluation guide for details.
  • Added support for pre-training new models:
    • Qwen3.5: Qwen3.5 35B & 397B is now supported.
    • Qwen3-Omni: Support for multimodal SFT (PR #3863).
  • Direct Preference Optimization (DPO/ORPO) Support: Full support for DPO and ORPO alignment pipelines. See the DPO tutorial for details.
  • Reinforcement Learning (RL) Recipe: Added a pre-configured RL recipe for Qwen3-30b-a3b.
  • Iterative Quality Monitoring (RL): Added intermediate evaluation hooks to automatically run quality benchmarks during RL training (every eval_interval steps), optimized with a new eval_batch_size configuration knob.
  • Developer Extensibility: Added dataset_processor_path CLI knob for custom dataset integration, and refactored shared post-training hooks to simplify custom SFT, DPO, and RL workflow development.
  • Generalized Learn-to-Init (LTI) for Distillation: Enhanced post-training distillation capabilities with generalized LTI support.
  • Added support for recording elastic goodput events during training to track efficiency (PR #3901).
  • Installation Updates: Updated the [tpu-post-train] installation command to require UV_TORCH_BACKEND=cpu(see Installation Guide).
  • Zero1 AOT Compilation: Added zero1 support to Ahead-Of-Time (AOT) compilation in train compile, improving compilation capabilities for zero1 config.
  • MoE Performance Optimization: Integrated ragged gather reduce into Mixture of Experts (MoE) layers to optimize memory and performance by replacing ragged scatter and supporting backward pass.
  • Added E2E scripts to run checkpoint conversion, pre-training and post-training (SFT, RL) with Gemma3-4B model.
  • Bug Fixes and Usability Enhancements:
    • Attention Masking Fix in RL: Fixed an issue in TunixMaxTextAdapter where queries at non-pad positions could attend to pad-position keys during training, which was corrupting log-probabilities and affecting GRPO training reward trajectories (PR #4016).
    • JAX/NNX Gradient Mutation Fix: Refactored post-training loops (train_distill, train_sft, train_rl) to use jax.value_and_grad with explicit NNX state split/merge instead of nesting nnx.value_and_grad inside nnx.jit (PR #3652).
    • Qwen3-MoE Checkpoint Conversion: Fixed checkpoint conversion issues for Qwen3-MoE models (PR #3868).
    • Duplicate Configuration Failures Fix: Allowed identical config overrides and handled configuration exceptions cleanly (PR #3933).
  • Documentation Improvements: Updated Getting started guide, including new guides for the evaluation framework and the DPO tutorial.

Deprecations