Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to rules for norm #337

Merged
merged 22 commits into from
May 10, 2021
Merged

Improvements to rules for norm #337

merged 22 commits into from
May 10, 2021

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Dec 26, 2020

This fixes FluxML/Zygote.jl#860 by relaxing the signature on the rule for norm. While there is no longer an explicit mention of Diagonal in the rule, the tests still check that this is correctly reproduced.

It should also fix gradient(norm, [1, 2]), which gave an InexactError.

And adds a few more InplaceableThunks while I was at it.

Edit -- fixes FluxML/Zygote.jl#960, too.

@codecov-io
Copy link

codecov-io commented Dec 26, 2020

Codecov Report

Merging #337 (df02563) into master (ebc99f7) will increase coverage by 0.03%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #337      +/-   ##
==========================================
+ Coverage   97.64%   97.67%   +0.03%     
==========================================
  Files          18       18              
  Lines        1018     1034      +16     
==========================================
+ Hits          994     1010      +16     
  Misses         24       24              
Impacted Files Coverage Δ
src/rulesets/LinearAlgebra/norm.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ebc99f7...df02563. Read the comment docs.

Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great work, thanks! I have just some very minor suggestions, but overall LGTM. I think this might even fix some problems I am currently having with Zygote, but need to look into that some more.

src/rulesets/LinearAlgebra/norm.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/norm.jl Outdated Show resolved Hide resolved
test/rulesets/LinearAlgebra/norm.jl Outdated Show resolved Hide resolved
test/rulesets/LinearAlgebra/norm.jl Show resolved Hide resolved
test/rulesets/LinearAlgebra/norm.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member

We made the restriction of the rrule to StridedArrays and a few others because it (or, really, any rrule that broadcasts over a similar type to the primal) does the wrong thing for large classes of arrays. For example, it will not work for Hermitian matrices. Using this PR:

julia> using LinearAlgebra, ChainRules

julia> A, ȳ = randn(ComplexF64, 5, 5), randn();

julia> hermA, Hermitian_back = rrule(Hermitian, A, :U);

julia> y, norm_back = rrule(norm, hermA);

julia> unthunk(norm_back(ȳ)[2])
ERROR: ArgumentError: Cannot set a non-diagonal index in a Hermitian matrix

This is due to the broadcasting to a similar type to A. Roughly, I would not expect the rule to work correctly for any matrix type that has non-sparsity constraints, and potentially to fail for some that only have sparsity constraints. So dense arrays, Diagonal, and AbstractTriangular are fine, but many other types may not be, and it's better to not have a rule than have one that silently does the wrong thing.

@nickrobinson251 nickrobinson251 added the type constraints Potentially raises a question about how tightly to constrain argument types for a rule. See #232 label Jan 1, 2021
@mcabbott
Copy link
Member Author

mcabbott commented Jan 1, 2021

I would not describe an ArgumentError as being silent, but I agree it's a problem, thanks for the example.

This is precisely because of the work-around for NaN .* Diagonal(rand(3)) (in 1b7307f), so perhaps whatever solves that needs to be applied more narrowly, leaving straight broadcasting as the most generic fallback. I have an idea...

I guess I'm not completely sold that "all is forbidden except what is permitted" is the right policy, compared to Julia's usual AbstractArray flexibility. The motivating issue above involves StaticArrays, which used to work, and will behave well under the generic brodcast etc. of a fairly simple rule like this. (And less well under the similar(x) variant.) It seems very natural and very Julian that they would share code. How would these be made to work under a more restrictive policy?

@oxinabox
Copy link
Member

oxinabox commented Jan 2, 2021

The special problem is that there is a class of types which had an overly broad rule not been defined then the AD would have done the right thing, but with it defined the wrong thing happens.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 4, 2021

had an overly broad rule not been defined then the AD would have done the right thing, but with it defined the wrong thing happens.

I'm not sure I follow. Which rule? In FluxML/Zygote.jl#860, there is no rule for norm being used. Are you saying this is caused by some other rule?

julia> Zygote.ChainRules.rrule(norm, [1,2,3.0])
(3.7416573867739413, ChainRules.var"#norm_pullback#1776"{Vector{Float64}, Float64}([1.0, 2.0, 3.0], 3.7416573867739413))

julia> Zygote.ChainRules.rrule(norm, SA[1,2,3.0]) === nothing
true

julia> Zygote.gradient(norm, [1,2,3.0])
([0.2672612419124244, 0.5345224838248488, 0.8017837257372732],)

julia> Zygote.gradient(norm, SA[1,2,3.0])
Internal error: encountered unexpected error in runtime:
BoundsError(a=Array{Any, (2,)}[
  Core.Compiler.VarState(typ=Zygote.Pullback{Tuple{typeof(StaticArrays._norm), StaticArrays.Size{(3,)}, StaticArrays.SArray{Tuple{3}, Float64, 1, 3}}, Any}, undef=false),
...

julia> Zygote.gradient(norm, [1,2,3]) # this one is easy
ERROR: InexactError: Int64(0.2672612419124244)

(Without this PR!)

@oxinabox
Copy link
Member

oxinabox commented Jan 4, 2021

 I'm not sure I follow. Which rule? In FluxML/Zygote.jl#860, there is no rule for norm being used. Are you saying this is caused by some other rule

It is a general statement, not specifically related to this PR.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 4, 2021

Maybe this should be a separate message. Another problem with the dx = similar(x) type-stability fix is that it breaks second derivatives, even when everything is an Array. This turned up at FluxML/Zygote.jl#865 (comment), but a more minimal example is:

julia> Zygote.gradient(x -> sum(Zygote.gradient(norm, x)[1]), rand(3))
ERROR: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#375#376")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/ywhiG/src/lib/array.jl:61
  [3] (::Zygote.var"#2258#back#377"{Zygote.var"#375#376"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
...
  [7] Pullback
    @ ~/.julia/packages/ChainRules/sGOE6/src/rulesets/LinearAlgebra/norm.jl:216 [inlined]
  [8] (::typeof(∂(_norm2_back)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/ywhiG/src/compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ChainRules/sGOE6/src/rulesets/LinearAlgebra/norm.jl:57 [inlined]
 [10] (::typeof(∂(λ)))(Δ::Tuple{Nothing, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
    @ Zygote ~/.julia/packages/Zygote/ywhiG/src/compiler/interface2.jl:0

Does this package have opinions about second derivatives, and about tests for them? Seems tricky, and maybe not every rule can support them, but where easy it would be nice to have.

@sethaxen
Copy link
Member

sethaxen commented Jan 4, 2021

Does this package have opinions about second derivatives, and about tests for them? Seems tricky, and maybe not every rule can support them, but where easy it would be nice to have.

It really seems that we rarely think about being second-differentiable, which seems like an oversight.

@mcabbott
Copy link
Member Author

mcabbott commented Jan 4, 2021

Latest commit 8ab56db is an idea for fixing NaN .* Diagonal(rand(3)) problem, still using broadcasting not writing explicitly into similar (and hence allowing for 2nd derivatives). It's not the prettiest thing ever... but see what you think?

julia> using ChainRules, LinearAlgebra

julia> ChainRules._norm1_back(Diagonal(rand(3)), 1.2, 3.4)
3×3 Diagonal{Float64, Vector{Float64}}:
 3.4   ⋅    ⋅ 
  ⋅   3.4   ⋅ 
  ⋅    ⋅   3.4

julia> ChainRules._norm1_back(Diagonal(rand(3)), 1.2, NaN)
3×3 Diagonal{Float64, Vector{Float64}}:
 NaN       ⋅      ⋅ 
    ⋅   NaN       ⋅ 
    ⋅      ⋅   NaN

julia> @code_warntype ChainRules._norm1_back(Diagonal(rand(3)), 1.2, NaN)
Variables
  #self#::Core.Const(ChainRules._norm1_back)
  x::Diagonal{Float64, Vector{Float64}}
  y::Float64
  Δy::Float64
  ∂x_data::Vector{Float64}

Body::Diagonal{Float64, Vector{Float64}}
1 ─ %1 = ChainRules.parent(x)::Vector{Float64}
│   %2 = Base.broadcasted(ChainRules.sign, %1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sign), Tuple{Vector{Float64}}}
│   %3 = ChainRules.real(Δy)::Float64
│   %4 = Base.broadcasted(ChainRules.:*, %2, %3)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sign), Tuple{Vector{Float64}}}, Float64}}
│        (∂x_data = Base.materialize(%4))
│   %6 = ChainRules.withsomezeros_rewrap(x, ∂x_data)::Diagonal{Float64, Vector{Float64}}
└──      return %6

Hermitian is handled by just broadcasting. And second derivatives may work:

julia> Zygote.gradient(norm, Hermitian(rand(3,3)))[1]
3×3 Matrix{Float64}:
 0.172849  0.319948   0.181923
 0.319948  0.566805   0.0347145
 0.181923  0.0347145  0.612797

julia> Zygote.gradient(x -> sum(Zygote.gradient(norm, x)[1]), rand(3))
([-0.30907017350951627, -0.27754572206132044, 1.6974578767958433],)

julia> Zygote.gradient(x -> sum(Zygote.gradient(norm, x)[1]), Diagonal(rand(3)))  # too much to ask?
ERROR: MethodError: no method matching +(::NamedTuple{(:diag,), Tuple{Vector{Float64}}}, ::Diagonal{Float64, Vector{Float64}})
  ...
Stacktrace:
  [1] accum(x::NamedTuple{(:diag,), Tuple{Vector{Float64}}}, y::Diagonal{Float64, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/ywhiG/src/lib/lib.jl:8
  [2] ZBack
    @ ~/.julia/packages/Zygote/ywhiG/src/compiler/chainrules.jl:77 [inlined]
  [3] Pullback
    @ ~/.julia/dev/ChainRules/src/rulesets/LinearAlgebra/norm.jl:54 [inlined]
  [4] (::typeof(∂(rrule)))(Δ::Tuple{Nothing, NamedTuple{(:x, :y), Tuple{NamedTuple{(:diag,), Tuple{V
...  

Edit -- Tests pass locally, but fail on CI, in tests of Adjoint etc. Something is failing to be conjugated? Haven't made this happen locally yet.

@oxinabox
Copy link
Member

oxinabox commented Jan 4, 2021

Supporting higher order rules is on the roadmap for ChainRules v2.0
It's not something we are worrying about right now. Burn that bridge when we get to it.

I have head that it is a very common problem for Zygote nested AD not to work to to mutation.
Because people deal with lack of mutation support by mutating and hinding the result behind the rule.
The real solution is long term Zygote must support mutation, there really isn't any other reasonable option.
Practically, the second derviatives that are actually wanted are almost always best calculated via forward over forward or forward over reverse, and no forward AD has mutation trouble.
We can fix things when and if they come up, but I wouldn't spend too much time worrying about it til it does.
Not yet. Other fish to cross first

MethodError: no method matching +(::NamedTuple{(:diag,), Tuple{Vector{Float64}}}, ::Diagonal{Float64, Vector{Float64}}

That one will be fixed one Zygote is switched over to using Composite internally, rather than NamedTuple.


∂x = Thunk() do
return if isempty(x) || p == 0
zero.(x) .* (zero(y) * zero(real(Δy)))
InplaceableThunk(
@thunk(zero.(x) .* (zero(y) * zero(real(Δy)))),
Copy link
Member

@oxinabox oxinabox Jan 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know what is fun? this is broken for some arrays.
We might want to constrain the eltype to be Number

Consider:

julia> norm([[[1]]], 2)
1.0

julia> zero.([[[1]]])
ERROR: MethodError: no method matching zero(::Type{Vector{Int64}})

because of JuliaLang/julia#38064

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking for this PR

Copy link
Member Author

@mcabbott mcabbott Apr 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire file should probably be restricted to arrays of <:Number?

Done in be61a4e, but not 100% sure that's a good idea -- haven't tried to audit which rules would or would not work for arrays of arrays.

`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`.
But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable.
"""
WithSomeZeros{T} = Union{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call these StructuredSparseArray

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You approve of the mechanism, #337 (comment)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am willing to give it a shot.
We can always change it later.
It's not going to lead to wrong behavour AFAICT.

It seems unfortunate not to take advantage of the fact that we know where the zeros are,
and we know that the pullback is going to map zeros to zeros, since linear.
So we should be able to skip some.
But idk that that is a generic API for our structurally sparse matrixes to know if an index will be zero.

Copy link
Member Author

@mcabbott mcabbott May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstand you, but both λ .* Diagonal(rand(3)) and this function do know where the zeros are, and do O(N) work. That's the only really sparse one.

For UpperTriangular, I haven't tried to time this against broadcasting... there could be trade-offs, maybe broadcasting skips half, but if so it needs lots of if statements. Frankly I doubt that anyone has ever called norm(::UpperTriangular) outside a test, though. So perhaps thinking about that can wait until this finds wider use where someone does need to care.

Copy link
Member Author

@mcabbott mcabbott May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be good to fix this instability upstream. Can't we argue that the off-diagonal elements are a strong zero like false, and make NaN .* Diagonal(rand(3)) just work? Is there an issue?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like all structural zeros should be strong yes.
I was sure I had seem julia displaying that behavour on SparseCSC matrixes, but I can't reproduce it right now.

@mcabbott mcabbott force-pushed the norm branch 2 times, most recently from cae1531 to 15e55ee Compare April 29, 2021 15:32
@oxinabox
Copy link
Member

oxinabox commented Apr 29, 2021

I wonder if we should just add StaticArrays as a dependency (it is a super popular package),
and then have a const PrimativeArray{T,N} = Union{StridedArray{T, N}, StaticArray{T,N}} or even const PrimativeArray{T,N} = Union{Array{T, N}, StaticArray{T,N}}
and that will break the current in-pass on.

StaticArray (like Array) is a fundermental array type that doesn't wrap some other type, and isn't really better representated as a struct or tuple either.

GPUArrays are also primative arrays in this sense I think?
But maybe can worry about that later

@willtebbutt

test/rulesets/LinearAlgebra/norm.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/norm.jl Show resolved Hide resolved
@mcabbott mcabbott force-pushed the norm branch 3 times, most recently from ae62bc0 to 741b141 Compare April 29, 2021 21:33
@codecov-commenter
Copy link

codecov-commenter commented Apr 29, 2021

Codecov Report

Merging #337 (87e4313) into master (987ee45) will increase coverage by 0.02%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #337      +/-   ##
==========================================
+ Coverage   98.46%   98.49%   +0.02%     
==========================================
  Files          23       23              
  Lines        1893     1929      +36     
==========================================
+ Hits         1864     1900      +36     
  Misses         29       29              
Impacted Files Coverage Δ
src/rulesets/LinearAlgebra/norm.jl 100.00% <100.00%> (ø)
src/rulesets/LinearAlgebra/utils.jl 88.88% <100.00%> (+0.65%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 987ee45...87e4313. Read the comment docs.

@mcabbott
Copy link
Member Author

mcabbott commented May 4, 2021

Maybe isn't a bad place to discuss this business of whether to write rules for AbstractArrays, or just Array. I think we may be talking past each other by generalising from experiences with different examples, without laying out what the assumptions are. I'm not sure the case for such restrictions has been made in any detail anywhere, but perhaps someone would care to make it here. The case against:

Consider the following arrays, all AbstractMatrix:

begin
    n = 100
    M = rand(n,n)
    U = UpperTriangular(M)
    D = Diagonal(rand(n))
    P = PermutedDimsArray(M, (2,1))
end;

The function norm treats all of these wrappers as unknown generic arrays: it calls generic_norm2, which is pure Julia and calls getindex, which then acts on the underlying Array (which I think written in C). So in principle a rule for getindex(::Array, ::Int...) alone would allow reverse-mode AD to produce correct answers.

This is not true for norm(M), which calls BLAS.nrm2(M). So you are always going to need a rule for that. One extreme position would be to write only such rules, only for ccalls.

But another position is that providing high-level rules allows more efficient code. Here's the present behaviour -- I guess you can blame getindex for this but anyway, being 3, or maybe 5, orders of magnitude slower if you miss the rule seems hard to ignore:

julia> @btime gradient(LinearAlgebra.generic_norm2, $M);
  149.084 ms (736258 allocations: 1.51 GiB)   # Zygote
  163.940 ms (796266 allocations: 788.73 MiB) # ... with one-hot getindex
  76.235 ms (270010 allocations: 772.17 MiB)  # Tracker

julia> @btime gradient(norm, $M);
  13.666 μs (2 allocations: 78.20 KiB)    # Zygote -- 1000x faster
  138.709 μs (91 allocations: 627.80 KiB) # Tracker

julia> @btime norm($M);  # subtract the forward pass and it's 10^5 times 
  12.416 μs (0 allocations: 0 bytes)

julia> @btime LinearAlgebra.generic_norm2($M);
  27.958 μs (0 allocations: 0 bytes)

julia> @btime norm($U);
  28.833 μs (0 allocations: 0 bytes)

Note that norm(U) does not unwrap and call norm(U.data), because that would be wrong. Even norm(M') does not unwrap, that's an optimisation nobody got around to adding -- this is the vector norm of a matrix, perhaps rarely used. There are many fallback functions like this.

In fact I would argue that such fallbacks are the reason UpperTriangular <: AbstractMatrix, and so on. The wrapper's main function is to be a flag by which LinearAlgebra routines can signal to each other, tidier than 'U'::Char, but this does not need an AbstractMatrix. The supertype is an opt-in to all sorts of other useful behaviour -- it favours making stuff work, over only getting optimal routines. This is even stronger for Diagonal <: AbstractMatrix, which explicitly opts in to lots of O(n^2) behavoiur, while still being mathematically correct. Using Diagonal not diagm means you will hit some faster routines, but Julia provides no guarantee. If you want to be completely sure, you should handle the Vector yourself (or define a StrictDiagonal type without getindex, perhaps).

Diagonal is pretty unusual in having different complexity. It is special-cased here (along with some friends) only to avoid a strange type-instability. Note again that the forward pass norm(D) is not specialised, and is quite a bit slower than the dense matrix:

julia> @btime gradient(norm, $D);
  32.041 μs (1 allocation: 896 bytes)       # Zygote, with this PR (or master)
  32.083 μs (4 allocations: 944 bytes)      # Zygote, without the special case
  124.333 μs (176 allocations: 487.84 KiB)  # Tracker, generic rule?

julia> @btime norm($D);  # forward pass dominates!
  32.000 μs (0 allocations: 0 bytes)

Ignoring this type instability issue, the generic rule already produces a Diagonal gradient, because the broadcasting it uses specialises. By being generic we are, almost accidentally, already doing better than the original function. If this failed, and the generic routine produced an O(n^2) rule where O(n) would have been possible, the asymptotic regime is still far away: n >> 10^5 is an 80GB matrix. And, adding a special case for Diagonal is also entirely OK -- especially if the original forward function has such a special case.

(I think we should also add an overall projection step which ensures the preservation of structured types such as Diagonal. Doing this would mean that, even if you hit a generic function at one step, it does not propagate backwards through other steps. I think there's broad agreement on this, and it's nearly orthogonal to the present discussion.)

There isn't a special case for SMatrix, nor CuArray, but since the gradient is just made of other array operations, and those all specialise, these work well. This seems to be precisely the kind of "unrelated packages working well together by dispatch" which is the advert for Julia.

julia> @btime gradient(norm, $(@SMatrix(rand(10,10)))); 
  126.949 ns (0 allocations: 0 bytes)  # Zygote -- 60x faster than Array

julia> @btime gradient(LinearAlgebra.generic_norm2, $(@SMatrix(rand(10,10))));
  601.334 μs (8955 allocations: 707.45 KiB) # Zygote -- 4700x slower

julia> @btime norm($(@SMatrix(rand(10,10)))); # uses its own routine
  38.642 ns (0 allocations: 0 bytes)

julia> @btime LinearAlgebra.generic_norm2($(@SMatrix(rand(10,10))));
  203.062 ns (0 allocations: 0 bytes)  # naiive code is 5x slower, same on Array

At present Zygote can't differentiate StaticArrays.norm, so perhaps timing generic_norm2 here is unfair? But still, the 5x difference between the naiive code, and whatever special routine StaticArrays is using, is precisely the sort of specialisation which (in other operations) the generic gradient can exploit. Without resorting to ccalls, there are lots of ways to write Julia array code which is much faster than the most naiive loops. And for GPUArrays, the most naiive loops won't work at all. If the rule encodes what we know to be true, mathematically, and acts at a high level, then it leaves room to re-use all these specialisations downstream.

The proposal to consider Union{Array,StaticArray} seems to me a strange one. There are no ccalls in there, it's just getindex on a Tuple ultimately. So the "extreme position" above would ignore it. What would the argument be to treat that at the array level, but leave U (and the adjoint of an SMatrix, etc.) to use scalar code? If this is considered a "fundermental array type", what's the definition, and does a FillArray, or a MappedArray, or a JuliennedArrays.Align, count too? I suppose the position I am arguing is that (say) a ReadonlyMappedArray{Float64,1,<:UnitRange} has opted in to being treated like a vector, at least as the fallback.

Some of my examples here could be evaded by making LinearAlgebra.norm un-wrap the wrappers defined in the same module. But there will never be a shortage of others, for instance P is a lazy wrapper defined outside LinearAlgebra, and there are many more defined in packages (such as OffsetArrays). Making AD work well with as many of them as possible seems like a worthwhile goal.

Being generic does introduce more ways to mess up -- it's obviously easier to test code which only accepts Matrix{Float64} than it is to test code which allows for (say) a structured matrix of unitful dual numbers. This is the cost of working in Julia, and the benefits are why we're all here. (We discussed elsewhere ideas for tooling to make such tests easier. Note that the present rule is so narrowly written that it fails for Matrix{Int}, missed by tests.)

For those proposing much more restricted rules, would you care to argue where you disagree with the above? For norm acting on U, D, P, or just on some other specific wrapper? Or for some other particular function -- perhaps there are classes of functions which ought to be treated differently, maybe lu is one? Can we find the dividing line?

(I think getindex is one, at least getindex(A, i::Int...), as that really does pass through wrappers in a simple way. But perhaps that's its own topic, as it's not a whole-array function like the others here.)

See also:

@willtebbutt
Copy link
Member

You are picturing f(::Diagonal) being an AD-friendly structure-exploiting method, but what also happens all the time is that norm(::Matrix{<:HwFloat}) is a faster but AD-unfriendly implementation, and norm(::CuArray) is something else entirely. Is there an automatic way you could distinguish these?

I would be surprised if there were an automatic way to distinguish these situations.

I think we've narrowed the range of functions we're discussing down a bit now, which is good -- as I understand it we're now just considering whether, for functions for which specialised methods exist for a particular type, it's a better idea to default to utilising generic rrules, or to let AD have a go at differentiating the specialisation directly. I get the feeling that for the former is what you want for the kinds of situations you're considering, and the latter is preferable in the situations that I care about.

So I think we're in the realm of trying to pick a tradeoff. Is this also your understanding?

In terms of giving us nicer options, I wonder whether it's worth enabling two kinds of rrules

  1. our current very aggressive ones that say "I'm sure that I'm the thing you always want to call, never do AD, regardless any specialisations which may exist", and
  2. ones that are less aggressive, as per the proposal above, which express "use me only if there's no more specific code for AD to have a crack at differentiating".

This would give us the option to make quite strong statements where appropriate, and weaker ones where we just want to ensure that the fallback behaviour is reasonable.

@mcabbott
Copy link
Member Author

mcabbott commented May 7, 2021

Yes, I think that's a fair statement of what's going on.

If there's no automatic way, then some manual way to mark these distinctions is required.

However, I'm not sure that tagging the original abstract rule is enough, as you suggest with these two classes of rrules. We'd like the generic rule to be employed for norm(::CuArray) and norm(::StaticArray), but not for say norm(::FillArray), assuming that's sufficiently AD friendly to prefer no rule. So the same function has to be in both classes.

(And to make automation harder, the overload is one method down, norm2(a::FillArrays.AbstractFill). It's fast but whether it's AD-able I haven't actually checked. This structure is also very common, of course.)

The other manual way is to opt out. Writing rrule(::typeof(norm), ::AbstractFill) = nothing doesn't in fact work, but my misunderstanding of this perhaps led me to see that it'll be easy to make it work (see edited code 4 messages up). This still seems attractive to me, in that (1) if anyone is writing a specialisation for f and Fill then they certainly have both packages loaded, and (2) if they are doing so while thinking even a little bit about AD, then asking that further add a line like ChainRulesCore.@no_rrule f(::AbstractFill) doesn't seem like an unreasonable burden.

By contrast, making people who don't care about AD have their specialisation norm(::StrideArray{Float32}) opt in (by adding a dependency) seems much harder to sell.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are having useful and converstations in this thread (that i have not yet caught up on)
this code benchmarks well.
and passes tests.
Can we address my last few comments and i will merge it.
If we want to change things later we can.
(nice thing about working in code rather than wood, mistakes can be removed invisably)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type constraints Potentially raises a question about how tightly to constrain argument types for a rule. See #232
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Internal error from a BoundsError Operations on StaticArrays give runtime BoundsError
9 participants