-
Notifications
You must be signed in to change notification settings - Fork 64
Description
It would be nice to have an extension to the ChainRules API that allowed for different rules to be written and hit depending on what combination of inputs the derivative is being taken with respect to.
(or maybe similar analogy for forwards mode)
Basically partial derivative rules.
Thunking is a simple approximation to this (with it's own set of struggles)
@wsmoses requested for Enzyme, though this is also relevant to every kind of operator overloading based AD (since only tracked types etc will have derivatives taken wrt to them).
In contrast it is useless for Zygote/Diffractor as they do no kind of activity analysis etc, and transform absolutely all code that is run.
A bit of a sketch for what that API might look-like is in https://gist.github.com/oxinabox/c6ad25c468b3108f8a799bda66c147f8/
This might also be useful for partial mutation support, since it is probably completely safe to have rules for things that mutate inputs that are not "active" on the derivative path? (cf JuliaDiff/ChainRules.jl#521)
Though as the main reason we don't do mutation is tied to Diffractor/Zygote not supporting it, that might be kinda moot, unless they got some at least some basic activity analysis.
(NB: we may not initially implement this in ChainRulesCore. It might be better to make a little experimental extension package for it first.)