diff --git a/Project.toml b/Project.toml index 8ae6220..734ab0a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.0.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" [compat] AbstractFFTs = "1" @@ -14,8 +15,8 @@ ForwardDiff = "0.10" julia = "1.6" [extras] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" [targets] -test = ["Test", "FFTW"] \ No newline at end of file +test = ["Test", "LinearAlgebra"] \ No newline at end of file diff --git a/src/FastTransformsForwardDiff.jl b/src/FastTransformsForwardDiff.jl index 3a47e42..25b9472 100644 --- a/src/FastTransformsForwardDiff.jl +++ b/src/FastTransformsForwardDiff.jl @@ -1,7 +1,9 @@ module FastTransformsForwardDiff -using ForwardDiff -import AbstractFFTs +using ForwardDiff, FFTW +using AbstractFFTs import ForwardDiff: value, partials, npartials, Dual, tagtype, derivative, jacobian, gradient +import AbstractFFTs: plan_fft, plan_ifft, plan_bfft, plan_rfft, plan_brfft, plan_irfft +import FFTW: r2r, r2r!, plan_r2r, mul!, Plan @inline tagtype(::Complex{T}) where T = tagtype(T) @inline tagtype(::Type{Complex{T}}) where T = tagtype(T) diff --git a/src/fft.jl b/src/fft.jl index 813d17e..262f762 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1,3 +1,8 @@ +dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x) +dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x))) +array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x)) +array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x))) + value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value) partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n)) @@ -12,70 +17,32 @@ AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x) AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d) -for plan in [:plan_fft, :plan_ifft, :plan_bfft] +for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft) @eval begin - - AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) = - AbstractFFTs.$plan(value.(x), region) - - AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, region=1:ndims(x)) = - AbstractFFTs.$plan(value.(x), region) - + $plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims) + $plan(x::AbstractArray{<:Complex{<:Dual}}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims) end end -# rfft only accepts real arrays -AbstractFFTs.plan_rfft(x::AbstractArray{<:Dual}, region=1:ndims(x)) = - AbstractFFTs.plan_rfft(value.(x), region) +plan_r2r(x::AbstractArray{<:Dual}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims) +plan_r2r(x::AbstractArray{<:Complex{<:Dual}}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims) -for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex? +for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex? @eval begin - - AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) = - AbstractFFTs.$plan(value.(x), region) - - AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, region=1:ndims(x)) = - AbstractFFTs.$plan(value.(x), d, region) - + $plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims) + $plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, dims=1:ndims(x)) = $plan(dual2array(x), d, 1 .+ dims) end end -# for f in (:dct, :idct) -# pf = Symbol("plan_", f) -# @eval begin -# AbstractFFTs.$f(x::AbstractArray{<:Dual}) = $pf(x) * x -# AbstractFFTs.$f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x -# AbstractFFTs.$pf(x::AbstractArray{<:Dual}, region; kws...) = $pf(value.(x), region; kws...) -# AbstractFFTs.$pf(x::AbstractArray{<:Complex}, region; kws...) = $pf(value.(x), region; kws...) -# end -# end +r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x +r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x -for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities +for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities @eval begin - - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Dual}) = - _apply_plan(p, x) - - Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:Dual}}) = - _apply_plan(p, x) - + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x)) + Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x)) end end -function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray) - xtil = p * value.(x) - dxtils = ntuple(npartials(eltype(x))) do n - p * partials.(x, n) - end - __apply_plan(tagtype(eltype(x)), xtil, dxtils) -end - -function __apply_plan(T, xtil, dxtils) - map(xtil, dxtils...) do val, parts... - Complex( - Dual{T}(real(val), map(real, parts)), - Dual{T}(imag(val), map(imag, parts)), - ) - end -end \ No newline at end of file +mul!(y::AbstractArray{<:Dual}, p::Plan, x::AbstractArray{<:Dual}) = copyto!(y, p*x) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 537852d..9d1dce2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,24 +1,36 @@ -using FastTransformsForwardDiff, FFTW, Test +using FastTransformsForwardDiff, FFTW, LinearAlgebra, Test using ForwardDiff: Dual, valtype, value, partials, derivative using AbstractFFTs: complexfloat, realfloat +@testset "complex dual" begin + x = Dual(1., 2., 3.) + im*Dual(4.,5.,6.) + @test value(x) == 1 + 4im + @test partials(x,1) == 2 + 5im + @test partials(x,2) == 3 + 6im +end @testset "fft and rfft" begin x1 = Dual.(1:4.0, 2:5, 3:6) @test value.(x1) == 1:4 @test partials.(x1, 1) == 2:5 + @test partials.(x1, 2) == 3:6 @test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im @test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0) @test fft(x1, 1)[1] isa Complex{<:Dual} - @testset "$f" for f in [fft, ifft, rfft, bfft] + @testset "$f" for f in (fft, ifft, rfft, bfft) @test value.(f(x1)) == f(value.(x1)) @test partials.(f(x1), 1) == f(partials.(x1, 1)) + @test partials.(f(x1), 2) == f(partials.(x1, 2)) end + @test ifft(fft(x1)) == x1 + @test irfft(rfft(x1), length(x1)) ≈ x1 + @test brfft(rfft(x1), length(x1)) ≈ 4x1 + f = x -> real(fft([x; 0; 0])[1]) @test derivative(f,0.1) ≈ 1 @@ -33,4 +45,38 @@ using AbstractFFTs: complexfloat, realfloat # c = x -> dct([x; 0; 0])[1] # @test derivative(c,0.1) ≈ 1 + + @testset "matrix" begin + A = x1 * (1:10)' + @test value.(fft(A)) == fft(value.(A)) + @test partials.(fft(A), 1) == fft(partials.(A, 1)) + @test partials.(fft(A), 2) == fft(partials.(A, 2)) + + @test value.(fft(A, 1)) == fft(value.(A), 1) + @test partials.(fft(A, 1), 1) == fft(partials.(A, 1), 1) + @test partials.(fft(A, 1), 2) == fft(partials.(A, 2), 1) + + @test value.(fft(A, 2)) == fft(value.(A), 2) + @test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2) + @test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2) + end +end + +@testset "r2r" begin + x1 = Dual.(1:4.0, 2:5, 3:6) + t = FFTW.r2r(x1, FFTW.R2HC) + + @test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC) + + t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC) + @test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC) + @test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC) + @test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC) + + f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1] + @test derivative(f, 0.1) ≡ 1.0 + + @test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC) end \ No newline at end of file