Add Mooncake extension for ArrayPartition cotangents#575
Conversation
Once both of these land: - SciML#1422 (df_iip/df_oop ArrayPartition cotangent unwrap) - SciML/RecursiveArrayTools.jl#575 (Mooncake increment_and_get_rdata! for ArrayPartition) this tutorial works end-to-end under `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`. Verified locally with both PRs applied (Lux + StatefulLuxLayer + SecondOrderODEProblem Adam loop trains). Drops the explanatory `!!! note` and adds the `import Mooncake`. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
`_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
such as the one produced by `SecondOrderODEProblem`) returns a
parameter / state cotangent as an `ArrayPartition`, Mooncake's
`@from_chainrules` / `@from_rrule` accumulator looks for an
`increment_and_get_rdata!` method matching
(FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)
There isn't a default method registered for this combination, so the
call falls through to the generic error path:
ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
rdata type Mooncake.NoRData, and tangent type
RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
combination is not supported with @from_chainrules or @from_rrule.
Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt`
weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of
inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}`
and the inner tuple positions line up with `t.x`. Walk the tuple
element-by-element and forward each leaf to the existing
`increment_and_get_rdata!` for the leaf's array type, which does the
actual in-place accumulation. Returns `Mooncake.NoRData()` to match the
no-rdata convention used by the equivalent ComponentArrays dispatch
(SciML/ComponentArrays.jl#350 / SciML#351).
Tested end-to-end against the SciMLSensitivity neural-ODE
`SecondOrderODEProblem` tutorial (via SciML/SciMLSensitivity.jl#1422,
which adds the matching `df_iip`/`df_oop` cotangent unwrap on the
SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition`
training loop now runs under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
The cherry-pick from the v3-backport branch left the Mooncake entry
under [deps] rather than [weakdeps] because the surrounding lines in
master's Project.toml differ from v3-backport's. Move Mooncake to the
weakdeps block so the extension loads via the normal weakdep trigger
instead of as a hard dependency, and add Mooncake to [extras] /
[targets.test] so the extension is actually exercised in CI.
Add test/mooncake.jl with a direct unit test for the new
`Mooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}},
::NoRData, ::ArrayPartition{P, T})` dispatch: constructs a matching
FData and ArrayPartition, calls `increment_and_get_rdata!`, and checks
that (a) the in-place accumulation on each inner-array leaf is correct
and (b) the method returns `NoRData()`. Also exercises a three-way
Float32 ArrayPartition to cover a different eltype and arity. Register
the testset in runtests.jl under the Core group.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
8ffc584 to
c547a94
Compare
|
Rebased onto current What changed The original branch was cut from Dropped all of the v3-backport ancestry and kept only the Mooncake commit (
Local test run ( Aqua (including |
Add Mooncake to [extras] and [targets.test] so the new
`RecursiveArrayToolsMooncakeExt` is actually loaded and exercised in
the test suite, and add test/mooncake.jl as a direct unit test for the
new `Mooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}},
::NoRData, ::ArrayPartition{P, T})` dispatch: constructs a matching
FData and ArrayPartition, calls `increment_and_get_rdata!`, and checks
that (a) the in-place accumulation on each inner-array leaf is correct
and (b) the method returns `NoRData()`. Also exercises a three-way
Float32 ArrayPartition to cover a different eltype and arity. Register
the testset in runtests.jl under the Core group.
Backport of SciML#575 to v3-backport.
Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's `_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`, such as the one produced by `SecondOrderODEProblem`) returns a parameter / state cotangent as an `ArrayPartition`, Mooncake's `@from_chainrules`/`@from_rrule` accumulator looks for an `increment_and_get_rdata!` method matching
```julia
(FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)
```
There isn't a default method registered for this combination, so the call falls through to the generic error path:
```
ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
rdata type Mooncake.NoRData, and tangent type
RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
combination is not supported with @from_chainrules or @from_rrule.
```
Fix
Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt` weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}` and the inner tuple positions line up with `t.x`. Walk the tuple element-by-element and forward each leaf to the existing `increment_and_get_rdata!` for the leaf's array type, which does the actual in-place accumulation. Returns `Mooncake.NoRData()` to match the no-rdata convention used by the equivalent ComponentArrays dispatch (SciML/ComponentArrays.jl#350 / #351).
This is a tiny additive extension — it only adds a method for a previously-unsupported `(FData, NoRData, tangent)` combination, so existing Mooncake users of `RecursiveArrayTools` are unaffected.
Test plan
Related
🤖 Generated with Claude Code