-
Notifications
You must be signed in to change notification settings - Fork 27
feat: add gradient with AutoReactant #918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
Outdated
Show resolved
Hide resolved
DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl
Outdated
Show resolved
Hide resolved
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #918 +/- ##
==========================================
- Coverage 98.10% 3.18% -94.92%
==========================================
Files 133 101 -32
Lines 7971 5553 -2418
==========================================
- Hits 7820 177 -7643
- Misses 151 5376 +5225
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@wsmoses does this look better to you now? |
| DI.check_prep(f, prep, rebackend, x) | ||
| backend = rebackend.mode | ||
| (; xr, compiled_gradient) = prep | ||
| copyto!(xr, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should only do this if x is not a reactantarray
| ) where {F, C} | ||
| _sig = DI.signature(f, rebackend, x; strict) | ||
| backend = rebackend.mode | ||
| xr = to_reac(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't save anything as a prep argument if a reactant array, I would keep this as if reactant array then xr is nothing otherwise to_rarray(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable
| DI.check_prep(f, prep, rebackend, x) | ||
| backend = rebackend.mode | ||
| (; xr, compiled_value_and_gradient) = prep | ||
| copyto!(xr, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here
| DI.check_prep(f, prep, rebackend, x) | ||
| backend = rebackend.mode | ||
| (; xr, gr, compiled_gradient!) = prep | ||
| copyto!(xr, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Etc
| @test check_inplace(backend) | ||
|
|
||
| test_differentiation( | ||
| backend, DifferentiationInterfaceTest.default_scenarios(; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test that the prep contains no data except the compiled fn if compiled for a reactant array
|
Modulo some comments being addressed above this looks reasonable to me. However note that there may be a potential mismatch in expectations behind what prepare gradient defines and what reactant compile defines. See https://enzymead.github.io/Reactant.jl/dev/tutorials/partial-evaluation Currently any data inside a constant or cache will be baked into the compiled function and will not be re read in later evaluation. Enzyme.jl in particular does not have any such constraint (as it will always re run with live data as prep is nothing). Something like reversediff compiled probably does bake in the assumption from compilation. So this is a question of what is the semantics of prep. If the non differentiated data is the same between prep and evaluation there is no difference between the two |
The semantics of
I thought converting the |
|
No your to_reac function does not achieve that. If you have a context of Tuple{Int, Int} this will not be converted unless you do However, concurrently, most of the time you actually want to partially evaluate integers in (e.g. for sizes/bounds/etc). I think the more reasonable setup here is to not to_rarray the context, and instead add a similar warning to the one from reversediff: |
|
Being forced to keep the same context values makes preparation pretty much useless. I think I'd rather have us trace everything in the context, even if it means a slowdown in some cases. Will it lead to actual errors? |
|
Or alternately we could restrict the kind of contexts we allow here |
|
Yes unnecessarily tracing objects can lead to errors that would fail to compile otherwise |
|
and it doesn't make it useless, it just means that the user is responsible for performing the to_rarray themselves for things that may change |
Can you give an example so that I wrap my mind around this?
That would be a Reactant-specific workaround, which doesn't fit other DI-supported backends. The whole point of DI is to enable easy backend switch, so I'd love to find a solution that doesn't expect users to wrap some of the arguments in Reactant-specific types when they want to switch to |
|
Besides, the problem is not specific to contexts: Maybe DI could expose a function like |
|
|
and I don't think this is terribly reactant-specific. The same core issue here equally applies to reversediff compiled [where the context will equally be baked it]. Just because reactant also has a way to circumvent the problem in special cases shouldn't mean it is treated differently here. |
bumping this @gdalle are you okay not to trace the contexts? |
|
Not really. Many use cases of DI that I can think of require changing contexts, and Reactant has to be relevant for these cases too. Re-compiling the derivatives for each context changes is very impractical. On the other hand, asking users to pass I don't have a lot of bandwith these days, but I think the right solution might be to expose something like |
|
In order to be differentiated the data must be a reactant array. If we assume that DI only officially supports array inputs this is fine |
"array" as in " |
|
By the way, this Reactant issue prevented me from testing |
Updated experiments with Reactant-accelerated derivatives.
@wsmoses is this still the right paradigm in your opinion? I may not implement every operator right away but I thought starting with a gradient made sense
Related:
Warning
Re-toggle tests once this is mergeable