-
-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
definitions for adjoints of FFTW functions #215
Conversation
updating to new master
added test that needs FillArrays functionality wo pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a couple comments but this is great to have.
src/lib/fft.jl
Outdated
# FFTW functions do not work with FillArrays, which are needed | ||
# for some functionality of Zygote. To make it work with FillArrays | ||
# as well, overload the relevant functions | ||
FFTW.fft(x::Fill) = FFTW.fft(collect(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use dims...
to make this more compact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the tip, I changed to dims... now
Project.toml
Outdated
@@ -16,7 +16,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | |||
Requires = "ae029012-a4dd-5104-9daa-d747884805df" | |||
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" | |||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | |||
|
|||
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the package manager to add FFTW, so that the formatting is right and so that the manifest gets updated too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed now via package manager
src/Zygote.jl
Outdated
@@ -35,6 +35,7 @@ include("lib/nnlib.jl") | |||
include("lib/broadcast.jl") | |||
include("lib/forward.jl") | |||
include("lib/utils.jl") | |||
include("lib/fft.jl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this file is short enough that it could just go in array.jl
, after linear algebra.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved it to array.jl now
src/lib/fft.jl
Outdated
# if it is not a single int, convert to array so that we can use it | ||
# for indexing | ||
if typeof(dims) != Int | ||
dims = collect(dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is unnecessary, since numbers act like single-element collections anyway for things like prod
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the exception for int was necessary because size(xs)[collect(1)] does not work, but I now changed to collect(size(xs))[dims] which works and seems more elegant to me and does not need different handling for ints
src/lib/fft.jl
Outdated
end | ||
|
||
@adjoint function FFTW.fft(xs, dims) | ||
# up to now only works when dims is a single integer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment isn't clear to me; you seem to be handling multiple dimensions below. If only single integers work, best to add a type restriction here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment was from an old version and I forgot to remove it, sorry for the confusion. It indeed works also with multiple dimensions. I removed the comment now
@test gradient((x)->real(fft(ifft(x))[1]),x)[1][1] == 1.0+0.0im | ||
|
||
# check ffts for individual dimensions | ||
@test gradient((x)->sum(abs.(FFTW.fft(x))),x)[1] ≈ gradient((x)->sum(abs.(FFTW.fft(FFTW.fft(x,1),2))),x)[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be a bit simpler to use gradcheck
here instead. That will check against numerical gradients, which is handy for correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added 4 gradchecks as well now. I think the old test still make sense since they check whether the FFTs along individual dims work and produce the same results as when doing all dimensions in one go. However if you think that the old tests are mainly bloating gradcheck.jl I can also ommit them
|
||
@adjoint function FFTW.fft(xs, dims) | ||
return FFTW.fft(xs, dims), function(Δ) | ||
# dims can be int, array or tuple, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indentation is off here, just needs a quick fix.
src/lib/array.jl
Outdated
|
||
@adjoint function FFTW.fft(xs, dims) | ||
return FFTW.fft(xs, dims), function(Δ) | ||
# dims can be int, array or tuple, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: indent these few lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I (hopefully) fixed the indentation now.
Alright, great, this looks good. There's a merge conflict due to the manifest; it'd be great if you could resolve that, and even better if you can use a newer Julia version to minimise the number of changes while you're at it. Other than that I think we're good to go. |
OK. I am very new both to Julia and collaborating on github, but I will do my best to resolve the conflicts |
I now switched to julia 1.1.1 and resolved the conflicts. However, for some reason my manifest has more entries than the current one in master (AbstractFFTs, Conda, JSON, Reexport, VersionParsing) in addition to FFTW that I added via the package manager. Should I remove them? |
No, that's ok, they are just dependencies of FFTW that got pulled in. Unfortunately the tests are now failing; looks like you just have a syntax error in the tests you added. |
I now fixed the tests, and they pass on my local installation. I also fixed the testset for hcat which had failed as well, and some other fixes in other testsets (forgotten @test) |
The travis CI has finished, but this does not seem to be reported to github (when I check https://travis-ci.org/FluxML/Zygote.jl/builds/543743653?utm_source=github_status&utm_medium=notification all tests have passed, but here it is still pending.) Is there any way to re-trigger the travis-ci build? |
@@ -4,6 +4,7 @@ version = "0.3.1" | |||
|
|||
[deps] | |||
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" | |||
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just depend on AbstractFFTs.jl and write the dispatches using the high level functions?
👍 Thanks! |
But can't this be done with just AbstractFFTs? |
Probably, but fine to do that as a follow up. The main issue is testing but we can have FFTW as a test dependency. |
Adjoint definitions and accompanying tests for the most important functions of FFTW (fft and ifft).
Since deriving the adjoint of FFTs is a bit tricky, I double-checked the results with differentiating a simple DFT-implementation which is differentiable with ReverseDiff, and the results are equal down to machine-precision. This is possible since FFT is in fact only a more efficient implementation of DFT, which itself boils down to a simple matrix-multiplication.
I add this comparison as comment here for reference: