Skip to content

Commit

Permalink
refactor the optimization flags set and utilities (#52998)
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Jan 22, 2024
1 parent 188cc93 commit 47d31ac
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 126 deletions.
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2722,7 +2722,7 @@ end
function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), vtypes::VarTable, sv::InferenceState)
if !isa(e, Expr)
if isa(e, PhiNode)
add_curr_ssaflag!(sv, IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)
add_curr_ssaflag!(sv, IR_FLAGS_REMOVABLE)
return RTEffects(abstract_eval_phi(interp, e, vtypes, sv), Union{}, EFFECTS_TOTAL)
end
(; rt, exct, effects) = abstract_eval_special_value(interp, e, vtypes, sv)
Expand Down
96 changes: 74 additions & 22 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,61 @@ const IR_FLAG_INBOUNDS = one(UInt32) << 0
const IR_FLAG_INLINE = one(UInt32) << 1
# This statement is marked as @noinline by user
const IR_FLAG_NOINLINE = one(UInt32) << 2
# This statement is on a code path that eventually `throw`s.
const IR_FLAG_THROW_BLOCK = one(UInt32) << 3
# This statement was proven :effect_free
const IR_FLAG_EFFECT_FREE = one(UInt32) << 4
# This statement was proven not to throw
const IR_FLAG_NOTHROW = one(UInt32) << 5
# This is :consistent
const IR_FLAG_CONSISTENT = one(UInt32) << 6
# An optimization pass has updated this statement in a way that may
# have exposed information that inference did not see. Re-running
# inference on this statement may be profitable.
const IR_FLAG_REFINED = one(UInt32) << 7
# This is :noub == ALWAYS_TRUE
const IR_FLAG_NOUB = one(UInt32) << 8

const IR_FLAG_REFINED = one(UInt32) << 4
# This statement is proven :consistent
const IR_FLAG_CONSISTENT = one(UInt32) << 5
# This statement is proven :effect_free
const IR_FLAG_EFFECT_FREE = one(UInt32) << 6
# This statement is proven :nothrow
const IR_FLAG_NOTHROW = one(UInt32) << 7
# This statement is proven :terminates
const IR_FLAG_TERMINATES = one(UInt32) << 8
# This statement is proven :noub
const IR_FLAG_NOUB = one(UInt32) << 9
# TODO: Both of these should eventually go away once
# This is :effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
const IR_FLAG_EFIIMO = one(UInt32) << 9
# This is :inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
const IR_FLAG_INACCESSIBLE_OR_ARGMEM = one(UInt32) << 10
# This statement is :effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
const IR_FLAG_EFIIMO = one(UInt32) << 10
# This statement is :inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
const IR_FLAG_INACCESSIBLEMEM_OR_ARGMEM = one(UInt32) << 11

const NUM_IR_FLAGS = 12 # sync with julia.h

const NUM_IR_FLAGS = 11 # sync with julia.h
const IR_FLAGS_EFFECTS =
IR_FLAG_CONSISTENT | IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW | IR_FLAG_NOUB

const IR_FLAGS_EFFECTS = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW | IR_FLAG_CONSISTENT | IR_FLAG_NOUB
const IR_FLAGS_REMOVABLE = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW

const IR_FLAGS_NEEDS_EA = IR_FLAG_EFIIMO | IR_FLAG_INACCESSIBLEMEM_OR_ARGMEM

has_flag(curr::UInt32, flag::UInt32) = (curr & flag) == flag

function flags_for_effects(effects::Effects)
flags = zero(UInt32)
if is_consistent(effects)
flags |= IR_FLAG_CONSISTENT
end
if is_effect_free(effects)
flags |= IR_FLAG_EFFECT_FREE
elseif is_effect_free_if_inaccessiblememonly(effects)
flags |= IR_FLAG_EFIIMO
end
if is_nothrow(effects)
flags |= IR_FLAG_NOTHROW
end
if is_inaccessiblemem_or_argmemonly(effects)
flags |= IR_FLAG_INACCESSIBLEMEM_OR_ARGMEM
end
if is_noub(effects)
flags |= IR_FLAG_NOUB
end
return flags
end

const TOP_TUPLE = GlobalRef(Core, :tuple)

# This corresponds to the type of `CodeInfo`'s `inlining_cost` field
Expand Down Expand Up @@ -263,9 +292,9 @@ end

