Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Writing good rules" docs should comment on type constraints #155

Open
nickrobinson251 opened this issue Apr 28, 2020 · 7 comments
Open
Labels
documentation Improvements or additions to documentation

Comments

@nickrobinson251
Copy link
Contributor

nickrobinson251 commented Apr 28, 2020

the frule/rrule method signatures should have the same type constraints as the primal function (JuliaDiff/ChainRules.jl#175 (comment))

For example if your pacakge defines two methods for foo:

foo(::Number) = ...
foo(::AbstractMatrix) = ...

prefer defining

frule((_, ẋ), ::typeof(foo), x::Union{Number, AbstractMatrix}) = ...
rrule(::typeof(foo), x::Union{Number, AbstractMatrix}) = ...

to defining

frule((_, ẋ), ::typeof(foo), x) = ...
rrule(::typeof(foo), x) = ...
@willtebbutt
Copy link
Member

I definitely agree that the above is necessary, but I'm not sure that it's sufficient. For example, you probably don't want to be implementing *(A::AbstractMatrix, B::AbstractMatrix) since, if you do so, bad things happen to your performance if A or B aren't dense.

@nickrobinson251
Copy link
Contributor Author

Yeah, we might want to add a note about performance considerations with that as an example

@oxinabox
Copy link
Member

oxinabox commented Apr 29, 2020

Yes,
I think the general statemenet is that type constraints on primal methods should match those on the chain rule methods.
If AbstractMatrix is good enough for the primal then it is good enough for the rrule/frule.
At least to a first approximation.

Sometimes you get to combine them via Union when the pullback/pushforward doesn't differ.
Sometimes you get the specialcase them for performance

@willtebbutt
Copy link
Member

willtebbutt commented Apr 29, 2020

If AbstractMatrix is good enough for the primal then it is good enough for the rrule/frule.

I don't buy that. The issue is that you're saying "for no subtype of AbstractMatrix should you look inside this function and figure out how to differentiate it". This just isn't the behaviour that you want from an AD tool in general. It really should be able to automatically exploit eg. that a matrix is Diagonal without having to have a rule for it.

AFAICT the defining difference between writing normal Julia method and writing code for a rule is that when writing a normal Julia method generic fallbacks are better than nothing. There's no other code, so what else could you hope to do? Conversely, the generic fallback to "just do AD" is often exactly what you would have wanted to do, and custom rules can get in the way of this, causing significantly worse performance than the fallback.

I'm not saying that it's impossible to write normal Julia methods that specialise and cause bad performance, but it seems to be much less of a problem in practice.

A better approach would be a recommendation like "only write a rule if you're confident that it's the right way to compute the rule for every reasonably-implemented subtype". e.g. it's probably reasonable to implement a rule for StridedMatrix in a lot of places where it's not reasonable to do so for AbstractMatrix - matrix-matrix multiplication being the canonical example.

The alternative is to require that you implement custom rules for all methods of * involving subtypes of AbstractMatrix that aren't dense / strided / similar.

edit: In short, err on the side of caution when implementing rules. Too specific is likely to cause fewer problems than overly general.

@nickrobinson251
Copy link
Contributor Author

I agree with Will. But also I think (hope?) this worry is only significant for very generic function (i.e. that have many specialisation and may be extended by many packages) and maybe even those which are performance critical. So * and + are the extreme cases, not the representative ones. And so probably the worry mostly applies to a relatively small number of Base (and maybe LinearAlgebra functions)?

@oxinabox
Copy link
Member

You make a good point Will, I retract my claim of the generality of my statement.

We do want a thing saying that rrule methods should be no more general than the methods they are written for.
and seperately a thing saying what Will just explained

@nickrobinson251
Copy link
Contributor Author

Xref JuliaDiff/ChainRules.jl#232

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants