Skip to content

Commit

Permalink
SSAIR: improve inlining performance with in-place IR-inflation (#45404)
Browse files Browse the repository at this point in the history
This commit improves the performance of a huge hot-spot within `inflate_ir`
by using the in-place version of it (`inflate_ir!`) and avoiding some
unnecessary allocations.
For `NativeInterpreter`, `CodeInfo`-IR passed to `inflate_ir` can come
from two ways:
1. from global cache: uncompressed from compressed format
2. from local cache: inferred `CodeInfo` as-is managed by `InferenceResult`

And in the case of 1, an uncompressed `CodeInfo` is an newly-allocated
object already and thus we can use the in-place version safely. And it
turns out that this helps us avoid many unnecessary allocations.
The original non-destructive `inflate_ir` remains there for testing or
interactive purpose.
  • Loading branch information
aviatesk authored and pchintalapudi committed May 25, 2022
1 parent 672602a commit f1fa94a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 41 deletions.
27 changes: 13 additions & 14 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ struct ResolvedInliningSpec
# Effects of the call statement
effects::Effects
end
ResolvedInliningSpec(ir::IRCode, effects::Effects) =
ResolvedInliningSpec(ir, linear_inline_eligible(ir), effects)

"""
Represents a callsite that our analysis has determined is legal to inline,
Expand Down Expand Up @@ -815,7 +817,7 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
et !== nothing && push!(et, mi)
return ConstantCase(quoted(inferred_src.val))
else
src = inferred_src
src = inferred_src # ::Union{Nothing,CodeInfo} for NativeInterpreter
end
effects = match.ipo_effects
else
Expand All @@ -829,7 +831,7 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
src = code.inferred
end
effects = decode_effects(code.ipo_purity_bits)
else
else # fallback pass for external AbstractInterpreter cache
effects = Effects()
src = code
end
Expand All @@ -843,13 +845,7 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)

src = inlining_policy(state.interp, src, flag, mi, argtypes)

if src === nothing
return compileable_specialization(et, match, effects)
end

if isa(src, IRCode)
src = copy(src)
end
src === nothing && return compileable_specialization(et, match, effects)

et !== nothing && push!(et, mi)
return InliningTodo(mi, src, effects)
Expand Down Expand Up @@ -913,16 +909,19 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
end

function InliningTodo(mi::MethodInstance, ir::IRCode, effects::Effects)
return InliningTodo(mi, ResolvedInliningSpec(ir, linear_inline_eligible(ir), effects))
ir = copy(ir)
return InliningTodo(mi, ResolvedInliningSpec(ir, effects))
end

function InliningTodo(mi::MethodInstance, src::Union{CodeInfo, Array{UInt8, 1}}, effects::Effects)
function InliningTodo(mi::MethodInstance, src::Union{CodeInfo, Vector{UInt8}}, effects::Effects)
if !isa(src, CodeInfo)
src = ccall(:jl_uncompress_ir, Any, (Any, Ptr{Cvoid}, Any), mi.def, C_NULL, src::Vector{UInt8})::CodeInfo
else
src = copy(src)
end

@timeit "inline IR inflation" begin;
return InliningTodo(mi, inflate_ir(src, mi)::IRCode, effects)
@timeit "inline IR inflation" begin
ir = inflate_ir!(src, mi)::IRCode
return InliningTodo(mi, ResolvedInliningSpec(ir, effects))
end
end

Expand Down
79 changes: 52 additions & 27 deletions base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

function inflate_ir(ci::CodeInfo, linfo::MethodInstance)
"""
inflate_ir!(ci::CodeInfo, linfo::MethodInstance) -> ir::IRCode
inflate_ir!(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any}) -> ir::IRCode
Inflates `ci::CodeInfo`-IR to `ir::IRCode`-format.
This should be used with caution as it is a in-place transformation where the fields of
the original `ci::CodeInfo` are modified.
"""
function inflate_ir!(ci::CodeInfo, linfo::MethodInstance)
sptypes = sptypes_from_meth_instance(linfo)
if ci.inferred
argtypes, _ = matching_cache_argtypes(linfo, nothing)
else
argtypes = Any[ Any for i = 1:length(ci.slotflags) ]
end
return inflate_ir(ci, sptypes, argtypes)
return inflate_ir!(ci, sptypes, argtypes)
end

function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
code = copy_exprargs(ci.code) # TODO: this is a huge hot-spot
function inflate_ir!(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
code = ci.code
cfg = compute_basic_blocks(code)
for i = 1:length(code)
stmt = code[i]
Expand All @@ -22,49 +29,67 @@ function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
code[i] = GotoIfNot(stmt.cond, block_for_inst(cfg, stmt.dest))
elseif isa(stmt, PhiNode)
code[i] = PhiNode(Int32[block_for_inst(cfg, Int(edge)) for edge in stmt.edges], stmt.values)
elseif isa(stmt, Expr) && stmt.head === :enter
elseif isexpr(stmt, :enter)
stmt.args[1] = block_for_inst(cfg, stmt.args[1]::Int)
code[i] = stmt
end
end
nstmts = length(code)
ssavaluetypes = let ssavaluetypes = ci.ssavaluetypes
ssavaluetypes isa Vector{Any} ? copy(ssavaluetypes) : Any[ Any for i = 1:(ssavaluetypes::Int) ]
ssavaluetypes = ci.ssavaluetypes
if !isa(ssavaluetypes, Vector{Any})
ssavaluetypes = Any[ Any for i = 1:ssavaluetypes::Int ]
end
info = Any[nothing for i = 1:nstmts]
stmts = InstructionStream(code, ssavaluetypes, info, ci.codelocs, ci.ssaflags)
linetable = ci.linetable
if !isa(linetable, Vector{LineInfoNode})
linetable = collect(LineInfoNode, linetable::Vector{Any})::Vector{LineInfoNode}
end
stmts = InstructionStream(code, ssavaluetypes, Any[nothing for i = 1:nstmts], copy(ci.codelocs), copy(ci.ssaflags))
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), argtypes, Expr[], sptypes)
return ir
meta = Expr[]
return IRCode(stmts, cfg, linetable, argtypes, meta, sptypes)
end

"""
inflate_ir(ci::CodeInfo, linfo::MethodInstance) -> ir::IRCode
inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any}) -> ir::IRCode
inflate_ir(ci::CodeInfo) -> ir::IRCode
Non-destructive version of `inflate_ir!`.
Mainly used for testing or interactive use.
"""
inflate_ir(ci::CodeInfo, linfo::MethodInstance) = inflate_ir!(copy(ci), linfo)
inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any}) = inflate_ir!(copy(ci), sptypes, argtypes)
inflate_ir(ci::CodeInfo) = inflate_ir(ci, Any[], Any[ Any for i = 1:length(ci.slotflags) ])

function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int)
@assert isempty(ir.new_nodes)
# All but the first `nargs` slots will now be unused
resize!(ci.slotflags, nargs)
stmts = ir.stmts
ci.code, ci.ssavaluetypes, ci.codelocs, ci.ssaflags, ci.linetable =
stmts.inst, stmts.type, stmts.line, stmts.flag, ir.linetable
code = ci.code = stmts.inst
ssavaluetypes = ci.ssavaluetypes = stmts.type
codelocs = ci.codelocs = stmts.line
ssaflags = ci.ssaflags = stmts.flag
linetable = ci.linetable = ir.linetable
for metanode in ir.meta
push!(ci.code, metanode)
push!(ci.codelocs, 1)
push!(ci.ssavaluetypes::Vector{Any}, Any)
push!(ci.ssaflags, IR_FLAG_NULL)
push!(code, metanode)
push!(codelocs, 1)
push!(ssavaluetypes, Any)
push!(ssaflags, IR_FLAG_NULL)
end
# Translate BB Edges to statement edges
# (and undo normalization for now)
for i = 1:length(ci.code)
stmt = ci.code[i]
for i = 1:length(code)
stmt = code[i]
if isa(stmt, GotoNode)
stmt = GotoNode(first(ir.cfg.blocks[stmt.label].stmts))
code[i] = GotoNode(first(ir.cfg.blocks[stmt.label].stmts))
elseif isa(stmt, GotoIfNot)
stmt = GotoIfNot(stmt.cond, first(ir.cfg.blocks[stmt.dest].stmts))
code[i] = GotoIfNot(stmt.cond, first(ir.cfg.blocks[stmt.dest].stmts))
elseif isa(stmt, PhiNode)
stmt = PhiNode(Int32[last(ir.cfg.blocks[edge].stmts) for edge in stmt.edges], stmt.values)
elseif isa(stmt, Expr) && stmt.head === :enter
code[i] = PhiNode(Int32[last(ir.cfg.blocks[edge].stmts) for edge in stmt.edges], stmt.values)
elseif isexpr(stmt, :enter)
stmt.args[1] = first(ir.cfg.blocks[stmt.args[1]::Int].stmts)
code[i] = stmt
end
ci.code[i] = stmt
end
end

# used by some tests
inflate_ir(ci::CodeInfo) = inflate_ir(ci, Any[], Any[ Any for i = 1:length(ci.slotflags) ])

0 comments on commit f1fa94a

Please sign in to comment.