Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme fails with MultiHeadAttention layer #2448

Open
mashu opened this issue May 16, 2024 · 13 comments
Open

Enzyme fails with MultiHeadAttention layer #2448

mashu opened this issue May 16, 2024 · 13 comments
Labels

Comments

@mashu
Copy link

mashu commented May 16, 2024

I am attaching MWE where Zygote (default of Flux) works fine but Enzyme fails compilation (@wsmoses )

using Enzyme
using Flux
using CUDA

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Active(x))
        else
            push!(args, Duplicated(x, make_zero(x)))
        end
    end
    ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return g
end

x = CUDA.rand(Float32, 64, 100, 512)
mha = MultiHeadAttention(64 => 64 => 64) |> gpu

# Failing
Δ = gradient_ez(mha) do m
    sum(first(m(x, x, x)))
end

# Working
Δ = Flux.gradient(mha) do m
    sum(first(m(x, x, x)))
end
@mashu
Copy link
Author

mashu commented May 16, 2024

The same code run on CUDA

@btime CUDA.@sync Flux.gradient(mha) do m
           sum(first(m(x, x, x)))
       end
 11.983 ms (2583 allocations: 137.55 KiB)

whereas

@btime CUDA.@sync gradient_ez(mha) do m
           sum(first(m($x, $x, $x)))
       end
       
 ....
   [2] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/srACB/src/api.jl:190
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:3141
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5074
  [5] codegen
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:4481 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771
  [7] _thunk
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5809 [inlined]
  [9] (::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{4, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5875
 [10] JuliaContext(f::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{}, Int64, Bool, Bool, UInt64, DataType}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
 [12] #s2027#559
    @ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5827 [inlined]

So it can be reproduced with following packages and Julia 1.10.3

  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [7da242da] Enzyme v0.12.6
  [587475ba] Flux v0.14.15
  [e88e6eb3] Zygote v0.6.70
  [02a925ec] cuDNN v1.3.1

Thanks!

@wsmoses
Copy link
Contributor

wsmoses commented May 16, 2024

@mashu can you post the whole log?

@mashu
Copy link
Author

mashu commented May 16, 2024

I was convinced I attached it earlier, but apparently I didn't so here it is
MWA.log.gz
The following code was run as

julia --project=@. src/MWA.jl 2> MWA.log

using Enzyme
using Flux
using CUDA

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)

function gradient_ez(f, x...)
    args = []
    for x in x
        if x isa Number
            push!(args, Active(x))
        else
            push!(args, Duplicated(x, make_zero(x)))
        end
    end
    ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
    g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
    return g
end

x = CUDA.rand(Float32, 64, 100, 512)
mha = MultiHeadAttention(64 => 64 => 64) |> gpu

Flux.gradient(mha) do m
    sum(first(m(x, x, x)))
end

Δ = gradient_ez(mha) do m
    sum(first(m(x, x, x)))
end

@wsmoses
Copy link
Contributor

wsmoses commented May 16, 2024

Also does this wokr on CPU?

@mashu
Copy link
Author

mashu commented May 16, 2024

@wsmoses Initially I got compilation error with CPU version, but after moving to separate project (MWE) it only fails for GPU. Having said that, I still can't figure out why it fails in my main project, as packages are up to date and basically the same version. But this GPU failure is at least reproducible.

@wsmoses
Copy link
Contributor

wsmoses commented May 16, 2024

GPU is in progress so the report is super helpful but also presently expected.

Maybe check the current versions of packages in your project and see if it's forcing an older Enzyme?

@mashu
Copy link
Author

mashu commented May 16, 2024

It's the same version of ⌅ [7cc45869] Enzyme_jll v0.0.109+0 in both working and non-working version. Must be some indirect dependency that I can't figure out.
As for the GPU part, my impression is that CPU paths are sometimes slow in Flux and not well optimized, probably because most people use GPU paths for any work.

@wsmoses
Copy link
Contributor

wsmoses commented May 16, 2024

Ah but what's your Enzyme version (rather than Enzyme_jll which is a dependncy)

@mashu
Copy link
Author

mashu commented May 16, 2024

Looks the same v0.12.6

Working MWE ]st

  [6e4b80f9] BenchmarkTools v1.5.0
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [7da242da] Enzyme v0.12.6
  [587475ba] Flux v0.14.15
  [e88e6eb3] Zygote v0.6.70
  [02a925ec] cuDNN v1.3.1

Broken one ]st

  [6e4b80f9] BenchmarkTools v1.5.0
  [336ed68f] CSV v0.10.14
  [052768ef] CUDA v5.3.4
  [082447d4] ChainRules v1.66.0
  [d360d2e6] ChainRulesCore v1.23.0
  [a93c6f00] DataFrames v1.6.1
  [864edb3b] DataStructures v0.18.20
  [31c24e10] Distributions v0.25.108
  [7da242da] Enzyme v0.12.6
  [c2308a5c] FASTX v2.1.5
  [587475ba] Flux v0.14.15
  [41a02a25] Folds v0.2.10
  [033835bb] JLD2 v0.4.47
  [682c06a0] JSON v0.21.4
  [e6f89c97] LoggingExtras v1.0.3
  [12afc1b8] NeuralAttentionlib v0.2.13
  [0b1bfda6] OneHotArrays v0.2.5
  [3bd65402] Optimisers v0.3.3
  [d7d3b36b] ParameterSchedulers v0.4.1
  [92933f4c] ProgressMeter v1.10.0
  [2913bbd2] StatsBase v0.34.3
  [b8865327] UnicodePlots v3.6.4
  [02a925ec] cuDNN v1.3.1
  [56ddb016] Logging

@mashu
Copy link
Author

mashu commented May 16, 2024

Also including log with error that happens CPU side on the broken project, not sure if that helps though.
CPU.log

@wsmoses
Copy link
Contributor

wsmoses commented May 16, 2024

From the log I think the simplest answer here is we should just add the attention custom derivative in nnlib. I assume there's one already for CR?

If so you can try our import CR rule into enzyme macro as a test to see if anything else fails, while in the interim we can look at making a fast rule for (CR rules will be slower and come with caveats)

@mashu
Copy link
Author

mashu commented May 16, 2024

@wsmoses Long story short, I wanted to use Enzyme, because I often lack skills to write rrule and there is none for MultiHeadAttention in NNlib. Longer answer is that I am using currently NeuralAttentionlib.jl which is part of Transformers.jl which has customization to layer I need and rrule that makes that variant of MHA couple of times faster on GPU. My hope was that maybe Enzyme does better job than Zygote when it comes to performance of the code it produces (when no rrule is provided).

@wsmoses
Copy link
Contributor

wsmoses commented May 16, 2024

If you can wait a short bit (it's currently unregistered and there's a bunch of small things we should add), Reactant.jl is an execution engine (eg does tons of fancy optimizations/kernel fusion), is both Enzyme and GPU compatible out of the box, and might be what you're looking for.

In the interim I'll push on the GPU support for native Enztme here too, but just throwing that out there if helpful.

https://github.com/EnzymeAD/Reactant.jl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants