Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble writing custom rule in Enzyme: AssertionError: !(overwritten[end]) #1242

Closed
GianlucaFuwa opened this issue Jan 18, 2024 · 10 comments · Fixed by #1258
Closed

Trouble writing custom rule in Enzyme: AssertionError: !(overwritten[end]) #1242

GianlucaFuwa opened this issue Jan 18, 2024 · 10 comments · Fixed by #1258

Comments

@GianlucaFuwa
Copy link

(Transferred issue from Julia discourse: https://discourse.julialang.org/t/trouble-writing-custom-rule-in-enzyme-assertionerror-overwritten-end/108939)

Hi,

I am currently trying to use Enzyme to get the derivative of a function that takes a multidimensional array of 3x3 matrices and outputs a scalar. The caveat is that these matrices are elements of the Lie-group of special unitary matrices SU(3) and therefore the derivatives w.r.t. each element in the array should be in the corresponding algebra (traceless anti-Hermitian matrices).

The function in question is:

using Accessors
using LinearAlgebra
using StaticArrays

smove(s::CartesianIndex{4}, μ, steps, lim) = @set s[μ] = mod1(s[μ] + steps, lim)
@inline function plaquette(U, μ, ν, site)
    Nμ = size(U)[1+μ]
    Nν = size(U)[1+ν]
    siteμ⁺ = smove(site, μ, 1, Nμ)
    siteν⁺ = smove(site, ν, 1, Nν)
    return remultr(U[μ,site], U[ν,siteμ⁺], U[μ,siteν⁺], U[ν,site])
end

function plaquette_sum(U::Array{SMatrix{3,3,ComplexF64,9}, 5})
    p = 0.0

    for site in CartesianIndices(size(U)[2:end])
        for μ in 1:3
            for ν in μ+1:4
                p += plaquette(U, μ, ν, site)
            end
        end
    end

    return 6.0 * (6*prod(size(U)[2:end]) - 1/3*p)
end

I've written a reverse-mode rule for remultr (which I will define later), directly following the example in the docs, that works well enough (and without error) for this example. I also overloaded the gradient! function for plaquette_sum, that results in an error:

@inline function Base.circshift(shift::Integer, args::Vararg{T, N}) where {N,T}
    j = mod1(shift, N)
    ntuple(k -> args[k-j+ifelse(k>j,0,N)], Val(N))
end

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)},
    ::Type{<:Active}, args::Vararg{Active,N}) where {N}

    argvals = ntuple(i -> args[i].val, Val(N))
    if needs_primal(config)
        primal = func.val(argvals...)
    else
        primal = nothing
    end
    if overwritten(config)[3]
        tape = copy(argvals)
    else
        tape = nothing
    end

    return AugmentedReturn(primal, nothing, tape)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)},
    dret::Active, tape, args::Vararg{Active,N}) where {N}

    argvals = ntuple(i -> args[i].val, Val(N))
    dargs = ntuple(Val(N)) do i
        0.5traceless_antihermitian(*(circshift(i-1, argvals...)...))
    end
    return dargs
end

function Enzyme.gradient(::ReverseMode, f::typeof(remultr), args::Vararg{T,N}) where {N,T}
    annots = ntuple(i -> Active(args[i]), Val(N))
    der = autodiff(Reverse, f, Active, annots...)
    return der
end

function Enzyme.gradient!(::ReverseMode, dU::Array{SMatrix{3,3,ComplexF64,9},5},
    f::typeof(plaquette_sum), U::Array{SMatrix{3,3,ComplexF64,9},5})

    autodiff(Reverse, f, Active, DuplicatedNoNeed(U, dU))
    return nothing
end

# Some necessary definitions 
const zero3 = @SArray [
    0.0+0.0im 0.0+0.0im 0.0+0.0im
    0.0+0.0im 0.0+0.0im 0.0+0.0im
    0.0+0.0im 0.0+0.0im 0.0+0.0im
]

const eye3 = @SArray [
    1.0+0.0im 0.0+0.0im 0.0+0.0im
    0.0+0.0im 1.0+0.0im 0.0+0.0im
    0.0+0.0im 0.0+0.0im 1.0+0.0im
]

U = Array{SMatrix{3, 3, ComplexF64, 9}, 5}(undef, 4, 4, 4, 4, 4); fill!(U, eye3); # Should be an array of special unitary matrices, but identities will do for now
dU = similar(U); fill!(dU, zero3);
Enzyme.gradient!(Reverse, dU, plaquette_sum, U) # -> "Assertion error: !(overwritten[end])"

Since writing the discourse post, I found that whether I get the error on gradient!(...) depends on how I define the remultr function.
When I use

remultr(args...) = real(tr(*(args...)))

I get no error, but if I use:

@generated function remultr(args::Vararg{T, N}) where {N,T}
    quote
        $(Expr(:meta, :inline))
        tmp = *(args...)
        real(tr(tmp))
    end
end

the error appears. The reason the @generated second definition exists, is because I wrote a custom matrix-multiplication routing using LoopVectorization.jl 's @turbo in a package I want to put this whole thing into. Ultimately, I should be able use a definition of remultr similar to the first and be fine, but I still wanted to post this issue here in case it could be useful. I hope this is enough to recreate the error.

P.S.: Here is the complete error message

ERROR: LoadError: AssertionError: !(overwritten[end])
Stacktrace:
  [1] enzyme_custom_setup_args(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, mi::Core.MethodInstance, RT::Type, reverse::Bool, isKWCall::Bool)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:3865
  [2] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tape::Nothing)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:4223
  [3] enzyme_custom_augfwd
    @ C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:4582 [inlined]
  [4] (::Enzyme.Compiler.var"#192#193")(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tapeR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}})
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:6449
  [5] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, 
dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\api.jl:128
  [6] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:7451
  [7] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, ctx::LLVM.ThreadSafeContext, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:8984
  [8] codegen
    @ C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:8592 [inlined]
  [9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, ctx::Nothing, postopt::Bool)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:9518
 [10] _thunk
    @ C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:9515 [inlined]
 [11] cached_compilation
    @ C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:9553 [inlined]
 [12] #s291#456
    @ C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\compiler.jl:9615 [inlined]
 [13] var"#s291#456"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler .\none:0
 [14] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:602
 [15] autodiff
    @ C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\Enzyme.jl:207 
[inlined]
 [16] autodiff
    @ C:\Users\gianl\.julia\packages\Enzyme\RiUxJ\src\Enzyme.jl:222 
[inlined]
 [17] gradient!(#unused#::ReverseMode{false, FFIABI}, dU::Array{SMatrix{3, 3, ComplexF64, 9}, 5}, f::typeof(plaquette_sum), U::Array{SMatrix{3, 3, ComplexF64, 9}, 5})
    @ Main C:\Users\gianl\.julia\dev\MetaQCD\test\mwe_enzyme.jl:110 
 [18] top-level scope
    @ C:\Users\gianl\.julia\dev\MetaQCD\test\mwe_enzyme.jl:116      
 [19] include(fname::String)
    @ Base.MainInclude .\client.jl:478
 [20] top-level scope
    @ REPL[15]:1
@wsmoses
Copy link
Member

wsmoses commented Jan 18, 2024

What version of Enzyme are you on?, can you should the result of ] st and juliaversion()

@GianlucaFuwa
Copy link
Author

What version of Enzyme are you on?, can you should the result of ] st and juliaversion()

(@v1.9) pkg> status Enzyme
Status `C:\Users\gianl\.julia\environments\v1.9\Project.toml`
⌃ [7da242da] Enzyme v0.11.4
Info Packages marked with ⌃ have new versions available and may be upgradable.
julia> versioninfo()
Julia Version 1.9.4
Commit 8e5136fa29 (2023-11-14 08:46 UTC) 
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)       
  CPU: 4 × Intel(R) Core(TM) i5-7600K CPU @ 3.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 1 on 4 virtual cores

@wsmoses
Copy link
Member

wsmoses commented Jan 18, 2024

You're on a very old version of Enzyme, can you update?

@GianlucaFuwa
Copy link
Author

GianlucaFuwa commented Jan 18, 2024

Yes, I will update and see if I still get an error.

Edit:
Excuse the delay. My julia environment has some compat issues that are holding Enzyme back from updating past v0.11.4. I will try to create a clean environment and report back as soon as possible.

@GianlucaFuwa
Copy link
Author

I've now tested on version 0.11.12 and the error persists:

(@v1.9) pkg> status Enzyme
Status `C:\Users\gianl\.julia\environments\v1.9\Project.toml`
  [7da242da] Enzyme v0.11.12
ERROR: LoadError: AssertionError: !(overwritten[end])
Stacktrace:
  [1] enzyme_custom_setup_args(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, mi::Core.MethodInstance, RT::Type, reverse::Bool, isKWCall::Bool)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\rules\customrules.jl:131
  [2] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tape::Nothing)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\rules\customrules.jl:490
  [3] enzyme_custom_augfwd
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\rules\customrules.jl:886 [inlined]
  [4] (::Enzyme.Compiler.var"#212#213")(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tapeR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}})
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\rules\llvmrules.jl:1139
  [5] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\api.jl:141
  [6] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, 
modifiedBetween::Tuple{Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type, loweredArgs::Set{Int64}, boxedArgs::Set{Int64})
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:3124 
  [7] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:4756 
  [8] codegen
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:4339 [inlined]       
  [9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5351 
 [10] _thunk
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5351 [inlined]       
 [11] cached_compilation
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5385 [inlined]       
 [12] (::Enzyme.Compiler.var"#506#507"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)  
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5451 
 [13] JuliaContext(f::Enzyme.Compiler.var"#506#507"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})        
    @ GPUCompiler C:\Users\gianl\.julia\packages\GPUCompiler\YO8Uj\src\driver.jl:47    
 [14] #s1056#505
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5403 [inlined]       
 [15] var"#s1056#505"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, 
#unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler .\none:0
 [16] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:602
 [17] autodiff
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\Enzyme.jl:209 [inlined]
 [18] autodiff
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\Enzyme.jl:224 [inlined]
 [19] gradient!(#unused#::ReverseMode{false, FFIABI}, dU::Array{SMatrix{3, 3, ComplexF64, 9}, 5}, f::typeof(plaquette_sum), U::Array{SMatrix{3, 3, ComplexF64, 9}, 5})        
    @ Main C:\Users\gianl\.julia\dev\MetaQCD\clean_env\mwe_enzyme.jl:110
 [20] top-level scope
    @ C:\Users\gianl\.julia\dev\MetaQCD\clean_env\mwe_enzyme.jl:116
 [21] include(fname::String)
    @ Base.MainInclude .\client.jl:478
 [22] top-level scope
    @ REPL[1]:1

@wsmoses
Copy link
Member

wsmoses commented Jan 27, 2024

So I believe that this error can likely be removed (and is in fact an extraneous assertion). However, I'm still investigating.

It would be really helpful to have a self-contained example (e.g. no external packages) if you're able to make a reproducer like that.

@GianlucaFuwa
Copy link
Author

The Accessors and LinearAlgebra dependencies were easy to get rid of but getting rid of StaticArrays would definitely be more of a headache, so I didn't do it for now. If it turns out to be really necessary, I would be open to doing so.

using Enzyme
using StaticArrays
import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

const zero3 = @SArray [
    0.0+0.0im 0.0+0.0im 0.0+0.0im
    0.0+0.0im 0.0+0.0im 0.0+0.0im
    0.0+0.0im 0.0+0.0im 0.0+0.0im
]

const eye3 = @SArray [
    1.0+0.0im 0.0+0.0im 0.0+0.0im
    0.0+0.0im 1.0+0.0im 0.0+0.0im
    0.0+0.0im 0.0+0.0im 1.0+0.0im
]

const unitvec = ((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1))

# U is filled with random special unitary matrices, but would make MWE too long
U = Array{SMatrix{3, 3, ComplexF64, 9}, 5}(undef, 4, 4, 4, 4, 4); fill!(U, eye3);
dU = similar(U); fill!(dU, zero3);

traceless_antihermitian(M::SMatrix{3,3,ComplexF64,9}) = 0.5*(M - M') - 1/6*tr(M - M')*eye3
tr(M::SMatrix{3,3,ComplexF64,9}) = M[1, 1] + M[2, 2] + M[3, 3]

@inline function smove(s::CartesianIndex{4}, μ, steps, lim)
    newI = mod1.(s.I .+ steps .* unitvec[μ], lim)
    return CartesianIndex(newI)
end

@generated function remultr(args::Vararg{T, N}) where {N,T}
    quote
        $(Expr(:meta, :inline))
        tmp = *(args...)
        real(tr(tmp))
    end
end

@inline function plaquette(U, μ, ν, site)
    Nμ = size(U)[1+μ]
    Nν = size(U)[1+ν]
    siteμ⁺ = smove(site, μ, 1, Nμ)
    siteν⁺ = smove(site, ν, 1, Nν)
    return remultr(U[μ,site], U[ν,siteμ⁺], U[μ,siteν⁺], U[ν,site])
end

function plaquette_sum(U::Array{SMatrix{3,3,ComplexF64,9}, 5})
    p = 0.0

    for site in CartesianIndices(size(U)[2:end])
        for μ in 1:3
            for ν in μ+1:4
                p += plaquette(U, μ, ν, site)
            end
        end
    end

    return 6.0 * (6*prod(size(U)[2:end]) - 1/3*p)
end

### RULES FOR REMULTR ###
@inline function Base.circshift(shift::Integer, args::Vararg{T, N}) where {N,T}
    j = mod1(shift, N)
    ntuple(k -> args[k-j+ifelse(k>j,0,N)], Val(N))
end

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)},
    ::Type{<:Active}, args::Vararg{Active,N}) where {N}

    argvals = ntuple(i -> args[i].val, Val(N))
    if needs_primal(config)
        primal = func.val(argvals...)
    else
        primal = nothing
    end
    if overwritten(config)[3]
        tape = copy(argvals)
    else
        tape = nothing
    end

    return AugmentedReturn(primal, nothing, tape)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)},
    dret::Active, tape, args::Vararg{Active,N}) where {N}

    argvals = ntuple(i -> args[i].val, Val(N))
    dargs = ntuple(Val(N)) do i
        0.5traceless_antihermitian(*(circshift(i-1, argvals...)...))
    end
    return dargs
end

function Enzyme.gradient(::ReverseMode, f::typeof(remultr), args::Vararg{T,N}) where {N,T}
    annots = ntuple(i -> Active(args[i]), Val(N))
    der = autodiff(Reverse, f, Active, annots...)
    return der
end

### RULES FOR PLAQUETTE_SUM ###
function Enzyme.gradient!(::ReverseMode, dU::Array{SMatrix{3,3,ComplexF64,9},5},
    f::typeof(plaquette_sum), U::Array{SMatrix{3,3,ComplexF64,9},5})

    autodiff(Reverse, f, Active, DuplicatedNoNeed(U, dU))
    return nothing
end

# matrices = [eye3 for _ in 1:4] # should be any random special unitary matrices
# dm = Enzyme.gradient(Reverse, remultr, matrices...) # works as wanted
Enzyme.gradient!(Reverse, dU, plaquette_sum, U) # AssertionError: !(overwritten[end])

