-
-
Notifications
You must be signed in to change notification settings - Fork 25
Description
I've been trying to do training in BFloat16, and I just figured out the primary cause for NaNs in the backward pass, namely that momentum parameters in the optimisers state need to be stored with higher precision than BFloat16 offers.
For example, the default beta values for Adam are (0.9, 0.999). The second value can't be stored in BFloat16, as the nearest values are 0.996 and 1.0, so it gets rounded to 1.0.
I think the simplest solution is to store momentum in Float32, but convert to BFloat16 as part of the broadcasting:
function apply!(o::Adam, state, x::AbstractArray{T}, dx) where T
η, β, ϵ = T(o.eta), o.beta, T(o.epsilon)
mt, vt, βt = state
@.. mt = T(β[1] * mt + (1 - β[1]) * dx)
@.. vt = T(β[2] * vt + (1 - β[2]) * abs2(dx))
dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η
return (mt, vt, βt .* β), dx′
endConversion to BFloat16 would also need to be avoided in the state init.
I will create a PR for this. It's not clear if this should also be done with Float16 (might be breaking). There might also be many optimisers that suffer from BFloat16 numerical issues.