From 8f6e945e71cd8478d6191b0bed77a830d1c24217 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sun, 12 Mar 2023 05:52:19 +0000 Subject: [PATCH 1/2] Allow external lattice elements to properly union split Currently `MustAlias` is the only lattice element that is allowed to widen to union types. However, there are others in external packages. Expand the support we have for this in order to allow union splitting of lattice elements. --- base/compiler/abstractinterpretation.jl | 13 ++++++------ base/compiler/abstractlattice.jl | 4 ++++ base/compiler/tfuncs.jl | 2 +- base/compiler/typelattice.jl | 2 ++ base/compiler/typeutils.jl | 28 +++++++++++++------------ test/compiler/inference.jl | 14 ++++++------- 6 files changed, 36 insertions(+), 27 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 0edba39607179..eb368cc1dbb59 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # as we may want to concrete-evaluate this frame in cases when there are # no overlayed calls, try an additional effort now to check if this call # isn't overlayed rather than just handling it conservatively - matches = find_matching_methods(arginfo.argtypes, atype, method_table(interp), + matches = find_matching_methods(ipo_lattice(interp), arginfo.argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if !isa(matches, FailedMethodMatch) nonoverlayed = matches.nonoverlayed @@ -75,7 +75,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end argtypes = arginfo.argtypes - matches = find_matching_methods(argtypes, atype, method_table(interp), + matches = find_matching_methods(ipo_lattice(interp), argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if isa(matches, FailedMethodMatch) add_remark!(interp, sv, matches.reason) @@ -273,11 +273,12 @@ struct UnionSplitMethodMatches end any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches) -function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView, +function find_matching_methods(𝕃::AbstractLattice, + argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView, max_union_splitting::Int, max_methods::Int) # NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type - if 1 < unionsplitcost(argtypes) <= max_union_splitting - split_argtypes = switchtupleunion(argtypes) + if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting + split_argtypes = switchtupleunion(𝕃, argtypes) infos = MethodMatchInfo[] applicable = Any[] applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match @@ -1495,7 +1496,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si:: end res = Union{} nargs = length(aargtypes) - splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).max_apply_union_enum + splitunions = 1 < unionsplitcost(typeinf_lattice(interp), aargtypes) <= InferenceParams(interp).max_apply_union_enum ctypes = [Any[aft]] infos = Vector{MaybeAbstractIterationInfo}[MaybeAbstractIterationInfo[]] effects = EFFECTS_TOTAL diff --git a/base/compiler/abstractlattice.jl b/base/compiler/abstractlattice.jl index f578ec8d6f60d..a84050816cb21 100644 --- a/base/compiler/abstractlattice.jl +++ b/base/compiler/abstractlattice.jl @@ -293,6 +293,10 @@ has_mustalias(𝕃::AbstractLattice) = has_mustalias(widenlattice(𝕃)) has_mustalias(::AnyMustAliasesLattice) = true has_mustalias(::JLTypeLattice) = false +has_extended_unionsplit(𝕃::AbstractLattice) = has_extended_unionsplit(widenlattice(𝕃)) +has_extended_unionsplit(::AnyMustAliasesLattice) = true +has_extended_unionsplit(::JLTypeLattice) = false + # Curried versions ⊑(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊑(lattice, a, b) ⊏(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊏(lattice, a, b) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 41da17c19d6d2..a99dd2d81b881 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -2542,7 +2542,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any}, isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo()) argtypes = argtypes[2:end] atype = argtypes_to_type(argtypes) - matches = find_matching_methods(argtypes, atype, method_table(interp), + matches = find_matching_methods(ipo_lattice(interp), argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if isa(matches, FailedMethodMatch) rt = Bool # too many matches to analyze diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 33d4d37e9c936..23f39d8b44f5e 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -120,6 +120,8 @@ end MustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) = MustAlias(slot_id(var), vartyp, fldidx, fldtyp) +_uniontypes(x::MustAlias, ts) = _uniontypes(widenconst(x), ts) + """ alias::InterMustAlias diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index c94bc0ca2aa75..293ef5797888b 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -165,7 +165,7 @@ function typesubtract(@nospecialize(a), @nospecialize(b), max_union_splitting::I if ub isa DataType if a.name === ub.name === Tuple.name && length(a.parameters) == length(ub.parameters) - if 1 < unionsplitcost(a.parameters) <= max_union_splitting + if 1 < unionsplitcost(JLTypeLattice(), a.parameters) <= max_union_splitting ta = switchtupleunion(a) return typesubtract(Union{ta...}, b, 0) elseif b isa DataType @@ -227,12 +227,11 @@ end # or outside of the Tuple/Union nesting, though somewhat more expensive to be # outside than inside because the representation is larger (because and it # informs the callee whether any splitting is possible). -function unionsplitcost(argtypes::Union{SimpleVector,Vector{Any}}) +function unionsplitcost(𝕃::AbstractLattice, argtypes::Union{SimpleVector,Vector{Any}}) nu = 1 max = 2 for ti in argtypes - # TODO remove this to implement callsite refinement of MustAlias - if isa(ti, MustAlias) && isa(widenconst(ti), Union) + if has_extended_unionsplit(𝕃) && !isvarargtype(ti) ti = widenconst(ti) end if isa(ti, Union) @@ -252,12 +251,12 @@ end # and `Union{return...} == ty` function switchtupleunion(@nospecialize(ty)) tparams = (unwrap_unionall(ty)::DataType).parameters - return _switchtupleunion(Any[tparams...], length(tparams), [], ty) + return _switchtupleunion(JLTypeLattice(), Any[tparams...], length(tparams), [], ty) end -switchtupleunion(argtypes::Vector{Any}) = _switchtupleunion(argtypes, length(argtypes), [], nothing) +switchtupleunion(𝕃::AbstractLattice, argtypes::Vector{Any}) = _switchtupleunion(𝕃, argtypes, length(argtypes), [], nothing) -function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt)) +function _switchtupleunion(𝕃::AbstractLattice, t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt)) if i == 0 if origt === nothing push!(tunion, copy(t)) @@ -268,17 +267,20 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci else origti = ti = t[i] # TODO remove this to implement callsite refinement of MustAlias - if isa(ti, MustAlias) && isa(widenconst(ti), Union) - ti = widenconst(ti) - end if isa(ti, Union) - for ty in uniontypes(ti::Union) + for ty in uniontypes(ti) + t[i] = ty + _switchtupleunion(𝕃, t, i - 1, tunion, origt) + end + t[i] = origti + elseif has_extended_unionsplit(𝕃) && !isa(ti, Const) && !isvarargtype(ti) && isa(widenconst(ti), Union) + for ty in uniontypes(ti) t[i] = ty - _switchtupleunion(t, i - 1, tunion, origt) + _switchtupleunion(𝕃, t, i - 1, tunion, origt) end t[i] = origti else - _switchtupleunion(t, i - 1, tunion, origt) + _switchtupleunion(𝕃, t, i - 1, tunion, origt) end end return tunion diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index e6be51542c205..8f8598c82bded 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2944,11 +2944,11 @@ end # issue #28356 # unit test to make sure countunionsplit overflows gracefully # we don't care what number is returned as long as it's large -@test Core.Compiler.unionsplitcost(Any[Union{Int32, Int64} for i=1:80]) > 100000 -@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}]) == 2 -@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8 -@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6 -@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int32, Int64} for i=1:80]) > 100000 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}]) == 2 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6 # make sure compiler doesn't hang in union splitting @@ -3949,13 +3949,13 @@ end # argtypes let - tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Core.Const(nothing)]) + tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Core.Const(nothing)]) @test length(tunion) == 2 @test Any[Int32, Core.Const(nothing)] in tunion @test Any[Int64, Core.Const(nothing)] in tunion end let - tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)]) + tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)]) @test length(tunion) == 4 @test Any[Int32, Float32, Core.Const(nothing)] in tunion @test Any[Int32, Float64, Core.Const(nothing)] in tunion From 6afd0e135a3ca4739ea05d3f6d10c139fa0a4406 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 17 Mar 2023 13:32:14 -0400 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> --- base/compiler/abstractinterpretation.jl | 4 ++-- base/compiler/tfuncs.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index eb368cc1dbb59..1c1c88167d072 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # as we may want to concrete-evaluate this frame in cases when there are # no overlayed calls, try an additional effort now to check if this call # isn't overlayed rather than just handling it conservatively - matches = find_matching_methods(ipo_lattice(interp), arginfo.argtypes, atype, method_table(interp), + matches = find_matching_methods(typeinf_lattice(interp), arginfo.argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if !isa(matches, FailedMethodMatch) nonoverlayed = matches.nonoverlayed @@ -75,7 +75,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end argtypes = arginfo.argtypes - matches = find_matching_methods(ipo_lattice(interp), argtypes, atype, method_table(interp), + matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if isa(matches, FailedMethodMatch) add_remark!(interp, sv, matches.reason) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index a99dd2d81b881..a89d9b89826b5 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -2542,7 +2542,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any}, isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo()) argtypes = argtypes[2:end] atype = argtypes_to_type(argtypes) - matches = find_matching_methods(ipo_lattice(interp), argtypes, atype, method_table(interp), + matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if isa(matches, FailedMethodMatch) rt = Bool # too many matches to analyze