Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnzymeRules #103

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft

Add EnzymeRules #103

wants to merge 4 commits into from

Conversation

sethaxen
Copy link
Contributor

Will fix #99

@sethaxen
Copy link
Contributor Author

For some reason, I can't seem to get the extension to work. Package precompilation fails with the error:

ERROR: The following 1 direct dependency failed to precompile:

AbstractFFTs [621f4979-c628-5d54-868e-fcf4e3e8185c]

Failed to precompile AbstractFFTs [621f4979-c628-5d54-868e-fcf4e3e8185c] to "/home/runner/.julia/compiled/v1.9/AbstractFFTs/jl_mYHZQL".
ERROR: LoadError: ArgumentError: Package AbstractFFTs does not have LinearAlgebra in its dependencies:
- You may have a partially installed environment. Try `Pkg.instantiate()`
  to ensure all packages in the environment are installed.
- Or, if you have AbstractFFTs checked out for development and have
  added LinearAlgebra as a dependency but haven't updated your primary
  environment's manifest file, try `Pkg.resolve()`.
- Otherwise you may need to report an issue with AbstractFFTs

although LinearAlgebra is clearly listed as both a dep and a weak dep.

Weirder still, if I activate the project, it now says it's empty, whereas if I remove this extension, it shows the dependencies:

julia> using Pkg; Pkg.activate(".")
  Activating project at `~/projects/AbstractFFTs.jl`

julia> Pkg.status()
Project AbstractFFTs v1.3.1
Status `~/projects/AbstractFFTs.jl/Project.toml` (empty project)

shell> head ./Project.toml
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.3.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

@KristofferC I've never had this problem with my extensions before. Do you know what could cause this?

@sethaxen
Copy link
Contributor Author

Nevermind, it seems extensions cannot have weak deps that are also deps. In this case, the dep needs to be loaded within the extension from the main package, see e.g. JuliaStats/LogExpFunctions.jl#63

@codecov
Copy link

codecov bot commented May 22, 2023

Codecov Report

Patch coverage has no change and project coverage change: -8.48 ⚠️

Comparison is base (a25656d) 87.08% compared to head (859abf0) 78.60%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #103      +/-   ##
==========================================
- Coverage   87.08%   78.60%   -8.48%     
==========================================
  Files           3        4       +1     
  Lines         209      229      +20     
==========================================
- Hits          182      180       -2     
- Misses         27       49      +22     
Impacted Files Coverage Δ
ext/AbstractFFTsEnzymeCoreExt.jl 0.00% <0.00%> (ø)

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@sethaxen
Copy link
Contributor Author

If #67 is merged, we could add rules for *(::Plan, ::StridedArray), so long as the plan is Const (if it's non-Const, then we would need the rule to support it being an in-place plan, which we can't do).

y::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
p::Const{<:AbstractFFTs.Plan{T}},
x::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
) where {T}
Copy link

@GiggleLiu GiggleLiu Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish the type T can be restricted to a finite set, e.g. BLAS number types, otherwise, it may produce incorrect gradients for user defined extensions. Generally speaking, I feel "generic" AD is not a good practise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pushforward of a linear operator is always itself. And so far as I know, every definition of an FFT is a linear operator. So I can see no reasons why this rule should be problematic for forward-mode.

Copy link

@GiggleLiu GiggleLiu Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, I may want to extended FFT with tropical numbers, which is not a real number. It is linear, but does not have an inverse. Then your rule would give me incorrect gradients without throwing an error. I have seen too many incorrect gradients in previous AD frameworks such as Zygote when handling complex numbers.

I agree it is good to have a generic backward routine there, but please constraint the interfaces to concrete types when porting it to an AD engine. It should not be so difficult for users to extend the list of supported types in the future. Defining fft rules on BLAS types would be good enough to cover most using cases. For those non-BLAS types, honestly we can not make any assumption for them. Julia community needs an AD engine with provable correctness, I think it is also one of the goals of Enzyme.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may want to extended FFT with tropical numbers

Is this really an FFT per se? I would consider a DFT generalized to some other ring to be a different transform.

Copy link

@GiggleLiu GiggleLiu Jul 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may want to extended FFT with tropical numbers

Is this really an FFT per se? I would consider a DFT generalized to some other ring to be a different transform.

Since Julia does not have a good trait system, I think it is in general impossible to restrict users to input what the functions are designed for. This is what I meant there lacks provable correctness.

It has been a big issue that none of the Julia libraries (except Enzyme) can provide reliable gradients. They claim too much on untested using cases, like complex numbers and tropical numbers. There has been a belief that "it is cool if the code works in cases that it is not expected to work". But no, untested rules are not reliable, they can break on any future change even it works now. Rules must be concrete and tested, they are easy to extend, but hard to debug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By that argument, no AD rules should be defined here anyways, since downstream a user could define a custom Plan that doesn't do any kind of FFT at all. Then even with BLAS number types and strides arrays, any rule we write here would be wrong.

The counterargument is that if a user adds a method of a function whose properties are well-documented, other code should be able to assume and depend on those properties when calling the method for arbitrary inputs.

Taken to its logical conclusion, wouldn't your principle require that rules are never defined for abstract types, and further, that the type of every argument is concrete and known to the rule implementer?

Copy link

@GiggleLiu GiggleLiu Jul 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't your principle require that rules are never defined for abstract types, and further, that the type of every argument is concrete and known to the rule implementer?

A big YES. I do not think many people need the backward rules for non-BLAS types. You may want to support e.g. double float that defined in DoubleFloat.jl. I would argue in these using cases, users can port the generic rule to the AD framework with little effort. The rule can be generic, but when porting it to the AD framework, it should be concrete.

We have to decide between support more data types and ensure the correctness. I really wish there can be a trait system that user can tell the compiler "this element type is a field", then users can use the rule with more confidence. Facts obvious to you, like "fft should work on field rather than other rings" may not be obvious to others.

The counterargument is that if a user adds a method of a function whose properties are well-documented, other code should be able to assume and depend on those properties when calling the method for arbitrary inputs.

To differentiate a long code, I will let the code fly and see where it falls. I will add new rules to the AD engine to keep it flying. It is not a problem for me if a rule does not exist. So when using a new element type, like complex number, symbolic type, finite field algebra or the Tropical number type as mentioned above, I will probably not check whether the property of each function is as documented.

Then even with BLAS number types and strides arrays, any rule we write here would be wrong.

A warning will be thrown when overloading an existing function. Also, pirating is not difficult to avoid.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, if we have ChainRules I think we should have the corresponding EnzymeRules.

If users make the questionable choice of overriding fft to compute an unrelated function, then it is up to them to override the EnzymeRules/ChainRules as well.

@sethaxen
Copy link
Contributor Author

sethaxen commented Aug 26, 2023

I've paused work on this until EnzymeTestUtils (EnzymeAD/Enzyme.jl#782) is registered, which will make testing these rules reliably much more straightforward.

@sethaxen
Copy link
Contributor Author

Coming back to this, I think Enzyme rules should only be defined here abstractly for cases where we know they will not be breaking downstream code that otherwise Enzyme would have handled fine. So I agree with the following restrictions:

  • Restrict eltypes to BLAS types
  • Restrict array types to StridedArrays
  • only have rules for fft, fft!, and other other variants. In general we cannot tell if a plan is in-place or not. If we can catch cases where it is Const (i.e. Enzyme has inferred it is not used to carry any derivative information) without breaking fall backs, then great, but otherwise we don't define the rule.

These rules are considerably stricter than the ChainRules and for good reason. ChainRules are by convention often defined to cover up indexing code and mutating code to help Zygote and Diffractor, but this comes at the cost of doing the wrong thing for lots of types, hence the ProjectTo mechanism. Enzyme, on the other hand, can in principle handle many more types well, so we want to avoid writing rules that do the wrong thing for any cases where with no rule Enzyme would have worked fine.

Rules for * with Plans can be define in packages like FFTW where the type informs the in-placeness of the plan.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding EnzymeRules
3 participants