Skip to content

Commit 3d21209

Browse files
committed
Remove fp8 NaN sanitization from train_step
Drop both the OWG bookkeeping correction and the blanket jnp.nan_to_num over float gradients. Both were introduced as drive-by additions in an NNX migration commit with no fp8 test, repro, or justification in the commit message. A/B on V6e-8, FSDP=8, 10 steps: - gpt3-52k fp8: bit-identical with or without the blocks. - llama2-7b fp8: NaN at step 2 either way; the blanket mask was previously hiding this as silent zeroed grads. Real fp8 + FSDP convergence issues should be fixed upstream in AQT / the fp8 backward pass, not masked in the trainer. Surfacing the NaN lets us actually investigate it.
1 parent 172d0f1 commit 3d21209

1 file changed

Lines changed: 0 additions & 31 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -356,31 +356,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
356356
else:
357357
grads = raw_grads
358358

359-
# fp8 fix: sanitize NaN OWG (overwrite-with-gradient) stats before apply_gradients.
360-
# Under FSDP, the fp8 output gradient amax can be NaN at step 0, which propagates into
361-
# amax_history and corrupts future steps. Replace NaN OWG entries with the current state
362-
# values (skip the amax update for that step) instead of letting NaN flow through.
363-
# Also restore OWG values after apply_gradients to bypass optimizer corruption
364-
# (Adam should not update fp8 scale/amax_history).
365-
fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None)
366-
if fp8_stats is not None:
367-
if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params:
368-
current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT]
369-
fp8_stats = jax.tree_util.tree_map(
370-
lambda new, cur: jnp.where(jnp.isnan(new), cur, new),
371-
fp8_stats,
372-
current_fp8,
373-
)
374-
else:
375-
fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats)
376-
grads = dict(grads)
377-
grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
378-
# Zero out any remaining NaN in float gradients to prevent param corruption
379-
grads = jax.tree_util.tree_map(
380-
lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x,
381-
grads,
382-
)
383-
384359
if config.optimizer_memory_host_offload:
385360
state = state.replace(
386361
opt_state=jax.device_put(
@@ -420,12 +395,6 @@ def move(path, value):
420395
else:
421396
new_state = state.apply_gradients(grads=grads)
422397

423-
# fp8 fix: restore sanitized OWG values, bypassing any optimizer update to fp8 stats.
424-
if fp8_stats is not None:
425-
new_params = dict(new_state.params)
426-
new_params[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
427-
new_state = new_state.replace(params=new_params)
428-
429398
# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
430399
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
431400
target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")

0 commit comments

Comments
 (0)