Skip to content

Numerical issues with BFloat16 #215

@AntonOresten

Description

@AntonOresten

I've been trying to do training in BFloat16, and I just figured out the primary cause for NaNs in the backward pass, namely that momentum parameters in the optimisers state need to be stored with higher precision than BFloat16 offers.

For example, the default beta values for Adam are (0.9, 0.999). The second value can't be stored in BFloat16, as the nearest values are 0.996 and 1.0, so it gets rounded to 1.0.

I think the simplest solution is to store momentum in Float32, but convert to BFloat16 as part of the broadcasting:

function apply!(o::Adam, state, x::AbstractArray{T}, dx) where T
    η, β, ϵ = T(o.eta), o.beta, T(o.epsilon)
    mt, vt, βt = state

    @.. mt = T(β[1] * mt + (1 - β[1]) * dx)
    @.. vt = T(β[2] * vt + (1 - β[2]) * abs2(dx))
    dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η

    return (mt, vt, βt .* β), dx′
end

Conversion to BFloat16 would also need to be avoided in the state init.

I will create a PR for this. It's not clear if this should also be done with Float16 (might be breaking). There might also be many optimisers that suffer from BFloat16 numerical issues.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions