You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Drop both the OWG bookkeeping correction and the blanket
jnp.nan_to_num over float gradients. Both were introduced as
drive-by additions in an NNX migration commit with no fp8 test,
repro, or justification in the commit message.
A/B on V6e-8, FSDP=8, 10 steps:
- gpt3-52k fp8: bit-identical with or without the blocks.
- llama2-7b fp8: NaN at step 2 either way; the blanket mask was
previously hiding this as silent zeroed grads.
Real fp8 + FSDP convergence issues should be fixed upstream in
AQT / the fp8 backward pass, not masked in the trainer. Surfacing
the NaN lets us actually investigate it.
0 commit comments