Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom rule not detected if defined after call to autodiff #696

Closed
gaurav-arya opened this issue Apr 5, 2023 · 7 comments · Fixed by #702
Closed

Custom rule not detected if defined after call to autodiff #696

gaurav-arya opened this issue Apr 5, 2023 · 7 comments · Fixed by #702

Comments

@gaurav-arya
Copy link
Member

gaurav-arya commented Apr 5, 2023

MWE on main:

julia> using Enzyme

julia> import .EnzymeRules: forward

julia> f(x) = x^2
f (generic function with 1 method)

julia> autodiff(Forward, f, Duplicated(1.0, 1.0))
(2.0,)

julia> forward(::Const{typeof(f)}, ::Type{<:DuplicatedNoNeed}, x::Duplicated) = 10+2*x.val*x.dval
forward (generic function with 1 method)

julia> forward(func::Const{typeof(f)}, ::Type{<:Duplicated}, x::Duplicated) = Duplicated(func.val(x.val), 10+2*x.val*x.dval)
forward (generic function with 2 methods)

julia> autodiff(Forward, f, Duplicated(1.0, 1.0)) # answer unchanged
(2.0,)

julia> f(x) = x^2 # redefine f and we get the correct answer
f (generic function with 1 method)

julia> autodiff(Forward, f, Duplicated(1.0, 1.0))
(12.0,)

The reason for the behaviour might be this line:

world = GPUCompiler.get_world(Core.Typeof(f.val), tt)

@vchuravy
Copy link
Member

vchuravy commented Apr 5, 2023

We have no invalidation edge from forward/reverse/augmented_forward to the result of our compilation.
So the code is not invalidated when a new rule is being defined.

@gaurav-arya
Copy link
Member Author

Is this something that you think is somewhat easy to fix? Happy to try to make a PR if you think so. The workaround is to redefine the original function each time the custom rule is modified, but it's easy to forget.

@vchuravy
Copy link
Member

vchuravy commented Apr 5, 2023

Hm not trivial, we need to add a fictitious edge to a potentially not yet existing method.

Somewhere in https://github.com/EnzymeAD/Enzyme.jl/blob/939f9b4086d62b07eca8107db9523a2f8fe043d3/src/compiler/interpreter.jl or

has_custom_rule = EnzymeRules.has_frule_from_sig(specTypes; world)

Maybe @aviatesk has some ideas. I suspect we will need to intercept abstract_gf_call or something similar and add inference edge... But we don't necessarily know which version of the rule we will call due to activity analysis.

That we only know around

llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world)

We should definitely try to add edges from there... But I am unsure how to add a speculative edge for an undefined method.

@gaurav-arya
Copy link
Member Author

speculative edge for an undefined method

Definitely a bit out of my depth here, but: would backedges be easier if forward etc. had a default fallback of nothing? (So rule detection would change from method detection to checking if the method returns nothing). That's the situation for frule and friends: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/79ba4ef03afdf5715b6fa0294e5accfe4e95c79b/src/rules.jl#L61

@aviatesk
Copy link
Contributor

aviatesk commented Apr 6, 2023

I believe I can fix this. How can we access the MethodInstance object representing the caller inside GPUCompiler.codegen? We need it to add appropriate backedges to it.

@wsmoses
Copy link
Member

wsmoses commented Apr 7, 2023

for (mi, k) in meta.compiled
this contains every compiled methodinstance [at least that we see]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants