Skip to content

Commit

Permalink
post-opt: use augmented post-domtree for visit_conditional_successors
Browse files Browse the repository at this point in the history
This commit fixes the first problem that was found while digging into
#53613. It turns out that the post-domtree constructed
from regular `IRCode` doesn't work for visiting conditional successors
for post-opt analysis in cases like:
```julia
julia> let code = Any[
               # block 1
               GotoIfNot(Argument(2), 3),
               # block 2
               ReturnNode(Argument(3)),
               # block 3 (we should visit this block)
               Expr(:call, throw, "potential throw"),
               ReturnNode(), # unreachable
           ]
           ir = make_ircode(code; slottypes=Any[Any,Bool,Bool])
           visited = BitSet()
           @test !Core.Compiler.visit_conditional_successors(CC.LazyPostDomtree(ir), ir, #=bb=#1) do succ::Int
               push!(visited, succ)
               return false
           end
           @test 2 βˆ‰ visited
           @test 3 ∈ visited
       end
Test Failed at REPL[14]:16
  Expression: 2 βˆ‰ visited
   Evaluated: 2 βˆ‰ BitSet([2])
```

This might mean that we need to fix on the `postdominates` end, but for
now, this commit tries to get around it by using the augmented post
domtree in `visit_conditional_successors`, while also enforcing the
augmented control flow graph (`construct_augmented_cfg`) to have a
single exit node really. Since the augmented post domtree is now
enforced to have a single return, we can keep using the current
`postdominates` to fix the issue.

However, this commit isn't enough to fix the NeuralNetworkReachability
segfault as reported in #53613, and we need to tackle the second issue
reported there too (#53613 (comment)).
  • Loading branch information
aviatesk committed Mar 8, 2024
1 parent 4dcf357 commit cd7508b
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 47 deletions.
81 changes: 45 additions & 36 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -527,15 +527,53 @@ function any_stmt_may_throw(ir::IRCode, bb::Int)
return false
end

function visit_conditional_successors(callback, lazypostdomtree::LazyPostDomtree, ir::IRCode, bb::Int)
mutable struct LazyAugmentedDomtrees
const ir::IRCode
cfg::CFG
domtree::DomTree
postdomtree::PostDomTree
LazyAugmentedDomtrees(ir::IRCode) = new(ir)
end

function get!(lazyagdomtrees::LazyAugmentedDomtrees, sym::Symbol)
isdefined(lazyagdomtrees, sym) && return getfield(lazyagdomtrees, sym)
if sym === :cfg
return lazyagdomtrees.cfg = construct_augmented_cfg(lazyagdomtrees.ir)
elseif sym === :domtree
return lazyagdomtrees.domtree = construct_domtree(get!(lazyagdomtrees, :cfg))
elseif sym === :postdomtree
return lazyagdomtrees.postdomtree = construct_postdomtree(get!(lazyagdomtrees, :cfg))
else
error("invalid field access")
end
end

function construct_augmented_cfg(ir::IRCode)
cfg = copy(ir.cfg)
# Add a virtual basic block to represent the single exit
push!(cfg.blocks, BasicBlock(StmtRange(0:-1)))
for bb = 1:(length(cfg.blocks)-1)
terminator = ir[SSAValue(last(cfg.blocks[bb].stmts))][:stmt]
if terminator isa ReturnNode
cfg_insert_edge!(cfg, bb, length(cfg.blocks))
end
end
return cfg
end

visit_conditional_successors(callback, ir::IRCode, bb::Int) =
visit_conditional_successors(callback, construct_postdomtree(construct_augmented_cfg(ir)), ir, bb)
visit_conditional_successors(callback, lazyagdomtrees::LazyAugmentedDomtrees, ir::IRCode, bb::Int) =
visit_conditional_successors(callback, get!(lazyagdomtrees, :postdomtree), ir, bb)
function visit_conditional_successors(callback, postdomtree::PostDomTree, ir::IRCode, bb::Int)
visited = BitSet((bb,))
worklist = Int[bb]
while !isempty(worklist)
thisbb = popfirst!(worklist)
for succ in ir.cfg.blocks[thisbb].succs
succ in visited && continue
push!(visited, succ)
if postdominates(get!(lazypostdomtree), succ, bb)
if postdominates(postdomtree, succ, bb)
# this successor is not conditional, so no need to visit it further
continue
elseif callback(succ)
Expand All @@ -548,40 +586,12 @@ function visit_conditional_successors(callback, lazypostdomtree::LazyPostDomtree
return false
end

struct AugmentedDomtree
cfg::CFG
domtree::DomTree
end

mutable struct LazyAugmentedDomtree
const ir::IRCode
agdomtree::AugmentedDomtree
LazyAugmentedDomtree(ir::IRCode) = new(ir)
end

function get!(lazyagdomtree::LazyAugmentedDomtree)
isdefined(lazyagdomtree, :agdomtree) && return lazyagdomtree.agdomtree
ir = lazyagdomtree.ir
cfg = copy(ir.cfg)
# Add a virtual basic block to represent the exit
push!(cfg.blocks, BasicBlock(StmtRange(0:-1)))
for bb = 1:(length(cfg.blocks)-1)
terminator = ir[SSAValue(last(cfg.blocks[bb].stmts))][:stmt]
if isa(terminator, ReturnNode) && isdefined(terminator, :val)
cfg_insert_edge!(cfg, bb, length(cfg.blocks))
end
end
domtree = construct_domtree(cfg)
return lazyagdomtree.agdomtree = AugmentedDomtree(cfg, domtree)
end

mutable struct PostOptAnalysisState
const result::InferenceResult
const ir::IRCode
const inconsistent::BitSetBoundedMinPrioritySet
const tpdum::TwoPhaseDefUseMap
const lazypostdomtree::LazyPostDomtree
const lazyagdomtree::LazyAugmentedDomtree
const lazyagdomtrees::LazyAugmentedDomtrees
const ea_analysis_pending::Vector{Int}
all_retpaths_consistent::Bool
all_effect_free::Bool
Expand All @@ -592,9 +602,8 @@ mutable struct PostOptAnalysisState
function PostOptAnalysisState(result::InferenceResult, ir::IRCode)
inconsistent = BitSetBoundedMinPrioritySet(length(ir.stmts))
tpdum = TwoPhaseDefUseMap(length(ir.stmts))
lazypostdomtree = LazyPostDomtree(ir)
lazyagdomtree = LazyAugmentedDomtree(ir)
return new(result, ir, inconsistent, tpdum, lazypostdomtree, lazyagdomtree, Int[],
lazyagdomtrees = LazyAugmentedDomtrees(ir)
return new(result, ir, inconsistent, tpdum, lazyagdomtrees, Int[],
true, true, nothing, true, true, false)
end
end
Expand Down Expand Up @@ -834,13 +843,13 @@ function ((; sv)::ScanStmt)(inst::Instruction, lstmt::Int, bb::Int)
# inconsistent region.
if !sv.result.ipo_effects.terminates
sv.all_retpaths_consistent = false
elseif visit_conditional_successors(sv.lazypostdomtree, sv.ir, bb) do succ::Int
elseif visit_conditional_successors(sv.lazyagdomtrees, sv.ir, bb) do succ::Int
return any_stmt_may_throw(sv.ir, succ)
end
# check if this `GotoIfNot` leads to conditional throws, which taints consistency
sv.all_retpaths_consistent = false
else
(; cfg, domtree) = get!(sv.lazyagdomtree)
cfg, domtree = get!(sv.lazyagdomtrees, :cfg), get!(sv.lazyagdomtrees, :domtree)
for succ in iterated_dominance_frontier(cfg, BlockLiveness(sv.ir.cfg.blocks[bb].succs, nothing), domtree)
if succ == length(cfg.blocks)
# Phi node in the virtual exit -> We have a conditional
Expand Down
206 changes: 206 additions & 0 deletions finite-iterate-interp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
include("test/compiler/newinterp.jl")

@newinterp FiniteIterateInterpreter

using Core.Compiler:
AbstractLattice, BaseInferenceLattice, IPOResultLattice, InferenceLattice,
ConditionalsLattice, InterConditionalsLattice, PartialsLattice, ConstsLattice,
SimpleInferenceLattice,
widenlattice, is_valid_lattice_norec, typeinf_lattice, ipo_lattice, optimizer_lattice,
widenconst, tmeet, tmerge, βŠ‘, ⊏, abstract_eval_special_value, widenreturn

const CC = Core.Compiler

struct FiniteIterateLattice{L<:AbstractLattice} <: AbstractLattice
parent::L
end
CC.widenlattice(𝕃::FiniteIterateLattice) = 𝕃.parent
CC.is_valid_lattice_norec(::FiniteIterateLattice, @nospecialize(elm)) = _is_finite_lattice(elm)
_is_finite_lattice(@nospecialize t) = (
isa(t, FiniteIterate) || isa(t, FiniteState) || isa(t, TerminatingCondition))

CC.typeinf_lattice(::FiniteIterateInterpreter) =
InferenceLattice(ConditionalsLattice(PartialsLattice(FiniteIterateLattice(ConstsLattice()))))
CC.ipo_lattice(::FiniteIterateInterpreter) =
InferenceLattice(InterConditionalsLattice(PartialsLattice(ConstsLattice())))
CC.optimizer_lattice(::FiniteIterateInterpreter) =
FiniteIterateLattice(SimpleInferenceLattice.instance)

struct FiniteIterate
typ
itr
function FiniteIterate(@nospecialize(typ), @nospecialize(itr))
@assert !_is_finite_lattice(typ) "nested FiniteLattice"
return new(typ, itr)
end
end
struct FiniteState
typ
itr
function FiniteState(@nospecialize(typ), @nospecialize(itr))
@assert !_is_finite_lattice(typ) "nested FiniteLattice"
return new(typ, itr)
end
end
struct TerminatingCondition end
function CC.tmeet(𝕃::FiniteIterateLattice, @nospecialize(v), @nospecialize(t::Type))
if isa(v, FiniteIterate)
error("tmeet FiniteIterate")
v = v.typ
elseif isa(v, FiniteState)
error("tmeet FiniteState")
v = v.typ
elseif isa(v, TerminatingCondition)
error("tmeet TerminatingCondition")
if t === Bool
return TerminatingCondition()
end
return Bool
end
return tmeet(widenlattice(𝕃), v, t)
end
function CC.tmerge(𝕃::FiniteIterateLattice, @nospecialize(x), @nospecialize(y))
if isa(x, FiniteIterate)
if isa(y, FiniteIterate) && x.itr === y.itr
return FiniteIterate(tmerge(widenlattice(𝕃), x.typ, y.typ), x.itr)
end
x = x.typ
elseif isa(y, FiniteIterate)
y = y.typ
end
if isa(x, FiniteState)
if isa(y, FiniteState) && x.itr === y.itr
return FiniteState(tmerge(widenlattice(𝕃), x.typ, y.typ), x.itr)
end
x = x.typ
elseif isa(y, FiniteState)
y = y.typ
end
if isa(x, TerminatingCondition)
if isa(y, TerminatingCondition)
return TerminatingCondition()
end
x = Bool
elseif isa(y, TerminatingCondition)
y = Bool
end
return tmerge(widenlattice(𝕃), x, y)
end
function CC.:βŠ‘(𝕃::FiniteIterateLattice, @nospecialize(x), @nospecialize(y))
if isa(x, FiniteIterate)
if isa(y, FiniteIterate)
if x.itr === y.itr
return βŠ‘(widenlattice(𝕃), x.typ, y.typ)
end
return false
elseif isa(y, FiniteState)
return false
elseif isa(y, TerminatingCondition)
return false
end
x = x.typ
elseif isa(y, FiniteIterate)
return x === Union{}
end
if isa(x, FiniteState)
if isa(y, FiniteState)
if x.itr === y.itr
return βŠ‘(widenlattice(𝕃), x.typ, y.typ)
end
return false
elseif isa(y, TerminatingCondition)
return false
end
x = x.typ
elseif isa(y, FiniteState)
return x === Union{}
end
if isa(x, TerminatingCondition)
return x !== Union{}
elseif isa(y, TerminatingCondition)
return x === Union{}
end
return βŠ‘(widenlattice(𝕃), x, y)
end
CC.widenconst(fi::FiniteIterate) = widenconst(fi.typ)
CC.widenconst(fs::FiniteState) = widenconst(fs.typ)
CC.widenconst(::TerminatingCondition) = Bool
function widenfiniteiterate(@nospecialize x)
if isa(x, FiniteIterate)
return x.typ
elseif isa(x, FiniteState)
return x.typ
elseif isa(x, TerminatingCondition)
return Bool
end
return x
end
CC.widenreturn(𝕃::FiniteIterateLattice, @nospecialize(rt), info::CC.BestguessInfo) =
CC.widenreturn(widenlattice(𝕃), widenfiniteiterate(rt), info)

function CC.abstract_call_known(interp::FiniteIterateInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int)
res = @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f::Any,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int)
(; fargs, argtypes) = arginfo
la = length(argtypes)
if la β‰₯ 2 && fargs !== nothing && CC.istopfunction(f, :iterate)
if res.rt !== Union{}
if argtypes[2] βŠ‘ Tuple
if la == 2
res = CC.CallMeta(FiniteIterate(res.rt, fargs[2]), res.exct, res.effects, res.info)
elseif la == 3
a3 = argtypes[3]
if a3 isa FiniteState && a3.itr === fargs[2]
res = CC.CallMeta(FiniteIterate(res.rt, a3.itr), res.exct, res.effects, res.info)
end
end
end
end
end
if la == 3 && CC.istopfunction(f, :(===))
if CC.widenconditional(res.rt) === Bool && argtypes[2] isa FiniteIterate
return CC.CallMeta(TerminatingCondition(), res.exct, res.effects, res.info)
end
end
return res
end
CC.@nospecs function CC._getfield_tfunc(𝕃::FiniteIterateLattice, s00, name, setfield::Bool)
if isa(s00, FiniteIterate)
if name isa Core.Const && name.val === 2
rt = CC._getfield_tfunc(widenlattice(𝕃), s00.typ, name, setfield)
if rt !== Union{}
return FiniteState(rt, s00.itr)
end
end
s00 = s00.typ
elseif isa(s00, FiniteState)
s00 = s00.typ
elseif isa(s00, TerminatingCondition)
return Union{}
end
name = widenfiniteiterate(name)
return CC._getfield_tfunc(widenlattice(𝕃), s00, name, setfield)
end
CC.@nospecs function CC.not_int_tfunc(𝕃::FiniteIterateLattice, x)
if isa(x, TerminatingCondition)
return TerminatingCondition()
end
return CC.not_int_tfunc(widenlattice(𝕃), x)
end
function CC.handle_control_backedge!(interp::FiniteIterateInterpreter, frame::CC.InferenceState,
from::Int, to::Int, @nospecialize(condt))
if condt === TerminatingCondition()
return nothing
end
@invoke CC.handle_control_backedge!(interp::CC.AbstractInterpreter, frame::CC.InferenceState,
from::Int, to::Int, condt::Any)
end

using Test
@test Base.infer_effects(; interp=FiniteIterateInterpreter()) do
for i = (1,2,3)
end
end |> CC.is_terminates
10 changes: 10 additions & 0 deletions test/compiler/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1387,3 +1387,13 @@ let; Base.Experimental.@force_compile; func52843(); end
# https://github.com/JuliaLang/julia/issues/53508
@test !Core.Compiler.is_consistent(Base.infer_effects(getindex, (UnitRange{Int},Int)))
@test !Core.Compiler.is_consistent(Base.infer_effects(getindex, (Base.OneTo{Int},Int)))

@noinline f53613() = @assert isdefined(@__MODULE__, :v53613)
g53613() = f53613()
@test !Core.Compiler.is_consistent(Base.infer_effects(f53613))
@test_broken !Core.Compiler.is_consistent(Base.infer_effects(g53613))
@test_throws AssertionError f53613()
@test_throws AssertionError g53613()
global v53613 = nothing
@test f53613() === nothing
@test g53613() === nothing
Loading

0 comments on commit cd7508b

Please sign in to comment.