Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.0.1"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"

[compat]
AbstractFFTs = "1"
Expand All @@ -14,8 +15,8 @@ ForwardDiff = "0.10"
julia = "1.6"

[extras]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"

[targets]
test = ["Test", "FFTW"]
test = ["Test", "LinearAlgebra"]
6 changes: 4 additions & 2 deletions src/FastTransformsForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
module FastTransformsForwardDiff
using ForwardDiff
import AbstractFFTs
using ForwardDiff, FFTW
using AbstractFFTs
import ForwardDiff: value, partials, npartials, Dual, tagtype, derivative, jacobian, gradient
import AbstractFFTs: plan_fft, plan_ifft, plan_bfft, plan_rfft, plan_brfft, plan_irfft
import FFTW: r2r, r2r!, plan_r2r, mul!, Plan

@inline tagtype(::Complex{T}) where T = tagtype(T)
@inline tagtype(::Type{Complex{T}}) where T = tagtype(T)
Expand Down
71 changes: 19 additions & 52 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
dual2array(x::Array{<:Dual{Tag,T}}) where {Tag,T} = reinterpret(reshape, T, x)
dual2array(x::Array{<:Complex{<:Dual{Tag, T}}}) where {Tag,T} = complex.(dual2array(real(x)), dual2array(imag(x)))
array2dual(DT::Type{<:Dual}, x::Array{T}) where T = reinterpret(reshape, DT, real(x))
array2dual(DT::Type{<:Dual}, x::Array{<:Complex{T}}) where T = complex.(array2dual(DT, real(x)), array2dual(DT, imag(x)))

value(x::Complex{<:Dual}) = Complex(x.re.value, x.im.value)

partials(x::Complex{<:Dual}, n::Int) = Complex(partials(x.re, n), partials(x.im, n))
Expand All @@ -12,70 +17,32 @@ AbstractFFTs.complexfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V
AbstractFFTs.realfloat(x::AbstractArray{<:Dual}) = AbstractFFTs.realfloat.(x)
AbstractFFTs.realfloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,float(V),N}, d)

for plan in [:plan_fft, :plan_ifft, :plan_bfft]
for plan in (:plan_fft, :plan_ifft, :plan_bfft, :plan_rfft)
@eval begin

AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
AbstractFFTs.$plan(value.(x), region)

AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, region=1:ndims(x)) =
AbstractFFTs.$plan(value.(x), region)

$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims)
$plan(x::AbstractArray{<:Complex{<:Dual}}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims)
end
end

# rfft only accepts real arrays
AbstractFFTs.plan_rfft(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
AbstractFFTs.plan_rfft(value.(x), region)
plan_r2r(x::AbstractArray{<:Dual}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims)
plan_r2r(x::AbstractArray{<:Complex{<:Dual}}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims)

for plan in [:plan_irfft, :plan_brfft] # these take an extra argument, only when complex?
for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex?
@eval begin

AbstractFFTs.$plan(x::AbstractArray{<:Dual}, region=1:ndims(x)) =
AbstractFFTs.$plan(value.(x), region)

AbstractFFTs.$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, region=1:ndims(x)) =
AbstractFFTs.$plan(value.(x), d, region)

$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims)
$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, dims=1:ndims(x)) = $plan(dual2array(x), d, 1 .+ dims)
end
end

# for f in (:dct, :idct)
# pf = Symbol("plan_", f)
# @eval begin
# AbstractFFTs.$f(x::AbstractArray{<:Dual}) = $pf(x) * x
# AbstractFFTs.$f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x
# AbstractFFTs.$pf(x::AbstractArray{<:Dual}, region; kws...) = $pf(value.(x), region; kws...)
# AbstractFFTs.$pf(x::AbstractArray{<:Complex}, region; kws...) = $pf(value.(x), region; kws...)
# end
# end
r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x
r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x


for P in [:Plan, :ScaledPlan] # need ScaledPlan to avoid ambiguities
for P in (:Plan, :ScaledPlan) # need ScaledPlan to avoid ambiguities
@eval begin

Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Dual}) =
_apply_plan(p, x)

Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{<:Dual}}) =
_apply_plan(p, x)

Base.:*(p::AbstractFFTs.$P, x::AbstractArray{DT}) where DT<:Dual = array2dual(DT, p * dual2array(x))
Base.:*(p::AbstractFFTs.$P, x::AbstractArray{<:Complex{DT}}) where DT<:Dual = array2dual(DT, p * dual2array(x))
end
end

function _apply_plan(p::AbstractFFTs.Plan, x::AbstractArray)
xtil = p * value.(x)
dxtils = ntuple(npartials(eltype(x))) do n
p * partials.(x, n)
end
__apply_plan(tagtype(eltype(x)), xtil, dxtils)
end

function __apply_plan(T, xtil, dxtils)
map(xtil, dxtils...) do val, parts...
Complex(
Dual{T}(real(val), map(real, parts)),
Dual{T}(imag(val), map(imag, parts)),
)
end
end
mul!(y::AbstractArray{<:Dual}, p::Plan, x::AbstractArray{<:Dual}) = copyto!(y, p*x)
50 changes: 48 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
using FastTransformsForwardDiff, FFTW, Test
using FastTransformsForwardDiff, FFTW, LinearAlgebra, Test
using ForwardDiff: Dual, valtype, value, partials, derivative
using AbstractFFTs: complexfloat, realfloat

@testset "complex dual" begin
x = Dual(1., 2., 3.) + im*Dual(4.,5.,6.)
@test value(x) == 1 + 4im
@test partials(x,1) == 2 + 5im
@test partials(x,2) == 3 + 6im
end

@testset "fft and rfft" begin
x1 = Dual.(1:4.0, 2:5, 3:6)

@test value.(x1) == 1:4
@test partials.(x1, 1) == 2:5
@test partials.(x1, 2) == 3:6

@test complexfloat(x1)[1] === complexfloat(x1[1]) === Dual(1.0, 2.0, 3.0) + 0im
@test realfloat(x1)[1] === realfloat(x1[1]) === Dual(1.0, 2.0, 3.0)

@test fft(x1, 1)[1] isa Complex{<:Dual}

@testset "$f" for f in [fft, ifft, rfft, bfft]
@testset "$f" for f in (fft, ifft, rfft, bfft)
@test value.(f(x1)) == f(value.(x1))
@test partials.(f(x1), 1) == f(partials.(x1, 1))
@test partials.(f(x1), 2) == f(partials.(x1, 2))
end

@test ifft(fft(x1)) == x1
@test irfft(rfft(x1), length(x1)) ≈ x1
@test brfft(rfft(x1), length(x1)) ≈ 4x1

f = x -> real(fft([x; 0; 0])[1])
@test derivative(f,0.1) ≈ 1

Expand All @@ -33,4 +45,38 @@ using AbstractFFTs: complexfloat, realfloat

# c = x -> dct([x; 0; 0])[1]
# @test derivative(c,0.1) ≈ 1

@testset "matrix" begin
A = x1 * (1:10)'
@test value.(fft(A)) == fft(value.(A))
@test partials.(fft(A), 1) == fft(partials.(A, 1))
@test partials.(fft(A), 2) == fft(partials.(A, 2))

@test value.(fft(A, 1)) == fft(value.(A), 1)
@test partials.(fft(A, 1), 1) == fft(partials.(A, 1), 1)
@test partials.(fft(A, 1), 2) == fft(partials.(A, 2), 1)

@test value.(fft(A, 2)) == fft(value.(A), 2)
@test partials.(fft(A, 2), 1) == fft(partials.(A, 1), 2)
@test partials.(fft(A, 2), 2) == fft(partials.(A, 2), 2)
end
end

@testset "r2r" begin
x1 = Dual.(1:4.0, 2:5, 3:6)
t = FFTW.r2r(x1, FFTW.R2HC)

@test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC)
@test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC)
@test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC)

t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC)
@test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC)
@test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC)
@test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC)

f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1]
@test derivative(f, 0.1) ≡ 1.0

@test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC)
end