Skip to content

Commit

Permalink
inference: Model type propagation through exceptions (#51754)
Browse files Browse the repository at this point in the history
Currently the type of a caught exception is always modeled as `Any`.
This isn't a huge problem, because control flow in Julia is generally
assumed to be somewhat slow, so the extra type imprecision of not
knowing the return type does not matter all that much. However, there
are a few situations where it matters. For example:

```
maybe_getindex(A, i) =
    try; A[i]; catch e; isa(e, BoundsError) && return nothing; rethrow(); end
```

At present, we cannot infer :nothrow for this method, even if that is
the only error type that `A[i]` can throw. This is particularly
noticable, since we can now optimize away `:nothrow` exception frames
entirely (#51674). Note that this PR still does not make the above
example particularly efficient (at least interprocedurally), though
specialized codegen could be added on top of this to make that happen.
It does however improve the inference result.

A second major motivation of this change is that reasoning about
exception types is likely to be a major aspect of any future work on
interface checking (since interfaces imply the absence of MethodErrors),
so this PR lays the groundwork for appropriate modeling of these error
paths.

Note that this PR adds all the required plumbing, but does not yet have
a particularly precise model of error types for our builtins, bailing to
`Any` for any builtin not known to be `:nothrow`. This can be improved
in follow up PRs as required.
  • Loading branch information
Keno committed Nov 19, 2023
1 parent f5d189f commit c8ca350
Show file tree
Hide file tree
Showing 26 changed files with 517 additions and 244 deletions.
6 changes: 3 additions & 3 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,13 @@ eval(Core, quote
end)

function CodeInstance(
mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
mi::MethodInstance, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
ipo_effects::UInt32, effects::UInt32, @nospecialize(analysis_results),
relocatability::UInt8)
return ccall(:jl_new_codeinst, Ref{CodeInstance},
(Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world,
(Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
mi, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
ipo_effects, effects, analysis_results,
relocatability)
end
Expand Down
309 changes: 202 additions & 107 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

62 changes: 62 additions & 0 deletions base/compiler/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,68 @@ function Effects(effects::Effects = _EFFECTS_UNKNOWN;
nonoverlayed)
end

function is_better_effects(new::Effects, old::Effects)
any_improved = false
if new.consistent == ALWAYS_TRUE
any_improved |= old.consistent != ALWAYS_TRUE
else
if !iszero(new.consistent & CONSISTENT_IF_NOTRETURNED)
old.consistent == ALWAYS_TRUE && return false
any_improved |= iszero(old.consistent & CONSISTENT_IF_NOTRETURNED)
elseif !iszero(new.consistent & CONSISTENT_IF_INACCESSIBLEMEMONLY)
old.consistent == ALWAYS_TRUE && return false
any_improved |= iszero(old.consistent & CONSISTENT_IF_INACCESSIBLEMEMONLY)
else
return false
end
end
if new.effect_free == ALWAYS_TRUE
any_improved |= old.consistent != ALWAYS_TRUE
elseif new.effect_free == EFFECT_FREE_IF_INACCESSIBLEMEMONLY
old.effect_free == ALWAYS_TRUE && return false
any_improved |= old.effect_free != EFFECT_FREE_IF_INACCESSIBLEMEMONLY
elseif new.effect_free != old.effect_free
return false
end
if new.nothrow
any_improved |= !old.nothrow
elseif new.nothrow != old.nothrow
return false
end
if new.terminates
any_improved |= !old.terminates
elseif new.terminates != old.terminates
return false
end
if new.notaskstate
any_improved |= !old.notaskstate
elseif new.notaskstate != old.notaskstate
return false
end
if new.inaccessiblememonly == ALWAYS_TRUE
any_improved |= old.inaccessiblememonly != ALWAYS_TRUE
elseif new.inaccessiblememonly == INACCESSIBLEMEM_OR_ARGMEMONLY
old.inaccessiblememonly == ALWAYS_TRUE && return false
any_improved |= old.inaccessiblememonly != INACCESSIBLEMEM_OR_ARGMEMONLY
elseif new.inaccessiblememonly != old.inaccessiblememonly
return false
end
if new.noub == ALWAYS_TRUE
any_improved |= old.noub != ALWAYS_TRUE
elseif new.noub == NOUB_IF_NOINBOUNDS
old.noub == ALWAYS_TRUE && return false
any_improved |= old.noub != NOUB_IF_NOINBOUNDS
elseif new.noub != old.noub
return false
end
if new.nonoverlayed
any_improved |= !old.nonoverlayed
elseif new.nonoverlayed != old.nonoverlayed
return false
end
return any_improved
end

function merge_effects(old::Effects, new::Effects)
return Effects(
merge_effectbits(old.consistent, new.consistent),
Expand Down
59 changes: 37 additions & 22 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed

mutable struct TryCatchFrame
exct
const enter_idx::Int
TryCatchFrame(@nospecialize(exct), enter_idx::Int) = new(exct, enter_idx)
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -218,7 +224,8 @@ mutable struct InferenceState
currbb::Int
currpc::Int
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handler_at::Vector{Int} # current exception handler info
handlers::Vector{TryCatchFrame}
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, exception stack) value at the pc
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
Expand All @@ -239,6 +246,7 @@ mutable struct InferenceState
unreachable::BitSet # statements that were found to be statically unreachable
valid_worlds::WorldRange
bestguess #::Type
exc_bestguess
ipo_effects::Effects

#= flags =#
Expand Down Expand Up @@ -266,7 +274,7 @@ mutable struct InferenceState

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_at = compute_trycatch(code, BitSet())
handler_at, handlers = compute_trycatch(code, BitSet())
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
Expand Down Expand Up @@ -296,6 +304,7 @@ mutable struct InferenceState

valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
bestguess = Bottom
exc_bestguess = Bottom
ipo_effects = EFFECTS_TOTAL

insert_coverage = should_insert_coverage(mod, src)
Expand All @@ -315,9 +324,9 @@ mutable struct InferenceState

return new(
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, ipo_effects,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)
end
Expand Down Expand Up @@ -347,16 +356,19 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
empty!(ip)
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill(0, n)
handler_at = fill((0, 0), n)
handlers = TryCatchFrame[]

# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isexpr(stmt, :enter)
l = stmt.args[1]::Int
handler_at[pc + 1] = pc
push!(handlers, TryCatchFrame(Bottom, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
push!(ip, pc + 1)
handler_at[l] = pc
handler_at[l] = (handler_id, handler_id)
push!(ip, l)
end
end
Expand All @@ -369,25 +381,26 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
pc´ = pc + 1 # next program-counter (after executing instruction)
delete!(ip, pc)
cur_hand = handler_at[pc]
@assert cur_hand != 0 "unbalanced try/catch"
cur_stacks = handler_at[pc]
@assert cur_stacks != (0, 0) "unbalanced try/catch"
stmt = code[pc]
if isa(stmt, GotoNode)
pc´ = stmt.label
elseif isa(stmt, GotoIfNot)
l = stmt.dest::Int
if handler_at[l] != cur_hand
@assert handler_at[l] == 0 "unbalanced try/catch"
handler_at[l] = cur_hand
if handler_at[l] != cur_stacks
@assert handler_at[l][1] == 0 || handler_at[l][1] == cur_stacks[1] "unbalanced try/catch"
handler_at[l] = cur_stacks
push!(ip, l)
end
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) "unbalanced try/catch"
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
break
elseif isa(stmt, Expr)
head = stmt.head
if head === :enter
cur_hand = pc
# Already set above
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
elseif head === :leave
l = 0
for j = 1:length(stmt.args)
Expand All @@ -403,19 +416,21 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
end
l += 1
end
cur_hand = cur_stacks[1]
for i = 1:l
cur_hand = handler_at[cur_hand]
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
end
cur_hand == 0 && break
cur_stacks = (cur_hand, cur_stacks[2])
cur_stacks == (0, 0) && break
elseif head === :pop_exception
cur_stacks = (cur_stacks[1], handler_at[(stmt.args[1]::SSAValue).id][2])
cur_stacks == (0, 0) && break
end
end

pc´ > n && break # can't proceed with the fast-path fall-through
if handler_at[pc´] != cur_hand
if handler_at[pc´] != 0
@assert false "unbalanced try/catch"
end
handler_at[pc´] = cur_hand
if handler_at[pc´] != cur_stacks
handler_at[pc´] = cur_stacks
elseif !in(pc´, ip)
break # already visited
end
Expand All @@ -424,7 +439,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
end

@assert first(ip) == n + 1
return handler_at
return handler_at, handlers
end

# check if coverage mode is enabled
Expand Down
6 changes: 4 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,10 @@ function run_passes_ipo_safe(
# @timeit "verify 2" verify_ir(ir)
@pass "compact 2" ir = compact!(ir)
@pass "SROA" ir = sroa_pass!(ir, sv.inlining)
@pass "ADCE" ir = adce_pass!(ir, sv.inlining)
@pass "compact 3" ir = compact!(ir, true)
@pass "ADCE" (ir, made_changes) = adce_pass!(ir, sv.inlining)
if made_changes
@pass "compact 3" ir = compact!(ir, true)
end
if JLOptions().debug_level == 2
@timeit "verify 3" (verify_ir(ir, true, false, optimizer_lattice(sv.inlining.interp)); verify_linetable(ir.linetable))
end
Expand Down
36 changes: 27 additions & 9 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ struct CFGTransformState
result_bbs::Vector{BasicBlock}
bb_rename_pred::Vector{Int}
bb_rename_succ::Vector{Int}
domtree::Union{Nothing, DomTree}
end

# N.B.: Takes ownership of the CFG array
Expand Down Expand Up @@ -622,11 +623,14 @@ function CFGTransformState!(blocks::Vector{BasicBlock}, allow_cfg_transforms::Bo
let blocks = blocks, bb_rename = bb_rename
result_bbs = BasicBlock[blocks[i] for i = 1:length(blocks) if bb_rename[i] != -1]
end
# TODO: This could be done by just renaming the domtree
domtree = construct_domtree(result_bbs)
else
bb_rename = Vector{Int}()
result_bbs = blocks
domtree = nothing
end
return CFGTransformState(allow_cfg_transforms, allow_cfg_transforms, result_bbs, bb_rename, bb_rename)
return CFGTransformState(allow_cfg_transforms, allow_cfg_transforms, result_bbs, bb_rename, bb_rename, domtree)
end

mutable struct IncrementalCompact
Expand Down Expand Up @@ -681,7 +685,7 @@ mutable struct IncrementalCompact
bb_rename = Vector{Int}()
pending_nodes = NewNodeStream()
pending_perm = Int[]
return new(code, parent.result, CFGTransformState(false, false, parent.cfg_transform.result_bbs, bb_rename, bb_rename),
return new(code, parent.result, CFGTransformState(false, false, parent.cfg_transform.result_bbs, bb_rename, bb_rename, nothing),
ssa_rename, parent.used_ssas,
parent.late_fixup, perm, 1,
parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm,
Expand Down Expand Up @@ -942,6 +946,14 @@ function insert_node_here!(compact::IncrementalCompact, newinst::NewInstruction,
return inst
end

function delete_inst_here!(compact)
# Delete the statement, update refcounts etc
compact[SSAValue(compact.result_idx-1)] = nothing
# Pretend that we never compacted this statement in the first place
compact.result_idx -= 1
return nothing
end

function getindex(view::TypesView, v::OldSSAValue)
id = v.id
ir = view.ir.ir
Expand Down Expand Up @@ -1222,19 +1234,25 @@ end

# N.B.: from and to are non-renamed indices
function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::Int)
# Note: We recursively kill as many edges as are obviously dead. However, this
# may leave dead loops in the IR. We kill these later in a CFG cleanup pass (or
# worstcase during codegen).
(; bb_rename_pred, bb_rename_succ, result_bbs) = compact.cfg_transform
# Note: We recursively kill as many edges as are obviously dead.
(; bb_rename_pred, bb_rename_succ, result_bbs, domtree) = compact.cfg_transform
preds = result_bbs[bb_rename_succ[to]].preds
succs = result_bbs[bb_rename_pred[from]].succs
deleteat!(preds, findfirst(x::Int->x==bb_rename_pred[from], preds)::Int)
deleteat!(succs, findfirst(x::Int->x==bb_rename_succ[to], succs)::Int)
if domtree !== nothing
domtree_delete_edge!(domtree, result_bbs, bb_rename_pred[from], bb_rename_succ[to])
end
# Check if the block is now dead
if length(preds) == 0
for succ in copy(result_bbs[bb_rename_succ[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x::Int->x==succ, bb_rename_pred)::Int)
if length(preds) == 0 || (domtree !== nothing && bb_unreachable(domtree, bb_rename_succ[to]))
to_succs = result_bbs[bb_rename_succ[to]].succs
for succ in copy(to_succs)
new_succ = findfirst(x::Int->x==succ, bb_rename_pred)
new_succ === nothing && continue
kill_edge!(compact, active_bb, to, new_succ)
end
empty!(preds)
empty!(to_succs)
if to < active_bb
# Kill all statements in the block
stmts = result_bbs[bb_rename_succ[to]].stmts
Expand Down

2 comments on commit c8ca350

@aviatesk
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("inference", vs="@e75dd479ee38907183cb940e572440411ea1c448")

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here.

Please sign in to comment.