Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check operations in @turbo automatically with can_avx; if failure, switch to @inbounds @fastmath #431

Merged
merged 31 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1f8cb06
Create `safe` kwarg for `@turbo` macro
MilesCranmer Sep 18, 2022
3e148f5
Run `can_avx` on each operator when checking loopset
MilesCranmer Sep 18, 2022
7a89027
Refactor `can_avx` test
MilesCranmer Sep 18, 2022
3585ec9
Add test for `safe=true` option in `@turbo`
MilesCranmer Sep 18, 2022
ec3f6a0
Remove debugging statement
MilesCranmer Sep 18, 2022
02919d8
Clean up preamble generation
MilesCranmer Sep 18, 2022
f60c1f5
Set `safe=false` for `@turbo` by default
MilesCranmer Sep 18, 2022
5115351
Switch to more generic `can_turbo` function for safe `@turbo`
MilesCranmer Sep 18, 2022
40c425a
Remove `@turbo safe=true` tests from `can_avx.jl`
MilesCranmer Sep 18, 2022
2dff297
Create file to test `@turbo safe=true` and `can_turbo`
MilesCranmer Sep 18, 2022
7136114
Compute `nargs` of instruction properly
MilesCranmer Sep 18, 2022
0df1606
Add missing `safe` kwarg in `vmaterialize!`
MilesCranmer Sep 18, 2022
066e349
Also unpack `warncheckarg` and `safe` from UNROLL
MilesCranmer Sep 18, 2022
b7b9470
Ensure warncheckarg and safe passed everywhere for consistency
MilesCranmer Sep 18, 2022
4c57fde
Consistency in `UNROLL` name
MilesCranmer Sep 18, 2022
bd2fc43
Add packages required for testing to `[extras]` and `[targets]`
MilesCranmer Sep 18, 2022
73f60ab
Add `safe` and `warncheckarg` throughout library
MilesCranmer Sep 18, 2022
e92949e
Merge branch 'main' into main
chriselrod Sep 19, 2022
3d399d0
Remove edits to Project
MilesCranmer Sep 19, 2022
181e10a
Add missing imports in save `@turbo` tests
MilesCranmer Sep 19, 2022
809dbf2
Fix call to `can_avx`
MilesCranmer Sep 19, 2022
da44c74
Remove nested `testset`
MilesCranmer Sep 19, 2022
5ef2edc
Test that `can_avx` validates `exp` by itself
MilesCranmer Sep 19, 2022
2fddc43
Add SpecialFunctions.jl to test
MilesCranmer Sep 19, 2022
02a29be
Clean up test set
MilesCranmer Sep 19, 2022
cbed1d3
Ping test
MilesCranmer Sep 19, 2022
9568ba9
Ensure that function names in safe test are unique
MilesCranmer Sep 19, 2022
a93f1ad
Add `RetVec2Int` for julia <1.6 as `Returns()`
MilesCranmer Sep 19, 2022
e126032
Use `RetVec2Int()` instead of `Returns(Vec{2,Int})`
MilesCranmer Sep 19, 2022
ec4f41c
Merge branch 'main' into main
MilesCranmer Sep 25, 2022
4efdb90
push functions into prepre
chriselrod Sep 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LoopVectorization"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
authors = ["Chris Elrod <elrodc@gmail.com>"]
version = "0.12.128"
version = "0.12.129"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
14 changes: 8 additions & 6 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ end
# we have an N dimensional loop.
# need to construct the LoopSet
ls = LoopSet(Mod)
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg = UNROLL
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
set_hw!(ls, rs, rc, cls)
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
loopsyms = [gensym!(ls, "n") for _ ∈ 1:N]
Expand All @@ -571,6 +571,7 @@ end
v,
threads % Int,
warncheckarg,
safe,
)
Expr(:block, Expr(:meta, :inline), sc, :dest)
end
Expand All @@ -584,7 +585,7 @@ end
# we have an N dimensional loop.
# need to construct the LoopSet
ls = LoopSet(Mod)
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg = UNROLL
inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL
set_hw!(ls, rs, rc, cls)
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
loopsyms = [gensym!(ls, "n") for _ ∈ 1:N]
Expand Down Expand Up @@ -614,6 +615,7 @@ end
v,
threads % Int,
warncheckarg,
safe,
),
:dest′,
)
Expand All @@ -626,7 +628,7 @@ end
::Val{UNROLL},
::Val{dontbc}
) where {T<:NativeTypes,N,T2<:Number,Mod,UNROLL,dontbc}
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads = UNROLL
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL
quote
$(Expr(:meta, :inline))
arg = T(first(bc.args))
Expand All @@ -646,7 +648,7 @@ end
::Val{UNROLL},
::Val{dontbc}
) where {T<:NativeTypes,N,A<:AbstractArray{T,N},T2<:Number,Mod,UNROLL,dontbc}
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads = UNROLL
inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL
quote
$(Expr(:meta, :inline))
arg = T(first(bc.args))
Expand All @@ -660,8 +662,8 @@ end
dest′
end
end
@inline function vmaterialize!(dest, bc, ::Val{Mod}, ::Val{Unroll}) where {Mod,Unroll}
vmaterialize!(dest, bc, Val{Mod}(), Val{Unroll}(), Val(_dontbc(bc)))
@inline function vmaterialize!(dest, bc, ::Val{Mod}, ::Val{UNROLL}) where {Mod,UNROLL}
vmaterialize!(dest, bc, Val{Mod}(), Val{UNROLL}(), Val(_dontbc(bc)))
end

