Skip to content

Commit

Permalink
Simplify logic that checks for Julia multithreading
Browse files Browse the repository at this point in the history
  • Loading branch information
ararslan committed Nov 27, 2019
1 parent f95a9da commit d060039
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions src/FFTW.jl
Expand Up @@ -3,6 +3,7 @@ module FFTW
using LinearAlgebra, Reexport
import Libdl
@reexport using AbstractFFTs
using Base.Threads

import AbstractFFTs: Plan, ScaledPlan,
fft, ifft, bfft, fft!, ifft!, bfft!,
Expand All @@ -26,7 +27,7 @@ end
const fftw_vendor = occursin("libmkl_rt", libfftw3) ? :mkl : :fftw

# Use Julia partr threading backend if present
@static if fftw_vendor == :fftw && isdefined(Threads, Symbol("@spawn"))
@static if fftw_vendor == :fftw
# callback function that FFTW uses to launch `num` parallel
# tasks (FFTW/fftw3#175):
function spawnloop(f::Ptr{Cvoid}, fdata::Ptr{Cvoid}, elsize::Csize_t, num::Cint, callback_data::Ptr{Cvoid})
Expand All @@ -49,19 +50,13 @@ function __init__()
if stat == 0 || statf == 0
error("could not initialize FFTW threads")
end
@static if fftw_vendor == :fftw
if Threads.nthreads() > 1 # number of Julia threads is set when Julia is launched
ccall((:fftw_make_planner_thread_safe, libfftw3), Cvoid, ())
ccall((:fftwf_make_planner_thread_safe, libfftw3f), Cvoid, ())
end
@static if isdefined(Threads, Symbol("@spawn"))
if Threads.nthreads() > 1 # partr will give us our threads
cspawnloop = @cfunction(spawnloop, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid}))
ccall((:fftw_threads_set_callback, libfftw3), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL)
ccall((:fftwf_threads_set_callback, libfftw3f), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL)
set_num_threads(Threads.nthreads() * 4) # spawn more tasks than threads to help load-balancing
end
end
@static if fftw_vendor == :fftw && nthreads() > 1
ccall((:fftw_make_planner_thread_safe, libfftw3), Cvoid, ())
ccall((:fftwf_make_planner_thread_safe, libfftw3f), Cvoid, ())
cspawnloop = @cfunction(spawnloop, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t, Cint, Ptr{Cvoid}))
ccall((:fftw_threads_set_callback, libfftw3), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL)
ccall((:fftwf_threads_set_callback, libfftw3f), Cvoid, (Ptr{Cvoid}, Ptr{Cvoid}), cspawnloop, C_NULL)
set_num_threads(nthreads() * 4) # spawn more tasks than threads to help load-balancing
end
end

Expand Down

0 comments on commit d060039

Please sign in to comment.