Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help writing complex rules #744

Closed
sethaxen opened this issue Apr 19, 2023 · 3 comments
Closed

Help writing complex rules #744

sethaxen opened this issue Apr 19, 2023 · 3 comments

Comments

@sethaxen
Copy link
Collaborator

While writing #739, I ran into some difficulties defining rules for functions with complex inputs and outputs. Here's a simple example:

foo(x::Complex) = 2x

function EnzymeRules.augmented_primal(
    config::EnzymeRules.ConfigWidth{1},
    func::Const{typeof(foo)},
    ::Type{<:Duplicated},
    x::Duplicated{<:Complex},
)
    println("In custom augmented primal rule.")
    # Compute primal
    r = func.val(x.val)
    if EnzymeRules.needs_primal(config)
        primal = r
    else
        primal = nothing
    end
    if EnzymeRules.needs_shadow(config)
        shadow = zero(r)
    else
        shadow = nothing
    end
    tape = nothing
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(
    config::EnzymeRules.ConfigWidth{1},
    func::Const{typeof(foo)},
    dret::Duplicated{<:Complex},
    tape,
    y::Duplicated{<:Complex},
)
    println("In custom reverse rule.")
    return ()
end

When I execute this rule, I get the following stacktrace:

julia> autodiff(Reverse, foo, Active, Active(1.0+3im))
ERROR: AssertionError: value_type(normalV) == value_type(orig)
Stacktrace:
  [1] enzyme_custom_common_rev(forward::Bool, B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tape::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:4038
  [2] enzyme_custom_augfwd(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tapeR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:4104
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/zSGqM/src/api.jl:124
  [4] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:6698
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.ThreadSafeContext, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:7939
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, ctx::Nothing, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8452
  [7] _thunk
    @ ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8449 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8487 [inlined]
  [9] #s286#173
    @ ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8545 [inlined]
 [10] var"#s286#173"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ::Any, ::Any, ::Any, ::Any, tt::Any, ::Any, ::Any, ::Any, ::Any, ::Any)
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] thunk
    @ ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8504 [inlined]
 [13] autodiff(#unused#::EnzymeCore.ReverseMode{false}, f::Const{typeof(foo)}, #unused#::Type{Active}, args::Active{ComplexF64})
    @ Enzyme ~/.julia/packages/Enzyme/zSGqM/src/Enzyme.jl:199
 [14] autodiff(::EnzymeCore.ReverseMode{false}, ::typeof(foo), ::Type, ::Active{ComplexF64})
    @ Enzyme ~/.julia/packages/Enzyme/zSGqM/src/Enzyme.jl:214
 [15] top-level scope
    @ REPL[18]:1

I was surprised that Enzyme seems to insist on using Duplicated annotations for complex scalars. If I specify Active for the inputs as done above, they are replaced with a Duplicated. Second, if I specify shadow=nothing, Enzyme complains that it expects the shadow to be a ComplexF64, but if I make it a ComplexF64, then I see this error. How can I repair the above rules to work?

@wsmoses
Copy link
Member

wsmoses commented Apr 22, 2023

This will be fixed by #754 but will require a jll bump.

@wsmoses
Copy link
Member

wsmoses commented Apr 22, 2023

Fixed by #754

@wsmoses wsmoses closed this as completed Apr 22, 2023
@sethaxen
Copy link
Collaborator Author

Thanks! For completeness, this now works:

julia> using Enzyme

julia> foo(x::Complex) = 2x;

julia> function EnzymeRules.augmented_primal(
           config::EnzymeRules.ConfigWidth{1},
           func::Const{typeof(foo)},
           ::Type{<:Active},
           x::Active{<:Complex},
       )
           println("In custom augmented primal rule.")
           # Compute primal
           r = func.val(x.val)
           if EnzymeRules.needs_primal(config)
               primal = r
           else
               primal = nothing
           end
           if EnzymeRules.needs_shadow(config)
               shadow = zero(r)
           else
               shadow = nothing
           end
           tape = nothing
           return EnzymeRules.AugmentedReturn(primal, shadow, tape)
       end

julia> function EnzymeRules.reverse(
           config::EnzymeRules.ConfigWidth{1},
           func::Const{typeof(foo)},
           dret::Active{<:Complex},
           tape,
           y::Active{<:Complex},
       )
           println("In custom reverse rule.")
           return (2*dret.val,)
       end

julia> autodiff(Reverse, foo, Active, Active(1.0+3im))
In custom augmented primal rule.
In custom reverse rule.
((2.0 + 0.0im,),)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants