Skip to content

Commit

Permalink
accumulate virtual stack trace on returning back to the parent frame,…
Browse files Browse the repository at this point in the history
… try to fix performance problem (#90)

accumulate virtual stack trace on returning back to the parent frame

The idea is to stop a frame chain traversal on report construction or
cached report restoring, but rather update reports "frame-by-frame".
Before exiting the local inference (or cache retrieval), we keeps
reports that should be updated in `interp.to_be_updated` and they will
be updated when returning back to the parent frame
(i.e. the next inter-procedural context).

This should eliminate lots of the previous frame chain traversal works,
and should give us some performance improvements, especially we're
analyzing deep frames.

Caveat:
For now I couldn't find a right way to handle mutual recursion
cycle.
As such JET doesn't always make the following condition hold, and thus
its stack trace could be wrong.
```julia
# in `_typeinf` or cache retrieval points
@Assert first(report.st).linfo === current_linfo
```
I'd like to leave this problem as a future work for now, and improve
performance for now.
  • Loading branch information
aviatesk committed Feb 18, 2021
1 parent 322b6e0 commit 7045d02
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 128 deletions.
1 change: 1 addition & 0 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import .CC:
# abstractinterpretation.jl
abstract_call_gf_by_type,
abstract_call_method_with_const_args,
abstract_call_method,
abstract_eval_special_value,
abstract_eval_value,
abstract_eval_statement,
Expand Down
27 changes: 25 additions & 2 deletions src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function abstract_call_gf_by_type(interp::$(JETInterpreter), @nospecialize(f), a
napplicable = length(applicable)
rettype = Bottom
edgecycle = false
edges = Any[]
edges = MethodInstance[]
nonbot = 0 # the index of the only non-Bottom inference result if > 0
seen = 0 # number of signatures actually inferred
istoplevel = sv.linfo.def isa Module
Expand Down Expand Up @@ -391,7 +391,10 @@ function abstract_call_method_with_const_args(interp::$(JETInterpreter), @nospec
result = inf_result.result
# if constant inference hits a cycle, just bail out
isa(result, InferenceState) && return Any
add_backedge!(inf_result.linfo, sv)
#=== abstract_call_method_with_const_args patch point 3 start ===#
add_backedge!(mi, sv)
$update_reports!(interp, sv)
#=== abstract_call_method_with_const_args patch point 3 end ===#
return result
end

Expand All @@ -401,6 +404,26 @@ end) # Core.eval(CC, quote
end # function overload_abstract_call_method_with_const_args!()
push_inithook!(overload_abstract_call_method_with_const_args!)

# works within inter-procedural context
function CC.abstract_call_method(interp::JETInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
ret = @invoke abstract_call_method(interp::AbstractInterpreter, method::Method, sig, sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)

update_reports!(interp, sv)

return ret
end

function update_reports!(interp::JETInterpreter, sv::InferenceState)
rs = interp.to_be_updated
if !isempty(rs)
vf = get_virtual_frame(sv)
for r in rs
pushfirst!(r.st, vf)
end
empty!(rs)
end
end

function CC.abstract_eval_special_value(interp::JETInterpreter, @nospecialize(e), vtypes::VarTable, sv::InferenceState)
ret = @invoke abstract_eval_special_value(interp::AbstractInterpreter, e, vtypes::VarTable, sv::InferenceState)

Expand Down
19 changes: 10 additions & 9 deletions src/abstractinterpreterinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ mutable struct JETInterpreter <: AbstractInterpreter
# stashes `UncaughtExceptionReport`s that are not caught so far
uncaught_exceptions::Vector{UncaughtExceptionReport}

# stashes `NativeRemark`s
native_remarks::Vector{NativeRemark}
# keeps reports that should be updated when returning back the parent frame (i.e. the next time we get back to inter-procedural context)
to_be_updated::Set{InferenceErrorReport}

# toplevel profiling (skip inference on actually interpreted statements)
concretized::BitVector
Expand All @@ -43,7 +43,6 @@ mutable struct JETInterpreter <: AbstractInterpreter
id = gensym(:JETInterpreterID),
reports = InferenceErrorReport[],
uncaught_exceptions = UncaughtExceptionReport[],
native_remarks = NativeRemark[],
concretized = BitVector(),
jetconfigs...)
inf_params = gen_inf_params(; jetconfigs...)
Expand All @@ -57,7 +56,7 @@ mutable struct JETInterpreter <: AbstractInterpreter
id,
reports,
uncaught_exceptions,
native_remarks,
Set{InferenceErrorReport}(),
concretized,
analysis_params,
nothing,
Expand Down Expand Up @@ -101,11 +100,12 @@ CC.get_world_counter(interp::JETInterpreter) = get_world_counter(interp.native)
CC.lock_mi_inference(::JETInterpreter, ::MethodInstance) = nothing
CC.unlock_mi_inference(::JETInterpreter, ::MethodInstance) = nothing

function CC.add_remark!(interp::JETInterpreter, sv::InferenceState, s::String)
AnalysisParams(interp).filter_native_remarks && return
push!(interp.native_remarks, NativeRemark(interp, sv, s))
return
end
# function CC.add_remark!(interp::JETInterpreter, sv::InferenceState, s::String)
# AnalysisParams(interp).filter_native_remarks && return
# push!(interp.native_remarks, NativeRemark(interp, sv, s))
# return
# end
CC.add_remark!(interp::JETInterpreter, sv::InferenceState, s::String) = return

CC.may_optimize(interp::JETInterpreter) = true
CC.may_compress(interp::JETInterpreter) = false
Expand Down Expand Up @@ -140,6 +140,7 @@ function gen_opt_params()
)
end

# TODO configurable analysis, e.g. ignore user-specified modules and such
@jetconfigurable function gen_analysis_params(; filter_native_remarks::Bool = true,
)
return AnalysisParams(filter_native_remarks)
Expand Down
13 changes: 9 additions & 4 deletions src/jetcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ function CC.get(wvc::WorldView{JETCache}, mi::MethodInstance, default)
global_cache = get(JET_GLOBAL_CACHE, mi, nothing)
if isa(global_cache, Vector{InferenceErrorReportCache})
interp = wvc.cache.interp
caller = interp.current_frame::InferenceState
for cached in global_cache
restore_cached_report!(cached, interp, caller)
restored = restore_cached_report!(cached, interp)
push!(interp.to_be_updated, restored) # should be updated in `abstract_call` (after exiting `typeinf_edge`)
# # TODO make this hold
# @assert first(cached.st).linfo === mi "invalid global restoring"
end
end
end
Expand Down Expand Up @@ -103,12 +105,15 @@ function CC.cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cac
sv = interp.current_frame::InferenceState
if !isa(inf_result.result, InferenceState)
# corresponds to report throw away logic in `_typeinf(interp::JETInterpreter, frame::InferenceState)`
filter!(r->!is_lineage(r.lineage, sv, inf_result.linfo), interp.reports)
filter!(!is_from_same_frame(sv.linfo, linfo), interp.reports)

local_cache = get(interp.cache, given_argtypes, nothing)
if isa(local_cache, Vector{InferenceErrorReportCache})
for cached in local_cache
restore_cached_report!(cached, interp, sv)
restored = restore_cached_report!(cached, interp)
push!(interp.to_be_updated, restored) # should be updated in `abstract_call_method_with_const_args`
# # TODO make this hold
# @assert first(cached.st).linfo === linfo "invalid local restoring"
end
end
end
Expand Down
107 changes: 23 additions & 84 deletions src/reports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ function Base.getproperty(er::InferenceErrorReport, sym::Symbol)
getfield(er, sym)::String
elseif sym === :sig
getfield(er, sym)::Vector{Any}
elseif sym === :lineage
getfield(er, sym)::Lineage
elseif sym === :lin # only needed for ExceptionReport
getfield(er, sym)::LineInfoNode
else
Expand Down Expand Up @@ -92,94 +90,47 @@ end
# "from entry call site to error point"
const VirtualStackTrace = Vector{VirtualFrame}

@withmixedhash struct LineageKey
file::Symbol
line::Int
linfo::MethodInstance
end
const Lineage = Set{LineageKey}

get_lineage_key(frame::InferenceState) = LineageKey(get_file_line(frame)..., frame.linfo)
get_lineage_key(vf::VirtualFrame) = LineageKey(vf.file, vf.line, vf.linfo)

function is_lineage(lineage::Lineage, parent::InferenceState, linfo::MethodInstance)
# check if current `linfo` is in `lineage`
# NOTE: we can't use `get_lineage_key` for this `linfo`, just because we don't analyze
# on cached frames and thus no appropriate lineage key (i.e. program counter) exists
for lk in lineage
lk.linfo === linfo && return is_lineage(lineage, parent)
end
return false
end
function is_lineage(lineage::Lineage, frame::InferenceState)
get_lineage_key(frame) in lineage || return false
return is_lineage(lineage, frame.parent)
end
is_lineage(::Lineage, ::Nothing) = true

# `ViewedVirtualStackTrace` is for `InferenceErrorReportCache` and only keeps a part of the
# stack trace of the original `InferenceErrorReport`, in the order of "from cached frame to error point"
const ViewedVirtualStackTrace = typeof(view(VirtualStackTrace(), 1:0))

struct InferenceErrorReportCache
T::Type{<:InferenceErrorReport}
st::ViewedVirtualStackTrace
st::VirtualStackTrace
msg::String
sig::Vector{Any}
spec_args::NTuple{N,Any} where N
end

function cache_report!(report::T, linfo, cache) where {T<:InferenceErrorReport}
st = report.st
i = findfirst(vf->vf.linfo===linfo, st)
# sometimes `linfo` can't be found within the `report.st` chain; e.g. frames for inner
# constructor methods doesn't seem to be tracked in the `(frame::InferenceState).parent`
# chain so that there is no `MethodInstance` within `report.st` for such a frame;
# XXX: reports from these frames might need to be cached as well rather than just giving up
isnothing(i) && return
st = view(st, i:length(st))
function cache_report!(report::T, cache) where {T<:InferenceErrorReport}
st = copy(report.st)
new = InferenceErrorReportCache(T, st, report.msg, report.sig, spec_args(report))
push!(cache, new)
end

function restore_cached_report!(cache::InferenceErrorReportCache,
interp#=::JETInterpreter=#,
caller::InferenceState,
)
report = restore_cached_report(cache, caller)
push!(isa(report, UncaughtExceptionReport) ? interp.uncaught_exceptions : interp.reports, report)
return
report = restore_cached_report(cache)
if isa(report, UncaughtExceptionReport)
stash_uncaught_exception!(interp, report)
else
report!(interp, report)
end
return report
end

function restore_cached_report(cache::InferenceErrorReportCache,
caller::InferenceState,
)
function restore_cached_report(cache::InferenceErrorReportCache)
T = cache.T
msg = cache.msg
sig = cache.sig
st = collect(cache.st)
spec_args = cache.spec_args
lineage = Lineage(get_lineage_key(vf) for vf in st)

prewalk_inf_frame(caller) do frame::InferenceState
linfo = frame.linfo
vf = get_virtual_frame(frame)
pushfirst!(st, vf)
push!(lineage, get_lineage_key(vf))
end

return T(st, msg, sig, lineage, spec_args)
st = copy(cache.st)
return T(st, cache.msg, cache.sig, cache.spec_args)::InferenceErrorReport
end

@withmixedhash struct IdentityKey
T::Type{<:InferenceErrorReport}
sig::Vector{Any}
entry_frame::VirtualFrame
# entry_frame::VirtualFrame
error_frame::VirtualFrame
end

get_identity_key(report::T) where {T<:InferenceErrorReport} =
IdentityKey(T, report.sig, first(report.st), last(report.st))
IdentityKey(T, report.sig, #=first(report.st),=# last(report.st))

macro reportdef(ex, kwargs...)
T = esc(first(ex.args))
Expand Down Expand Up @@ -212,37 +163,26 @@ macro reportdef(ex, kwargs...)

msg = get_msg(#= T, interp, sv, ... =# $(args′...))
sig = get_sig(#= T, interp, sv, ... =# $(args′...))
st = VirtualFrame[]
lineage = Lineage()

$(track_from_frame && :(let
$(if track_from_frame quote
# when report is constructed _after_ the inference on `sv` has been done,
# collect location information from `sv.linfo` and start traversal from `sv.parent`
linfo = sv.linfo
vf = get_virtual_frame(linfo)
push!(st, vf)
push!(lineage, get_lineage_key(vf))
sv = sv.parent
end))

prewalk_inf_frame(sv) do frame::InferenceState
vf = get_virtual_frame(frame)
pushfirst!(st, vf)
push!(lineage, get_lineage_key(vf))
end
# collect location information from `sv.linfo`
st = VirtualFrame[get_virtual_frame(sv.linfo)]
end else quote
st = VirtualFrame[get_virtual_frame(sv)]
end end)

return new(st, msg, sig, lineage, $(spec_args′...))
return new(st, msg, sig, $(spec_args′...))
end))

spec_types = extract_type_decls.(spec_args)

cache_constructor_sig = :($(T)(st::VirtualStackTrace,
msg::AbstractString,
sig::AbstractVector,
lineage::Lineage,
@nospecialize(spec_args),
))
cache_constructor_call = :(new(st, msg, sig, lineage))
cache_constructor_call = :(new(st, msg, sig))
for (i, spec_type) in enumerate(spec_types)
push!(cache_constructor_call.args,
:($(esc(:spec_args))[$(i)]::$(spec_type)), # `esc` is needed because `@nospecialize` escapes its argument anyway
Expand All @@ -268,7 +208,6 @@ macro reportdef(ex, kwargs...)
st::VirtualStackTrace
msg::String
sig::Vector{Any}
lineage::Lineage
$(spec_args...)

# constructor from abstract interpretation process by `JETInterpreter`
Expand Down

0 comments on commit 7045d02

Please sign in to comment.