Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making Testing More Automatic #213

willtebbutt opened this issue Sep 6, 2021 · 6 comments

Making Testing More Automatic #213

willtebbutt opened this issue Sep 6, 2021 · 6 comments


Copy link

willtebbutt commented Sep 6, 2021

In an ideal world, a user would test an rrule by writing something like the following, and have it work all of the time:

test_rrule(foo, args...)

By work all of the time, I mean that the tests that we want to run to determine the correctness of an rrule implementation are always run successfully, provided that the function is something that know how to test (broadly speaking, the output is deterministic given the input), and for any input type that is either

  1. a primitive that we know about (Real, Array, etc), or
  2. a composite type.

It's important that this works automatically because we want people to be testing their code using CRTU, and people like to define new types (including new AbstractArrays) and new functions. Unfortunately, I don't believe it's possible to automate in all cases, but the way in which it fails (AFAICT) is very specific, and I think we can document it and make it easy to resolve for users.

Roughly speaking, list of the functionality that always needs to always work in order to achieve this is

  1. to_vec
  2. to_vec_tangent (a new function)
  3. rand_tangent
  4. test_approx

to_vec, to_vec_tangent and rand_tangent can be made to "always work", but test_approx occassionally has a quirk that I don't believe that we can automate.

The outcome is the following proposals:

  1. remove all (or at least most) to_vec implementations in favour of the generic to_vec implementation of isstructtype types, and necessary to_vec implementations for isprimitivetype types,
  2. introduce a to_vec_tangent (better name welcome) function, which is like to_vec, but the closure returned returns a tangent rather than a primal,
  3. add a function called remove_junk_data, or something similar, which applies to primals, and returns another object which contains only the bits the primal relevant for defining isapprox and whenever we test rules, we test the composition of remove_junk_data and the function being tested, rather than just the function. This enables us to define test_approx in a really generic manner.

I'll explain throughout this issue why I believe these are sensible proposals, and how they resolve things.

Additionally, while this proposal is independent from other proposed changes, it clearly favours a structural view of the world because I'm interested in automating things. See JuliaDiff/ChainRulesCore.jl#449 for a proposal for how we can do this without sacrificing usability, and how this leads to a precise definition for natural tangents.

I would be really interested to know if anyone thinks I've obviously missed something, or whether this sounds about right.

edit: I completely neglected constraint-related problems (eg. if the tangent provided to FiniteDifferences needs to represent a positive definite matrix for some reason). AFAICT the things discussed are essentially orthogonal to the constraint problems though.

edit2: note: undefined references are not fun. For example, perfectly well-defined Dict objects can contain undefined memory. I think this probably comes under the heading of "junk" data, but is seems to cause problems for to_vec as it's currently defined. I wonder whether it could be generalised?

Example 1: Diagonal size mismatch

Consider testing

f(x::Diagonal) = 5x

Let the output (co)tangent be

x = Diagonal(randn(2))
dx = Tangent{typeof(x)}(diag=randn(2))


FiniteDifferences.j′vp(central_fdm(5, 1), f, dx, x)

produces the error:

ERROR: DimensionMismatch("second dimension of A, 4, does not match length of x, 2")
 [1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
   @ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:530
 [2] mul!
   @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:97 [inlined]
 [3] mul!
   @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
 [4] *(transA::Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
   @ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:87
 [5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
   @ FiniteDifferences ~/.julia/packages/FiniteDifferences/W3rQO/src/grad.jl:80
 [6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Tangent{Diagonal{Float64, Vector{Float64}}, NamedTuple{(:diag,), Tuple{Vector{Float64}}}}, x::Diagonal{Float64, Vector{Float64}})
   @ FiniteDifferences ~/.julia/packages/FiniteDifferences/W3rQO/src/grad.jl:73
 [7] top-level scope
   @ REPL[24]:1

Why is this example a problem?

Firstly, we presently require that rules accept either a natural or structural tangent. Due to the above, it's not currently possible to test functions which output a Diagonal with a Tangent tangent.

Secondly, there exist Diagonal matrices whose tangent cannot be represented by a Diagonal. Specifically, any Diagonal whose diag field doesn't provide a way to produce an AbstractVector as its tangent (i.e. for whatever reason, lacks a natural tangent). Consequently, in order for our testing facilities to handle any type, they must be able to work with structural tangents.

Finally, our current imlementations special case to_vec for lots of different arrays (Diagonal, Symmetric etc). This is a problem in itself, but moreover we're never entirely sure what the right thing to do is when we encounter a new array.

How to fix this problem

Remove the specialised to_vec methods for Diagonal and other struct AbstractArrays (UpperTriangular, Symmetric, etc), and instead just rely on the generic to_vec operation for structs.

Doing this immediately means that we can to_vec anything that is either

  1. a primitive that we've defined to_vec on, or
  2. any struct or mutable struct.

This solution brings into focus a problem that we're currently solving on an ad-hoc basis in to_vec: "junk" data in e.g. the lower triangle of a Symmetric can wind up being used in approximate equality checks (and could in principle introduce non-determinism in an otherwise deterministic function, although I've yet to find an example of this in the wild), which makes no sense. We'll address this later.

Example 2: to_vec gives the wrong type sometimes

to_vec only knows about primals -- it knows nothing about tangents. The reason for this is because it was written when we also knew nowhere near enough about tangents, in particular for arrays. The particular problem is in the call to vec_to_x on this line in j′vp. It attempts to convert a "flat" vector representation of a cotangent into a primal. While this works fine in some cases (a surprisingly large number, given how much mileage we've gotten out of to_vec over the years), we know that it doesn't work for all types.

Once you've removed the to_vec implementations for the various concrete subtypes of AbstractArray, you'll find that

FiniteDifferences.j′vp(central_fdm(5, 1), identity, dx, x)[1]


2×2 Diagonal{Float64, Vector{Float64}}:

Why is this is a problem

Having the pullback for identity return anything other than whatever cotangent it is provided seems highly undesirable to me, so I'm going to assume that our rrule for identity does just that. If that is the case, then the cotangent returned the pullback produced by that rrule will be a Tangent if the input is a Tangent , not a Diagonal. This means that our current implementation is incorrect. While this particular example seems reasonably benign, to my mind it's not correct. However, even if you believe it's correct, it's clearly only correct because Diagonal{Float64, Vector{Float64}}s happens to have nice natural tangents that happen to be produced by to_vec, rather than by design.

A more obviously incorrect / plainly-uninterpretable example is a Symmetric -- the from_vec output from to_vec(::Symmetric) will produce a Symmetric with an uplo field that is a Char. Since a Char isn't an appropriate tangent for a Char (it should be a NoTangent), this is plainly nonsensical if the goal is to obtain a tangent. If you wound up comparing between this representation of the tangent and a Tangent output from AD, you would need to compare a NoTangent with this Char, which under any sensible definition would fail (I can't imagine a world in which I would wish to reside in which NoTangent is considered equal to a Char).

How to fix this problem

Introduce another function to_vec_tangent (better name would be nice) which returns a closure that always returns an appropriate tangent representation (primtive for primitives, structural for composites). This would require roughly the same level of implementation effort as to_vec, and would mirror its structure almost entirely (specific methods for primitives, generic method for all isstructtype types).

Example 3: Propagation of Junk Data


x = Symmetric(randn(2, 2))
dx = Tangent{typeof(x)}(data=randn(2, 2))
FiniteDifferences.j′vp(central_fdm(5, 1), identity, dx, x)[1].data


2×2 Matrix{Float64}:
 -0.180472  -0.793039
  0.740994   0.900423

Note that the data field is the relevant bit of the output from j′vp here, because the consistent / correct interpretation of the thing that FiniteDifferences outputs is a Tangent, not a Symmetric, as discussed in the previous example. Observe that the lower triangle (element (2, 1)) will be used when test_approx is computed, because the generic definition of test_approx doesn't know about the specific semantics of Symmetric. Since the standard libary makes no promises about the lower triangle of a Symmetric, it seems to me intuitive that we shouldn't have to worry about it in our gradient definitions. I'm happy to expand on this, but there's a good example here.

The solution I believe is best is to ensure that the gradient w.r.t. irrelevant elements is always 0 by always testing

x -> remove_junk_data(f(x))

rather than just f. The function remove_junk_data would be defined such that it doesn't propagate any junk data (data which isn't relevant for equality computations). The implementations that I have so far are things like:

remove_junk_data(x::Number) = x

remove_junk_data(x::StridedArray) = map(remove_junk_data, x)

remove_junk_data(x::Symmetric{T, <:StridedArray{T}}) where {T} = collect(x)

remove_junk_data(x::UpperTriangular{T, <:StridedArray{T}}) where {T} = collect(x)

remove_junk_data(x::LowerTriangular{T, <:StridedArray{T}}) where {T} = collect(x)

function remove_junk_data(x::T) where {T}
    Base.isstructtype(T) || throw(error("Expected a struct type"))
    return map(remove_junk_data, fieldnames(T))

Another option one could consider is trying to define equality properly on Tangents. This isn't general though because e.g. the data field of a Tangent{Symmetric} might itself be a Tangent, which doesn't have a conception of its own lower triangle. The benefit of composing with remove_junk_data is that we get to operate on primal types, whose semantics everyone is familiar with (the data field of a Symmetric definitely does know about triangles because its an AbstractArray and has getindex defined).

So we can instruct type authors (ourselves for stdlib types) that if their types have any data that's essentially "junk" they must define a method of remove_junk_data, and accept that we'll have to expend some extra computation internally to differentiate remove_junk_data when testing (can probably be optimised away in most cases, since its the identity function in most cases).

Note that the generic fallback for composites means that we'll get overly restrictive tests by default, and type-authors have to opt-in to say that some bits of their type aren't important. This seems like the desirable way around to me -- I'd rather have tests yelling at me when they ought not to be, than them to fail to yell at me when they should.


Assuming that this pans out, this is a win-win for developers and users.

Developers get simpler, more robust, more straightforward to understand code with fewer edge cases -- the edge cases that remain have clear semantics and it's clear why they're necessary.

Users benefit from more predictable and reliable infrastructure.

The issue with this proposal is that it requires structural tangents to actually be taken seriously by everyone. Again, see JuliaDiff/ChainRulesCore.jl#449 for a discussion of how to make this more straightforward for all involved.

Copy link
Member Author

willtebbutt commented Sep 6, 2021

Another way to go about the equality problem is to avoid defining approximate equality on tangents entirely, and instead rely on

  1. being able to add a tangent to a primal
  2. checking whether two primals are approximately equal.

i.e. implement test_rrule along the lines of

dx_ad = _compute_ad_cotangent
dx_fd = _compute_fd_cotangent

x_ad = x + dx_ad
x_fd = x + dx_fd

test_approx(x_ad, x_fd)

While test_approx would need to be implemented for primals, we could hijack isapprox for many types with junk data (in particular any AbstractArrays containing junk data), and this might be more intuitive for type implementers than having to implement remove_junk_data. We could still utilise generic definitions for test_approx for arbitrary structs, because we're not bound to the isapprox definition of approximate equality in all cases.

edit: the issue with this proposal is, of course, that you can't always add a tangent to a primal, because constraints.

Copy link
Member Author

willtebbutt commented Sep 6, 2021

Sketch implementation for to_vec_tangent (untested):

function to_vec_tangent(x::Real)
    Real_Tangent_from_vec(x_vec) = first(x_vec)
    return [x], Real_Tangent_from_vec

function to_vec_tangent(z::Complex)
    Complex_Tangent_from_vec(z_vec) = Complex(z_vec[1], z_vec[2])
    return [real(z), imag(z)], Complex_Tangent_from_vec

to_vec_tangent(x::Vector{<:Union{Real, Complex}}) = (x, identity)

function to_vec_tangent(x::Vector)
    x_vecs_and_backs = map(to_vec_tangent, x)
    x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
    function Vector_Tangent_from_vec(x_vec)
        sz = cumsum(map(length, x_vecs))
        x_Vec = [backs[n](x_vec[sz[n] - length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
        return x_Vec

    # handle empty x
    x_vec = isempty(x_vecs) ? eltype(eltype(x_vecs))[] : reduce(vcat, x_vecs)
    return x_vec, Vector_Tangent_from_vec

function to_vec_tangent(x::Array)
    x_vec, Tangent_from_vec = to_vec_tangent(vec(x))

    function Array_Tangent_from_vec(x_vec)
        return collect(reshape(Tangent_from_vec(x_vec), size(x)))

    return x_vec, Array_Tangent_from_vec

to_vec_tangent(x::Char) = (Bool[], _ -> x)

# Any struct ought to be interpretable as a Tangent, regardless inner constructors etc.
function to_vec_tangent(x::T) where {T}
    Base.isstructtype(T) || throw(error("Expected a struct type"))
    isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types

    val_vecs_and_backs = map(name -> to_vec_tangent(getfield(x, name)), fieldnames(T))
    vals = first.(val_vecs_and_backs)
    backs = last.(val_vecs_and_backs)

    v, Tangents_from_vec = to_vec_tangent(vals)
    function structtype_Tangent_from_vec(v::Vector{<:Real})
        val_vecs = Tangents_from_vec(v)
        tangents = map((b, v) -> b(v), backs, val_vecs)
        return Tangent{T}(NamedTuple(zip(fieldnames(T), tangents))...)
    return v, structtype_Tangent_from_vec

Seems like the generic method will be slightly simpler than to_vec, because there's no need to mess with Julia internals to reconstruct types that have restrictive constructors, but otherwise appears to me to be basically the same.

Copy link

oxinabox commented Sep 6, 2021

I also would like to get rid of to_vec per JuliaDiff/FiniteDifferences.jl#97
by making FiniteDifferences stop using it and use ChainRules types instead.

It may be that we need to incorporate a new thing that helps teach it how to perturb a object though all it's basis points, where each perturbation returns a ChainRules type.
But I am hopeful that we can work out a way that this is rare.
But also possibly that such a way to control how it perturbs might help with objects that must remain on manifold that is not reflected in their type. eg. ##152 and #99
This is similar but not quite the same as to_vec_tangent, I think?

(From there I would like to be using norm on those to do things the inner product way #204 )

I suspect the easy fix for issues with test_approx right now is to not use it directly on tangents, but rather add the primal to each tangent being compared and them compare them, no?

Copy link
Member Author

I suspect the easy fix for issues with test_approx right now is to not use it directly on tangents, but rather add the primal to each tangent being compared and them compare them, no?

#213 (comment)

Also my thoughts. The only problem you get is the standard "can perturb out of the set of types that can be represented" problem. remove_junk_data doesn't have that problem.

I also would like to get rid of to_vec per JuliaDiff/FiniteDifferences.jl#97

Agreed -- this would be nice. However, my expectation is that it would be a non-trivial amount of work to make that work properly, while implementing to_vec_tangent is going to be almost trivial under the scheme proposed above. All I'm saying is that I would prefer that the above weren't considered blocked by what you're proposing because I think what I'm proposing could fix a decent chunk of code.

But also possibly that such a way to control how it perturbs might help with objects that must remain on manifold that is not reflected in their type. eg. ##152 and #99

Good point. I've editted my first comment to point out that I forgot about these kinds of problems. I agree they're important, but I think they're basically orthogonal to what I'm discussing here. Please correct me if I'm wrong.

Copy link
Member Author

willtebbutt commented Sep 6, 2021

Fun story: I managed to get the default to_vec for structs to successfully to_vec a Dict. I did this by defining undefined / unassigned entries in arrays to be non-differentiable in general (which seems reasonable -- if such things had a type, we would make them non-differentiable). The same isn't true for Tangent{<:Dict}s, because we've made them primitives by defining the backing to be a Dict.

On balance, probably not worth worrying about undefined / unassigned entries unless we find more examples where they matter.

Copy link

mzgubic commented Sep 28, 2021

One thing to point out about Example 1 is that the "reverse" problem is also there: a primal output that is a dense matrix will error with the same dimension mismatch if it receives a Diagonal tangent. That said, this can also be fixed with the tangent_to_vec, if it knows about the primal as you suggest in JuliaDiff/FiniteDifferences.jl#189 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

No branches or pull requests

3 participants