@inline function vmaterialize(
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/lower_threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ function thread_one_loops_expr(
valid_thread_loop::Vector{Bool},
ntmax::UInt,
c::Float64,
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
OPS::Expr,
ARF::Expr,
AM::Expr,
Expand Down Expand Up @@ -615,7 +615,7 @@ function thread_two_loops_expr(
valid_thread_loop::Vector{Bool},
ntmax::UInt,
c::Float64,
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
OPS::Expr,
ARF::Expr,
AM::Expr,
Expand Down Expand Up @@ -877,7 +877,7 @@ function valid_thread_loops(ls::LoopSet)
end
function avx_threads_expr(
ls::LoopSet,
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt},
UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool},
nt::UInt,
OPS::Expr,
ARF::Expr,
Expand Down
69 changes: 63 additions & 6 deletions src/condense_loopset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,9 @@ end
::StaticInt{NT},
::StaticInt{CLS},
) where {CNFARG,W,RS,AR,CLS,NT}
inline, u₁, u₂, v, BROADCAST, thread = CNFARG
inline, u₁, u₂, v, BROADCAST, thread, warncheckarg, safe = CNFARG
nt = min(thread % UInt, NT % UInt)
t = Expr(:tuple, inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, nt)
t = Expr(:tuple, inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, nt, warncheckarg, safe)
length(CNFARG) == 7 && push!(t.args, CNFARG[7])
Expr(:call, Expr(:curly, :Val, t))
end
Expand Down Expand Up @@ -605,6 +605,8 @@ function split_ifelse!(
k::Int,
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
thread::UInt,
warncheckarg::Int,
safe::Bool,
debug::Bool,
)
roots[k] = false
Expand Down Expand Up @@ -662,6 +664,8 @@ function split_ifelse!(
copy(extra_args),
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
))
else
Expand All @@ -673,6 +677,8 @@ function split_ifelse!(
extra_args,
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
))
end
Expand All @@ -685,6 +691,8 @@ function generate_call(
ls::LoopSet,
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
thread::UInt,
warncheckarg::Int,
safe::Bool,
debug::Bool,
)
extra_args = Expr(:tuple)
Expand All @@ -698,6 +706,8 @@ function generate_call(
extra_args,
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
)
end
Expand All @@ -709,6 +719,8 @@ function generate_call_split(
extra_args::Expr,
inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8},
thread::UInt,
warncheckarg::Int,
safe::Bool,
debug::Bool,
)
for (k, op) ∈ enumerate(operations(ls))
Expand All @@ -725,6 +737,8 @@ function generate_call_split(
k,
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
)
end
Expand All @@ -737,6 +751,8 @@ function generate_call_split(
extra_args,
inlineu₁u₂,
thread,
warncheckarg,
safe,
debug,
)
end
Expand All @@ -750,6 +766,8 @@ function generate_call_types(
extra_args::Expr,
(inline, u₁, u₂, v)::Tuple{Bool,Int8,Int8,Int8},
thread::UInt,
warncheckarg::Int,
safe::Bool,
debug::Bool,
)
# good place to check for split
Expand Down Expand Up @@ -782,7 +800,7 @@ function generate_call_types(
loop_syms = tuple_expr(QuoteNode, ls.loopsymbols)
func = debug ? lv(:_turbo_loopset_debug) : lv(:_turbo_!)
lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds
configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread)
configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread, warncheckarg, safe)
unroll_param_tup =
Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL)
q = Expr(
Expand Down Expand Up @@ -884,6 +902,39 @@ function check_args_call(ls::LoopSet)
end
q
end
struct RetVec2Int end
(::RetVec2Int)(_) = Vec{2,Int}
"""
can_turbo(f::Function, ::Val{NARGS})

Check whether a given function with a specified number of arguments
can be used inside a `@turbo` loop.
"""
function can_turbo(f::F, ::Val{NARGS})::Bool where {F,NARGS}
promoted_op = Base.promote_op(f, ntuple(RetVec2Int(), Val(NARGS))...)
return promoted_op !== Union{}
end

"""
check_turbo_safe(ls::LoopSet)

Returns an expression of the form `true && can_turbo(op1) && can_turbo(op2) && ...`
"""
function check_turbo_safe(ls::LoopSet)
q = Expr(:&&, true)
last = q
for op in operations(ls)
iscompute(op) || continue
c = callexpr(op.instruction)
nargs = length(parents(op))
push!(c.args, Val(nargs))
pushfirst!(c.args, can_turbo)
new_last = Expr(:&&, c)
push!(last.args, new_last)
last = new_last
end
q
end

make_fast(q) =
Expr(:macrocall, Symbol("@fastmath"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), q)
Expand Down Expand Up @@ -956,7 +1007,7 @@ function setup_call_final(ls::LoopSet, q::Expr)
return ls.preamble
end
function setup_call_debug(ls::LoopSet)
generate_call(ls, (false, zero(Int8), zero(Int8), zero(Int8)), zero(UInt), true)
generate_call(ls, (false, zero(Int8), zero(Int8), zero(Int8)), zero(UInt), 1, true, true)
end
function setup_call(
ls::LoopSet,
Expand All @@ -969,6 +1020,7 @@ function setup_call(
v::Int8,
thread::Int,
warncheckarg::Int,
safe::Bool,
)
# We outline/inline at the macro level by creating/not creating an anonymous function.
# The old API instead was based on inlining or not inline the generated function, but
Expand All @@ -977,7 +1029,7 @@ function setup_call(
# inlining the generated function into the loop preamble.
lnns = extract_all_lnns(q)
pushfirst!(lnns, source)
call = generate_call(ls, (inline, u₁, u₂, v), thread % UInt, false)
call = generate_call(ls, (inline, u₁, u₂, v), thread % UInt, 1, true, false)
call = check_empty ? check_if_empty(ls, call) : call
argfailure = make_crashy(make_fast(q))
if warncheckarg ≠ 0
Expand All @@ -986,7 +1038,12 @@ function setup_call(
warncheckarg > 0 && push!(warning.args, :(maxlog = $warncheckarg))
argfailure = Expr(:block, warning, argfailure)
end
pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, argfailure))
call_check = if safe
Expr(:&&, check_args_call(ls), check_turbo_safe(ls))
else
check_args_call(ls)
end
pushprepreamble!(ls, Expr(:if, call_check, call, argfailure))
prepend_lnns!(ls.prepreamble, lnns)
return ls.prepreamble
end
25 changes: 15 additions & 10 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ function substitute_broadcast(
v::Int8,
threads::Int,
warncheckarg::Int,
safe::Bool,
)
ci = first(Meta.lower(LoopVectorization, q).args).code
nargs = length(ci) - 1
ex = Expr(:block)
syms = [gensym() for _ ∈ 1:nargs]
configarg = (inline, u₁, u₂, v, true, threads, warncheckarg)
configarg = (inline, u₁, u₂, v, true, threads, warncheckarg, safe)
unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), staticexpr(0))
for n ∈ 1:nargs
ciₙ = ci[n]
Expand Down Expand Up @@ -102,6 +103,7 @@ function check_macro_kwarg(
v::Int8,
threads::Int,
warncheckarg::Int,
safe::Bool,
)
((arg.head === :(=)) && (length(arg.args) == 2)) ||
throw(ArgumentError("macro kwarg should be of the form `argname = value`."))
Expand Down Expand Up @@ -132,14 +134,16 @@ function check_macro_kwarg(
end
elseif kw === :warn_check_args
warncheckarg = convert(Int, value)::Int
elseif kw === :safe
safe = convert(Bool, value)
else
throw(
ArgumentError(
"Received unrecognized keyword argument $kw. Recognized arguments include:\n`inline`, `unroll`, `check_empty`, and `thread`.",
),
)
end
inline, check_empty, u₁, u₂, v, threads, warncheckarg
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe
end
function process_args(
args;
Expand All @@ -150,12 +154,13 @@ function process_args(
v::Int8 = zero(Int8),
threads::Int = 1,
warncheckarg::Int = 1,
safe::Bool = false,
)
for arg ∈ args
inline, check_empty, u₁, u₂, v, threads, warncheckarg =
check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg)
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe =
check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe)
end
inline, check_empty, u₁, u₂, v, threads, warncheckarg
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe
end
# check if the body of loop is a block, if not convert it to a block issue#395
# and check if the range of loop is an enumerate, if it is replace it, issue#393
Expand Down Expand Up @@ -225,12 +230,12 @@ function turbo_macro(mod, src, q, args...)
q = macroexpand(mod, q)
if q.head === :for
ls = LoopSet(q, mod)
inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args)
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg))
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe = process_args(args)
esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe))
else
inline, check_empty, u₁, u₂, v, threads, warncheckarg =
inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe =
process_args(args, inline = true)
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg)
substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg, safe)
end
end
"""
Expand Down Expand Up @@ -367,7 +372,7 @@ macro _turbo(arg, q)
@assert q.head === :for
q = macroexpand(__module__, q)
inline, check_empty, u₁, u₂, v =
check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), zero(Int8), 1, 0)
check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), zero(Int8), 1, 0, true)
ls = LoopSet(q, __module__)
set_hw!(ls)
def_outer_reduct_types!(ls)
Expand Down
4 changes: 2 additions & 2 deletions src/modeling/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ function instruction!(ls::LoopSet, x::Expr)
instr ∈ keys(COST) && return Instruction(:LoopVectorization, instr)
# end
instr = gensym!(ls, "f")
pushpreamble!(ls, Expr(:(=), instr, x))
pushprepreamble!(ls, Expr(:(=), instr, x))
Instruction(Symbol(""), instr)
end
instruction!(ls::LoopSet, x::Symbol) = instruction(x)
Expand Down Expand Up @@ -1481,7 +1481,7 @@ function add_operation!(
add_comparison!(ls, LHS_sym, RHS, elementbytes, position)
else
throw(LoopError("Expression not recognized.", RHS))
end
end
end

function prepare_rhs_for_storage!(
Expand Down