I still get the error, but with slightly different stacktrace:

ERROR: AssertionError: !(overwritten[end])
Stacktrace:
  [1] enzyme_custom_setup_args(B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, mi::Core.MethodInstance, RT::Type, reverse::Bool, isKWCall::Bool)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\rules\customrules.jl:131
  [2] enzyme_custom_common_rev(forward::Bool, B::LLVM.IRBuilder, orig::LLVM.CallInst, gutils::Enzyme.Compiler.GradientUtils, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tape::Nothing)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\rules\customrules.jl:490
  [3] enzyme_custom_augfwd
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\rules\customrules.jl:886 [inlined]
  [4] (::Enzyme.Compiler.var"#212#213")(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tapeR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}})
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\rules\llvmrules.jl:1139
  [5] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\api.jl:141      
  [6] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type, loweredArgs::Set{Int64}, boxedArgs::Set{Int64})
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:3124
  [7] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:4756
  [8] codegen
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:4339 [inlined] 
  [9] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5351
 [10] _thunk
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5351 [inlined] 
 [11] cached_compilation
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5385 [inlined] 
 [12] (::Enzyme.Compiler.var"#506#507"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5451
 [13] JuliaContext(f::Enzyme.Compiler.var"#506#507"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler C:\Users\gianl\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:47
 [14] #s1056#505
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\compiler.jl:5403 [inlined] 
 [15] var"#s1056#505"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, 
width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler .\none:0
 [16] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:602
 [17] autodiff
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\Enzyme.jl:209 [inlined]    
 [18] autodiff
    @ C:\Users\gianl\.julia\packages\Enzyme\Dd2LU\src\Enzyme.jl:224 [inlined]    
 [19] gradient!(#unused#::ReverseMode{false, FFIABI}, dU::Array{SMatrix{3, 3, ComplexF64, 9}, 5}, f::typeof(plaquette_sum), U::Array{SMatrix{3, 3, ComplexF64, 9}, 5})
    @ Main c:\Users\gianl\.julia\dev\MetaQCD\clean_env\mwe_enzyme.jl:106
 [20] top-level scope
    @ c:\Users\gianl\.julia\dev\MetaQCD\clean_env\mwe_enzyme.jl:112

@wsmoses
Copy link
Member

wsmoses commented Jan 28, 2024

using Enzyme
Enzyme.API.printall!(true)

import .EnzymeRules: forward, reverse, augmented_primal
using .EnzymeRules

function remultr(arg)
    real(arg)
end

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(remultr)},
    ::Type{<:Active}, args::Vararg{Active,N}) where {N}
    return AugmentedReturn(func.val(args[1].val), nothing, nothing)
end

function EnzymeRules.reverse(config::ConfigWidth{1}, func::Const{typeof(remultr)},
    dret::Active, tape, args::Vararg{Active,N}) where {N}

    dargs = ntuple(Val(N)) do i
        0.5
    end
    return dargs
end

# U is filled with random special unitary matrices, but would make MWE too long
U = Array{Complex{Float64}, 2}(undef, 4, 4)
dU = similar(U)

function plaquette_sum(U)
    p = 0.0

    for site in CartesianIndices(size(U)[2:end])
        p += remultr(@inbounds U[1,site]) # , U[1,site])
    end

    return p
end

@show collect(CartesianIndices(size(U)[2:end]))
autodiff(Reverse, plaquette_sum, Active, DuplicatedNoNeed(U, dU))

@wsmoses
Copy link
Member

wsmoses commented Jan 29, 2024

Okay, I've deduced this issue is due to some lacking lifetime information.

The PR I linked above is a workaround for 1.10. The full actual solution is for JuliaLang/julia#53095 to be merged and backported to whatever LLVM version you're using.

@GianlucaFuwa
Copy link
Author

Wow, I'm glad you were able to figure out the root cause despite my contrived example!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants