Skip to content

Commit 2fbc8c2

Browse files
committed
NNX post-train fixes: unpack MultimodalInput for NNX decoder; support scalar LR in adam_pax
- models.py: NNX Transformer was passing `multimodal_input=MultimodalInput(...)` to NNXDecoder, which expects individual keyword args (image_embeddings, image_masks, audio_embeddings, audio_masks, bidirectional_mask). Unpack the object at the call site. - optimizers.py: adam_pax called `learning_rate_fn(count)` unconditionally, failing when `optax.inject_hyperparams` passes a pre-evaluated scalar instead of a callable schedule. Add `callable()` guard to handle both cases.
1 parent 3f34221 commit 2fbc8c2

7 files changed

Lines changed: 224 additions & 61 deletions

File tree

src/maxtext/models/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,11 @@ def __call__(
520520
previous_chunk=previous_chunk,
521521
slot=slot,
522522
page_state=page_state,
523-
multimodal_input=multimodal_input,
523+
image_embeddings=multimodal_input.image_embeddings if multimodal_input is not None else None,
524+
image_masks=multimodal_input.image_masks if multimodal_input is not None else None,
525+
audio_embeddings=multimodal_input.audio_embeddings if multimodal_input is not None else None,
526+
audio_masks=multimodal_input.audio_masks if multimodal_input is not None else None,
527+
bidirectional_mask=multimodal_input.bidirectional_mask if multimodal_input is not None else None,
524528
kv_caches=kv_caches,
525529
attention_metadata=attention_metadata,
526530
deepstack_visual_embeds=deepstack_visual_embeds,

src/maxtext/optimizers/optimizers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def _update_momentum(update, mu, nu):
336336
else:
337337
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)
338338

339-
step_size = -1.0 * learning_rate_fn(count)
339+
# learning_rate_fn may be a callable schedule or a scalar (e.g. when wrapped
340+
# by optax.inject_hyperparams, it is passed as a pre-evaluated scalar).
341+
step_size = -1.0 * (learning_rate_fn(count) if callable(learning_rate_fn) else learning_rate_fn)
340342
# Finally, fold in step size.
341343
updates = jax.tree_util.tree_map(lambda x: step_size * x, updates)
342344

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -266,29 +266,44 @@ def _safe_shard(x, pspec):
266266
nnx.update(self.optimizer, optimizer_sharded_state)
267267

268268
def _train_step(self, model, optimizer, inputs):
269-
"""Overrides the main JIT block to natively handle ModelBundle module."""
269+
"""Overrides the main JIT block to natively handle ModelBundle module.
270270
271+
Uses jax.value_and_grad with explicit split/merge to avoid nesting
272+
nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign
273+
conflicting outer_index values and raises:
274+
ValueError: The graph structure of a node added to cached_partial was
275+
mutated inside the transformation.
276+
"""
271277
batch = self.gen_model_input_fn(inputs)
278+
student = model.student_model
279+
teacher = model.teacher_model
280+
281+
# Run teacher inference outside of value_and_grad.
282+
# The teacher is frozen (stop_gradient), so its output is a constant
283+
# from the perspective of the student gradient computation.
284+
if "teacher_output" in batch:
285+
teacher_output = batch["teacher_output"]
286+
else:
287+
teacher_output = self.strategy.teacher_forward_fn(
288+
model=teacher,
289+
input_tokens=batch["input_tokens"],
290+
positions=batch["positions"],
291+
attention_mask=batch.get("attention_mask"),
292+
decoder_segment_ids=batch.get("decoder_segment_ids"),
293+
decoder_target_tokens=batch.get("targets", None),
294+
decoder_target_mask=batch.get("targets_segmentation", None),
295+
cache=None,
296+
)
297+
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
272298

273-
def loss_wrapper(student, teacher, batch):
274-
if "teacher_output" in batch:
275-
teacher_output = batch["teacher_output"]
276-
else:
277-
teacher_output = self.strategy.teacher_forward_fn(
278-
model=teacher,
279-
input_tokens=batch["input_tokens"],
280-
positions=batch["positions"],
281-
attention_mask=batch.get("attention_mask"),
282-
decoder_segment_ids=batch.get("decoder_segment_ids"),
283-
decoder_target_tokens=batch.get("targets", None),
284-
decoder_target_mask=batch.get("targets_segmentation", None),
285-
cache=None,
286-
)
287-
288-
teacher_output = jax.tree.map(jax.lax.stop_gradient, teacher_output)
299+
# Split student into differentiable params and non-differentiable rest.
300+
# Capture graphdef outside of jax.value_and_grad for stable graph tracking.
301+
student_graphdef, diff_params, rest = nnx.split(student, self.wrt_filter, ...)
289302

303+
def loss_wrapper_pure(diff_params, rest):
304+
local_student = nnx.merge(student_graphdef, diff_params, rest, copy=True)
290305
student_output = self.strategy.student_forward_fn(
291-
model=student,
306+
model=local_student,
292307
input_tokens=batch["input_tokens"],
293308
positions=batch["positions"],
294309
attention_mask=batch.get("attention_mask"),
@@ -297,27 +312,24 @@ def loss_wrapper(student, teacher, batch):
297312
decoder_target_mask=batch.get("targets_segmentation", None),
298313
cache=None,
299314
)
300-
# we should apply a mask for labels to disable segment-separator tokens
301315
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
302-
return self.strategy.compute_loss(student_output, teacher_output, labels)
303-
304-
# Because student is the 0th argument, argnums=0 guarantees
305-
# we only compute gradients for the student.
306-
grad_fn = nnx.value_and_grad(
307-
loss_wrapper,
308-
argnums=nnx.DiffState(0, self.wrt_filter),
309-
has_aux=True,
310-
)
316+
loss, aux = self.strategy.compute_loss(student_output, teacher_output, labels)
317+
# Capture updated non-param state (e.g. RNG counters) from local_student.
318+
_, _, new_rest = nnx.split(local_student, self.wrt_filter, ...)
319+
return loss, (aux, new_rest)
311320

312-
out, grads = grad_fn(model.student_model, model.teacher_model, batch)
321+
grad_fn = jax.value_and_grad(loss_wrapper_pure, argnums=0, has_aux=True)
322+
(loss, (aux, new_rest)), grads = grad_fn(diff_params, rest)
313323

314-
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
324+
# Propagate updated non-param state back to student.
325+
nnx.update(student, new_rest)
315326

316-
optimizer.update(model.student_model, grads)
327+
optimizer.update(student, grads)
317328

329+
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
318330
if tunix_expects_grad_norm:
319-
return out[0], out[1], optax.global_norm(grads)
320-
return out[0], out[1]
331+
return loss, aux, optax.global_norm(grads)
332+
return loss, aux
321333

322334
def _eval_step(self, model, inputs):
323335
"""Evaluation only needs the student."""

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,42 @@
5656
import pathwaysutils
5757
import tensorflow_datasets as tfds
5858

59+
# JAX 0.9+ changed with_sharding_constraint to assert (not reshard) when all
60+
# mesh axes are Explicit. tpu_inference still expects resharding semantics.
61+
# Patch: try the original (works for Auto axes); on AssertionError (Explicit
62+
# mesh) fall back to jax.sharding.reshard.
63+
_orig_wsc = jax.lax.with_sharding_constraint
64+
65+
66+
def _compat_wsc(x, shardings):
67+
try:
68+
return _orig_wsc(x, shardings)
69+
except AssertionError:
70+
return jax.sharding.reshard(x, shardings)
71+
72+
73+
jax.lax.with_sharding_constraint = _compat_wsc
74+
75+
# tpu_inference JaxEinsum defaults param_dtype=float32, so tpu_inference model weights
76+
# initialize as float32. During weight sync, tunix._apply_dtype_cast then upcasts the
77+
# incoming bfloat16 MaxText weights → float32 to match the target. This leaves v_proj
78+
# as float32 while k_proj output appears bfloat16 (due to k_norm dtype promotion),
79+
# causing a dtype mismatch in the ragged paged attention kernel.
80+
# Fix: skip bfloat16→float32 upcasts during weight sync so synced weights stay bfloat16.
81+
import jax.numpy as _jnp
82+
import tunix.generate.utils as _tunix_utils
83+
84+
_orig_apply_dtype_cast = _tunix_utils._apply_dtype_cast # pylint: disable=protected-access
85+
86+
87+
def _no_bf16_to_f32_cast(val, tgt_dtype, src_key):
88+
if hasattr(val, "dtype") and val.dtype == _jnp.bfloat16 and tgt_dtype == _jnp.float32:
89+
return val # keep bfloat16; tpu_inference model dtype is bfloat16 despite float32 init
90+
return _orig_apply_dtype_cast(val, tgt_dtype, src_key)
91+
92+
93+
_tunix_utils._apply_dtype_cast = _no_bf16_to_f32_cast # pylint: disable=protected-access
94+
5995
from absl import app
6096
from absl import logging as absl_logging
6197
from etils import epath
@@ -543,6 +579,8 @@ def create_rl_components(
543579
"hf_overrides": trainer_config.vllm_hf_overrides,
544580
"enable_expert_parallel": sampler_config.rollout_expert_parallelism > 1,
545581
"enable_prefix_caching": True, # Enable prefix caching to speed up generation for long prompts
582+
# Ensures vLLM model initializes with correct dtype (not float32 default)
583+
"dtype": trainer_config.weight_dtype,
546584
},
547585
rollout_vllm_sampling_kwargs={
548586
"stop": trainer_config.stop_strings,

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@
3535
eval_interval=-1 steps=10 profiler=xplane weight_dtype=bfloat16
3636
"""
3737

38-
from typing import Sequence
38+
from typing import Any, Sequence
3939

4040
from absl import app
4141
import os
4242
import jax
4343
import optax
4444
import pathwaysutils
4545

46+
from flax import nnx
4647
from flax.linen import partitioning as nn_partitioning
4748

4849
from orbax import checkpoint as ocp
@@ -68,6 +69,69 @@
6869
from maxtext.utils import model_creation_utils
6970

7071

72+
class MaxTextPeftTrainer(peft_trainer.PeftTrainer):
73+
"""MaxText-specific PeftTrainer that avoids nested NNX transformations.
74+
75+
Tunix's default PeftTrainer._train_step creates nnx.value_and_grad inside
76+
nnx.jit. This nesting causes Flax NNX to assign conflicting outer_index
77+
values to graph nodes, resulting in:
78+
ValueError: The graph structure of a node added to cached_partial was
79+
mutated inside the transformation.
80+
81+
This subclass overrides create_train_step_fn to use jax.value_and_grad
82+
with an explicit split/merge pattern (matching MaxText's pre-training NNX
83+
train_step), which avoids the nested NNX transformation issue entirely.
84+
"""
85+
86+
def create_train_step_fn(self):
87+
"""Creates a train step using jax.value_and_grad with explicit NNX split/merge."""
88+
loss_fn_ref = self.loss_fn
89+
has_aux = self._has_aux
90+
gen_fn = self.gen_model_input_fn
91+
is_lora_enabled = self._lora_enabled
92+
wrt = nnx.LoRAParam if is_lora_enabled else nnx.Param
93+
94+
# Capture the graphdef once outside of JIT so that split/merge inside
95+
# jax.value_and_grad can use a stable (non-traced) structural descriptor.
96+
graphdef, _, _ = nnx.split(self.model, wrt, ...)
97+
98+
def train_step(model: nnx.Module, optimizer: nnx.Optimizer, inputs: Any):
99+
inputs = gen_fn(inputs)
100+
101+
# Split model into differentiable params and non-differentiable rest.
102+
# Using jax.value_and_grad (not nnx.value_and_grad) avoids nesting NNX
103+
# transforms inside nnx.jit, which would corrupt outer_index tracking.
104+
_, diff_params, rest = nnx.split(model, wrt, ...)
105+
106+
def loss_wrapper(diff_params, rest, **inputs_kw):
107+
local_model = nnx.merge(graphdef, diff_params, rest, copy=True)
108+
out = loss_fn_ref(local_model, **inputs_kw)
109+
# Capture updated non-param state (e.g. RNG counters) from local_model.
110+
_, _, new_rest = nnx.split(local_model, wrt, ...)
111+
if has_aux:
112+
loss, aux = out
113+
return loss, (aux, new_rest)
114+
else:
115+
return out, (None, new_rest)
116+
117+
grad_fn = jax.value_and_grad(loss_wrapper, argnums=0, has_aux=True)
118+
(out_val, (aux, new_rest)), grads = grad_fn(diff_params, rest, **inputs)
119+
120+
# Propagate updated non-param state (RNG counters, etc.) back to model.
121+
nnx.update(model, new_rest)
122+
123+
# Apply optimizer update. grads has the same nnx.State(wrt) structure
124+
# as diff_params, which is compatible with optimizer.update.
125+
optimizer.update(model, grads)
126+
127+
if has_aux:
128+
return out_val, aux
129+
else:
130+
return out_val, None
131+
132+
return train_step
133+
134+
71135
def get_tunix_config(mt_config):
72136
"""Gets the Tunix training configurations from the MaxText config.
73137
@@ -161,7 +225,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
161225
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
162226
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)
163227

164-
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
228+
trainer = MaxTextPeftTrainer(model, optimizer, tunix_config)
165229
trainer.with_training_hooks(training_hooks)
166230
trainer.with_data_hooks(data_hooks)
167231
trainer = use_maxtext_loss_function(trainer, mt_config)

src/maxtext/utils/model_creation_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,13 @@ def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAI
361361
# Get the structure of checkpoint in `config.load_parameters_path`
362362
metadata = ckptr.metadata(config.load_parameters_path)
363363

364+
if metadata is None or metadata.item_metadata is None:
365+
raise ValueError(
366+
f"Cannot read checkpoint metadata from '{config.load_parameters_path}'. "
367+
"The checkpoint directory may be empty or the save did not complete "
368+
"(missing _CHECKPOINT_METADATA). Ensure the checkpoint save finished successfully."
369+
)
370+
364371
is_nnx_checkpoint = True
365372
if (
366373
"params" in metadata.item_metadata.tree.keys()

0 commit comments

Comments
 (0)