diff --git a/Project.toml b/Project.toml index 5e1316e..8079e7e 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" [compat] AbstractFFTs = "0.5" Conda = "1" +FFTW_jll = "3.3" Reexport = "0.2" julia = "1.3" diff --git a/src/FFTW.jl b/src/FFTW.jl index afd5d92..e30b8d8 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -3,6 +3,7 @@ module FFTW using LinearAlgebra, Reexport import Libdl @reexport using AbstractFFTs +using Base.Threads import AbstractFFTs: Plan, ScaledPlan, fft, ifft, bfft, fft!, ifft!, bfft!, @@ -26,7 +27,7 @@ end const fftw_vendor = occursin("libmkl_rt", libfftw3) ? :mkl : :fftw # Use Julia partr threading backend if present -@static if fftw_vendor == :fftw && isdefined(Threads, Symbol("@spawn")) +@static if fftw_vendor == :fftw # callback function that FFTW uses to launch `num` parallel # tasks (FFTW/fftw3#175): function spawnloop(f::Ptr{Cvoid}, fdata::Ptr{Cvoid}, elsize::Csize_t, num::Cint, callback_data::Ptr{Cvoid}) @@ -49,19 +50,13 @@ function __init__() if stat == 0 || statf == 0 error("could not initialize FFTW threads") end - @static if fftw_vendor == :fftw - if Threads.nthreads() > 1 # number of Julia threads is set when Julia is launched - ccall((:fftw_make_planner_thread_safe, libfftw3), Cvoid, ()) - ccall((:fftwf_make_planner_thread_safe, libfftw3f), Cvoid, ()) - end - @static if isdefined(Threads, Symbol("@spawn")) - if Threads.nthreads() > 1 # partr will give us our threads - cspawnloop = @cfunction(spawnloop, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid})) - ccall((:fftw_threads_set_callback, libfftw3), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL) - ccall((:fftwf_threads_set_callback, libfftw3f), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL) - set_num_threads(Threads.nthreads() * 4) # spawn more tasks than threads to help load-balancing - end - end + @static if fftw_vendor == :fftw && nthreads() > 1 + ccall((:fftw_make_planner_thread_safe, libfftw3), Cvoid, ()) + ccall((:fftwf_make_planner_thread_safe, libfftw3f), Cvoid, ()) + cspawnloop = @cfunction(spawnloop, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid})) + ccall((:fftw_threads_set_callback, libfftw3), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL) + ccall((:fftwf_threads_set_callback, libfftw3f), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL) + set_num_threads(nthreads() * 4) # spawn more tasks than threads to help load-balancing end end