From 876de6d793f0d27d5b2b245992cbb6b625fb9ad7 Mon Sep 17 00:00:00 2001 From: William Kearney Date: Mon, 3 Aug 2020 10:42:41 -0500 Subject: [PATCH] Call the correct function for irfft(xs,d,dims) Fixes #755 --- src/lib/array.jl | 8 ++++---- test/gradcheck.jl | 8 +++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 563e5076f..48b1c134f 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -931,16 +931,16 @@ end end @adjoint function irfft(xs, d, dims) - return AbstractFFTs.ifft(xs, dims), function(Δ) + return AbstractFFTs.irfft(xs, d, dims), function(Δ) dims = collect(dims) N = prod(collect(size(xs))[dims]) - return (AbstractFFTs.rfft(Δ, dims)/N, nothing, nothing) + return (AbstractFFTs.rfft(real.(Δ), dims)/N, nothing, nothing) end end @adjoint function brfft(xs, d, dims) - return AbstractFFTs.ifft(xs, dims), function(Δ) + return AbstractFFTs.brfft(xs, d, dims), function(Δ) dims = collect(dims) - return (AbstractFFTs.rfft(Δ, dims), nothing, nothing) + return (AbstractFFTs.rfft(real.(Δ), dims), nothing, nothing) end end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index b07837214..e120293d7 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1453,8 +1453,7 @@ end x = randn(Float64,16,16) @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2} - # This errors: something is the wrong size - #@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2} x = randn(Float32,16) P = plan_fft(x) @@ -1463,9 +1462,8 @@ end @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1} x = randn(Float32,16,16) - @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} - # This errors: something is the wrong size - #@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} + @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2} + @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2} end @testset "FillArrays" begin