Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize away try/catch blocks that are known not to trigger #51674

Merged
merged 2 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 21 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2981,6 +2981,20 @@ function update_bestguess!(interp::AbstractInterpreter, frame::InferenceState,
end
end

function propagate_to_error_handler!(frame::InferenceState, currpc::Int, W::BitSet, 𝕃ᵢ::AbstractLattice, currstate::VarTable)
# If this statement potentially threw, propagate the currstate to the
# exception handler, BEFORE applying any state changes.
cur_hand = frame.handler_at[currpc]
if cur_hand != 0
enter = frame.src.code[cur_hand]::Expr
l = enter.args[1]::Int
exceptbb = block_for_inst(frame.cfg, l)
if update_bbstate!(𝕃ᵢ, frame, exceptbb, currstate)
push!(W, exceptbb)
end
end
end

# make as much progress on `frame` as possible (without handling cycles)
function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@assert !is_inferred(frame)
Expand Down Expand Up @@ -3037,6 +3051,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if nothrow
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
else
propagate_to_error_handler!(frame, currpc, W, 𝕃ᵢ, currstate)
merge_effects!(interp, frame, EFFECTS_THROWS)
end

Expand Down Expand Up @@ -3107,12 +3122,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
ssavaluetypes[frame.currpc] = Any
@goto find_next_bb
elseif isexpr(stmt, :enter)
# Propagate entry info to exception handler
l = stmt.args[1]::Int
catchbb = block_for_inst(frame.cfg, l)
if update_bbstate!(𝕃ᵢ, frame, catchbb, currstate)
push!(W, catchbb)
end
ssavaluetypes[currpc] = Any
@goto fallthrough
elseif isexpr(stmt, :leave)
ssavaluetypes[currpc] = Any
@goto fallthrough
end
Expand All @@ -3121,26 +3133,15 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# Process non control-flow statements
(; changes, type) = abstract_eval_basic_statement(interp,
stmt, currstate, frame)
if (get_curr_ssaflag(frame) & IR_FLAG_NOTHROW) != IR_FLAG_NOTHROW
propagate_to_error_handler!(frame, currpc, W, 𝕃ᵢ, currstate)
end
if type === Bottom
ssavaluetypes[currpc] = Bottom
@goto find_next_bb
end
if changes !== nothing
stoverwrite1!(currstate, changes)
let cur_hand = frame.handler_at[currpc], l, enter
while cur_hand != 0
enter = frame.src.code[cur_hand]::Expr
l = enter.args[1]::Int
exceptbb = block_for_inst(frame.cfg, l)
# propagate new type info to exception handler
# the handling for Expr(:enter) propagates all changes from before the try/catch
# so this only needs to propagate any changes
if stupdate1!(𝕃ᵢ, states[exceptbb]::VarTable, changes)
Keno marked this conversation as resolved.
Show resolved Hide resolved
push!(W, exceptbb)
end
cur_hand = frame.handler_at[cur_hand]
end
end
end
if type === nothing
ssavaluetypes[currpc] = Any
Expand Down
55 changes: 34 additions & 21 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -839,27 +839,35 @@ function convert_to_ircode(ci::CodeInfo, sv::OptimizationState)
code = copy_exprargs(ci.code)
for i = 1:length(code)
expr = code[i]
if !(i in sv.unreachable) && isa(expr, GotoIfNot)
# Replace this live GotoIfNot with:
# - no-op if :nothrow and the branch target is unreachable
# - cond if :nothrow and both targets are unreachable
# - typeassert if must-throw
block = block_for_inst(sv.cfg, i)
if ssavaluetypes[i] === Bottom
destblock = block_for_inst(sv.cfg, expr.dest)
cfg_delete_edge!(sv.cfg, block, block + 1)
((block + 1) != destblock) && cfg_delete_edge!(sv.cfg, block, destblock)
expr = Expr(:call, Core.typeassert, expr.cond, Bool)
elseif i + 1 in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
cfg_delete_edge!(sv.cfg, block, block + 1)
expr = GotoNode(expr.dest)
elseif expr.dest in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = nothing
if !(i in sv.unreachable)
if isa(expr, GotoIfNot)
# Replace this live GotoIfNot with:
# - no-op if :nothrow and the branch target is unreachable
# - cond if :nothrow and both targets are unreachable
# - typeassert if must-throw
block = block_for_inst(sv.cfg, i)
if ssavaluetypes[i] === Bottom
destblock = block_for_inst(sv.cfg, expr.dest)
cfg_delete_edge!(sv.cfg, block, block + 1)
((block + 1) != destblock) && cfg_delete_edge!(sv.cfg, block, destblock)
expr = Expr(:call, Core.typeassert, expr.cond, Bool)
elseif i + 1 in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
cfg_delete_edge!(sv.cfg, block, block + 1)
expr = GotoNode(expr.dest)
elseif expr.dest in sv.unreachable
@assert (ci.ssaflags[i] & IR_FLAG_NOTHROW) != 0
cfg_delete_edge!(sv.cfg, block, block_for_inst(sv.cfg, expr.dest))
expr = nothing
end
code[i] = expr
elseif isexpr(expr, :enter)
catchdest = expr.args[1]::Int
if catchdest in sv.unreachable
cfg_delete_edge!(sv.cfg, block_for_inst(sv.cfg, i), block_for_inst(sv.cfg, catchdest))
code[i] = nothing
end
end
code[i] = expr
end
end

Expand Down Expand Up @@ -1239,7 +1247,12 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
end
if el.head === :enter
tgt = el.args[1]::Int
el.args[1] = tgt + labelchangemap[tgt]
was_deleted = labelchangemap[tgt] == typemin(Int)
if was_deleted
body[i] = nothing
else
el.args[1] = tgt + labelchangemap[tgt]
end
elseif !is_meta_expr_head(el.head)
args = el.args
for i = 1:length(args)
Expand Down
15 changes: 15 additions & 0 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,21 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
ssa_rename[idx] = nothing
return result_idx
end
elseif isexpr(stmt, :leave)
let i = 1
while i <= length(stmt.args)
if stmt.args[i] === nothing
deleteat!(stmt.args, i)
else
i += 1
end
end
end
if isempty(stmt.args)
# This :leave is dead
ssa_rename[idx] = nothing
return result_idx
end
end
typ = inst[:type]
if isa(typ, Const) && is_inlineable_constant(typ.val)
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,8 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
new_code[idx] = GotoIfNot(stmt.cond, new_dest)
end
elseif isexpr(stmt, :enter)
new_code[idx] = Expr(:enter, block_for_inst(cfg, stmt.args[1]::Int))
except_bb = block_for_inst(cfg, stmt.args[1]::Int)
new_code[idx] = Expr(:enter, except_bb)
ssavalmap[idx] = SSAValue(idx) # Slot to store token for pop_exception
elseif isexpr(stmt, :leave) || isexpr(stmt, :(=)) || isa(stmt, ReturnNode) ||
isexpr(stmt, :meta) || isa(stmt, NewvarNode)
Expand Down
18 changes: 0 additions & 18 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -759,24 +759,6 @@ function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable)
return changed
end

