From 3d212094d28ec6b94492564549bff39011ad3695 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 22 Apr 2026 18:06:57 +0000 Subject: [PATCH] Remove fp8 NaN sanitization from train_step Drop both the OWG bookkeeping correction and the blanket jnp.nan_to_num over float gradients. Both were introduced as drive-by additions in an NNX migration commit with no fp8 test, repro, or justification in the commit message. A/B on V6e-8, FSDP=8, 10 steps: - gpt3-52k fp8: bit-identical with or without the blocks. - llama2-7b fp8: NaN at step 2 either way; the blanket mask was previously hiding this as silent zeroed grads. Real fp8 + FSDP convergence issues should be fixed upstream in AQT / the fp8 backward pass, not masked in the trainer. Surfacing the NaN lets us actually investigate it. --- src/maxtext/trainers/pre_train/train.py | 31 ------------------------- 1 file changed, 31 deletions(-) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 720354fe4d..2a365ec97d 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -356,31 +356,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat else: grads = raw_grads - # fp8 fix: sanitize NaN OWG (overwrite-with-gradient) stats before apply_gradients. - # Under FSDP, the fp8 output gradient amax can be NaN at step 0, which propagates into - # amax_history and corrupts future steps. Replace NaN OWG entries with the current state - # values (skip the amax update for that step) instead of letting NaN flow through. - # Also restore OWG values after apply_gradients to bypass optimizer corruption - # (Adam should not update fp8 scale/amax_history). - fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None) - if fp8_stats is not None: - if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params: - current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT] - fp8_stats = jax.tree_util.tree_map( - lambda new, cur: jnp.where(jnp.isnan(new), cur, new), - fp8_stats, - current_fp8, - ) - else: - fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats) - grads = dict(grads) - grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats - # Zero out any remaining NaN in float gradients to prevent param corruption - grads = jax.tree_util.tree_map( - lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x, - grads, - ) - if config.optimizer_memory_host_offload: state = state.replace( opt_state=jax.device_put( @@ -420,12 +395,6 @@ def move(path, value): else: new_state = state.apply_gradients(grads=grads) - # fp8 fix: restore sanitized OWG values, bypassing any optimizer update to fp8 stats. - if fp8_stats is not None: - new_params = dict(new_state.params) - new_params[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats - new_state = new_state.replace(params=new_params) - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")