Skip to content

Commit

Permalink
Merge #756
Browse files Browse the repository at this point in the history
756: Call the correct function for irfft(xs,d,dims) r=CarloLucibello a=wkearn

Fixes #755

Co-authored-by: William Kearney <William.Kearney.ctr@nrlssc.navy.mil>
  • Loading branch information
bors[bot] and wkearn committed Aug 8, 2020
2 parents 86d1dd5 + 876de6d commit 956575e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -931,16 +931,16 @@ end
end

@adjoint function irfft(xs, d, dims)
return AbstractFFTs.ifft(xs, dims), function(Δ)
return AbstractFFTs.irfft(xs, d, dims), function(Δ)
dims = collect(dims)
N = prod(collect(size(xs))[dims])
return (AbstractFFTs.rfft(Δ, dims)/N, nothing, nothing)
return (AbstractFFTs.rfft(real.(Δ), dims)/N, nothing, nothing)
end
end
@adjoint function brfft(xs, d, dims)
return AbstractFFTs.ifft(xs, dims), function(Δ)
return AbstractFFTs.brfft(xs, d, dims), function(Δ)
dims = collect(dims)
return (AbstractFFTs.rfft(Δ, dims), nothing, nothing)
return (AbstractFFTs.rfft(real.(Δ), dims), nothing, nothing)
end
end

Expand Down
8 changes: 3 additions & 5 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1453,8 +1453,7 @@ end

x = randn(Float64,16,16)
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2}
# This errors: something is the wrong size
#@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}

x = randn(Float32,16)
P = plan_fft(x)
Expand All @@ -1463,9 +1462,8 @@ end
@test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1}

x = randn(Float32,16,16)
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2}
# This errors: something is the wrong size
#@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
@test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2}
@test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
end

@testset "FillArrays" begin
Expand Down

0 comments on commit 956575e

Please sign in to comment.