Skip to content

Conversation

@gdalle
Copy link
Member

@gdalle gdalle commented Nov 15, 2025

Updated experiments with Reactant-accelerated derivatives.

@wsmoses is this still the right paradigm in your opinion? I may not implement every operator right away but I thought starting with a gradient made sense

Related:

Warning

Re-toggle tests once this is mergeable

@codecov
Copy link

codecov bot commented Nov 16, 2025

Codecov Report

❌ Patch coverage is 94.23077% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 3.18%. Comparing base (bbc39fd) to head (c7e7598).

Files with missing lines Patch % Lines
...e/ext/DifferentiationInterfaceReactantExt/utils.jl 50.00% 3 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (bbc39fd) and HEAD (c7e7598). Click for more details.

HEAD has 59 uploads less than BASE
Flag BASE (bbc39fd) HEAD (c7e7598)
DIT 12 0
DI 48 1
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #918       +/-   ##
==========================================
- Coverage   98.10%   3.18%   -94.92%     
==========================================
  Files         133     101       -32     
  Lines        7971    5553     -2418     
==========================================
- Hits         7820     177     -7643     
- Misses        151    5376     +5225     
Flag Coverage Δ
DI 3.18% <94.23%> (-95.65%) ⬇️
DIT ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gdalle
Copy link
Member Author

gdalle commented Nov 16, 2025

@wsmoses does this look better to you now?
I'm not sure what we should do in terms of storage versus allocations. We can store xr (and even contextsr) during preparation and then copy to them at execution time instead of generating a new RArray, but that would require a copying method to be defined (which doesn't apply to all non-array objects).

DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, compiled_gradient) = prep
copyto!(xr, x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only do this if x is not a reactantarray

) where {F, C}
_sig = DI.signature(f, rebackend, x; strict)
backend = rebackend.mode
xr = to_reac(x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't save anything as a prep argument if a reactant array, I would keep this as if reactant array then xr is nothing otherwise to_rarray(x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable

DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, compiled_value_and_gradient) = prep
copyto!(xr, x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here

DI.check_prep(f, prep, rebackend, x)
backend = rebackend.mode
(; xr, gr, compiled_gradient!) = prep
copyto!(xr, x)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Etc

@test check_inplace(backend)

test_differentiation(
backend, DifferentiationInterfaceTest.default_scenarios(;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test that the prep contains no data except the compiled fn if compiled for a reactant array

@wsmoses
Copy link

wsmoses commented Nov 16, 2025

Modulo some comments being addressed above this looks reasonable to me.

However note that there may be a potential mismatch in expectations behind what prepare gradient defines and what reactant compile defines.

See https://enzymead.github.io/Reactant.jl/dev/tutorials/partial-evaluation

Currently any data inside a constant or cache will be baked into the compiled function and will not be re read in later evaluation.

Enzyme.jl in particular does not have any such constraint (as it will always re run with live data as prep is nothing).

Something like reversediff compiled probably does bake in the assumption from compilation.

So this is a question of what is the semantics of prep.

If the non differentiated data is the same between prep and evaluation there is no difference between the two

@gdalle
Copy link
Member Author

gdalle commented Nov 16, 2025

So this is a question of what is the semantics of prep. If the non differentiated data is the same between prep and evaluation there is no difference between the two

The semantics of prep: differentiated and non-differentiated data are free to change between preparation and execution, as long as they keep the same types and sizes. See here for details.

Currently any data inside a constant or cache will be baked into the compiled function and will not be re read in later evaluation.

I thought converting the contexts into reactant arrays inside contextr would allow them to be traced? If that's true, then they won't be partially evaluated into the compiled function, which means the semantics of prep are respected?

@wsmoses
Copy link

wsmoses commented Nov 16, 2025

No your to_reac function does not achieve that. If you have a context of Tuple{Int, Int} this will not be converted unless you do to_rarray(context; track_numbers=Number).

However, concurrently, most of the time you actually want to partially evaluate integers in (e.g. for sizes/bounds/etc).

I think the more reasonable setup here is to not to_rarray the context, and instead add a similar warning to the one from reversediff:

These rules hold for the majority of backends, but there are some exceptions. The most important exception is [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) and its taping mechanism, which is sensitive to control flow inside the function.

@gdalle
Copy link
Member Author

gdalle commented Nov 16, 2025

Being forced to keep the same context values makes preparation pretty much useless. I think I'd rather have us trace everything in the context, even if it means a slowdown in some cases. Will it lead to actual errors?

@gdalle
Copy link
Member Author

gdalle commented Nov 16, 2025

Or alternately we could restrict the kind of contexts we allow here

@wsmoses
Copy link

wsmoses commented Nov 16, 2025

Yes unnecessarily tracing objects can lead to errors that would fail to compile otherwise

@wsmoses
Copy link

wsmoses commented Nov 16, 2025

and it doesn't make it useless, it just means that the user is responsible for performing the to_rarray themselves for things that may change

@gdalle
Copy link
Member Author

gdalle commented Nov 17, 2025

Yes unnecessarily tracing objects can lead to errors that would fail to compile otherwise

Can you give an example so that I wrap my mind around this?

and it doesn't make it useless, it just means that the user is responsible for performing the to_rarray themselves for things that may change

That would be a Reactant-specific workaround, which doesn't fit other DI-supported backends. The whole point of DI is to enable easy backend switch, so I'd love to find a solution that doesn't expect users to wrap some of the arguments in Reactant-specific types when they want to switch to AutoReactant.

@gdalle
Copy link
Member Author

gdalle commented Nov 17, 2025

Besides, the problem is not specific to contexts: x itself can contain integers we don't necessarily want to track. And a preparation that can only be reused if nothing at all changes will only ever be used once anyway, so it is pointless.

Maybe DI could expose a function like trace(a, backend) or translate(a, backend) which takes care of populating values in the correct way for differentiation / Reactant compilation? I already use such a function internally anyway, especially in ForwardDiff and other operator overloading-based backends.

@wsmoses
Copy link

wsmoses commented Nov 17, 2025

julia> using Reactant; x = Reactant.to_rarray(ones(10)); s = Reactant.ConcreteRNumber(2); e = Reactant.ConcreteRNumber(5);

julia> f(x, s, e) = x[s:e]
f (generic function with 1 method)

julia> @jit f(x, s, e)
ERROR: TypeError: non-boolean (Reactant.TracedRNumber{Bool}) used in boolean context
Stacktrace:
  [1] getindex_linear
    @ ~/git/Reactant.jl/src/Indexing.jl:340 [inlined]
  [2] (::Nothing)(none::typeof(Reactant.TracedIndexing.getindex_linear), none::Reactant.TracedRArray{Float64, 1}, none::Reactant.TracedUnitRange{Reactant.TracedRNumber{Int64}})
    @ Reactant ./<missing>:0
  [3] getindex_linear
    @ ~/git/Reactant.jl/src/Indexing.jl:339 [inlined]
  [4] call_with_reactant(::Reactant.MustThrowError, ::typeof(Reactant.TracedIndexing.getindex_linear), ::Reactant.TracedRArray{…}, ::Reactant.TracedUnitRange{…})
    @ Reactant ~/git/Reactant.jl/src/utils.jl:0
  [5] getindex
    @ ~/git/Reactant.jl/src/Indexing.jl:75 [inlined]
  [6] (::Nothing)(none::typeof(getindex), none::Reactant.TracedRArray{Float64, 1}, none::Reactant.TracedUnitRange{Reactant.TracedRNumber{Int64}})
    @ Reactant ./<missing>:0
  [7] getindex
    @ ~/git/Reactant.jl/src/Indexing.jl:75 [inlined]
  [8] call_with_reactant(::Reactant.MustThrowError, ::typeof(getindex), ::Reactant.TracedRArray{Float64, 1}, ::Reactant.TracedUnitRange{Reactant.TracedRNumber{Int64}})
    @ Reactant ~/git/Reactant.jl/src/utils.jl:0
  [9] f
    @ ./REPL[5]:1 [inlined]
 [10] (::Nothing)(none::typeof(f), none::Reactant.TracedRArray{Float64, 1}, none::Reactant.TracedRNumber{Int64}, none::Reactant.TracedRNumber{Int64})
    @ Reactant ./<missing>:0
 [11] TracedUnitRange
    @ ~/git/Reactant.jl/src/Types.jl:108 [inlined]
 [12] TracedUnitRange
    @ ~/git/Reactant.jl/src/TracedRange.jl:124 [inlined]
 [13] Colon
    @ ~/git/Reactant.jl/src/TracedRange.jl:181 [inlined]
 [14] f
    @ ./REPL[5]:1 [inlined]
 [15] call_with_reactant(::typeof(f), ::Reactant.TracedRArray{Float64, 1}, ::Reactant.TracedRNumber{Int64}, ::Reactant.TracedRNumber{Int64})
    @ Reactant ~/git/Reactant.jl/src/utils.jl:0
 [16] make_mlir_fn(f::typeof(f), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/git/Reactant.jl/src/TracedUtils.jl:345
 [17] make_mlir_fn
    @ ~/git/Reactant.jl/src/TracedUtils.jl:275 [inlined]
 [18] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(f), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}, sdygroupidcache::Tuple{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:1608
 [19] compile_mlir!
    @ ~/git/Reactant.jl/src/Compiler.jl:1570 [inlined]
 [20] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:3516
 [21] compile_xla
    @ ~/git/Reactant.jl/src/Compiler.jl:3488 [inlined]
 [22] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:3592
 [23] top-level scope
    @ ~/git/Reactant.jl/src/Compiler.jl:2661
Some type information was truncated. Use `show(err)` to see complete types.

julia> @jit f(x, 2, 5)
4-element ConcretePJRTArray{Float64,1}:
 1.0
 1.0
 1.0
 1.0

@wsmoses
Copy link

wsmoses commented Nov 17, 2025

and I don't think this is terribly reactant-specific. The same core issue here equally applies to reversediff compiled [where the context will equally be baked it]. Just because reactant also has a way to circumvent the problem in special cases shouldn't mean it is treated differently here.

@wsmoses
Copy link

wsmoses commented Nov 18, 2025

and I don't think this is terribly reactant-specific. The same core issue here equally applies to reversediff compiled [where the context will equally be baked it]. Just because reactant also has a way to circumvent the problem in special cases shouldn't mean it is treated differently here.

bumping this @gdalle are you okay not to trace the contexts?

@gdalle
Copy link
Member Author

gdalle commented Nov 18, 2025

Not really. Many use cases of DI that I can think of require changing contexts, and Reactant has to be relevant for these cases too. Re-compiling the derivatives for each context changes is very impractical. On the other hand, asking users to pass RArrays instead of their normal arguments might make other backends fail, so the code on which DI runs is no longer fully generic, and I want to avoid that too. It would be like asking ForwardDiff users to pass arrays of Dual numbers.
Besides, that problem is not specific to contexts: what is stopping x itself (the active argument) from containing scalars that users may or may not want to trace?

I don't have a lot of bandwith these days, but I think the right solution might be to expose something like DI.to_reactant, telling users that the function will be called on every argument before Reactant compilation. That way, if they want to enforce specific tracing behavior, they can wrap their argument in a custom type and overload to_reactant, but it doesn't force them to

@wsmoses
Copy link

wsmoses commented Nov 18, 2025

In order to be differentiated the data must be a reactant array. If we assume that DI only officially supports array inputs this is fine

@gdalle
Copy link
Member Author

gdalle commented Nov 18, 2025

In order to be differentiated the data must be a reactant array. If we assume that DI only officially supports array inputs this is fine

"array" as in "RArray only" or as in "any nested struct of RArrays?

@gdalle
Copy link
Member Author

gdalle commented Nov 18, 2025

By the way, this Reactant issue prevented me from testing DI.Cache here, if you happen to have a quick fix lying around. I can also modify the DI EnzymeExt source code if this is expected behavior

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants