diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index dad6afbdf6c73..2087ad96f27ce 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # as we may want to concrete-evaluate this frame in cases when there are # no overlayed calls, try an additional effort now to check if this call # isn't overlayed rather than just handling it conservatively - matches = find_matching_methods(arginfo.argtypes, atype, method_table(interp), + matches = find_matching_methods(typeinf_lattice(interp), arginfo.argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if !isa(matches, FailedMethodMatch) nonoverlayed = matches.nonoverlayed @@ -75,7 +75,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end argtypes = arginfo.argtypes - matches = find_matching_methods(argtypes, atype, method_table(interp), + matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if isa(matches, FailedMethodMatch) add_remark!(interp, sv, matches.reason) @@ -273,11 +273,12 @@ struct UnionSplitMethodMatches end any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches) -function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView, +function find_matching_methods(𝕃::AbstractLattice, + argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView, max_union_splitting::Int, max_methods::Int) # NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type - if 1 < unionsplitcost(argtypes) <= max_union_splitting - split_argtypes = switchtupleunion(argtypes) + if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting + split_argtypes = switchtupleunion(𝕃, argtypes) infos = MethodMatchInfo[] applicable = Any[] applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match @@ -1496,7 +1497,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si:: end res = Union{} nargs = length(aargtypes) - splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).max_apply_union_enum + splitunions = 1 < unionsplitcost(typeinf_lattice(interp), aargtypes) <= InferenceParams(interp).max_apply_union_enum ctypes = [Any[aft]] infos = Vector{MaybeAbstractIterationInfo}[MaybeAbstractIterationInfo[]] effects = EFFECTS_TOTAL diff --git a/base/compiler/abstractlattice.jl b/base/compiler/abstractlattice.jl index f578ec8d6f60d..a84050816cb21 100644 --- a/base/compiler/abstractlattice.jl +++ b/base/compiler/abstractlattice.jl @@ -293,6 +293,10 @@ has_mustalias(𝕃::AbstractLattice) = has_mustalias(widenlattice(𝕃)) has_mustalias(::AnyMustAliasesLattice) = true has_mustalias(::JLTypeLattice) = false +has_extended_unionsplit(𝕃::AbstractLattice) = has_extended_unionsplit(widenlattice(𝕃)) +has_extended_unionsplit(::AnyMustAliasesLattice) = true +has_extended_unionsplit(::JLTypeLattice) = false + # Curried versions ⊑(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊑(lattice, a, b) ⊏(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊏(lattice, a, b) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 41da17c19d6d2..a89d9b89826b5 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -2542,7 +2542,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any}, isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo()) argtypes = argtypes[2:end] atype = argtypes_to_type(argtypes) - matches = find_matching_methods(argtypes, atype, method_table(interp), + matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp), InferenceParams(interp).max_union_splitting, max_methods) if isa(matches, FailedMethodMatch) rt = Bool # too many matches to analyze diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 33d4d37e9c936..23f39d8b44f5e 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -120,6 +120,8 @@ end MustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) = MustAlias(slot_id(var), vartyp, fldidx, fldtyp) +_uniontypes(x::MustAlias, ts) = _uniontypes(widenconst(x), ts) + """ alias::InterMustAlias diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index c94bc0ca2aa75..293ef5797888b 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -165,7 +165,7 @@ function typesubtract(@nospecialize(a), @nospecialize(b), max_union_splitting::I if ub isa DataType if a.name === ub.name === Tuple.name && length(a.parameters) == length(ub.parameters) - if 1 < unionsplitcost(a.parameters) <= max_union_splitting + if 1 < unionsplitcost(JLTypeLattice(), a.parameters) <= max_union_splitting ta = switchtupleunion(a) return typesubtract(Union{ta...}, b, 0) elseif b isa DataType @@ -227,12 +227,11 @@ end # or outside of the Tuple/Union nesting, though somewhat more expensive to be # outside than inside because the representation is larger (because and it # informs the callee whether any splitting is possible). -function unionsplitcost(argtypes::Union{SimpleVector,Vector{Any}}) +function unionsplitcost(𝕃::AbstractLattice, argtypes::Union{SimpleVector,Vector{Any}}) nu = 1 max = 2 for ti in argtypes - # TODO remove this to implement callsite refinement of MustAlias - if isa(ti, MustAlias) && isa(widenconst(ti), Union) + if has_extended_unionsplit(𝕃) && !isvarargtype(ti) ti = widenconst(ti) end if isa(ti, Union) @@ -252,12 +251,12 @@ end # and `Union{return...} == ty` function switchtupleunion(@nospecialize(ty)) tparams = (unwrap_unionall(ty)::DataType).parameters - return _switchtupleunion(Any[tparams...], length(tparams), [], ty) + return _switchtupleunion(JLTypeLattice(), Any[tparams...], length(tparams), [], ty) end -switchtupleunion(argtypes::Vector{Any}) = _switchtupleunion(argtypes, length(argtypes), [], nothing) +switchtupleunion(𝕃::AbstractLattice, argtypes::Vector{Any}) = _switchtupleunion(𝕃, argtypes, length(argtypes), [], nothing) -function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt)) +function _switchtupleunion(𝕃::AbstractLattice, t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt)) if i == 0 if origt === nothing push!(tunion, copy(t)) @@ -268,17 +267,20 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci else origti = ti = t[i] # TODO remove this to implement callsite refinement of MustAlias - if isa(ti, MustAlias) && isa(widenconst(ti), Union) - ti = widenconst(ti) - end if isa(ti, Union) - for ty in uniontypes(ti::Union) + for ty in uniontypes(ti) + t[i] = ty + _switchtupleunion(𝕃, t, i - 1, tunion, origt) + end + t[i] = origti + elseif has_extended_unionsplit(𝕃) && !isa(ti, Const) && !isvarargtype(ti) && isa(widenconst(ti), Union) + for ty in uniontypes(ti) t[i] = ty - _switchtupleunion(t, i - 1, tunion, origt) + _switchtupleunion(𝕃, t, i - 1, tunion, origt) end t[i] = origti else - _switchtupleunion(t, i - 1, tunion, origt) + _switchtupleunion(𝕃, t, i - 1, tunion, origt) end end return tunion diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index e6be51542c205..8f8598c82bded 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2944,11 +2944,11 @@ end # issue #28356 # unit test to make sure countunionsplit overflows gracefully # we don't care what number is returned as long as it's large -@test Core.Compiler.unionsplitcost(Any[Union{Int32, Int64} for i=1:80]) > 100000 -@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}]) == 2 -@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8 -@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6 -@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int32, Int64} for i=1:80]) > 100000 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}]) == 2 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6 +@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6 # make sure compiler doesn't hang in union splitting @@ -3949,13 +3949,13 @@ end # argtypes let - tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Core.Const(nothing)]) + tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Core.Const(nothing)]) @test length(tunion) == 2 @test Any[Int32, Core.Const(nothing)] in tunion @test Any[Int64, Core.Const(nothing)] in tunion end let - tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)]) + tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)]) @test length(tunion) == 4 @test Any[Int32, Float32, Core.Const(nothing)] in tunion @test Any[Int32, Float64, Core.Const(nothing)] in tunion