-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial finite differencing testing #14
Conversation
I can't speak to Jarrett's overarching vision, so hopefully he can chime in on those parts, but overall this looks really good to me. I was just thinking about integrating FDM for the tests the other day, as it's proved immensely useful for Nabla. As I said in some in-line comments, I think much of the machinery here could actually be moved to FDM. |
I'm in total agreement with you here. Will open a PR to FDM. Would definitely simplify a lot of stuff here and various other bits of work, as you point out. |
src/rules/linalg.jl
Outdated
@@ -31,7 +31,7 @@ end | |||
function rrule(::typeof(inv), x::AbstractArray) | |||
Ω = inv(x) | |||
m = @thunk(-Ω) | |||
return Ω, Rule(ΔΩ -> m' * ΔΩ * Ω') | |||
return Ω, Rule(ΔΩ -> extern(m)' * ΔΩ * Ω') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good catch. IIRC, this was from back when adjoint
was defined on t::Thunk
to be something like @thunk(extern(t)')
.
Perhaps we should just move the adjoint into the @thunk
, e.g.:
m = @thunk(-Ω')
return Ω, Rule(ΔΩ -> m * ΔΩ * Ω')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we even need a Thunk
here? What do we gain in this particular case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. The general reason to use a Thunk
for closed-over computations is to avoid performing the computation if it ends up being unnecessary; e.g. if a Zero
is passed in for the differential. However, that seems pretty unlikely for a unary rule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhh I see. Zero
has higher precedence than Thunk
in your monad-y dispatch list, meaning that if the stuff on the RHS of the Thunk
is Zero
then the Thunk
is never materialised and we avoid the computation entirely. Have I understood this correctly? (I'm still getting too grips with this...)
src/rules.jl
Outdated
function accumulate!(Δ, rule::AbstractRule, args...) | ||
return materialize!(Δ, broadcastable(add(cast(Δ), rule(args...)))) | ||
end | ||
accumulate!(Δ::Real, rule::AbstractRule, args...) = accumulate(Δ, rule, args...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable. Could probably expand this to Δ::Number
, even. Could we add a space between these two definitions? 🙂
I guess one place where this might yield unintuitive behavior is if it causes users to think that passing in a non-materialize!
-able Δ
always "just" works, e.g. if you call accumulate!(::SArray, ...)
a user might expect it to fallback to accumulate
instead of hitting the setindex!
error it would currently hit when materialize!(::SArray, ...)
is called. It's not immediately clear how to implement something that ensures that. However, I think we can justify/explain this fallback by saying that it only exists so that callers can handle things generically without checking for the scalar case. That way, it makes sense that numbers are special here, and avoids the impression that a similar fallback exists for immutable containers in general.
Downstream packages are allowed to special-case these methods as well, which makes it even more okay for us to do so. I'm writing docstrings for these methods today which will hopefully make things clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a space between these two definitions?
ofc :)
Could probably expand this to Δ::Number
Sounds reasonable.
I'm on board with the rest of what you've suggested as well
This is awesome, thanks so much! Made some comments but everything here generally LGTM, agree with moving the relevant functionality to FDM.jl.
Yup, a forward-mode equivalent will be necessary at some point, though it could probably eventually be merged with the current There are quite a few permutations in terms of features to test; for any given rule, we could presumably test both modes vs. all relevant differential types vs. all relevant rule types vs. a sample of possible input shapes vs. different levels of materialization, etc. In the long run, we'd want to cover as much of that space as possible in shared test harnesses, and leave the rest to ad hoc per-rule tests. As time goes on and more per-rule ad hoc tests are added, we can refactor whenever we find similar tests/common functionality to pull out into the shared harnesses.
I think these could be considered some of those ad hoc per-rule tests; seems like each test file could have its own Anyway, this PR is already a huge improvement, thanks again 🙂 |
Are you happy for this particular PR to be merged before this happens?
Sounds reasonable to me. |
Latest push requires this FDM.jl PR to be merged and a new version tagged before tests stand a chance of passing. |
This can now use FDM 0.4.0. |
Co-Authored-By: willtebbutt <wt0881@my.bristol.ac.uk>
…nRules.jl into wct/fdm-testing
@jrevels anything else that you want done before merging? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Should definitely be squashed on merge.
Haha yes, for sure. |
LGTM too 👍 thanks again! Gave you commit bit so feel free to merge after rebase (sorry for the conflict, I can add the note about |
I see no merge conflict... |
Huh, I guess GitHub served me an outdated version of the page or something? weird... |
This is some initial work to set up systematic testing for
frule
s andrrule
s. This initial work covers:accumulate!
very slightly to accomodate scalars. @jrevels I could use your input here, this might be totally inconsistent with what you have in mind.src
andtest
files.frule
s andrrule
s inlinalg
are now covered as a demonstration of the functionality.There are a few things that need to be addressed before I go about adding extra finite-differencing test-coverage (in a later PR) / before this could be merged:
test/linalg.jl
to cover the entire codebase, what still wouldn't be covered? Thetest_adjoint!
function was already here when I started, so I've kept that. Does it need extending? Do we need a forwards-mode equivalent?runtests.jl
that I've wrapped in aMisc Tests
testset. These would ideally go somewhere else. I'm not entirely sure where the optimal location is though and could do with some advice.