Skip to content

Commit

Permalink
Check operations in @turbo automatically with can_avx; if failure…
Browse files Browse the repository at this point in the history
…, switch to `@inbounds @fastmath` (#431)

* Create `safe` kwarg for `@turbo` macro

Currently, this macro does nothing.

* Run `can_avx` on each operator when checking loopset

* Refactor `can_avx` test

* Add test for `safe=true` option in `@turbo`

* Remove debugging statement

* Clean up preamble generation

* Set `safe=false` for `@turbo` by default

* Switch to more generic `can_turbo` function for safe `@turbo`

* Remove `@turbo safe=true` tests from `can_avx.jl`

* Create file to test `@turbo safe=true` and `can_turbo`

* Compute `nargs` of instruction properly

* Add missing `safe` kwarg in `vmaterialize!`

* Also unpack `warncheckarg` and `safe` from UNROLL

* Ensure warncheckarg and safe passed everywhere for consistency

* Consistency in `UNROLL` name

* Add packages required for testing to `[extras]` and `[targets]`

* Add `safe` and `warncheckarg` throughout library

* Remove edits to Project

* Add missing imports in save `@turbo` tests

* Fix call to `can_avx`

* Remove nested `testset`

Seems to be breaking imports.

* Test that `can_avx` validates `exp` by itself

* Add SpecialFunctions.jl to test

* Clean up test set

* Ping test

* Ensure that function names in safe test are unique

* Add `RetVec2Int` for julia <1.6 as `Returns()`

Co-authored-by: Chris Elrod <elrodc@gmail.com>

* Use `RetVec2Int()` instead of `Returns(Vec{2,Int})`

Co-authored-by: Chris Elrod <elrodc@gmail.com>

* push functions into prepre

Co-authored-by: Chris Elrod <elrodc@gmail.com>
  • Loading branch information
MilesCranmer and chriselrod committed Sep 27, 2022
1 parent 1238fc8 commit e123cb2
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <elrodc@gmail.com>"]
version = "0.12.128"
version = "0.12.129"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
14 changes: 8 additions & 6 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ end
# we have an N dimensional loop.
# need to construct the LoopSet
ls = LoopSet(Mod)
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg = UNROLL
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
set_hw!(ls, rs, rc, cls)
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
loopsyms = [gensym!(ls, "n") for _ 1:N]
Expand All @@ -571,6 +571,7 @@ end
v,
threads % Int,
warncheckarg,
safe,
)
Expr(:block, Expr(:meta, :inline), sc, :dest)
end
Expand All @@ -584,7 +585,7 @@ end
# we have an N dimensional loop.
# need to construct the LoopSet
ls = LoopSet(Mod)
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg = UNROLL
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
set_hw!(ls, rs, rc, cls)
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
loopsyms = [gensym!(ls, "n") for _ 1:N]
Expand Down Expand Up @@ -614,6 +615,7 @@ end
v,
threads % Int,
warncheckarg,
safe,
),
:dest′,
)
Expand All @@ -626,7 +628,7 @@ end
::Val{UNROLL},
::Val{dontbc}
) where {T<:NativeTypes,N,T2<:Number,Mod,UNROLL,dontbc}
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads = UNROLL
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL
quote
$(Expr(:meta, :inline))
arg = T(first(bc.args))
Expand All @@ -646,7 +648,7 @@ end
::Val{UNROLL},
::Val{dontbc}
) where {T<:NativeTypes,N,A<:AbstractArray{T,N},T2<:Number,Mod,UNROLL,dontbc}
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads = UNROLL
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL
quote
$(Expr(:meta, :inline))
arg = T(first(bc.args))
Expand All @@ -660,8 +662,8 @@ end
dest′
end
end
@inline function vmaterialize!(dest, bc, ::Val{Mod}, ::Val{Unroll}) where {Mod,Unroll}
vmaterialize!(dest, bc, Val{Mod}(), Val{Unroll}(), Val(_dontbc(bc)))
@inline function vmaterialize!(dest, bc, ::Val{Mod}, ::Val{UNROLL}) where {Mod,UNROLL}
vmaterialize!(dest, bc, Val{Mod}(), Val{UNROLL}(), Val(_dontbc(bc)))
end

