Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

definitions for adjoints of FFTW functions #215

Merged
merged 21 commits into from Jul 8, 2019
Merged

Conversation

sipposip
Copy link
Contributor

@sipposip sipposip commented Jun 4, 2019

Adjoint definitions and accompanying tests for the most important functions of FFTW (fft and ifft).

Since deriving the adjoint of FFTs is a bit tricky, I double-checked the results with differentiating a simple DFT-implementation which is differentiable with ReverseDiff, and the results are equal down to machine-precision. This is possible since FFT is in fact only a more efficient implementation of DFT, which itself boils down to a simple matrix-multiplication.

I add this comparison as comment here for reference:

using ReverseDiff
using FFTW
using Test
using Zygote
using FillArrays: Fill


function dft(x::AbstractArray)
    """discrete fourier transform"""
    N = size(x)[1]
    out = Array{Any}(undef,N)
    for k in 0:N-1
        out[k+1] = sum([x[n+1]*exp(-2*im*π*k*n/N) for n in 0:N-1])
    end
    return out
end

function idft(x::AbstractArray)
    """discrete inverse fourier transform"""
    N = size(x)[1]
    out = Array{Any}(undef,N)
    for n in 0:N-1
        out[n+1] = 1/N*sum([x[k+1]*exp(2*im*π*k*n/N) for k in 0:N-1])
    end
    return out
end


x = randn(100)
@test dft(x)  FFTW.fft(x)
@test idft(x)  FFTW.ifft(x)


# FFTW functions do not work with FillArray. To make it work with FillArrays
# as well, overload the functions
FFTW.fft(x::Fill) = FFTW.fft(collect(x))
FFTW.ifft(x::Fill) = FFTW.ifft(collect(x))

# the gradient of an FFT with respect to its input is the reverse FFT of the
# gradient of its inputs.
Zygote.@adjoint function FFTW.fft(xs)
    return FFTW.fft(xs), function(Δ)
        N = length(Δ)
        return (N * FFTW.ifft(Δ),)
    end
end

Zygote.@adjoint function FFTW.ifft(xs)
    return FFTW.ifft(xs), function(Δ)
        N = length(Δ)
        return (1/N* FFTW.fft(Δ),)
    end
end


@test real(Zygote.gradient((x)->sum(abs.(FFTW.fft(x))),x)[1])  ReverseDiff.gradient((x)->sum(abs.(dft(x))),x)
@test real(Zygote.gradient((x)->sum(abs.(FFTW.ifft(x))),x)[1])  ReverseDiff.gradient((x)->sum(abs.(idft(x))),x)

Copy link
Member

@MikeInnes MikeInnes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a couple comments but this is great to have.

src/lib/fft.jl Outdated
# FFTW functions do not work with FillArrays, which are needed
# for some functionality of Zygote. To make it work with FillArrays
# as well, overload the relevant functions
FFTW.fft(x::Fill) = FFTW.fft(collect(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use dims... to make this more compact.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the tip, I changed to dims... now

Project.toml Outdated
@@ -16,7 +16,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the package manager to add FFTW, so that the formatting is right and so that the manifest gets updated too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed now via package manager

src/Zygote.jl Outdated
@@ -35,6 +35,7 @@ include("lib/nnlib.jl")
include("lib/broadcast.jl")
include("lib/forward.jl")
include("lib/utils.jl")
include("lib/fft.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file is short enough that it could just go in array.jl, after linear algebra.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to array.jl now

src/lib/fft.jl Outdated
# if it is not a single int, convert to array so that we can use it
# for indexing
if typeof(dims) != Int
dims = collect(dims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is unnecessary, since numbers act like single-element collections anyway for things like prod.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the exception for int was necessary because size(xs)[collect(1)] does not work, but I now changed to collect(size(xs))[dims] which works and seems more elegant to me and does not need different handling for ints

src/lib/fft.jl Outdated
end

@adjoint function FFTW.fft(xs, dims)
# up to now only works when dims is a single integer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't clear to me; you seem to be handling multiple dimensions below. If only single integers work, best to add a type restriction here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment was from an old version and I forgot to remove it, sorry for the confusion. It indeed works also with multiple dimensions. I removed the comment now

@test gradient((x)->real(fft(ifft(x))[1]),x)[1][1] == 1.0+0.0im

# check ffts for individual dimensions
@test gradient((x)->sum(abs.(FFTW.fft(x))),x)[1] ≈ gradient((x)->sum(abs.(FFTW.fft(FFTW.fft(x,1),2))),x)[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a bit simpler to use gradcheck here instead. That will check against numerical gradients, which is handy for correctness.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added 4 gradchecks as well now. I think the old test still make sense since they check whether the FFTs along individual dims work and produce the same results as when doing all dimensions in one go. However if you think that the old tests are mainly bloating gradcheck.jl I can also ommit them


@adjoint function FFTW.fft(xs, dims)
return FFTW.fft(xs, dims), function(Δ)
# dims can be int, array or tuple,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation is off here, just needs a quick fix.

src/lib/array.jl Outdated

@adjoint function FFTW.fft(xs, dims)
return FFTW.fft(xs, dims), function(Δ)
# dims can be int, array or tuple,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent these few lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I (hopefully) fixed the indentation now.

@MikeInnes
Copy link
Member

Alright, great, this looks good.

There's a merge conflict due to the manifest; it'd be great if you could resolve that, and even better if you can use a newer Julia version to minimise the number of changes while you're at it. Other than that I think we're good to go.

@sipposip
Copy link
Contributor Author

sipposip commented Jun 7, 2019

OK. I am very new both to Julia and collaborating on github, but I will do my best to resolve the conflicts

@sipposip
Copy link
Contributor Author

sipposip commented Jun 7, 2019

I now switched to julia 1.1.1 and resolved the conflicts. However, for some reason my manifest has more entries than the current one in master (AbstractFFTs, Conda, JSON, Reexport, VersionParsing) in addition to FFTW that I added via the package manager. Should I remove them?

@MikeInnes
Copy link
Member

No, that's ok, they are just dependencies of FFTW that got pulled in.

Unfortunately the tests are now failing; looks like you just have a syntax error in the tests you added.

@sipposip
Copy link
Contributor Author

I now fixed the tests, and they pass on my local installation. I also fixed the testset for hcat which had failed as well, and some other fixes in other testsets (forgotten @test)

@sipposip
Copy link
Contributor Author

The travis CI has finished, but this does not seem to be reported to github (when I check https://travis-ci.org/FluxML/Zygote.jl/builds/543743653?utm_source=github_status&utm_medium=notification all tests have passed, but here it is still pending.) Is there any way to re-trigger the travis-ci build?

@@ -4,6 +4,7 @@ version = "0.3.1"

[deps]
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you just depend on AbstractFFTs.jl and write the dispatches using the high level functions?

@MikeInnes
Copy link
Member

👍 Thanks!

@MikeInnes MikeInnes merged commit c08fca6 into FluxML:master Jul 8, 2019
@ChrisRackauckas
Copy link
Member

But can't this be done with just AbstractFFTs?

@MikeInnes
Copy link
Member

Probably, but fine to do that as a follow up. The main issue is testing but we can have FFTW as a test dependency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants