From e123cb279759ae5920a8a93d8b365a49305ebfbe Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Tue, 27 Sep 2022 19:02:00 -0400 Subject: [PATCH] Check operations in `@turbo` automatically with `can_avx`; if failure, switch to `@inbounds @fastmath` (#431) * Create `safe` kwarg for `@turbo` macro Currently, this macro does nothing. * Run `can_avx` on each operator when checking loopset * Refactor `can_avx` test * Add test for `safe=true` option in `@turbo` * Remove debugging statement * Clean up preamble generation * Set `safe=false` for `@turbo` by default * Switch to more generic `can_turbo` function for safe `@turbo` * Remove `@turbo safe=true` tests from `can_avx.jl` * Create file to test `@turbo safe=true` and `can_turbo` * Compute `nargs` of instruction properly * Add missing `safe` kwarg in `vmaterialize!` * Also unpack `warncheckarg` and `safe` from UNROLL * Ensure warncheckarg and safe passed everywhere for consistency * Consistency in `UNROLL` name * Add packages required for testing to `[extras]` and `[targets]` * Add `safe` and `warncheckarg` throughout library * Remove edits to Project * Add missing imports in save `@turbo` tests * Fix call to `can_avx` * Remove nested `testset` Seems to be breaking imports. * Test that `can_avx` validates `exp` by itself * Add SpecialFunctions.jl to test * Clean up test set * Ping test * Ensure that function names in safe test are unique * Add `RetVec2Int` for julia <1.6 as `Returns()` Co-authored-by: Chris Elrod * Use `RetVec2Int()` instead of `Returns(Vec{2,Int})` Co-authored-by: Chris Elrod * push functions into prepre Co-authored-by: Chris Elrod --- Project.toml | 2 +- src/broadcast.jl | 14 ++++---- src/codegen/lower_threads.jl | 6 ++-- src/condense_loopset.jl | 69 ++++++++++++++++++++++++++++++++---- src/constructors.jl | 25 +++++++------ src/modeling/graphs.jl | 4 +-- src/reconstruct_loopset.jl | 4 +-- test/Project.toml | 1 + test/can_avx.jl | 24 ++++++------- test/grouptests.jl | 2 ++ test/safe_turbo.jl | 56 +++++++++++++++++++++++++++++ 11 files changed, 164 insertions(+), 43 deletions(-) create mode 100644 test/safe_turbo.jl diff --git a/Project.toml b/Project.toml index d938f0d1e..f8d6d7864 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LoopVectorization" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" authors = ["Chris Elrod "] -version = "0.12.128" +version = "0.12.129" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/broadcast.jl b/src/broadcast.jl index eb06474c3..19807a63f 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -548,7 +548,7 @@ end # we have an N dimensional loop. # need to construct the LoopSet ls = LoopSet(Mod) - inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg = UNROLL + inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL set_hw!(ls, rs, rc, cls) ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro loopsyms = [gensym!(ls, "n") for _ ∈ 1:N] @@ -571,6 +571,7 @@ end v, threads % Int, warncheckarg, + safe, ) Expr(:block, Expr(:meta, :inline), sc, :dest) end @@ -584,7 +585,7 @@ end # we have an N dimensional loop. # need to construct the LoopSet ls = LoopSet(Mod) - inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg = UNROLL + inline, u₁, u₂, v, isbroadcast, _, rs, rc, cls, threads, warncheckarg, safe = UNROLL set_hw!(ls, rs, rc, cls) ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro loopsyms = [gensym!(ls, "n") for _ ∈ 1:N] @@ -614,6 +615,7 @@ end v, threads % Int, warncheckarg, + safe, ), :dest′, ) @@ -626,7 +628,7 @@ end ::Val{UNROLL}, ::Val{dontbc} ) where {T<:NativeTypes,N,T2<:Number,Mod,UNROLL,dontbc} - inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads = UNROLL + inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL quote $(Expr(:meta, :inline)) arg = T(first(bc.args)) @@ -646,7 +648,7 @@ end ::Val{UNROLL}, ::Val{dontbc} ) where {T<:NativeTypes,N,A<:AbstractArray{T,N},T2<:Number,Mod,UNROLL,dontbc} - inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads = UNROLL + inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, threads, warncheckarg, safe = UNROLL quote $(Expr(:meta, :inline)) arg = T(first(bc.args)) @@ -660,8 +662,8 @@ end dest′ end end -@inline function vmaterialize!(dest, bc, ::Val{Mod}, ::Val{Unroll}) where {Mod,Unroll} - vmaterialize!(dest, bc, Val{Mod}(), Val{Unroll}(), Val(_dontbc(bc))) +@inline function vmaterialize!(dest, bc, ::Val{Mod}, ::Val{UNROLL}) where {Mod,UNROLL} + vmaterialize!(dest, bc, Val{Mod}(), Val{UNROLL}(), Val(_dontbc(bc))) end @inline function vmaterialize( diff --git a/src/codegen/lower_threads.jl b/src/codegen/lower_threads.jl index 873c01cc4..4152061e7 100644 --- a/src/codegen/lower_threads.jl +++ b/src/codegen/lower_threads.jl @@ -420,7 +420,7 @@ function thread_one_loops_expr( valid_thread_loop::Vector{Bool}, ntmax::UInt, c::Float64, - UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt}, + UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool}, OPS::Expr, ARF::Expr, AM::Expr, @@ -615,7 +615,7 @@ function thread_two_loops_expr( valid_thread_loop::Vector{Bool}, ntmax::UInt, c::Float64, - UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt}, + UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool}, OPS::Expr, ARF::Expr, AM::Expr, @@ -877,7 +877,7 @@ function valid_thread_loops(ls::LoopSet) end function avx_threads_expr( ls::LoopSet, - UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt}, + UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool}, nt::UInt, OPS::Expr, ARF::Expr, diff --git a/src/condense_loopset.jl b/src/condense_loopset.jl index d18a3ad7e..18f5d04fe 100644 --- a/src/condense_loopset.jl +++ b/src/condense_loopset.jl @@ -558,9 +558,9 @@ end ::StaticInt{NT}, ::StaticInt{CLS}, ) where {CNFARG,W,RS,AR,CLS,NT} - inline, u₁, u₂, v, BROADCAST, thread = CNFARG + inline, u₁, u₂, v, BROADCAST, thread, warncheckarg, safe = CNFARG nt = min(thread % UInt, NT % UInt) - t = Expr(:tuple, inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, nt) + t = Expr(:tuple, inline, u₁, u₂, v, BROADCAST, W, RS, AR, CLS, nt, warncheckarg, safe) length(CNFARG) == 7 && push!(t.args, CNFARG[7]) Expr(:call, Expr(:curly, :Val, t)) end @@ -605,6 +605,8 @@ function split_ifelse!( k::Int, inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8}, thread::UInt, + warncheckarg::Int, + safe::Bool, debug::Bool, ) roots[k] = false @@ -662,6 +664,8 @@ function split_ifelse!( copy(extra_args), inlineu₁u₂, thread, + warncheckarg, + safe, debug, )) else @@ -673,6 +677,8 @@ function split_ifelse!( extra_args, inlineu₁u₂, thread, + warncheckarg, + safe, debug, )) end @@ -685,6 +691,8 @@ function generate_call( ls::LoopSet, inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8}, thread::UInt, + warncheckarg::Int, + safe::Bool, debug::Bool, ) extra_args = Expr(:tuple) @@ -698,6 +706,8 @@ function generate_call( extra_args, inlineu₁u₂, thread, + warncheckarg, + safe, debug, ) end @@ -709,6 +719,8 @@ function generate_call_split( extra_args::Expr, inlineu₁u₂::Tuple{Bool,Int8,Int8,Int8}, thread::UInt, + warncheckarg::Int, + safe::Bool, debug::Bool, ) for (k, op) ∈ enumerate(operations(ls)) @@ -725,6 +737,8 @@ function generate_call_split( k, inlineu₁u₂, thread, + warncheckarg, + safe, debug, ) end @@ -737,6 +751,8 @@ function generate_call_split( extra_args, inlineu₁u₂, thread, + warncheckarg, + safe, debug, ) end @@ -750,6 +766,8 @@ function generate_call_types( extra_args::Expr, (inline, u₁, u₂, v)::Tuple{Bool,Int8,Int8,Int8}, thread::UInt, + warncheckarg::Int, + safe::Bool, debug::Bool, ) # good place to check for split @@ -782,7 +800,7 @@ function generate_call_types( loop_syms = tuple_expr(QuoteNode, ls.loopsymbols) func = debug ? lv(:_turbo_loopset_debug) : lv(:_turbo_!) lbarg = debug ? Expr(:call, :typeof, loop_bounds) : loop_bounds - configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread) + configarg = (inline, u₁, u₂, v, ls.isbroadcast, thread, warncheckarg, safe) unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), VECTORWIDTHSYMBOL) q = Expr( @@ -884,6 +902,39 @@ function check_args_call(ls::LoopSet) end q end +struct RetVec2Int end +(::RetVec2Int)(_) = Vec{2,Int} +""" + can_turbo(f::Function, ::Val{NARGS}) + +Check whether a given function with a specified number of arguments +can be used inside a `@turbo` loop. +""" +function can_turbo(f::F, ::Val{NARGS})::Bool where {F,NARGS} + promoted_op = Base.promote_op(f, ntuple(RetVec2Int(), Val(NARGS))...) + return promoted_op !== Union{} +end + +""" + check_turbo_safe(ls::LoopSet) + +Returns an expression of the form `true && can_turbo(op1) && can_turbo(op2) && ...` +""" +function check_turbo_safe(ls::LoopSet) + q = Expr(:&&, true) + last = q + for op in operations(ls) + iscompute(op) || continue + c = callexpr(op.instruction) + nargs = length(parents(op)) + push!(c.args, Val(nargs)) + pushfirst!(c.args, can_turbo) + new_last = Expr(:&&, c) + push!(last.args, new_last) + last = new_last + end + q +end make_fast(q) = Expr(:macrocall, Symbol("@fastmath"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), q) @@ -956,7 +1007,7 @@ function setup_call_final(ls::LoopSet, q::Expr) return ls.preamble end function setup_call_debug(ls::LoopSet) - generate_call(ls, (false, zero(Int8), zero(Int8), zero(Int8)), zero(UInt), true) + generate_call(ls, (false, zero(Int8), zero(Int8), zero(Int8)), zero(UInt), 1, true, true) end function setup_call( ls::LoopSet, @@ -969,6 +1020,7 @@ function setup_call( v::Int8, thread::Int, warncheckarg::Int, + safe::Bool, ) # We outline/inline at the macro level by creating/not creating an anonymous function. # The old API instead was based on inlining or not inline the generated function, but @@ -977,7 +1029,7 @@ function setup_call( # inlining the generated function into the loop preamble. lnns = extract_all_lnns(q) pushfirst!(lnns, source) - call = generate_call(ls, (inline, u₁, u₂, v), thread % UInt, false) + call = generate_call(ls, (inline, u₁, u₂, v), thread % UInt, 1, true, false) call = check_empty ? check_if_empty(ls, call) : call argfailure = make_crashy(make_fast(q)) if warncheckarg ≠ 0 @@ -986,7 +1038,12 @@ function setup_call( warncheckarg > 0 && push!(warning.args, :(maxlog = $warncheckarg)) argfailure = Expr(:block, warning, argfailure) end - pushprepreamble!(ls, Expr(:if, check_args_call(ls), call, argfailure)) + call_check = if safe + Expr(:&&, check_args_call(ls), check_turbo_safe(ls)) + else + check_args_call(ls) + end + pushprepreamble!(ls, Expr(:if, call_check, call, argfailure)) prepend_lnns!(ls.prepreamble, lnns) return ls.prepreamble end diff --git a/src/constructors.jl b/src/constructors.jl index bac498992..fbf2eccb6 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -52,12 +52,13 @@ function substitute_broadcast( v::Int8, threads::Int, warncheckarg::Int, + safe::Bool, ) ci = first(Meta.lower(LoopVectorization, q).args).code nargs = length(ci) - 1 ex = Expr(:block) syms = [gensym() for _ ∈ 1:nargs] - configarg = (inline, u₁, u₂, v, true, threads, warncheckarg) + configarg = (inline, u₁, u₂, v, true, threads, warncheckarg, safe) unroll_param_tup = Expr(:call, lv(:avx_config_val), :(Val{$configarg}()), staticexpr(0)) for n ∈ 1:nargs ciₙ = ci[n] @@ -102,6 +103,7 @@ function check_macro_kwarg( v::Int8, threads::Int, warncheckarg::Int, + safe::Bool, ) ((arg.head === :(=)) && (length(arg.args) == 2)) || throw(ArgumentError("macro kwarg should be of the form `argname = value`.")) @@ -132,6 +134,8 @@ function check_macro_kwarg( end elseif kw === :warn_check_args warncheckarg = convert(Int, value)::Int + elseif kw === :safe + safe = convert(Bool, value) else throw( ArgumentError( @@ -139,7 +143,7 @@ function check_macro_kwarg( ), ) end - inline, check_empty, u₁, u₂, v, threads, warncheckarg + inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe end function process_args( args; @@ -150,12 +154,13 @@ function process_args( v::Int8 = zero(Int8), threads::Int = 1, warncheckarg::Int = 1, + safe::Bool = false, ) for arg ∈ args - inline, check_empty, u₁, u₂, v, threads, warncheckarg = - check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg) + inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe = + check_macro_kwarg(arg, inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe) end - inline, check_empty, u₁, u₂, v, threads, warncheckarg + inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe end # check if the body of loop is a block, if not convert it to a block issue#395 # and check if the range of loop is an enumerate, if it is replace it, issue#393 @@ -225,12 +230,12 @@ function turbo_macro(mod, src, q, args...) q = macroexpand(mod, q) if q.head === :for ls = LoopSet(q, mod) - inline, check_empty, u₁, u₂, v, threads, warncheckarg = process_args(args) - esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg)) + inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe = process_args(args) + esc(setup_call(ls, q, src, inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe)) else - inline, check_empty, u₁, u₂, v, threads, warncheckarg = + inline, check_empty, u₁, u₂, v, threads, warncheckarg, safe = process_args(args, inline = true) - substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg) + substitute_broadcast(q, Symbol(mod), inline, u₁, u₂, v, threads, warncheckarg, safe) end end """ @@ -367,7 +372,7 @@ macro _turbo(arg, q) @assert q.head === :for q = macroexpand(__module__, q) inline, check_empty, u₁, u₂, v = - check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), zero(Int8), 1, 0) + check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), zero(Int8), 1, 0, true) ls = LoopSet(q, __module__) set_hw!(ls) def_outer_reduct_types!(ls) diff --git a/src/modeling/graphs.jl b/src/modeling/graphs.jl index 92e5aea42..4cd73de94 100644 --- a/src/modeling/graphs.jl +++ b/src/modeling/graphs.jl @@ -1283,7 +1283,7 @@ function instruction!(ls::LoopSet, x::Expr) instr ∈ keys(COST) && return Instruction(:LoopVectorization, instr) # end instr = gensym!(ls, "f") - pushpreamble!(ls, Expr(:(=), instr, x)) + pushprepreamble!(ls, Expr(:(=), instr, x)) Instruction(Symbol(""), instr) end instruction!(ls::LoopSet, x::Symbol) = instruction(x) @@ -1481,7 +1481,7 @@ function add_operation!( add_comparison!(ls, LHS_sym, RHS, elementbytes, position) else throw(LoopError("Expression not recognized.", RHS)) - end + end end function prepare_rhs_for_storage!( diff --git a/src/reconstruct_loopset.jl b/src/reconstruct_loopset.jl index fab8f53ed..224cf372d 100644 --- a/src/reconstruct_loopset.jl +++ b/src/reconstruct_loopset.jl @@ -874,7 +874,7 @@ function avx_loopset!( end function avx_body( ls::LoopSet, - UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt}, + UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool}, ) inline, u₁, u₂, v, isbroadcast, W, rs, rc, cls, nt = UNROLL q = @@ -916,7 +916,7 @@ function _turbo_loopset( @nospecialize(LPSYMsv), LBsv::Core.SimpleVector, vargs::Core.SimpleVector, - UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt}, + UNROLL::Tuple{Bool,Int8,Int8,Int8,Bool,Int,Int,Int,Int,UInt,Int,Bool}, ) nops = length(OPSsv) ÷ 3 instr = Instruction[Instruction(OPSsv[3i+1], OPSsv[3i+2]) for i ∈ 0:nops-1] diff --git a/test/Project.toml b/test/Project.toml index e7c69465d..e57e06bb3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SnoopCompileCore = "e2b509da-e806-4183-be48-004708413034" SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/can_avx.jl b/test/can_avx.jl index b72d498fe..1b1289453 100644 --- a/test/can_avx.jl +++ b/test/can_avx.jl @@ -3,18 +3,16 @@ @testset "can_avx" begin - @test LoopVectorization.ArrayInterface.can_avx(log) - @test LoopVectorization.ArrayInterface.can_avx(log1p) - @test LoopVectorization.ArrayInterface.can_avx(exp) - @test LoopVectorization.ArrayInterface.can_avx(+) - @test LoopVectorization.ArrayInterface.can_avx(-) - @test LoopVectorization.ArrayInterface.can_avx(Base.FastMath.add_fast) - @test LoopVectorization.ArrayInterface.can_avx(/) - @test LoopVectorization.ArrayInterface.can_avx(sqrt) - @test LoopVectorization.ArrayInterface.can_avx(tanh_fast) - @test LoopVectorization.ArrayInterface.can_avx(sigmoid_fast) - @test LoopVectorization.ArrayInterface.can_avx(LoopVectorization.relu) - @test !LoopVectorization.ArrayInterface.can_avx(clenshaw) - @test !LoopVectorization.ArrayInterface.can_avx(println) + using LoopVectorization + + good_operators = [log, log1p, exp, +, -, Base.FastMath.add_fast, /, sqrt, tanh_fast, sigmoid_fast, LoopVectorization.relu] + bad_operators = [clenshaw, println] + + for op in good_operators + @test LoopVectorization.ArrayInterface.can_avx(op) + end + for op in bad_operators + @test !LoopVectorization.ArrayInterface.can_avx(op) + end end diff --git a/test/grouptests.jl b/test/grouptests.jl index 7dc105022..74c66b161 100644 --- a/test/grouptests.jl +++ b/test/grouptests.jl @@ -22,6 +22,8 @@ const START_TIME = time() @time include("can_avx.jl") + @time include("safe_turbo.jl") + @time include("fallback.jl") @time include("utils.jl") diff --git a/test/safe_turbo.jl b/test/safe_turbo.jl new file mode 100644 index 000000000..84868d224 --- /dev/null +++ b/test/safe_turbo.jl @@ -0,0 +1,56 @@ +using LoopVectorization +using Test +import SpecialFunctions + +_f1(a) = SpecialFunctions.gamma(a) +_f2(a) = exp(a) +_f3(a, b) = a + SpecialFunctions.gamma(b) +_f4(a, b) = a + exp(b) +_f5(a, b) = a + SpecialFunctions.gamma(b) +_f6(a, b) = a + SpecialFunctions.gamma(b) + +@testset "Safe @turbo" begin + + + # All methods, both `can_avx` and `can_turbo`, should recognize that + # `gamma` is not AVX-able + + @test !LoopVectorization.ArrayInterface.can_avx(SpecialFunctions.gamma) + @test !LoopVectorization.can_turbo(SpecialFunctions.gamma, Val(1)) + @test !LoopVectorization.can_turbo(_f1, Val(1)) + + # `can_avx` is not able to detect that a function `f` which is just + # `gamma` can be AVX'd, but `can_turbo` can: + + @test LoopVectorization.ArrayInterface.can_avx(exp) + @test !LoopVectorization.ArrayInterface.can_avx(_f2) + @test LoopVectorization.can_turbo(exp, Val(1)) + @test LoopVectorization.can_turbo(_f2, Val(1)) + + # Next, we test with multiple arguments: + @test !LoopVectorization.can_turbo(_f3, Val(2)) + @test LoopVectorization.can_turbo(_f4, Val(2)) + + x = Float32.(1.05:0.1:10) + y = Float32.(0.55:0.1:10.5) + z = similar(x) + truth = similar(x) + + LoopVectorization.@turbo safe=true for i in indices(x) + z[i] = SpecialFunctions.gamma(x[i]) + end + for i in indices(x) + truth[i] = SpecialFunctions.gamma(x[i]) + end + @test z ≈ truth + + LoopVectorization.@turbo safe=true for i in indices(x) + z[i] = _f5(x[i], y[i]) + end + for i in indices(x) + truth[i] = _f6(x[i], y[i]) + end + @test z ≈ truth + +end +