Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding EnzymeRules #99

Open
sethaxen opened this issue Apr 19, 2023 · 6 comments · May be fixed by #103
Open

Adding EnzymeRules #99

sethaxen opened this issue Apr 19, 2023 · 6 comments · May be fixed by #103

Comments

@sethaxen
Copy link
Contributor

Similar to the ChainRulesCore support, we could use EnzymeCore.EnzymeRules to define Forward and Reverse mode rules for Enzyme in an extension.

EnzymeCore requires at least Julia v1.6. Making it a dependency for Julia versions older than v1.6 (as is done with ChainRulesCore) would then only be possible if AbstractFFTs add a Julia v1.6 version bound. But since EnzymeCore's sole dependency (Adapt) depends on Requires, it may make more sense to conditionally load on pre-v1.9 using Requires.jl: https://pkgdocs.julialang.org/dev/creating-packages/#Requires.jl

Unlike ChainRulesCore support, it probably only makes sense to only define rules for StridedArray inputs to avoid doing the wrong thing for sparse or structured arrays.

@ChrisRackauckas
Copy link
Member

I don't think anyone would be against it. Make it an extension package using v1.9 extensions and there's no dep added here. It's more about getting it done.

@wsmoses
Copy link

wsmoses commented Apr 19, 2023

Enzyme rules require v1.7 or above. cc @vchuravy

EnzymeCore is designed to be the light flexible package like ChainRules core for importing and adding rules.

@sethaxen
Copy link
Contributor Author

Okay, I'm happy to tackle this once I understand some confusing behavior of complex array rules (see EnzymeAD/Enzyme.jl#744 and EnzymeAD/Enzyme.jl#739 (comment)).

Unlike the ChainRules support, I would restrict Enzyme rules to StridedArray inputs, which is what FFTW promotes everything to and avoids doing the wrong thing for structurally sparse array inputs.

@sethaxen sethaxen linked a pull request May 22, 2023 that will close this issue
@sethaxen
Copy link
Contributor Author

A few observations after starting work on this:

  • unlike ChainRules, we should not define rules for fft, ifft, bfft etc, since this will cover up the primal functions or any custom methods and do more work than is necessary
  • we cannot define rules for *(p::Plan, x::StridedVector) since we have no way of knowing whether p was constructed as an in-place plan or not, and the rule changes depending on whether this is the case.
  • we cannot define reverse-mode rules for mul!(y::AbstractArray, plan::Plan, x::AbstractArray), since we don't know what the plan is and therefore don't know how to normalize.
  • it might be safe to define the forward-mode rules for mul!(y::AbstractArray, plan::Plan, x::AbstractArray). As long as every possible FFT is linear, then the pushforward is the same as the primal, and we know that plan must be out-of-place
  • most of the rules probably need to be in the package that implements the interface, e.g. FFTW.

@wsmoses
Copy link

wsmoses commented May 23, 2023

Why would defining rules for fft/etc cause more work to be done?

@sethaxen
Copy link
Contributor Author

unlike ChainRules, we should not define rules for fft, ifft, bfft etc, since this will cover up the primal functions or any custom methods and do more work than is necessary

Let me clarify. Let f be fft, fft!, etc. Assume no-one pirates f for an array type defined in the standard library in some other package. Then it is safe for us to define forward- and reverse-mode rules for f for all array types defined in the standard lib (or maybe even just StridedArray). To avoid doing extra work, when f is not in-place, we can use the same plan for the primal and tangent. Because FFTW works by promoting any array to a StridedArray, this would cover all cases covered by ChainRules when FFTW is the backend.

For AbstractArray inputs, we cannot define reverse-mode rules for f, because the input might be structured; these rules need to be defined in the backend package.

Why would defining rules for fft/etc cause more work to be done?

We can also define forward-mode rules for f for AbstractArray inputs, but since a package might overload f for a custom array type to use a method that does not use a plan, we just call the primal function and don't construct a plan. If the primal does use a plan, then this would construct at least one unnecessary plan and do more work than is necessary. In the batched case, this could construct many more plans than are needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants