@@ -266,29 +266,44 @@ def _safe_shard(x, pspec):
266266 nnx .update (self .optimizer , optimizer_sharded_state )
267267
268268 def _train_step (self , model , optimizer , inputs ):
269- """Overrides the main JIT block to natively handle ModelBundle module."""
269+ """Overrides the main JIT block to natively handle ModelBundle module.
270270
271+ Uses jax.value_and_grad with explicit split/merge to avoid nesting
272+ nnx.value_and_grad inside nnx.jit, which causes Flax NNX to assign
273+ conflicting outer_index values and raises:
274+ ValueError: The graph structure of a node added to cached_partial was
275+ mutated inside the transformation.
276+ """
271277 batch = self .gen_model_input_fn (inputs )
278+ student = model .student_model
279+ teacher = model .teacher_model
280+
281+ # Run teacher inference outside of value_and_grad.
282+ # The teacher is frozen (stop_gradient), so its output is a constant
283+ # from the perspective of the student gradient computation.
284+ if "teacher_output" in batch :
285+ teacher_output = batch ["teacher_output" ]
286+ else :
287+ teacher_output = self .strategy .teacher_forward_fn (
288+ model = teacher ,
289+ input_tokens = batch ["input_tokens" ],
290+ positions = batch ["positions" ],
291+ attention_mask = batch .get ("attention_mask" ),
292+ decoder_segment_ids = batch .get ("decoder_segment_ids" ),
293+ decoder_target_tokens = batch .get ("targets" , None ),
294+ decoder_target_mask = batch .get ("targets_segmentation" , None ),
295+ cache = None ,
296+ )
297+ teacher_output = jax .tree .map (jax .lax .stop_gradient , teacher_output )
272298
273- def loss_wrapper (student , teacher , batch ):
274- if "teacher_output" in batch :
275- teacher_output = batch ["teacher_output" ]
276- else :
277- teacher_output = self .strategy .teacher_forward_fn (
278- model = teacher ,
279- input_tokens = batch ["input_tokens" ],
280- positions = batch ["positions" ],
281- attention_mask = batch .get ("attention_mask" ),
282- decoder_segment_ids = batch .get ("decoder_segment_ids" ),
283- decoder_target_tokens = batch .get ("targets" , None ),
284- decoder_target_mask = batch .get ("targets_segmentation" , None ),
285- cache = None ,
286- )
287-
288- teacher_output = jax .tree .map (jax .lax .stop_gradient , teacher_output )
299+ # Split student into differentiable params and non-differentiable rest.
300+ # Capture graphdef outside of jax.value_and_grad for stable graph tracking.
301+ student_graphdef , diff_params , rest = nnx .split (student , self .wrt_filter , ...)
289302
303+ def loss_wrapper_pure (diff_params , rest ):
304+ local_student = nnx .merge (student_graphdef , diff_params , rest , copy = True )
290305 student_output = self .strategy .student_forward_fn (
291- model = student ,
306+ model = local_student ,
292307 input_tokens = batch ["input_tokens" ],
293308 positions = batch ["positions" ],
294309 attention_mask = batch .get ("attention_mask" ),
@@ -297,27 +312,24 @@ def loss_wrapper(student, teacher, batch):
297312 decoder_target_mask = batch .get ("targets_segmentation" , None ),
298313 cache = None ,
299314 )
300- # we should apply a mask for labels to disable segment-separator tokens
301315 labels = self .strategy .create_labels (batch ["targets" ], targets_segmentation = batch .get ("targets_segmentation" , None ))
302- return self .strategy .compute_loss (student_output , teacher_output , labels )
303-
304- # Because student is the 0th argument, argnums=0 guarantees
305- # we only compute gradients for the student.
306- grad_fn = nnx .value_and_grad (
307- loss_wrapper ,
308- argnums = nnx .DiffState (0 , self .wrt_filter ),
309- has_aux = True ,
310- )
316+ loss , aux = self .strategy .compute_loss (student_output , teacher_output , labels )
317+ # Capture updated non-param state (e.g. RNG counters) from local_student.
318+ _ , _ , new_rest = nnx .split (local_student , self .wrt_filter , ...)
319+ return loss , (aux , new_rest )
311320
312- out , grads = grad_fn (model .student_model , model .teacher_model , batch )
321+ grad_fn = jax .value_and_grad (loss_wrapper_pure , argnums = 0 , has_aux = True )
322+ (loss , (aux , new_rest )), grads = grad_fn (diff_params , rest )
313323
314- tunix_expects_grad_norm = getattr (self , "_tunix_expects_grad_norm" , True )
324+ # Propagate updated non-param state back to student.
325+ nnx .update (student , new_rest )
315326
316- optimizer .update (model . student_model , grads )
327+ optimizer .update (student , grads )
317328
329+ tunix_expects_grad_norm = getattr (self , "_tunix_expects_grad_norm" , True )
318330 if tunix_expects_grad_norm :
319- return out [ 0 ], out [ 1 ] , optax .global_norm (grads )
320- return out [ 0 ], out [ 1 ]
331+ return loss , aux , optax .global_norm (grads )
332+ return loss , aux
321333
322334 def _eval_step (self , model , inputs ):
323335 """Evaluation only needs the student."""
0 commit comments