function stupdate1!(lattice::AbstractLattice, state::VarTable, change::StateUpdate)
changeid = slot_id(change.var)
for i = 1:length(state)
invalidated = invalidate_slotwrapper(state[i], changeid, change.conditional)
if invalidated !== nothing
state[i] = invalidated
end
end
# and update the type of it
newtype = change.vtype
oldtype = state[changeid]
if schanged(lattice, newtype, oldtype)
state[changeid] = smerge(lattice, oldtype, newtype)
return true
end
return false
end

function stoverwrite!(state::VarTable, newstate::VarTable)
for i = 1:length(state)
state[i] = newstate[i]
Expand Down
11 changes: 8 additions & 3 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5180,14 +5180,19 @@ static void emit_ssaval_assign(jl_codectx_t &ctx, ssize_t ssaidx_0based, jl_valu
ctx.ssavalue_assigned[ssaidx_0based] = true;
}

static void emit_varinfo_assign(jl_codectx_t &ctx, jl_varinfo_t &vi, jl_cgval_t rval_info, jl_value_t *l=NULL)
static void emit_varinfo_assign(jl_codectx_t &ctx, jl_varinfo_t &vi, jl_cgval_t rval_info, jl_value_t *l=NULL, bool allow_mismatch=false)
{
if (!vi.used || vi.value.typ == jl_bottom_type)
return;

// convert rval-type to lval-type
jl_value_t *slot_type = vi.value.typ;
rval_info = convert_julia_type(ctx, rval_info, slot_type);
// If allow_mismatch is set, type mismatches will not result in traps.
// This is used for upsilon nodes, where the destination can have a narrower
// type than the store, if inference determines that the store is never read.
Value *dummy = NULL;
Value **skip = allow_mismatch ? &dummy : NULL;
rval_info = convert_julia_type(ctx, rval_info, slot_type, skip);
if (rval_info.typ == jl_bottom_type)
return;

Expand Down Expand Up @@ -5284,7 +5289,7 @@ static void emit_upsilonnode(jl_codectx_t &ctx, ssize_t phic, jl_value_t *val)
// was unreachable and dead
val = NULL;
else
emit_varinfo_assign(ctx, vi, rval_info);
emit_varinfo_assign(ctx, vi, rval_info, NULL, true);
}
if (!val) {
if (vi.boxroot) {
Expand Down
69 changes: 69 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5284,3 +5284,72 @@ end
@test only(Base.return_types((x,f) -> getfield(x, f), (An51317, Symbol))) === Int
@test only(Base.return_types(x -> getfield(x, :b), (A51317,))) === Union{}
@test only(Base.return_types(x -> getfield(x, :b), (An51317,))) === Union{}

# Don't visit the catch block for empty try/catch
function completely_dead_try_catch()
try
catch
return 2.0
end
return 1
end
@test Base.return_types(completely_dead_try_catch) |> only === Int
@test fully_eliminated(completely_dead_try_catch)

function nothrow_try_catch()
try
1+1
catch
return 2.0
end
return 1
end
@test Base.return_types(nothrow_try_catch) |> only === Int
@test fully_eliminated(nothrow_try_catch)

may_error(b) = Base.inferencebarrier(b) && error()
function phic_type1()
a = 1
try
may_error(false)
a = 1.0
catch
return a
end
return 2
end
@test Base.return_types(phic_type1) |> only === Int
@test phic_type1() === 2

function phic_type2()
a = 1
try
may_error(false)
a = 1.0
may_error(false)
catch
return a
end
return 2
end
@test Base.return_types(phic_type2) |> only === Union{Int, Float64}
@test phic_type2() === 2

function phic_type3()
a = 1
try
may_error(false)
a = 1.0
may_error(false)
if Base.inferencebarrier(false)
a = Ref(1)
elseif Base.inferencebarrier(false)
a = nothing
end
catch
return a
end
return 2
end
@test Base.return_types(phic_type3) |> only === Union{Int, Float64}
@test phic_type3() === 2