@inline function vmaterialize(
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/lower_threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ function thread_one_loops_expr(
valid_thread_loop::Vector{Bool},
ntmax::UInt,
c::Float64,
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
OPS::Expr,
ARF::Expr,
AM::Expr,
Expand Down Expand Up @@ -615,7 +615,7 @@ function thread_two_loops_expr(
valid_thread_loop::Vector{Bool},
ntmax::UInt,
c::Float64,
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
OPS::Expr,
ARF::Expr,
AM::Expr,
Expand Down Expand Up @@ -877,7 +877,7 @@ function valid_thread_loops(ls::LoopSet)
end
function avx_threads_expr(
ls::LoopSet,
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
nt::UInt,
OPS::Expr,
ARF::Expr,
Expand Down
69 changes: 63 additions & 6 deletions src/condense_loopset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,9 @@ end
::StaticInt{NT},
::StaticInt{CLS},
) where {CNFARG,W,RS,AR,CLS,NT}
inline, u₁, u₂, v, BROADCAST, thread = CNFARG
inline, u₁, u₂, v, BROADCAST, thread, warncheckarg, safe = CNFARG
nt = min(thread % UInt, NT % UInt)
t = Expr(:tuple, inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, nt)
t = Expr(:tuple, inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, nt, warncheckarg, safe)
length(CNFARG) == 7 && push!(t.args, CNFARG[7])
Expr(:call, Expr(:curly, :Val, t))
end
Expand Down Expand Up @@ -605,6 +605,8 @@ function split_ifelse!(
k::Int,
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
thread::UInt,
warncheckarg::Int,
safe::Bool,
debug::Bool,
)
roots[k] = false
Expand Down Expand Up @@ -662,6 +664,8 @@ function split_ifelse!(
copy(extra_args),
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
))
else
Expand All @@ -673,6 +677,8 @@ function split_ifelse!(
extra_args,
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
))
end
Expand All @@ -685,6 +691,8 @@ function generate_call(
ls::LoopSet,
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
thread::UInt,
warncheckarg::Int,
safe::Bool,
debug::Bool,
)
extra_args = Expr(:tuple)
Expand All @@ -698,6 +706,8 @@ function generate_call(
extra_args,
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
)
end
Expand All @@ -709,6 +719,8 @@ function generate_call_split(
extra_args::Expr,
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
thread::UInt,
warncheckarg::Int,
safe::Bool,
debug::Bool,
)
for (k, op) enumerate(operations(ls))
Expand All @@ -725,6 +737,8 @@ function generate_call_split(
k,
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
)
end
Expand All @@ -737,6 +751,8 @@ function generate_call_split(
extra_args,
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
)
end
Expand All @@ -750,6 +766,8 @@ function generate_call_types(
extra_args::Expr,
(inline, u₁, u₂, v)::Tuple{Bool,Int8,Int8,Int8},
thread::UInt,
warncheckarg::Int,
safe::Bool,
debug::Bool,
)
# good place to check for split
Expand Down Expand Up @@ -782,7 +800,7 @@ function generate_call_types(
loop_syms = tuple_expr(QuoteNode, ls.loopsymbols)
func = debug ? lv(:_turbo_loopset_debug) : lv(:_turbo_!)
lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds
configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread)
configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread, warncheckarg, safe)
unroll_param_tup =
Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL)
q = Expr(
Expand Down Expand Up @@ -884,6 +902,39 @@ function check_args_call(ls::LoopSet)
end
q
end
struct RetVec2Int end
(::RetVec2Int)(_) = Vec{2,Int}
"""
can_turbo(f::Function, ::Val{NARGS})
Check whether a given function with a specified number of arguments
can be used inside a `@turbo` loop.
"""
function can_turbo(f::F, ::Val{NARGS})::Bool where {F,NARGS}
promoted_op = Base.promote_op(f, ntuple(RetVec2Int(), Val(NARGS))...)
return promoted_op !== Union{}
end

"""
check_turbo_safe(ls::LoopSet)
Returns an expression of the form `true && can_turbo(op1) && can_turbo(op2) && ...`
"""
function check_turbo_safe(ls::LoopSet)
q = Expr(:&&, true)
last = q
for op in operations(ls)
iscompute(op) || continue
c = callexpr(op.instruction)
nargs = length(parents(op))
push!(c.args, Val(nargs))
pushfirst!(c.args, can_turbo)
new_last = Expr(:&&, c)
push!(last.args, new_last)
last = new_last
end
q
end

make_fast(q) =
Expr(:macrocall, Symbol("@fastmath"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), q)
Expand Down Expand Up @@ -956,7 +1007,7 @@ function setup_call_final(ls::LoopSet, q::Expr)
return ls.preamble
end
function setup_call_debug(ls::LoopSet)
generate_call(ls, (false, zero(Int8), zero(Int8), zero(Int8)), zero(UInt), true)
generate_call(ls, (false, zero(Int8), zero(Int8), zero(Int8)), zero(UInt), 1, true, true)
end
function setup_call(
ls::LoopSet,
Expand All @@ -969,6 +1020,7 @@ function setup_call(
v::Int8,
thread::Int,
warncheckarg::Int,
safe::Bool,
)
# We outline/inline at the macro level by creating/not creating an anonymous function.
# The old API instead was based on inlining or not inline the generated function, but
Expand All @@ -977,7 +1029,7 @@ function setup_call(
# inlining the generated function into the loop preamble.
lnns = extract_all_lnns(q)
pushfirst!(lnns, source)
call = generate_call(ls, (inline, u₁, u₂, v), thread % UInt, false)
call = generate_call(ls, (inline, u₁, u₂, v), thread % UInt, 1, true, false)
call = check_empty ? check_if_empty(ls, call) : call
argfailure = make_crashy(make_fast(q))
if warncheckarg 0
Expand All @@ -986,7 +1038,12 @@ function setup_call(
warncheckarg > 0 && push!(warning.args, :(maxlog = $warncheckarg))
argfailure = Expr(:block, warning, argfailure)
end
pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, argfailure))
call_check = if safe
Expr(:&&, check_args_call(ls), check_turbo_safe(ls))
else
check_args_call(ls)
end
pushprepreamble!(ls, Expr(:if, call_check, call, argfailure))
prepend_lnns!(ls.prepreamble, lnns)
return ls.prepreamble
end
25 changes: 15 additions & 10 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ function substitute_broadcast(
v::Int8,
threads::Int,
warncheckarg::Int,
safe::Bool,
)
ci = first(Meta.lower(LoopVectorization, q).args).code
nargs = length(ci) - 1
ex = Expr(:block)
syms = [gensym() for _ 1:nargs]
configarg = (inline, u₁, u₂, v, true, threads, warncheckarg)
configarg = (inline, u₁, u₂, v, true, threads, warncheckarg, safe)
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), staticexpr(0))
for n 1:nargs
ciₙ = ci[n]
Expand Down Expand Up @@ -102,6 +103,7 @@ function check_macro_kwarg(
v::Int8,
threads::Int,
warncheckarg::Int,
safe::Bool,
)
((arg.head === :(=)) && (length(arg.args) == 2)) ||
throw(ArgumentError("macro kwarg should be of the form `argname = value`."))
Expand Down Expand Up @@ -132,14 +134,16 @@ function check_macro_kwarg(
end
elseif kw === :warn_check_args
warncheckarg = convert(Int, value)::Int
elseif kw === :safe
safe = convert(Bool, value)
else
throw(
ArgumentError(
"Received unrecognized keyword argument $kw. Recognized arguments include:\n`inline`, `unroll`, `check_empty`, and `thread`.",
),
)
end
inline, check_empty, u₁, u₂, v, threads, warncheckarg
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe
end
function process_args(
args;
Expand All @@ -150,12 +154,13 @@ function process_args(
v::Int8 = zero(Int8),
threads::Int = 1,
warncheckarg::Int = 1,
safe::Bool = false,
)
for arg args
inline, check_empty, u₁, u₂, v, threads, warncheckarg =
check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg)
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe =
check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe)
end
inline, check_empty, u₁, u₂, v, threads, warncheckarg
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe
end
# check if the body of loop is a block, if not convert it to a block issue#395
# and check if the range of loop is an enumerate, if it is replace it, issue#393
Expand Down Expand Up @@ -225,12 +230,12 @@ function turbo_macro(mod, src, q, args...)
q = macroexpand(mod, q)
if q.head === :for
ls = LoopSet(q, mod)
inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args)
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg))
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe = process_args(args)
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe))
else
inline, check_empty, u₁, u₂, v, threads, warncheckarg =
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe =
process_args(args, inline = true)
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg)
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg, safe)
end
end
"""
Expand Down Expand Up @@ -367,7 +372,7 @@ macro _turbo(arg, q)
@assert q.head === :for
q = macroexpand(__module__, q)
inline, check_empty, u₁, u₂, v =
check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), zero(Int8), 1, 0)
check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), zero(Int8), 1, 0, true)
ls = LoopSet(q, __module__)
set_hw!(ls)
def_outer_reduct_types!(ls)
Expand Down
4 changes: 2 additions & 2 deletions src/modeling/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ function instruction!(ls::LoopSet, x::Expr)
instr keys(COST) && return Instruction(:LoopVectorization, instr)
# end
instr = gensym!(ls, "f")
pushpreamble!(ls, Expr(:(=), instr, x))
pushprepreamble!(ls, Expr(:(=), instr, x))
Instruction(Symbol(""), instr)
end
instruction!(ls::LoopSet, x::Symbol) = instruction(x)
Expand Down Expand Up @@ -1481,7 +1481,7 @@ function add_operation!(
add_comparison!(ls, LHS_sym, RHS, elementbytes, position)
else
throw(LoopError("Expression not recognized.", RHS))
end
end
end

function prepare_rhs_for_storage!(
Expand Down

2 comments on commit e123cb2

@chriselrod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/69084

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.129 -m "<description of version>" e123cb279759ae5920a8a93d8b365a49305ebfbe
git push origin v0.12.129

Please sign in to comment.