"""
stmt_effect_flags(stmt, rt, src::Union{IRCode,IncrementalCompact}) ->
(consistent::Bool, effect_free_and_nothrow::Bool, nothrow::Bool)
(consistent::Bool, removable::Bool, nothrow::Bool)
Returns a tuple of `(:consistent, :effect_free_and_nothrow, :nothrow)` flags for a given statement.
Returns a tuple of `(:consistent, :removable, :nothrow)` flags for a given statement.
"""
function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact})
# TODO: We're duplicating analysis from inference here.
Expand Down Expand Up @@ -309,7 +338,8 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
consistent = is_consistent(effects)
effect_free = is_effect_free(effects)
nothrow = is_nothrow(effects)
return (consistent, effect_free & nothrow, nothrow)
removable = effect_free & nothrow
return (consistent, removable, nothrow)
elseif head === :new
return new_expr_effect_flags(𝕃ₒ, args, src)
elseif head === :foreigncall
Expand All @@ -319,7 +349,8 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
consistent = is_consistent(effects)
effect_free = is_effect_free(effects)
nothrow = is_nothrow(effects)
return (consistent, effect_free & nothrow, nothrow)
removable = effect_free & nothrow
return (consistent, removable, nothrow)
elseif head === :new_opaque_closure
length(args) < 4 && return (false, false, false)
typ = argextype(args[1], src)
Expand All @@ -346,6 +377,29 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
return (true, true, true)
end

function recompute_effects_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospecialize(rt),
src::Union{IRCode,IncrementalCompact})
flag = IR_FLAG_NULL
(consistent, removable, nothrow) = stmt_effect_flags(𝕃ₒ, stmt, rt, src)
if consistent
flag |= IR_FLAG_CONSISTENT
end
if removable
flag |= IR_FLAGS_REMOVABLE
elseif nothrow
flag |= IR_FLAG_NOTHROW
end
if !(isexpr(stmt, :call) || isexpr(stmt, :invoke))
# There is a bit of a subtle point here, which is that some non-call
# statements (e.g. PiNode) can be UB:, however, we consider it
# illegal to introduce such statements that actually cause UB (for any
# input). Ideally that'd be handled at insertion time (TODO), but for
# the time being just do that here.
flag |= IR_FLAG_NOUB
end
return flag
end

"""
argextype(x, src::Union{IRCode,IncrementalCompact}) -> t
argextype(x, src::CodeInfo, sptypes::Vector{VarState}) -> t
Expand Down Expand Up @@ -694,8 +748,6 @@ function is_conditional_noub(inst::Instruction, sv::PostOptAnalysisState)
return true
end

const IR_FLAGS_NEEDS_EA = IR_FLAG_EFIIMO | IR_FLAG_INACCESSIBLE_OR_ARGMEM

function scan_non_dataflow_flags!(inst::Instruction, sv::PostOptAnalysisState)
flag = inst[:flag]
# If we can prove that the argmem does not escape the current function, we can
Expand Down
12 changes: 6 additions & 6 deletions base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ using ._TOP_MOD: # Base definitions
unwrap_unionall, !, !=, !==, &, *, +, -, :, <, <<, =>, >, |, , , , , , , ,
using Core.Compiler: # Core.Compiler specific definitions
Bottom, IRCode, IR_FLAG_NOTHROW, InferenceResult, SimpleInferenceLattice,
argextype, check_effect_free!, fieldcount_noerror, hasintersect, has_flag,
intrinsic_nothrow, is_meta_expr_head, isbitstype, isexpr, println, setfield!_nothrow,
singleton_type, try_compute_field, try_compute_fieldidx, widenconst, , AbstractLattice
argextype, fieldcount_noerror, hasintersect, has_flag, intrinsic_nothrow,
is_meta_expr_head, isbitstype, isexpr, println, setfield!_nothrow, singleton_type,
try_compute_field, try_compute_fieldidx, widenconst, , AbstractLattice

include(x) = _TOP_MOD.include(@__MODULE__, x)
if _TOP_MOD === Core.Compiler
Expand Down Expand Up @@ -597,12 +597,12 @@ struct LivenessChange <: Change
end
const Changes = Vector{Change}

struct AnalysisState{T, L <: AbstractLattice}
struct AnalysisState{GetEscapeCache, Lattice<:AbstractLattice}
ir::IRCode
estate::EscapeState
changes::Changes
𝕃ₒ::L
get_escape_cache::T
𝕃ₒ::Lattice
get_escape_cache::GetEscapeCache
end

"""
Expand Down
77 changes: 19 additions & 58 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ function ir_prepare_inlining!(insert_node!::Inserter, inline_target::Union{IRCod
if !validate_sparams(mi.sparam_vals)
# N.B. This works on the caller-side argexprs, (i.e. before the va fixup below)
spvals_ssa = insert_node!(
effect_free_and_nothrow(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
removable_if_unused(NewInstruction(Expr(:call, Core._compute_sparams, def, argexprs...), SimpleVector, topline)))
end
if def.isva
nargs_def = Int(def.nargs::Int32)
Expand Down Expand Up @@ -425,7 +425,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact.result[idx′][:type] =
argextype(val, isa(val, Argument) || isa(val, Expr) ? compact : inline_compact)
# Everything legal in value position is guaranteed to be effect free in stmt position
inline_compact.result[idx′][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
inline_compact.result[idx′][:flag] = IR_FLAGS_REMOVABLE
break
elseif isexpr(stmt′, :boundscheck)
adjust_boundscheck!(inline_compact, idx′, stmt′, boundscheck)
Expand Down Expand Up @@ -692,7 +692,7 @@ function batch_inline!(ir::IRCode, todo::Vector{Pair{Int,Any}}, propagate_inboun
for aidx in 1:length(argexprs)
aexpr = argexprs[aidx]
if isa(aexpr, Expr) || isa(aexpr, GlobalRef)
ninst = effect_free_and_nothrow(NewInstruction(aexpr, argextype(aexpr, compact), compact.result[idx][:line]))
ninst = removable_if_unused(NewInstruction(aexpr, argextype(aexpr, compact), compact.result[idx][:line]))
argexprs[aidx] = insert_node_here!(compact, ninst)
end
end
Expand Down Expand Up @@ -996,28 +996,6 @@ function retrieve_ir_for_inlining(::MethodInstance, ir::IRCode, preserve_local_s
return ir
end

function flags_for_effects(effects::Effects)
flags::UInt32 = 0
if is_consistent(effects)
flags |= IR_FLAG_CONSISTENT
end
if is_effect_free(effects)
flags |= IR_FLAG_EFFECT_FREE
elseif is_effect_free_if_inaccessiblememonly(effects)
flags |= IR_FLAG_EFIIMO
end
if is_inaccessiblemem_or_argmemonly(effects)
flags |= IR_FLAG_INACCESSIBLE_OR_ARGMEM
end
if is_nothrow(effects)
flags |= IR_FLAG_NOTHROW
end
if is_noub(effects)
flags |= IR_FLAG_NOUB
end
return flags
end

function handle_single_case!(todo::Vector{Pair{Int,Any}},
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(case),
isinvoke::Bool = false)
Expand Down Expand Up @@ -1252,29 +1230,12 @@ end
# As a matter of convenience, this pass also computes effect-freenes.
# For primitives, we do that right here. For proper calls, we will
# discover this when we consult the caches.
function check_effect_free!(ir::IRCode, idx::Int, @nospecialize(stmt), @nospecialize(rt), state::InliningState)
return check_effect_free!(ir, idx, stmt, rt, optimizer_lattice(state.interp))
end
function check_effect_free!(ir::IRCode, idx::Int, @nospecialize(stmt), @nospecialize(rt), 𝕃ₒ::AbstractLattice)
(consistent, effect_free_and_nothrow, nothrow) = stmt_effect_flags(𝕃ₒ, stmt, rt, ir)
inst = ir.stmts[idx]
if consistent
add_flag!(inst, IR_FLAG_CONSISTENT)
end
if effect_free_and_nothrow
add_flag!(inst, IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)
elseif nothrow
add_flag!(inst, IR_FLAG_NOTHROW)
end
if !(isexpr(stmt, :call) || isexpr(stmt, :invoke))
# There is a bit of a subtle point here, which is that some non-call
# statements (e.g. PiNode) can be UB:, however, we consider it
# illegal to introduce such statements that actually cause UB (for any
# input). Ideally that'd be handled at insertion time (TODO), but for
# the time being just do that here.
add_flag!(inst, IR_FLAG_NOUB)
end
return effect_free_and_nothrow
add_inst_flag!(inst::Instruction, ir::IRCode, state::InliningState) =
add_inst_flag!(inst, ir, optimizer_lattice(state.interp))
function add_inst_flag!(inst::Instruction, ir::IRCode, 𝕃ₒ::AbstractLattice)
flags = recompute_effects_flags(𝕃ₒ, inst[:stmt], inst[:type], ir)
add_flag!(inst, flags)
return !iszero(flags & IR_FLAGS_REMOVABLE)
end

# Handles all analysis and inlining of intrinsics and builtins. In particular,
Expand All @@ -1283,11 +1244,11 @@ end
function process_simple!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, state::InliningState)
inst = ir[SSAValue(idx)]
stmt = inst[:stmt]
rt = inst[:type]
if !(stmt isa Expr)
check_effect_free!(ir, idx, stmt, rt, state)
add_inst_flag!(inst, ir, state)
return nothing
end
rt = inst[:type]
head = stmt.head
if head !== :call
if head === :splatnew
Expand All @@ -1299,7 +1260,7 @@ function process_simple!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stat
sig === nothing && return nothing
return stmt, sig
end
check_effect_free!(ir, idx, stmt, rt, state)
add_inst_flag!(inst, ir, state)
return nothing
end

Expand All @@ -1317,7 +1278,7 @@ function process_simple!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stat
return nothing
end

if check_effect_free!(ir, idx, stmt, rt, state)
if add_inst_flag!(inst, ir, state)
if sig.f === typeassert || (optimizer_lattice(state.interp), sig.ft, typeof(typeassert))
# typeassert is a no-op if effect free
inst[:stmt] = stmt.args[2]
Expand All @@ -1335,7 +1296,7 @@ function process_simple!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stat
lateres = late_inline_special_case!(ir, idx, stmt, rt, sig, state)
if isa(lateres, SomeCase)
inst[:stmt] = lateres.val
check_effect_free!(ir, idx, lateres.val, rt, state)
add_inst_flag!(inst, ir, state)
return nothing
end

Expand Down Expand Up @@ -1683,7 +1644,7 @@ function inline_const_if_inlineable!(inst::Instruction)
inst[:stmt] = quoted(rt.val)
return true
end
add_flag!(inst, IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)
add_flag!(inst, IR_FLAGS_REMOVABLE)
return false
end

Expand Down Expand Up @@ -1808,7 +1769,7 @@ function late_inline_special_case!(
return SomeCase(quoted(type.val))
end
cmp_call = Expr(:call, GlobalRef(Core, :(===)), stmt.args[2], stmt.args[3])
cmp_call_ssa = insert_node!(ir, idx, effect_free_and_nothrow(NewInstruction(cmp_call, Bool)))
cmp_call_ssa = insert_node!(ir, idx, removable_if_unused(NewInstruction(cmp_call, Bool)))
not_call = Expr(:call, GlobalRef(Core.Intrinsics, :not_int), cmp_call_ssa)
return SomeCase(not_call)
elseif length(argtypes) == 3 && istopfunction(f, :(>:))
Expand Down Expand Up @@ -1853,13 +1814,13 @@ end

function insert_spval!(insert_node!::Inserter, spvals_ssa::SSAValue, spidx::Int, do_isdefined::Bool)
ret = insert_node!(
effect_free_and_nothrow(NewInstruction(Expr(:call, Core._svec_ref, spvals_ssa, spidx), Any)))
removable_if_unused(NewInstruction(Expr(:call, Core._svec_ref, spvals_ssa, spidx), Any)))
tcheck_not = nothing
if do_isdefined
tcheck = insert_node!(
effect_free_and_nothrow(NewInstruction(Expr(:call, Core.isa, ret, Core.TypeVar), Bool)))
removable_if_unused(NewInstruction(Expr(:call, Core.isa, ret, Core.TypeVar), Bool)))
tcheck_not = insert_node!(
effect_free_and_nothrow(NewInstruction(Expr(:call, not_int, tcheck), Bool)))
removable_if_unused(NewInstruction(Expr(:call, not_int, tcheck), Bool)))
end
return (ret, tcheck_not)
end
Expand Down
Loading

0 comments on commit 47d31ac

Please sign in to comment.