Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnzymeRules #103

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,24 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
AbstractFFTsEnzymeCoreExt = "EnzymeCore"

[compat]
ChainRulesCore = "1"
EnzymeCore = "0.3"
julia = "^1.0"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["ChainRulesCore", "ChainRulesTestUtils", "Random", "Test", "Unitful"]
test = ["ChainRulesCore", "ChainRulesTestUtils", "Enzyme", "Random", "Test", "Unitful"]
58 changes: 58 additions & 0 deletions ext/AbstractFFTsEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
module AbstractFFTsEnzymeCoreExt

using AbstractFFTs
using AbstractFFTs.LinearAlgebra
using EnzymeCore
using EnzymeCore.EnzymeRules

######################
# Forward-mode rules #
######################

const DuplicatedOrBatchDuplicated{T} = Union{Duplicated{T},BatchDuplicated{T}}

# since FFTs are linear, implement all forward-model rules generically at a low-level

function EnzymeRules.forward(

Check warning on line 16 in ext/AbstractFFTsEnzymeCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsEnzymeCoreExt.jl#L16

Added line #L16 was not covered by tests
func::Const{typeof(mul!)},
RT::Type{<:Const},
y::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
p::Const{<:AbstractFFTs.Plan{T}},
x::DuplicatedOrBatchDuplicated{<:StridedArray{T}},
) where {T}
Copy link

@GiggleLiu GiggleLiu Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish the type T can be restricted to a finite set, e.g. BLAS number types, otherwise, it may produce incorrect gradients for user defined extensions. Generally speaking, I feel "generic" AD is not a good practise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pushforward of a linear operator is always itself. And so far as I know, every definition of an FFT is a linear operator. So I can see no reasons why this rule should be problematic for forward-mode.

Copy link

@GiggleLiu GiggleLiu Jul 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, I may want to extended FFT with tropical numbers, which is not a real number. It is linear, but does not have an inverse. Then your rule would give me incorrect gradients without throwing an error. I have seen too many incorrect gradients in previous AD frameworks such as Zygote when handling complex numbers.

I agree it is good to have a generic backward routine there, but please constraint the interfaces to concrete types when porting it to an AD engine. It should not be so difficult for users to extend the list of supported types in the future. Defining fft rules on BLAS types would be good enough to cover most using cases. For those non-BLAS types, honestly we can not make any assumption for them. Julia community needs an AD engine with provable correctness, I think it is also one of the goals of Enzyme.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may want to extended FFT with tropical numbers

Is this really an FFT per se? I would consider a DFT generalized to some other ring to be a different transform.

Copy link

@GiggleLiu GiggleLiu Jul 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may want to extended FFT with tropical numbers

Is this really an FFT per se? I would consider a DFT generalized to some other ring to be a different transform.

Since Julia does not have a good trait system, I think it is in general impossible to restrict users to input what the functions are designed for. This is what I meant there lacks provable correctness.

It has been a big issue that none of the Julia libraries (except Enzyme) can provide reliable gradients. They claim too much on untested using cases, like complex numbers and tropical numbers. There has been a belief that "it is cool if the code works in cases that it is not expected to work". But no, untested rules are not reliable, they can break on any future change even it works now. Rules must be concrete and tested, they are easy to extend, but hard to debug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By that argument, no AD rules should be defined here anyways, since downstream a user could define a custom Plan that doesn't do any kind of FFT at all. Then even with BLAS number types and strides arrays, any rule we write here would be wrong.

The counterargument is that if a user adds a method of a function whose properties are well-documented, other code should be able to assume and depend on those properties when calling the method for arbitrary inputs.

Taken to its logical conclusion, wouldn't your principle require that rules are never defined for abstract types, and further, that the type of every argument is concrete and known to the rule implementer?

Copy link

@GiggleLiu GiggleLiu Jul 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't your principle require that rules are never defined for abstract types, and further, that the type of every argument is concrete and known to the rule implementer?

A big YES. I do not think many people need the backward rules for non-BLAS types. You may want to support e.g. double float that defined in DoubleFloat.jl. I would argue in these using cases, users can port the generic rule to the AD framework with little effort. The rule can be generic, but when porting it to the AD framework, it should be concrete.

We have to decide between support more data types and ensure the correctness. I really wish there can be a trait system that user can tell the compiler "this element type is a field", then users can use the rule with more confidence. Facts obvious to you, like "fft should work on field rather than other rings" may not be obvious to others.

The counterargument is that if a user adds a method of a function whose properties are well-documented, other code should be able to assume and depend on those properties when calling the method for arbitrary inputs.

To differentiate a long code, I will let the code fly and see where it falls. I will add new rules to the AD engine to keep it flying. It is not a problem for me if a rule does not exist. So when using a new element type, like complex number, symbolic type, finite field algebra or the Tropical number type as mentioned above, I will probably not check whether the property of each function is as documented.

Then even with BLAS number types and strides arrays, any rule we write here would be wrong.

A warning will be thrown when overloading an existing function. Also, pirating is not difficult to avoid.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, if we have ChainRules I think we should have the corresponding EnzymeRules.

If users make the questionable choice of overriding fft to compute an unrelated function, then it is up to them to override the EnzymeRules/ChainRules as well.

val = func.val(y.val, p.val, x.val)
if x isa Duplicated && y isa Duplicated
dval = func.val(y.dval, p.val, x.dval)
elseif x isa Duplicated && y isa Duplicated
dval = map(y.dval, x.dval) do dy, dx
return func.val(dy, p.val, dx)

Check warning on line 28 in ext/AbstractFFTsEnzymeCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsEnzymeCoreExt.jl#L23-L28

Added lines #L23 - L28 were not covered by tests
end
end
return nothing

Check warning on line 31 in ext/AbstractFFTsEnzymeCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsEnzymeCoreExt.jl#L31

Added line #L31 was not covered by tests
end

function EnzymeRules.forward(

Check warning on line 34 in ext/AbstractFFTsEnzymeCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsEnzymeCoreExt.jl#L34

Added line #L34 was not covered by tests
func::Const{typeof(*)},
RT::Type{
<:Union{Const,Duplicated,DuplicatedNoNeed,BatchDuplicated,BatchDuplicatedNoNeed}
},
p::Const{<:AbstractFFTs.Plan},
x::DuplicatedOrBatchDuplicated{<:StridedArray},
)
RT <: Const && return func.val(p.val, x.val)
if x isa Duplicated
dval = func.val(p.val, x.dval)
RT <: DuplicatedNoNeed && return dval
val = func.val(p.val, x.val)
RT <: Duplicated && return Duplicated(val, dval)

Check warning on line 47 in ext/AbstractFFTsEnzymeCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsEnzymeCoreExt.jl#L42-L47

Added lines #L42 - L47 were not covered by tests
else # x isa BatchDuplicated
dval = map(x.dval) do dx
return func.val(p.val, dx)

Check warning on line 50 in ext/AbstractFFTsEnzymeCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsEnzymeCoreExt.jl#L49-L50

Added lines #L49 - L50 were not covered by tests
end
RT <: BatchDuplicatedNoNeed && return dval
val = func.val(p.val, x.val)
RT <: BatchDuplicated && return BatchDuplicated(val, dval)

Check warning on line 54 in ext/AbstractFFTsEnzymeCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/AbstractFFTsEnzymeCoreExt.jl#L52-L54

Added lines #L52 - L54 were not covered by tests
end
end

end # module
Loading