Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN with custom mask for MultiHeadAttention #572

Open
mashu opened this issue Apr 3, 2024 · 3 comments
Open

NaN with custom mask for MultiHeadAttention #572

mashu opened this issue Apr 3, 2024 · 3 comments

Comments

@mashu
Copy link
Contributor

mashu commented Apr 3, 2024

Hi,

The background is that in Encoder-Decoder model used for translation from "Attention Is All You Need" I desired to mask-out the padding in sentence passed to Encoder's MultiHeadAttention, but I notice that the computed for the mask -neginf based on the logits eltype might cause some issues and lead to NaN.

Minimal example is provided here https://github.com/mashu/NaNTracker.jl
The result is

caused by: DomainError with Float32[-7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; … ; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7; -7.7486f-7 4.76837f-7 … 3.57628f-7 3.57628f-7;;; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; … ; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7; 4.76837f-7 -6.55651f-7 … -4.76837f-7 3.57628f-7;;; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; … ; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7; -7.7486f-7 1.78814f-7 … 5.36442f-7 -1.19209f-7;;; … ;;; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; … ; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7; -8.9407f-7 5.96046f-7 … 0.0 3.57628f-7;;; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; … ; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7; 4.17233f-7 4.17233f-7 … -1.78814f-7 4.17233f-7;;; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; -2.38419f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; … ; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7; -2.98023f-7 6.55651f-7 … 3.57628f-7 -4.17233f-7]:
NaN on gradient input for layer: KeyPath(:mha, :out_proj)
Stacktrace:
  [1] (::Main.NaNTracker.var"#pb_check#2"{DebugWrapper{}, Zygote.var"#ad_pullback#58"{}})(Δ::Array{Float32, 3})
    @ Main.NaNTracker ~/NaNTracker.jl/src/NaNTracker.jl:26
  [2] ZBack
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
  [3] #_#334
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:129 [inlined]
  [4] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [5] MultiHeadAttention
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:120 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [7] MultiHeadAttention
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:120 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [9] #_#332
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:115 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [11] MultiHeadAttention
    @ ~/.julia/packages/Flux/Wz6D4/src/layers/attention.jl:115 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Array{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [13] #_#5
    @ ~/NaNTracker.jl/src/Example.jl:24 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{Float32, 3, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [15] EncoderOnly
    @ ~/NaNTracker.jl/src/Example.jl:22 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Any})(Δ::FillArrays.Fill{Float32, 3, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [17] #14
    @ ~/NaNTracker.jl/src/Example.jl:47 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [20] withgradient(f::Function, args::EncoderOnly)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:213
 [21] testit()
    @ Main ~/NaNTracker.jl/src/Example.jl:46
 [22] with_logging(::Function)
    @ Main.NaNTracker ~/NaNTracker.jl/src/NaNTracker.jl:36
 [23] top-level scope
    @ ~/NaNTracker.jl/src/Example.jl:50

I hope it's not an issue with my understanding of how mask should look like, but to be honest documentation in Flux could use a couple of examples for this particular use-case in addition to just make_causal_mask.

@mashu
Copy link
Contributor Author

mashu commented Apr 3, 2024

Just a thought, maybe these values need clamping in apply_attn_mask ?

@mcabbott
Copy link
Member

mcabbott commented Apr 4, 2024

I think the unusual thing here is that the mask (size 16×1×1×32) is constant for whole batches, and thus it's trying to set every value to -Inf before the softmax. That's not illegal according to help:

mask: Input array broadcastable to size (kv_len, q_len, nheads, batch_size). The mask
       is applied to the attention scores just before the softmax.

but it is unusual. Are you sure this is what you want, rather than e.g. running a smaller batch of randomly selected items?

Slightly shorter reproducer, and then a case with a more orthodox shape mask, I think:

julia> struct MaskAttention{A<:MultiHeadAttention, M<:AbstractArray}
           att::A
           mask::M
       end

julia> (m::MaskAttention)(x::AbstractArray) = first(m.att(x; m.mask))

julia> Flux.@layer :expand MaskAttention

julia> x = map(f->rand(Int32.(2:10), rand(8:16)), 1:32);

julia> x = reduce(hcat, rpad.(x, maximum(length.(x)), 1))
16×32 Matrix{Int32}:
  7   7   7   6  9   5   2   8   9   7   3   4     9   2   6   4   8   2   9   9  4   5   3   3
  2   5   3  10  7   7   9   7   9   6   2   3      7   7   7  10   8   7   7   8  7  10  10   7
  6   3   5   5  3   4   2   4   9   9   2   7      9   4   8   4   9   6   3   9  5   4   2   3

julia> mask = permutedims(repeat((x .== 1), outer = [1, 1, 1, 1]), (1, 4, 3, 2))
16×1×1×32 BitArray{4}:
[:, :, 1, 1] =
 0
 0
 0

julia> model = MaskAttention(MultiHeadAttention(16), mask)
MaskAttention(
  MultiHeadAttention(16; nheads=8),     # 1_024 parameters
  Bool[0; 0;  ; 1; 1;;;; 0; 0;  ; 1; 1;;;; 0; 0;  ; 0; 1;;;;  ;;;; 0; 0;  ; 1; 1;;;; 0; 0;  ; 1; 1;;;; 0; 0;  ; 0; 1],  # 512 parameters
)                   # Total: 5 arrays, 1_536 parameters, 4.547 KiB.

julia> xx = randn32(16, 16, 32);

julia> model(xx) |> summary
"16×16×32 Array{Float32, 3}"

julia> model(xx) |> sum
NaN32

julia> findall(isnan, model(xx))
1280-element Vector{CartesianIndex{3}}:
 CartesianIndex(1, 1, 11)
 CartesianIndex(2, 1, 11)
 CartesianIndex(3, 1, 11)
 CartesianIndex(4, 1, 11)
 CartesianIndex(5, 1, 11)

julia> loss, grads = Flux.withgradient(model) do m
         sum(abs2, m(xx))
       end
(val = NaN32, grad = ((att = (nheads = nothing, q_proj = (weight = Float32[NaN NaN  NaN NaN; NaN NaN  NaN NaN;  ; NaN NaN  NaN NaN; NaN NaN  NaN NaN], bias = nothing, σ = nothing), k_proj = (weight = Float32[NaN NaN  NaN NaN; NaN NaN  NaN NaN;  ; NaN NaN  NaN NaN; NaN NaN  NaN NaN], bias = nothing, σ = nothing), v_proj = (weight = Float32[NaN NaN  NaN NaN; NaN NaN  NaN NaN;  ; NaN NaN  NaN NaN; NaN NaN  NaN NaN], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[NaN NaN  NaN NaN; NaN NaN  NaN NaN;  ; NaN NaN  NaN NaN; NaN NaN  NaN NaN], bias = nothing, σ = nothing)), mask = nothing),))

julia> mask2 = rand(Bool, 16, 16);

julia> model2 = MaskAttention(MultiHeadAttention(16), mask2);

julia> model2(xx) |> sum
15.281105f0

julia> loss, grads = Flux.withgradient(model2) do m
         sum(abs2, m(xx))
       end
(val = 2290.0107f0, grad = ((att = (nheads = nothing, q_proj = (weight = Float32[-12.446103 -4.3954763 … -23.305235 3.0878963; 17.259354 -6.7767124 … 25.798717 6.8329597; … ; -22.414658 -17.097672 … 93.836235 -9.206725; 29.134241 13.785215 … -41.797527 -23.159046], bias = nothing, σ = nothing), k_proj = (weight = Float32[10.827733 18.880678 … -35.907875 23.034206; 8.75228 23.103594 … 13.421799 -14.956886; … ; -71.6004 -13.324369 … -35.224113 61.402447; 27.011333 124.142815 … -21.26666 -63.877186], bias = nothing, σ = nothing), v_proj = (weight = Float32[34.734734 -24.169163 … 102.391365 -69.705055; -1.5541999 37.55378 … -69.58273 39.782215; … ; 5.4440746 -176.59694 … 171.69466 161.58585; -139.37288 181.93517 … -213.38739 -58.174618], bias = nothing, σ = nothing), attn_drop = nothing, out_proj = (weight = Float32[-18.867039 0.40507406 … 92.7197 -36.512943; -30.138624 -14.419439 … -92.25858 -63.47702; … ; -89.09866 92.45964 … -212.48007 164.08275; 35.601555 -31.360823 … 128.91348 -104.323494], bias = nothing, σ = nothing)), mask = nothing),))

@mashu
Copy link
Contributor Author

mashu commented Apr 4, 2024

So I do want to vary this mask per batch, because sequences that are recruited by sampling into the next batch vary in length and padding varies. This minimal example is just one batch to show the issue. I tried clamping before softmax and NaNs are gone. The idea is to mask out from attention in encoder the padding tokens. If it's unusual that am I doing something wrong ? I have three different kinds of masks: padding mask in encoder (this mwe), casual mask in decoder and padding mask in loss function which affects only target sequence in decoder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants