Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow external absint to hold custom data in codeinst.inferred #53300

Merged
merged 2 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,20 +373,17 @@ end
function NativeInterpreter(world::UInt = get_world_counter();
inf_params::InferenceParams = InferenceParams(),
opt_params::OptimizationParams = OptimizationParams())
curr_max_world = get_world_counter()
# Sometimes the caller is lazy and passes typemax(UInt).
# we cap it to the current world age for correctness
if world == typemax(UInt)
world = get_world_counter()
world = curr_max_world
end

# If they didn't pass typemax(UInt) but passed something more subtly
# incorrect, fail out loudly.
@assert world <= get_world_counter()

@assert world <= curr_max_world
method_table = CachedMethodTable(InternalMethodTable(world))

inf_cache = Vector{InferenceResult}() # Initially empty cache

return NativeInterpreter(world, method_table, inf_cache, inf_params, opt_params)
end

Expand Down
7 changes: 5 additions & 2 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,11 @@ STATIC_INLINE jl_value_t *_jl_rettype_inferred(jl_value_t *owner, jl_method_inst
if (jl_atomic_load_relaxed(&codeinst->min_world) <= min_world &&
max_world <= jl_atomic_load_relaxed(&codeinst->max_world) &&
jl_egal(codeinst->owner, owner)) {
jl_value_t *code = jl_atomic_load_relaxed(&codeinst->inferred);
if (code && (code == jl_nothing || jl_ir_flag_inferred(code)))
jl_value_t *inferred = jl_atomic_load_relaxed(&codeinst->inferred);
if (inferred && ((inferred == jl_nothing) || (
// allow whatever code instance external abstract interpreter produced
// since `jl_ir_flag_inferred` is specific to the native interpreter
codeinst->owner != jl_nothing || jl_ir_flag_inferred(inferred))))
return (jl_value_t*)codeinst;
}
codeinst = jl_atomic_load_relaxed(&codeinst->next);
Expand Down
33 changes: 32 additions & 1 deletion test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ end
Core.eval(Core.Compiler, quote f(;a=1) = a end)
@test_throws MethodError Core.Compiler.f(;b=2)


# Custom lookup function
# ======================

Expand Down Expand Up @@ -469,3 +468,35 @@ let # generate cache
@test occursin("j_sin_", s)
@test !occursin("j_cos_", s)
end

# custom inferred data
# ====================

@newinterp CustomDataInterp
struct CustomDataInterpToken end
CC.cache_owner(::CustomDataInterp) = CustomDataInterpToken()
struct CustomData
inferred
CustomData(@nospecialize inferred) = new(inferred)
end
function CC.transform_result_for_cache(interp::CustomDataInterp,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
inferred_result = @invoke CC.transform_result_for_cache(interp::CC.AbstractInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
return CustomData(inferred_result)
end
function CC.inlining_policy(interp::CustomDataInterp, @nospecialize(src),
@nospecialize(info::CC.CallInfo), stmt_flag::UInt32)
if src isa CustomData
src = src.inferred
end
return @invoke CC.inlining_policy(interp::CC.AbstractInterpreter, src::Any,
info::CC.CallInfo, stmt_flag::UInt32)
end
let src = code_typed((Int,); interp=CustomDataInterp()) do x
return sin(x) + cos(x)
end |> only |> first
@test count(isinvoke(:sin), src.code) == 1
@test count(isinvoke(:cos), src.code) == 1
@test count(isinvoke(:+), src.code) == 0
end
76 changes: 76 additions & 0 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,82 @@ let newinterp_path = abspath("compiler/newinterp.jl")
@test found
end
end

write(joinpath(load_path, "CustomAbstractInterpreterCaching2.jl"), :(module CustomAbstractInterpreterCaching2
import SimpleModule: basic_caller, basic_callee

module Custom
const CC = Core.Compiler
include("$($newinterp_path)")
@newinterp PrecompileInterpreter
struct CustomData
inferred
CustomData(@nospecialize inferred) = new(inferred)
end
function CC.transform_result_for_cache(interp::PrecompileInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
inferred_result = @invoke CC.transform_result_for_cache(interp::CC.AbstractInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
return CustomData(inferred_result)
end
function CC.inlining_policy(interp::PrecompileInterpreter, @nospecialize(src),
@nospecialize(info::CC.CallInfo), stmt_flag::UInt32)
if src isa CustomData
src = src.inferred
end
return @invoke CC.inlining_policy(interp::CC.AbstractInterpreter, src::Any,
info::CC.CallInfo, stmt_flag::UInt32)
end
end

Base.return_types((Float64,)) do x
basic_caller(x)
end
Base.return_types((Float64,); interp=Custom.PrecompileInterpreter()) do x
basic_caller(x)
end
Base.return_types((Vector{Float64},)) do x
sum(x)
end
Base.return_types((Vector{Float64},); interp=Custom.PrecompileInterpreter()) do x
sum(x)
end
end) |> string)
Base.compilecache(Base.PkgId("CustomAbstractInterpreterCaching2"))
@eval let
using CustomAbstractInterpreterCaching2
cache_owner = Core.Compiler.cache_owner(
CustomAbstractInterpreterCaching2.Custom.PrecompileInterpreter())
let m = only(methods(CustomAbstractInterpreterCaching.basic_callee))
mi = only(Base.specializations(m))
ci = mi.cache
@test isdefined(ci, :next)
@test ci.owner === nothing
@test ci.max_world == typemax(UInt)
ci = ci.next
@test !isdefined(ci, :next)
@test ci.owner === cache_owner
@test ci.max_world == typemax(UInt)
end
let m = only(methods(sum, (Vector{Float64},)))
found = false
for mi = Base.specializations(m)
if mi isa Core.MethodInstance && mi.specTypes == Tuple{typeof(sum),Vector{Float64}}
ci = mi.cache
@test isdefined(ci, :next)
@test ci.owner === cache_owner
@test ci.max_world == typemax(UInt)
ci = ci.next
@test !isdefined(ci, :next)
@test ci.owner === nothing
@test ci.max_world == typemax(UInt)
found = true
break
end
end
@test found
end
end
end
end

Expand Down