-
-
Notifications
You must be signed in to change notification settings - Fork 216
definitions for adjoints of FFTW functions #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5d5446f
b4aeed8
82660e9
9e4cb70
70f9d11
15577d3
a56de20
d947c0c
13b698f
ac9bcc2
d235d9c
960c1f5
9505dbf
d7e1a88
cdb5615
286b2fb
80d5f14
e85895d
f556fc8
4787375
fab7e3a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| using FillArrays | ||
| using FillArrays, FFTW | ||
| using FillArrays: AbstractFill, getindex_value | ||
| using Base.Broadcast: broadcasted, broadcast_shape | ||
|
|
||
|
|
@@ -429,6 +429,53 @@ end | |
| @adjoint -(A::AbstractArray, B::AbstractArray) = A - B, Δ->(Δ, -Δ) | ||
| @adjoint -(A::AbstractArray) = -A, Δ->(-Δ,) | ||
|
|
||
| # FFTW | ||
| # =================== | ||
|
|
||
| # FFTW functions do not work with FillArrays, which are needed | ||
| # for some functionality of Zygote. To make it work with FillArrays | ||
| # as well, overload the relevant functions | ||
| FFTW.fft(x::Fill, dims...) = FFTW.fft(collect(x), dims...) | ||
| FFTW.ifft(x::Fill, dims...) = FFTW.ifft(collect(x), dims...) | ||
|
|
||
|
|
||
| # the adjoint jacobian of an FFT with respect to its input is the reverse FFT of the | ||
| # gradient of its inputs, but with different normalization factor | ||
| @adjoint function FFTW.fft(xs) | ||
| return FFTW.fft(xs), function(Δ) | ||
| N = length(xs) | ||
| return (N * FFTW.ifft(Δ),) | ||
| end | ||
| end | ||
|
|
||
| @adjoint function FFTW.ifft(xs) | ||
| return FFTW.ifft(xs), function(Δ) | ||
| N = length(xs) | ||
| return (1/N* FFTW.fft(Δ),) | ||
| end | ||
| end | ||
|
|
||
| @adjoint function FFTW.fft(xs, dims) | ||
| return FFTW.fft(xs, dims), function(Δ) | ||
| # dims can be int, array or tuple, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The indentation is off here, just needs a quick fix. |
||
| # convert to collection for use as index | ||
| dims = collect(dims) | ||
| # we need to multiply by all dimensions that we FFT over | ||
| N = prod(collect(size(xs))[dims]) | ||
| return (N * FFTW.ifft(Δ, dims), nothing) | ||
| end | ||
| end | ||
|
|
||
| @adjoint function FFTW.ifft(xs,dims) | ||
| return FFTW.ifft(xs, dims), function(Δ) | ||
| # dims can be int, array or tuple, | ||
| # convert to collection for use as index | ||
| dims = collect(dims) | ||
| # we need to divide by all dimensions that we FFT over | ||
| N = prod(collect(size(xs))[dims]) | ||
| return (1/N * FFTW.fft(Δ, dims),nothing) | ||
| end | ||
| end | ||
|
|
||
| # FillArray functionality | ||
| # ======================= | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays | ||
| using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics, FillArrays, FFTW | ||
| using Zygote: gradient | ||
| using NNlib: conv, ∇conv_data, depthwiseconv | ||
| using Base.Broadcast: broadcast_shape | ||
|
|
@@ -563,17 +563,38 @@ using Zygote: Buffer | |
| end | ||
|
|
||
| @testset "FillArrays" begin | ||
| gradcheck(x->sum(Fill(x[], (2, 2))), [0.1]) | ||
| @test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1]) | ||
| @test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing | ||
| @test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing | ||
| end | ||
|
|
||
| @testset "AbstractArray Addition / Subtraction / Negation" begin | ||
| rng, M, N, P = MersenneTwister(123567), 3, 7, 11 | ||
| A, B = randn(rng, M, N, P), randn(rng, M, N, P) | ||
| gradtest(+, A, B) | ||
| gradtest(-, A, B) | ||
| gradtest(-, A) | ||
| @test gradtest(+, A, B) | ||
| @test gradtest(-, A, B) | ||
| @test gradtest(-, A) | ||
| end | ||
|
|
||
| @testset "FFTW" begin | ||
| x=[-0.353213 -0.789656 -0.270151; -0.95719 -1.27933 0.223982] | ||
| # gradient of ifft(rfft) must be 1 | ||
| @test gradient((x)->real(ifft(fft(x))[1]),x)[1][1] == 1.0+0.0im | ||
| @test gradient((x)->real(fft(ifft(x))[1]),x)[1][1] == 1.0+0.0im | ||
|
|
||
| # check ffts for individual dimensions | ||
| @test gradient((x)->sum(abs.(FFTW.fft(x))),x)[1] ≈ gradient((x)->sum(abs.(FFTW.fft(FFTW.fft(x,1),2))),x)[1] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be a bit simpler to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added 4 gradchecks as well now. I think the old test still make sense since they check whether the FFTs along individual dims work and produce the same results as when doing all dimensions in one go. However if you think that the old tests are mainly bloating gradcheck.jl I can also ommit them |
||
| @test gradient((x)->abs(sum((FFTW.fft(x)))),x)[1] ≈ gradient((x)->abs(sum(FFTW.fft(FFTW.fft(x,1),2))),x)[1] | ||
| @test gradient((x, dims)->sum(abs.(FFTW.fft(x,dims))),x,(1,2))[1] ≈ gradient((x)->sum(abs.(FFTW.fft(x))),x)[1] | ||
| @test gradient((x)->sum(abs.(FFTW.fft(x,(1,2)))),x)[1] ≈ gradient((x)->sum(abs.(FFTW.fft(FFTW.fft(x,1),2))),x)[1] | ||
| @test gradient((x, dims)->sum(abs.(FFTW.ifft(x,dims))),x,(1,2))[1] ≈ gradient((x)->sum(abs.(FFTW.ifft(x))),x)[1] | ||
| @test gradient((x)->sum(abs.(FFTW.ifft(x,(1,2)))),x)[1] ≈ gradient((x)->sum(abs.(FFTW.ifft(FFTW.ifft(x,1),2))),x)[1] | ||
|
|
||
| @test gradcheck(x->sum(abs.(FFTW.fft(x))), x) | ||
| @test gradcheck(x->sum(abs.(FFTW.ifft(x))), x) | ||
| @test gradcheck(x->sum(abs.(FFTW.fft(x, 1))), x) | ||
| @test gradcheck(x->sum(abs.(FFTW.ifft(x, 1))), x) | ||
|
|
||
| end | ||
|
|
||
| @testset "FillArrays" begin | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just depend on AbstractFFTs.jl and write the dispatches using the high level functions?