Skip to content

Add Mooncake extension for ArrayPartition cotangents#575

Merged
ChrisRackauckas merged 2 commits intoSciML:masterfrom
ChrisRackauckas-Claude:mooncake-arraypartition-rdata
Apr 12, 2026
Merged

Add Mooncake extension for ArrayPartition cotangents#575
ChrisRackauckas merged 2 commits intoSciML:masterfrom
ChrisRackauckas-Claude:mooncake-arraypartition-rdata

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Summary

When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's `_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`, such as the one produced by `SecondOrderODEProblem`) returns a parameter / state cotangent as an `ArrayPartition`, Mooncake's `@from_chainrules`/`@from_rrule` accumulator looks for an `increment_and_get_rdata!` method matching

```julia
(FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)
```

There isn't a default method registered for this combination, so the call falls through to the generic error path:

```
ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
rdata type Mooncake.NoRData, and tangent type
RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
combination is not supported with @from_chainrules or @from_rrule.
```

Fix

Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt` weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}` and the inner tuple positions line up with `t.x`. Walk the tuple element-by-element and forward each leaf to the existing `increment_and_get_rdata!` for the leaf's array type, which does the actual in-place accumulation. Returns `Mooncake.NoRData()` to match the no-rdata convention used by the equivalent ComponentArrays dispatch (SciML/ComponentArrays.jl#350 / #351).

This is a tiny additive extension — it only adds a method for a previously-unsupported `(FData, NoRData, tangent)` combination, so existing Mooncake users of `RecursiveArrayTools` are unaffected.

Test plan

  • End-to-end against the SciMLSensitivity `SecondOrderODEProblem` tutorial (via concrete_solve: unwrap ArrayPartition cotangents in df_iip/df_oop SciMLSensitivity.jl#1422, which adds the matching `df_iip`/`df_oop` cotangent unwrap on the SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition` training loop now runs under `OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`
  • Happy to add a unit test if you'd like — let me know what shape (mirror of an existing extension's tests, against a synthetic FData/ArrayPartition pair)

Related

🤖 Generated with Claude Code

ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/SciMLSensitivity.jl that referenced this pull request Apr 11, 2026
Once both of these land:
  - SciML#1422 (df_iip/df_oop ArrayPartition cotangent unwrap)
  - SciML/RecursiveArrayTools.jl#575 (Mooncake increment_and_get_rdata! for ArrayPartition)
this tutorial works end-to-end under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.
Verified locally with both PRs applied (Lux + StatefulLuxLayer +
SecondOrderODEProblem Adam loop trains).

Drops the explanatory `!!! note` and adds the `import Mooncake`.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
ChrisRackauckas and others added 2 commits April 11, 2026 19:04
When an upstream ChainRules-based adjoint (e.g. SciMLSensitivity's
`_concrete_solve_adjoint` for an ODE whose state is an `ArrayPartition`,
such as the one produced by `SecondOrderODEProblem`) returns a
parameter / state cotangent as an `ArrayPartition`, Mooncake's
`@from_chainrules` / `@from_rrule` accumulator looks for an
`increment_and_get_rdata!` method matching

    (FData{NamedTuple{(:x,), Tuple{Tuple{Vector, …}}}}, NoRData, ArrayPartition)

There isn't a default method registered for this combination, so the
call falls through to the generic error path:

    ArgumentError: The fdata type Mooncake.FData{@NamedTuple{x::Tuple{Vector{Float32}, Vector{Float32}}}},
    rdata type Mooncake.NoRData, and tangent type
    RecursiveArrayTools.ArrayPartition{Float32, Tuple{Vector{Float32}, Vector{Float32}}}
    combination is not supported with @from_chainrules or @from_rrule.

Add the missing dispatch via a new `RecursiveArrayToolsMooncakeExt`
weak-dep extension. An `ArrayPartition`'s only field is `x::Tuple` of
inner arrays, so the FData layout is `FData{@NamedTuple{x::Tuple{...}}}`
and the inner tuple positions line up with `t.x`. Walk the tuple
element-by-element and forward each leaf to the existing
`increment_and_get_rdata!` for the leaf's array type, which does the
actual in-place accumulation. Returns `Mooncake.NoRData()` to match the
no-rdata convention used by the equivalent ComponentArrays dispatch
(SciML/ComponentArrays.jl#350 / SciML#351).

Tested end-to-end against the SciMLSensitivity neural-ODE
`SecondOrderODEProblem` tutorial (via SciML/SciMLSensitivity.jl#1422,
which adds the matching `df_iip`/`df_oop` cotangent unwrap on the
SciMLSensitivity side): with both PRs applied, the Lux + `ArrayPartition`
training loop now runs under
`OPT.AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))`.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
The cherry-pick from the v3-backport branch left the Mooncake entry
under [deps] rather than [weakdeps] because the surrounding lines in
master's Project.toml differ from v3-backport's. Move Mooncake to the
weakdeps block so the extension loads via the normal weakdep trigger
instead of as a hard dependency, and add Mooncake to [extras] /
[targets.test] so the extension is actually exercised in CI.

Add test/mooncake.jl with a direct unit test for the new
`Mooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}},
::NoRData, ::ArrayPartition{P, T})` dispatch: constructs a matching
FData and ArrayPartition, calls `increment_and_get_rdata!`, and checks
that (a) the in-place accumulation on each inner-array leaf is correct
and (b) the method returns `NoRData()`. Also exercises a three-way
Float32 ArrayPartition to cover a different eltype and arity. Register
the testset in runtests.jl under the Core group.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ChrisRackauckas-Claude ChrisRackauckas-Claude force-pushed the mooncake-arraypartition-rdata branch from 8ffc584 to c547a94 Compare April 11, 2026 23:17
@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor Author

Rebased onto current master (was CONFLICTING / DIRTY, now CLEAN / MERGEABLE).

What changed

The original branch was cut from v3-backport, not master, so it carried 15 unrelated ancestor commits (FastBroadcast Serial/Polyester v3-line fixes and the v3.52–v3.54 version bumps — all already merged on v3-backport via #556, #560, #564/#566, #568, #569). None of those belong on master; master already has its own equivalent Polyester extension via #565/#567/#571 and is on v4.0.1. Nothing else was supposed to be merged — just the Mooncake commit.

Dropped all of the v3-backport ancestry and kept only the Mooncake commit (8ffc584). While rebasing I found two things worth fixing in a follow-up commit:

  1. The Mooncake Project.toml patch landed in [deps] instead of [weakdeps] because the cherry-pick's context lines matched inside master's shorter [deps] block (master already has CUDA in [weakdeps], so alphabetical-position drift). Moved it to [weakdeps] so the extension stays weakly loaded.
  2. Mooncake wasn't in [extras] / [targets.test], so the new extension was never exercised in CI. Added it, and added test/mooncake.jl as a direct unit test for Mooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}}, ::NoRData, ::ArrayPartition{P, T}): constructs a matching FData and ArrayPartition, calls the method, and asserts in-place accumulation on each inner-array leaf is correct and the return is NoRData(). Also covers a three-way Float32 partition to exercise a different eltype and arity. Registered under the Core test group in runtests.jl.

Local test run (Pkg.test on Julia 1.10.11):

Test Summary:  | Pass  Total  Time
Mooncake Tests |    8      8  1.0s
...
Testing RecursiveArrayTools tests passed

Aqua (including test_stale_deps) passes, full Core group green.

@ChrisRackauckas ChrisRackauckas merged commit a22d73a into SciML:master Apr 12, 2026
31 of 37 checks passed
ChrisRackauckas-Claude pushed a commit to ChrisRackauckas-Claude/RecursiveArrayTools.jl that referenced this pull request Apr 12, 2026
Add Mooncake to [extras] and [targets.test] so the new
`RecursiveArrayToolsMooncakeExt` is actually loaded and exercised in
the test suite, and add test/mooncake.jl as a direct unit test for the
new `Mooncake.increment_and_get_rdata!(::FData{@NamedTuple{x::T}},
::NoRData, ::ArrayPartition{P, T})` dispatch: constructs a matching
FData and ArrayPartition, calls `increment_and_get_rdata!`, and checks
that (a) the in-place accumulation on each inner-array leaf is correct
and (b) the method returns `NoRData()`. Also exercises a three-way
Float32 ArrayPartition to cover a different eltype and arity. Register
the testset in runtests.jl under the Core group.

Backport of SciML#575 to v3-backport.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants