Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme requests reverse rule on Const returns #1380

Closed
dominic-chang opened this issue Apr 1, 2024 · 6 comments
Closed

Enzyme requests reverse rule on Const returns #1380

dominic-chang opened this issue Apr 1, 2024 · 6 comments

Comments

@dominic-chang
Copy link
Contributor

I'm trying to define custom autodiff rules on some special functions. Defining any custom autodiff rules however causes Enzyme to request a rule for Const returns. Here's a MWE

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active) 
    println("In custom augmented primal rule.")
    if needs_primal(config)
        primal = func.val(x.val)
    else
        primal = nothing
    end

    if overwritten(config)[2]
        tape = copy(x.val)
    else
        tape = nothing
    end

    return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Active, tape, x::Active) 
    println("In custom reverse rule.")
    xval = overwritten(config)[2] ? tape : x.val
    dx = inv(2*sqrt(xval))' * dret.val
    return (dx, )
end

The following error then occurs when a Const output is requested from a reverse diff

In custom augmented primal rule.
ERROR: Enzyme execution failed.
Enzyme: No custom reverse rule was applicable for Tuple{ConfigWidth{1, false, false, (false, false)}, Const{typeof(sqrt)}, Type{Const{Float64}}, Nothing, Active{Float64}}

Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:1289
 [2] macro expansion
   @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
 [3] enzyme_call
   @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
 [4] CombinedAdjointThunk
   @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:4998 [inlined]
 [5] autodiff(::ReverseMode{false, FFIABI}, f::Const{typeof(sqrt)}, ::Type{Const}, args::Active{Float64})
   @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:215
 [6] autodiff(::ReverseMode{false, FFIABI}, ::typeof(sqrt), ::Type, ::Active{Float64})
   @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:224
 [7] top-level scope
   @ REPL[10]:1
@wsmoses
Copy link
Member

wsmoses commented Apr 1, 2024

If a function could mutate things in place, returning a constant output is well defined. You need to also define a rule for the const output case (which for this function would do nothing since it is read-only).

@dominic-chang
Copy link
Contributor Author

Sorry, I forgot to mention that I did try doing that, but still received the same error.

function reverse(config, ::Const{typeof(sqrt)}, dret::Const, tape, x::Active) 
    println("In custom reverse rule.")
    return (zero(x.val), )
end
autodiff(Enzyme.Reverse, sqrt, Const, Active(0.5))
ERROR: Enzyme execution failed.
Enzyme: No custom augmented_primal rule was applicable for Tuple{ConfigWidth{1, false, false, (false, false)}, Const{typeof(sqrt)}, Type{Const{Float64}}, Active{Float64}}

@wsmoses
Copy link
Member

wsmoses commented Apr 2, 2024

dret should be Type{<:Const}, not const. An actual value isn't passed unless it is active

@dominic-chang
Copy link
Contributor Author

I receive the same error with this method signature

function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Type{<:Const}, tape, x::Active) 
    println("In custom reverse rule.")
    xval = overwritten(config)[2] ? tape : x.val
    return (zero(x.val), )
end

@wsmoses
Copy link
Member

wsmoses commented Apr 2, 2024

You also need the corresponding augmented primal rule as well (the error above says augmented_primal) wasn't found

@dominic-chang
Copy link
Contributor Author

Sorry. Your right. I missed that. This worked 😀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants