Skip to content

Ability to specify different rules based on what combinations of inputs are actually being used #452

@oxinabox

Description

@oxinabox

It would be nice to have an extension to the ChainRules API that allowed for different rules to be written and hit depending on what combination of inputs the derivative is being taken with respect to.
(or maybe similar analogy for forwards mode)
Basically partial derivative rules.

Thunking is a simple approximation to this (with it's own set of struggles)

@wsmoses requested for Enzyme, though this is also relevant to every kind of operator overloading based AD (since only tracked types etc will have derivatives taken wrt to them).
In contrast it is useless for Zygote/Diffractor as they do no kind of activity analysis etc, and transform absolutely all code that is run.

A bit of a sketch for what that API might look-like is in https://gist.github.com/oxinabox/c6ad25c468b3108f8a799bda66c147f8/

This might also be useful for partial mutation support, since it is probably completely safe to have rules for things that mutate inputs that are not "active" on the derivative path? (cf JuliaDiff/ChainRules.jl#521)
Though as the main reason we don't do mutation is tied to Diffractor/Zygote not supporting it, that might be kinda moot, unless they got some at least some basic activity analysis.

(NB: we may not initially implement this in ChainRulesCore. It might be better to make a little experimental extension package for it first.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    designRequires some desgin before changes are made

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions