Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 0 additions & 31 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,31 +356,6 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
else:
grads = raw_grads

# fp8 fix: sanitize NaN OWG (overwrite-with-gradient) stats before apply_gradients.
# Under FSDP, the fp8 output gradient amax can be NaN at step 0, which propagates into
# amax_history and corrupts future steps. Replace NaN OWG entries with the current state
# values (skip the amax update for that step) instead of letting NaN flow through.
# Also restore OWG values after apply_gradients to bypass optimizer corruption
# (Adam should not update fp8 scale/amax_history).
fp8_stats = dict(grads).get(maxtext_utils.OVERWRITE_WITH_GRADIENT, None)
if fp8_stats is not None:
if maxtext_utils.OVERWRITE_WITH_GRADIENT in state.params:
current_fp8 = state.params[maxtext_utils.OVERWRITE_WITH_GRADIENT]
fp8_stats = jax.tree_util.tree_map(
lambda new, cur: jnp.where(jnp.isnan(new), cur, new),
fp8_stats,
current_fp8,
)
else:
fp8_stats = jax.tree_util.tree_map(lambda x: jnp.nan_to_num(x, nan=0.0), fp8_stats)
grads = dict(grads)
grads[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
# Zero out any remaining NaN in float gradients to prevent param corruption
grads = jax.tree_util.tree_map(
lambda x: jnp.nan_to_num(x, nan=0.0) if jnp.issubdtype(x.dtype, jnp.floating) else x,
grads,
)

if config.optimizer_memory_host_offload:
state = state.replace(
opt_state=jax.device_put(
Expand Down Expand Up @@ -420,12 +395,6 @@ def move(path, value):
else:
new_state = state.apply_gradients(grads=grads)

# fp8 fix: restore sanitized OWG values, bypassing any optimizer update to fp8 stats.
if fp8_stats is not None:
new_params = dict(new_state.params)
new_params[maxtext_utils.OVERWRITE_WITH_GRADIENT] = fp8_stats
new_state = new_state.replace(params=new_params)

# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")
Expand Down
Loading