Skip to content

Commit

Permalink
add mechanism for spoofing inference work-limiting heuristics (#24852)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Jan 12, 2018
1 parent daf1235 commit ec9e92e
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 14 deletions.
1 change: 1 addition & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ struct GeneratedFunctionStub
spnames::Union{Nothing, Array{Any,1}}
line::Int
file::Symbol
expand_early::Bool
end

# invoke and wrap the results of @generated
Expand Down
45 changes: 37 additions & 8 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,14 @@ function _validate(linfo::MethodInstance, src::CodeInfo, kind::String)
end

function get_staged(li::MethodInstance)
return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo
try
# user code might throw errors – ignore them
return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo
catch
return nothing
end
end


mutable struct OptimizationState
linfo::MethodInstance
vararg_type_container #::Type
Expand Down Expand Up @@ -472,12 +476,7 @@ end
function retrieve_code_info(linfo::MethodInstance)
m = linfo.def::Method
if isdefined(m, :generator)
try
# user code might throw errors – ignore them
c = get_staged(linfo)
catch
return nothing
end
return get_staged(linfo)
else
# TODO: post-inference see if we can swap back to the original arrays?
if isa(m.source, Array{UInt8,1})
Expand All @@ -489,6 +488,35 @@ function retrieve_code_info(linfo::MethodInstance)
return c
end

# TODO: Use these functions instead of directly manipulating
# the "actual" method for appropriate places in inference (see #24676)
function method_for_inference_heuristics(cinfo, default)
if isa(cinfo, CodeInfo)
# appropriate format for `sig` is svec(ftype, argtypes, world)
sig = cinfo.signature_for_inference_heuristics
if isa(sig, SimpleVector) && length(sig) == 3
methods = _methods(sig[1], sig[2], -1, sig[3])
if length(methods) == 1
_, _, m = methods[]
if isa(m, Method)
return m
end
end
end
end
return default
end

function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams, world)
if isdefined(method, :generator) && method.generator.expand_early
method_instance = code_for_method(method, sig, sparams, world, false)
if isa(method_instance, MethodInstance)
return method_for_inference_heuristics(get_staged(method_instance), method)
end
end
return method
end

@inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id : (s::TypedSlot).id # using a function to ensure we can infer this

# avoid cycle due to over-specializing `any` when used by inference
Expand Down Expand Up @@ -3396,6 +3424,7 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
method = linfo.def::Method
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
tree.code = Any[ Expr(:return, quoted(linfo.inferred_const)) ]
tree.signature_for_inference_heuristics = nothing
tree.slotnames = Any[ compiler_temp_sym for i = 1:method.nargs ]
tree.slotflags = UInt8[ 0 for i = 1:method.nargs ]
tree.slottypes = nothing
Expand Down
8 changes: 5 additions & 3 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2046,8 +2046,9 @@ void jl_init_types(void)
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(9,
jl_perm_symsvec(10,
"code",
"signature_for_inference_heuristics",
"slottypes",
"ssavaluetypes",
"slotflags",
Expand All @@ -2056,17 +2057,18 @@ void jl_init_types(void)
"inlineable",
"propagate_inbounds",
"pure"),
jl_svec(9,
jl_svec(10,
jl_array_any_type,
jl_any_type,
jl_any_type,
jl_any_type,
jl_array_uint8_type,
jl_array_any_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
0, 1, 9);
0, 1, 10);

jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
Expand Down
3 changes: 2 additions & 1 deletion src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@
'nothing
(cons 'list (map car sparams)))
,(if (null? loc) 0 (cadr loc))
(inert ,(if (null? loc) 'none (caddr loc))))))))
(inert ,(if (null? loc) 'none (caddr loc)))
false)))))
(list gf))
'()))
(types (llist-types argl))
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ typedef struct _jl_llvm_functions_t {
// This type describes a single function body
typedef struct _jl_code_info_t {
jl_array_t *code; // Any array of statements
jl_value_t *signature_for_inference_heuristics; // optional method used during inference
jl_value_t *slottypes; // types of variable slots (or `nothing`)
jl_value_t *ssavaluetypes; // types of ssa values (or count of them)
jl_array_t *slotflags; // local var bit flags
Expand Down
6 changes: 4 additions & 2 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ static void jl_code_info_set_ast(jl_code_info_t *li, jl_expr_t *ast)
jl_array_del_end(meta, na - ins);
}
}
li->signature_for_inference_heuristics = jl_nothing;
jl_array_t *vinfo = (jl_array_t*)jl_exprarg(ast, 1);
jl_array_t *vis = (jl_array_t*)jl_array_ptr_ref(vinfo, 0);
size_t nslots = jl_array_len(vis);
Expand Down Expand Up @@ -255,6 +256,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
(jl_code_info_t*)jl_gc_alloc(ptls, sizeof(jl_code_info_t),
jl_code_info_type);
src->code = NULL;
src->signature_for_inference_heuristics = NULL;
src->slotnames = NULL;
src->slotflags = NULL;
src->slottypes = NULL;
Expand Down Expand Up @@ -442,8 +444,8 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
else if (jl_expr_nargs(st) == 2 && jl_exprarg(st, 0) == (jl_value_t*)generated_sym) {
m->generator = NULL;
jl_value_t *gexpr = jl_exprarg(st, 1);
if (jl_expr_nargs(gexpr) == 6) {
// expects (new (core GeneratedFunctionStub) funcname argnames sp line file)
if (jl_expr_nargs(gexpr) == 7) {
// expects (new (core GeneratedFunctionStub) funcname argnames sp line file expandearly)
jl_value_t *funcname = jl_exprarg(gexpr, 1);
assert(jl_is_symbol(funcname));
if (jl_get_global(m->module, (jl_sym_t*)funcname) != NULL) {
Expand Down
1 change: 1 addition & 0 deletions src/toplevel.c
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ static jl_code_info_t *expr_to_code_info(jl_value_t *expr)
jl_gc_wb(src, src->slotflags);
src->ssavaluetypes = jl_box_long(0);
jl_gc_wb(src, src->ssavaluetypes);
src->signature_for_inference_heuristics = jl_nothing;

JL_GC_POP();
return src;
Expand Down
74 changes: 74 additions & 0 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1317,3 +1317,77 @@ bar_22708(x) = f_22708(x)

@test bar_22708(1) == "x"

# mechanism for spoofing work-limiting heuristics and early generator expansion (#24852)
function _generated_stub(gen::Symbol, args::Vector{Any}, params::Vector{Any}, line, file, expand_early)
stub = Expr(:new, Core.GeneratedFunctionStub, gen, args, params, line, file, expand_early)
return Expr(:meta, :generated, stub)
end

f24852_kernel(x, y) = x * y

function f24852_kernel_cinfo(x, y)
sig, spvals, method = Base._methods_by_ftype(Tuple{typeof(f24852_kernel),x,y}, -1, typemax(UInt))[1]
code_info = Base.uncompressed_ast(method)
body = Expr(:block, code_info.code...)
Base.Core.Inference.substitute!(body, 0, Any[], sig, Any[spvals...], 0, :propagate)
return method, code_info
end

function f24852_gen_cinfo_uninflated(X, Y, f, x, y)
_, code_info = f24852_kernel_cinfo(x, y)
return code_info
end

function f24852_gen_cinfo_inflated(X, Y, f, x, y)
method, code_info = f24852_kernel_cinfo(x, y)
code_info.signature_for_inference_heuristics = Core.Inference.svec(f, (x, y), typemax(UInt))
return code_info
end

function f24852_gen_expr(X, Y, f, x, y)
return :(f24852_kernel(x::$X, y::$Y))
end

@eval begin
function f24852_late_expr(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:f24852_late_expr, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
end
function f24852_late_inflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_late_inflated, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
end
function f24852_late_uninflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_late_uninflated, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), false))
end
end

@eval begin
function f24852_early_expr(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_expr, Any[:f24852_early_expr, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
end
function f24852_early_inflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_inflated, Any[:f24852_early_inflated, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
end
function f24852_early_uninflated(x::X, y::Y) where {X, Y}
$(_generated_stub(:f24852_gen_cinfo_uninflated, Any[:f24852_early_uninflated, :x, :y],
Any[:X, :Y], @__LINE__, QuoteNode(Symbol(@__FILE__)), true))
end
end

x, y = rand(), rand()
result = f24852_kernel(x, y)

@test result === f24852_late_expr(x, y)
@test result === f24852_late_uninflated(x, y)
@test result === f24852_late_inflated(x, y)

@test result === f24852_early_expr(x, y)
@test result === f24852_early_uninflated(x, y)
@test result === f24852_early_inflated(x, y)

# TODO: test that `expand_early = true` + inflated `signature_for_inference_heuristics`
# can be used to tighten up some inference result.

2 comments on commit ec9e92e

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily benchmark build, I will reply here when finished:

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something went wrong when running your job:

NanosoldierError: failed to run benchmarks against primary commit: failed process: Process(`sudo cset shield -e su nanosoldier -- -c ./benchscript.sh`, ProcessExited(1)) [1]

Logs and partial data can be found here
cc @ararslan

Please sign in to comment.