Skip to content

Commit

Permalink
switch to using the internal code cache
Browse files Browse the repository at this point in the history
Leverages JuliaLang/julia#52233 to use the internal code cache that
comes with the inherent invalidation support.

Still requires:
- JuliaLang/julia#53300 (or JuliaLang/julia#53219)
- JuliaLang/julia#53318
  • Loading branch information
aviatesk committed Feb 13, 2024
1 parent 2c9555a commit 3b53d55
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 24 deletions.
14 changes: 7 additions & 7 deletions src/abstractinterpret/abstractanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ mutable struct AnalyzerState

# the temporal stash to keep track of the context of caller inference/optimization and
# the caller itself, to which reconstructed cached reports will be appended
cache_target::Union{Nothing,Pair{Symbol,InferenceResult}}
cache_target::(@static VERSION ≥ v"1.11.0-DEV.1552" ? Nothing : Union{Nothing,Pair{Symbol,InferenceResult}})

## abstract toplevel execution ##

Expand Down Expand Up @@ -417,12 +417,12 @@ struct AnalysisCache
end
AnalysisCache() = AnalysisCache(IdDict{MethodInstance,CodeInstance}())

Base.haskey(analysis_cache::AnalysisCache, mi::MethodInstance) = haskey(analysis_cache.cache, mi)
Base.get(analysis_cache::AnalysisCache, mi::MethodInstance, default) = get(analysis_cache.cache, mi, default)
Base.getindex(analysis_cache::AnalysisCache, mi::MethodInstance) = getindex(analysis_cache.cache, mi)
Base.setindex!(analysis_cache::AnalysisCache, ci::CodeInstance, mi::MethodInstance) = setindex!(analysis_cache.cache, ci, mi)
Base.delete!(analysis_cache::AnalysisCache, mi::MethodInstance) = delete!(analysis_cache.cache, mi)
Base.show(io::IO, analysis_cache::AnalysisCache) = print(io, typeof(analysis_cache), "(", length(analysis_cache.cache), " entries)")
# Base.haskey(analysis_cache::AnalysisCache, mi::MethodInstance) = haskey(analysis_cache.cache, mi)
# Base.get(analysis_cache::AnalysisCache, mi::MethodInstance, default) = get(analysis_cache.cache, mi, default)
# Base.getindex(analysis_cache::AnalysisCache, mi::MethodInstance) = getindex(analysis_cache.cache, mi)
# Base.setindex!(analysis_cache::AnalysisCache, ci::CodeInstance, mi::MethodInstance) = setindex!(analysis_cache.cache, ci, mi)
# Base.delete!(analysis_cache::AnalysisCache, mi::MethodInstance) = delete!(analysis_cache.cache, mi)
# Base.show(io::IO, analysis_cache::AnalysisCache) = print(io, typeof(analysis_cache), "(", length(analysis_cache.cache), " entries)")

"""
AnalysisCache(analyzer::AbstractAnalyzer) -> analysis_cache::AnalysisCache
Expand Down
85 changes: 68 additions & 17 deletions src/abstractinterpret/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ end
function CC.const_prop_call(analyzer::AbstractAnalyzer,
mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState,
concrete_eval_result::Union{Nothing,CC.ConstCallResults})
@static if VERSION < v"1.11.0-DEV.1552"
set_cache_target!(analyzer, :const_prop_call => sv.result)
end
const_result = @invoke CC.const_prop_call(analyzer::AbstractInterpreter,
mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, sv::InferenceState,
concrete_eval_result::Union{Nothing,CC.ConstCallResults})
@static if VERSION < v"1.11.0-DEV.1552"
@assert get_cache_target(analyzer) === nothing "invalid JET analysis state"
end
if const_result !== nothing
# successful constant prop', we need to update reports
collect_callee_reports!(analyzer, sv)
Expand Down Expand Up @@ -150,6 +154,26 @@ end
# global
# ------

@static if VERSION v"1.11.0-DEV.1552"

CC.cache_owner(analyzer::AbstractAnalyzer) = AnalysisCache(analyzer)

function CC.return_cached_result(analyzer::AbstractAnalyzer, codeinst::CodeInstance, caller::InferenceState)
# cache hit, now we need to append cached reports associated with this `MethodInstance`
inferred = @atomic :monotonic codeinst.inferred
for cached in (inferred::CachedAnalysisResult).reports
restored = add_cached_report!(analyzer, caller.result, cached)
@static if JET_DEV_MODE
actual, expected = first(restored.vst).linfo, codeinst.def
@assert actual === expected "invalid global cache restoration, expected $expected but got $actual"
end
stash_report!(analyzer, restored) # should be updated in `abstract_call` (after exiting `typeinf_edge`)
end
return @invoke CC.return_cached_result(analyzer::AbstractInterpreter, codeinst::CodeInstance, caller::InferenceState)
end

else # if VERSION ≥ v"1.11.0-DEV.1552"

function CC.code_cache(analyzer::AbstractAnalyzer)
view = AbstractAnalyzerView(analyzer)
worlds = WorldRange(get_inference_world(analyzer))
Expand Down Expand Up @@ -208,21 +232,6 @@ function CC.getindex(wvc::WorldView{<:AbstractAnalyzerView}, mi::MethodInstance)
return codeinst::CodeInstance
end

function CC.transform_result_for_cache(analyzer::AbstractAnalyzer,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
cache = InferenceErrorReport[]
for report in get_any_reports(analyzer, result)
@static if JET_DEV_MODE
actual, expected = first(report.vst).linfo, linfo
@assert actual === expected "invalid global caching detected, expected $expected but got $actual"
end
cache_report!(cache, report)
end
inferred_result = @invoke transform_result_for_cache(analyzer::AbstractInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
return CachedAnalysisResult(inferred_result, cache)
end

function CC.setindex!(wvc::WorldView{<:AbstractAnalyzerView}, codeinst::CodeInstance, mi::MethodInstance)
analysis_cache = AnalysisCache(wvc)
add_jet_callback!(mi, analysis_cache)
Expand All @@ -241,7 +250,7 @@ end
function (callback::JETCallback)(replaced::MethodInstance, max_world::UInt32)
delete!(callback.analysis_cache, replaced)
end
else
else # if VERSION ≥ v"1.11.0-DEV.798"
function add_jet_callback!(mi::MethodInstance, analysis_cache::AnalysisCache)
callback = JETCallback(analysis_cache)
if !isdefined(mi, :callbacks)
Expand All @@ -268,11 +277,36 @@ function (callback::JETCallback)(replaced::MethodInstance, max_world::UInt32,
end
return nothing
end
end
end # if VERSION ≥ v"1.11.0-DEV.798"

end # if VERSION ≥ v"1.11.0-DEV.1552"

# local
# -----

@static if VERSION v"1.11.0-DEV.1552"

CC.get_inference_cache(analyzer::AbstractAnalyzer) = get_inf_cache(analyzer)

function CC.return_cached_result(analyzer::AbstractAnalyzer, inf_result::InferenceResult, caller::InferenceState)
# as the analyzer uses the reports that are cached by the abstract-interpretation
# with the extended lattice elements, here we should throw-away the error reports
# that are collected during the previous non-constant abstract-interpretation
# (see the `CC.typeinf(::AbstractAnalyzer, ::InferenceState)` overload)
filter_lineages!(analyzer, caller.result, inf_result.linfo)
for cached in get_cached_reports(analyzer, inf_result)
restored = add_cached_report!(analyzer, caller.result, cached)
@static if JET_DEV_MODE
actual, expected = first(restored.vst).linfo, inf_result.linfo
@assert actual === expected "invalid local cache restoration, expected $expected but got $actual"
end
stash_report!(analyzer, restored) # should be updated in `abstract_call_method_with_const_args`
end
return @invoke CC.return_cached_result(analyzer::AbstractInterpreter, inf_result::InferenceResult, caller::InferenceState)
end

else # if VERSION ≥ v"1.11.0-DEV.1552"

CC.get_inference_cache(analyzer::AbstractAnalyzer) = AbstractAnalyzerView(analyzer)

function CC.cache_lookup(𝕃ᵢ::CC.AbstractLattice, mi::MethodInstance, given_argtypes::Argtypes, view::AbstractAnalyzerView)
Expand Down Expand Up @@ -316,6 +350,8 @@ end

CC.push!(view::AbstractAnalyzerView, inf_result::InferenceResult) = CC.push!(get_inf_cache(view.analyzer), inf_result)

end # if VERSION ≥ v"1.11.0-DEV.1552"

# main driver
# ===========

Expand Down Expand Up @@ -539,6 +575,21 @@ function CC.cache_result!(analyzer::AbstractAnalyzer, caller::InferenceResult)
@invoke CC.cache_result!(analyzer::AbstractInterpreter, caller::InferenceResult)
end

function CC.transform_result_for_cache(analyzer::AbstractAnalyzer,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
cache = InferenceErrorReport[]
for report in get_any_reports(analyzer, result)
@static if JET_DEV_MODE
actual, expected = first(report.vst).linfo, linfo
@assert actual === expected "invalid global caching detected, expected $expected but got $actual"
end
cache_report!(cache, report)
end
inferred_result = @invoke transform_result_for_cache(analyzer::AbstractInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
return CachedAnalysisResult(inferred_result, cache)
end

# top-level bridge
# ================

Expand Down

0 comments on commit 3b53d55

Please sign in to comment.