Skip to content

Commit

Permalink
inference: Model type propagation through exceptions
Browse files Browse the repository at this point in the history
Currently the type of a caught exception is always modeled as `Any`.
This isn't a huge problem, because control flow in Julia is generally
assumed to be somewhat slow, so the extra type imprecision of not
knowing the return type does not matter all that much. However,
there are a few situations where it matters. For example:

```
maybe_getindex(A, i) =
    try; A[i]; catch e; isa(e, BoundsError) && return nothing; rethrow(); end
```

At present, we cannot infer :nothrow for this method, even if that
is the only error type that `A[i]` can throw. This is particularly
noticable, since we can now optimize away `:nothrow` exception frames
entirely (#51674). Note that this PR still does not make the above
example particularly efficient (at least interprocedurally), though
specialized codegen could be added on top of this to make that happen.
It does however improve the inference result.

A second major motivation of this change is that reasoning about
exception types is likely to be a major aspect of any future work
on interface checking (since interfaces imply the absence of
MethodErrors), so this PR lays the groundwork for appropriate modeling
of these error paths.

Note that this PR adds all the required plumbing, but does not yet have
a particularly precise model of error types for our builtins, bailing
to `Any` for any builtin not known to be `:nothrow`. This can be improved
in follow up PRs as required.
  • Loading branch information
Keno committed Oct 18, 2023
1 parent 86cbb60 commit 0397be1
Show file tree
Hide file tree
Showing 17 changed files with 293 additions and 180 deletions.
6 changes: 3 additions & 3 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,13 @@ eval(Core, quote
end)

function CodeInstance(
mi::MethodInstance, @nospecialize(rettype), @nospecialize(inferred_const),
mi::MethodInstance, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
ipo_effects::UInt32, effects::UInt32, @nospecialize(argescapes#=::Union{Nothing,Vector{ArgEscapeInfo}}=#),
relocatability::UInt8)
return ccall(:jl_new_codeinst, Ref{CodeInstance},
(Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world,
(Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, UInt32, Any, UInt8),
mi, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
ipo_effects, effects, argescapes,
relocatability)
end
Expand Down
258 changes: 162 additions & 96 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

57 changes: 35 additions & 22 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}

mutable struct TryCatchFrame
exct
const enter_idx
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -213,7 +218,8 @@ mutable struct InferenceState
currbb::Int
currpc::Int
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handler_at::Vector{Int} # current exception handler info
handlers::Vector{TryCatchFrame}
handler_at::Vector{Tuple{Int, Int}} # tuple of current (handler, excecption stack) value at the pc
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
Expand All @@ -234,6 +240,7 @@ mutable struct InferenceState
unreachable::BitSet # statements that were found to be statically unreachable
valid_worlds::WorldRange
bestguess #::Type
exc_bestguess
ipo_effects::Effects

#= flags =#
Expand Down Expand Up @@ -261,7 +268,7 @@ mutable struct InferenceState

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_at = compute_trycatch(code, BitSet())
handler_at, handlers = compute_trycatch(code, BitSet())
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
Expand Down Expand Up @@ -291,6 +298,7 @@ mutable struct InferenceState

valid_worlds = WorldRange(src.min_world, src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
bestguess = Bottom
exc_bestguess = Bottom
ipo_effects = EFFECTS_TOTAL

insert_coverage = should_insert_coverage(mod, src)
Expand All @@ -312,9 +320,9 @@ mutable struct InferenceState

return new(
linfo, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handlers, handler_at, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, ipo_effects,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cached, insert_coverage,
interp)
end
Expand Down Expand Up @@ -344,16 +352,19 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
empty!(ip)
ip.offset = 0 # for _bits_findnext
push!(ip, n + 1)
handler_at = fill(0, n)
handler_at = fill((0, 0), n)
handlers = TryCatchFrame[]

# start from all :enter statements and record the location of the try
for pc = 1:n
stmt = code[pc]
if isexpr(stmt, :enter)
l = stmt.args[1]::Int
handler_at[pc + 1] = pc
push!(handlers, TryCatchFrame(Bottom, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
push!(ip, pc + 1)
handler_at[l] = pc
handler_at[l] = (handler_id, handler_id)
push!(ip, l)
end
end
Expand All @@ -366,25 +377,26 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
pc´ = pc + 1 # next program-counter (after executing instruction)
delete!(ip, pc)
cur_hand = handler_at[pc]
@assert cur_hand != 0 "unbalanced try/catch"
cur_stacks = handler_at[pc]
@assert cur_stacks != (0, 0) "unbalanced try/catch"
stmt = code[pc]
if isa(stmt, GotoNode)
pc´ = stmt.label
elseif isa(stmt, GotoIfNot)
l = stmt.dest::Int
if handler_at[l] != cur_hand
@assert handler_at[l] == 0 "unbalanced try/catch"
handler_at[l] = cur_hand
if handler_at[l] != cur_stacks
@assert handler_at[l][1] == 0 || handler_at[l][1] == cur_stacks[1] "unbalanced try/catch"
handler_at[l] = cur_stacks
push!(ip, l)
end
elseif isa(stmt, ReturnNode)
@assert !isdefined(stmt, :val) "unbalanced try/catch"
@assert !isdefined(stmt, :val) || cur_stacks[1] == 0 "unbalanced try/catch"
break
elseif isa(stmt, Expr)
head = stmt.head
if head === :enter
cur_hand = pc
# Already set aboves
cur_stacks = (handler_at[pc´][1], cur_stacks[2])
elseif head === :leave
l = 0
for j = 1:length(stmt.args)
Expand All @@ -400,19 +412,20 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
end
l += 1
end
cur_hand = cur_stacks[1]
for i = 1:l
cur_hand = handler_at[cur_hand]
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
end
cur_hand == 0 && break
cur_stacks = (cur_hand, cur_stacks[2])
cur_stacks == (0, 0) && break
elseif head === :pop_exception
cur_stacks = (cur_stacks[1], handler_at[(stmt.args[1]::SSAValue).id][2])
end
end

pc´ > n && break # can't proceed with the fast-path fall-through
if handler_at[pc´] != cur_hand
if handler_at[pc´] != 0
@assert false "unbalanced try/catch"
end
handler_at[pc´] = cur_hand
if handler_at[pc´] != cur_stacks
handler_at[pc´] = cur_stacks
elseif !in(pc´, ip)
break # already visited
end
Expand All @@ -421,7 +434,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
end

@assert first(ip) == n + 1
return handler_at
return handler_at, handlers
end

# check if coverage mode is enabled
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
si = StmtInfo(true) # TODO better job here?
(; rt, effects, info) = abstract_call(interp, arginfo, si, irsv)
(; rt, exct, effects, info) = abstract_call(interp, arginfo, si, irsv)
irsv.ir.stmts[irsv.curridx][:info] = info
return RTEffects(rt, effects)
return RTEffects(rt, exct, effects)
end

function update_phi!(irsv::IRInterpretationState, from::Int, to::Int)
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
end

# Record the correct exception handler for all critical sections
handler_at = compute_trycatch(code, BitSet())
handler_at, handlers = compute_trycatch(code, BitSet())

phi_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
live_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
Expand Down Expand Up @@ -810,8 +810,8 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
incoming_vals[id] = Pair{Any, Any}(thisval, thisdef)
has_pinode[id] = false
enter_idx = idx
while handler_at[enter_idx] != 0
enter_idx = handler_at[enter_idx]
while handler_at[enter_idx][1] != 0
(; enter_idx) = handlers[handler_at[enter_idx][1]]
leave_block = block_for_inst(cfg, code[enter_idx].args[1]::Int)
cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block])
if cidx !== nothing
Expand Down
1 change: 1 addition & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and any additional information (`call.info`) for a given generic call.
"""
struct CallMeta
rt::Any
exct::Any
effects::Effects
info::CallInfo
end
Expand Down
42 changes: 21 additions & 21 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1392,10 +1392,10 @@ end
function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::AbsIntState)
nargs = length(argtypes)
if !isempty(argtypes) && isvarargtype(argtypes[nargs])
nargs - 1 <= 6 || return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
nargs > 3 || return CallMeta(Any, Effects(), NoCallInfo())
nargs - 1 <= 6 || return CallMeta(Bottom, Any, EFFECTS_THROWS, NoCallInfo())
nargs > 3 || return CallMeta(Any, Any, Effects(), NoCallInfo())
else
5 <= nargs <= 6 || return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
5 <= nargs <= 6 || return CallMeta(Bottom, Any, EFFECTS_THROWS, NoCallInfo())
end
𝕃ᵢ = typeinf_lattice(interp)
o = unwrapva(argtypes[2])
Expand All @@ -1417,7 +1417,7 @@ function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any
end
info = ModifyFieldInfo(callinfo.info)
end
return CallMeta(RT, Effects(), info)
return CallMeta(RT, Any, Effects(), info)
end
@nospecs function replacefield!_tfunc(𝕃::AbstractLattice, o, f, x, v, success_order, failure_order)
return replacefield!_tfunc(𝕃, o, f, x, v)
Expand Down Expand Up @@ -2650,7 +2650,7 @@ end
# since abstract_call_gf_by_type is a very inaccurate model of _method and of typeinf_type,
# while this assumes that it is an absolutely precise and accurate and exact model of both
function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, si::StmtInfo, sv::AbsIntState)
UNKNOWN = CallMeta(Type, EFFECTS_THROWS, NoCallInfo())
UNKNOWN = CallMeta(Type, Any, EFFECTS_THROWS, NoCallInfo())
if !(2 <= length(argtypes) <= 3)
return UNKNOWN
end
Expand Down Expand Up @@ -2679,7 +2679,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
end

if contains_is(argtypes_vec, Union{})
return CallMeta(Const(Union{}), EFFECTS_TOTAL, NoCallInfo())
return CallMeta(Const(Union{}), Any, EFFECTS_TOTAL, NoCallInfo())
end

# Run the abstract_call without restricting abstract call
Expand All @@ -2697,12 +2697,12 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
rt = widenslotwrapper(call.rt)
if isa(rt, Const)
# output was computed to be constant
return CallMeta(Const(typeof(rt.val)), EFFECTS_TOTAL, info)
return CallMeta(Const(typeof(rt.val)), Any, EFFECTS_TOTAL, info)
end
rt = widenconst(rt)
if rt === Bottom || (isconcretetype(rt) && !iskindtype(rt))
# output cannot be improved so it is known for certain
return CallMeta(Const(rt), EFFECTS_TOTAL, info)
return CallMeta(Const(rt), Union{}, EFFECTS_TOTAL, info)
elseif isa(sv, InferenceState) && !isempty(sv.pclimitations)
# conservatively express uncertainty of this result
# in two ways: both as being a subtype of this, and
Expand All @@ -2711,19 +2711,19 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
elseif isa(tt, Const) || isconstType(tt)
# input arguments were known for certain
# XXX: this doesn't imply we know anything about rt
return CallMeta(Const(rt), EFFECTS_TOTAL, info)
return CallMeta(Const(rt), Union{}, EFFECTS_TOTAL, info)
elseif isType(rt)
return CallMeta(Type{rt}, EFFECTS_TOTAL, info)
return CallMeta(Type{rt}, Union{}, EFFECTS_TOTAL, info)
else
return CallMeta(Type{<:rt}, EFFECTS_TOTAL, info)
return CallMeta(Type{<:rt}, Union{}, EFFECTS_TOTAL, info)
end
end

# a simplified model of abstract_call_gf_by_type for applicable
function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
sv::AbsIntState, max_methods::Int)
length(argtypes) < 2 && return CallMeta(Bottom, EFFECTS_THROWS, NoCallInfo())
isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo())
length(argtypes) < 2 && return CallMeta(Bottom, Any, EFFECTS_THROWS, NoCallInfo())
isvarargtype(argtypes[2]) && return CallMeta(Bool, Any, EFFECTS_UNKNOWN, NoCallInfo())
argtypes = argtypes[2:end]
atype = argtypes_to_type(argtypes)
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
Expand Down Expand Up @@ -2762,7 +2762,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
end
end
end
return CallMeta(rt, EFFECTS_TOTAL, NoCallInfo())
return CallMeta(rt, Union{}, EFFECTS_TOTAL, NoCallInfo())
end
add_tfunc(applicable, 1, INT_INF, @nospecs((𝕃::AbstractLattice, f, args...)->Bool), 40)

Expand All @@ -2771,26 +2771,26 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv
if length(argtypes) == 3 && !isvarargtype(argtypes[3])
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo())
ft === Bottom && return CallMeta(Bool, Any, EFFECTS_THROWS, NoCallInfo())
typeidx = 3
elseif length(argtypes) == 2 && !isvarargtype(argtypes[2])
typeidx = 2
else
return CallMeta(Any, Effects(), NoCallInfo())
return CallMeta(Any, Any, Effects(), NoCallInfo())
end
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, typeidx), false)
isexact || return CallMeta(Bool, Effects(), NoCallInfo())
isexact || return CallMeta(Bool, Any, Effects(), NoCallInfo())
unwrapped = unwrap_unionall(types)
if types === Bottom || !(unwrapped isa DataType) || unwrapped.name !== Tuple.name
return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo())
return CallMeta(Bool, Any, EFFECTS_THROWS, NoCallInfo())
end
if typeidx == 3
isdispatchelem(ft) || return CallMeta(Bool, Effects(), NoCallInfo()) # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
isdispatchelem(ft) || return CallMeta(Bool, Any, Effects(), NoCallInfo()) # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
types = rewrap_unionall(Tuple{ft, unwrapped.parameters...}, types)::Type
end
mt = ccall(:jl_method_table_for, Any, (Any,), types)
if !isa(mt, MethodTable)
return CallMeta(Bool, EFFECTS_THROWS, NoCallInfo())
return CallMeta(Bool, Any, EFFECTS_THROWS, NoCallInfo())
end
match, valid_worlds = findsup(types, method_table(interp))
update_valid_age!(sv, valid_worlds)
Expand All @@ -2802,7 +2802,7 @@ function _hasmethod_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, sv
edge = specialize_method(match)::MethodInstance
add_invoke_backedge!(sv, types, edge)
end
return CallMeta(rt, EFFECTS_TOTAL, NoCallInfo())
return CallMeta(rt, Any, EFFECTS_TOTAL, NoCallInfo())
end

# N.B.: typename maps type equivalence classes to a single value
Expand Down

0 comments on commit 0397be1

Please sign in to comment.