Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow external lattice elements to properly union split #49030

Merged
merged 2 commits into from Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions base/compiler/abstractinterpretation.jl
Expand Up @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# as we may want to concrete-evaluate this frame in cases when there are
# no overlayed calls, try an additional effort now to check if this call
# isn't overlayed rather than just handling it conservatively
matches = find_matching_methods(arginfo.argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), arginfo.argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if !isa(matches, FailedMethodMatch)
nonoverlayed = matches.nonoverlayed
Expand All @@ -75,7 +75,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

argtypes = arginfo.argtypes
matches = find_matching_methods(argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
Expand Down Expand Up @@ -273,11 +273,12 @@ struct UnionSplitMethodMatches
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)

function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
function find_matching_methods(𝕃::AbstractLattice,
argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
max_union_splitting::Int, max_methods::Int)
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
if 1 < unionsplitcost(argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(argtypes)
if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(𝕃, argtypes)
infos = MethodMatchInfo[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
Expand Down Expand Up @@ -1495,7 +1496,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
end
res = Union{}
nargs = length(aargtypes)
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).max_apply_union_enum
splitunions = 1 < unionsplitcost(typeinf_lattice(interp), aargtypes) <= InferenceParams(interp).max_apply_union_enum
ctypes = [Any[aft]]
infos = Vector{MaybeAbstractIterationInfo}[MaybeAbstractIterationInfo[]]
effects = EFFECTS_TOTAL
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/abstractlattice.jl
Expand Up @@ -293,6 +293,10 @@ has_mustalias(𝕃::AbstractLattice) = has_mustalias(widenlattice(𝕃))
has_mustalias(::AnyMustAliasesLattice) = true
has_mustalias(::JLTypeLattice) = false

has_extended_unionsplit(𝕃::AbstractLattice) = has_extended_unionsplit(widenlattice(𝕃))
has_extended_unionsplit(::AnyMustAliasesLattice) = true
has_extended_unionsplit(::JLTypeLattice) = false

# Curried versions
⊑(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊑(lattice, a, b)
⊏(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> ⊏(lattice, a, b)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Expand Up @@ -2542,7 +2542,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo())
argtypes = argtypes[2:end]
atype = argtypes_to_type(argtypes)
matches = find_matching_methods(argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if isa(matches, FailedMethodMatch)
rt = Bool # too many matches to analyze
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/typelattice.jl
Expand Up @@ -120,6 +120,8 @@ end
MustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) =
MustAlias(slot_id(var), vartyp, fldidx, fldtyp)

_uniontypes(x::MustAlias, ts) = _uniontypes(widenconst(x), ts)

"""
alias::InterMustAlias

Expand Down
28 changes: 15 additions & 13 deletions base/compiler/typeutils.jl
Expand Up @@ -165,7 +165,7 @@ function typesubtract(@nospecialize(a), @nospecialize(b), max_union_splitting::I
if ub isa DataType
if a.name === ub.name === Tuple.name &&
length(a.parameters) == length(ub.parameters)
if 1 < unionsplitcost(a.parameters) <= max_union_splitting
if 1 < unionsplitcost(JLTypeLattice(), a.parameters) <= max_union_splitting
ta = switchtupleunion(a)
return typesubtract(Union{ta...}, b, 0)
elseif b isa DataType
Expand Down Expand Up @@ -227,12 +227,11 @@ end
# or outside of the Tuple/Union nesting, though somewhat more expensive to be
# outside than inside because the representation is larger (because and it
# informs the callee whether any splitting is possible).
function unionsplitcost(argtypes::Union{SimpleVector,Vector{Any}})
function unionsplitcost(𝕃::AbstractLattice, argtypes::Union{SimpleVector,Vector{Any}})
nu = 1
max = 2
for ti in argtypes
# TODO remove this to implement callsite refinement of MustAlias
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
if has_extended_unionsplit(𝕃) && !isvarargtype(ti)
ti = widenconst(ti)
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This widenconst would be wasteful in the native compilation pipeline since we don't enable MustAlias by default. Maybe we want to add a hook has_extended_unionsplit(𝕃::AbstractLattice) to configure this widening?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, seems reasonable

end
if isa(ti, Union)
Expand All @@ -252,12 +251,12 @@ end
# and `Union{return...} == ty`
function switchtupleunion(@nospecialize(ty))
tparams = (unwrap_unionall(ty)::DataType).parameters
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
return _switchtupleunion(JLTypeLattice(), Any[tparams...], length(tparams), [], ty)
end

switchtupleunion(argtypes::Vector{Any}) = _switchtupleunion(argtypes, length(argtypes), [], nothing)
switchtupleunion(𝕃::AbstractLattice, argtypes::Vector{Any}) = _switchtupleunion(𝕃, argtypes, length(argtypes), [], nothing)

function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
function _switchtupleunion(𝕃::AbstractLattice, t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
if i == 0
if origt === nothing
push!(tunion, copy(t))
Expand All @@ -268,17 +267,20 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci
else
origti = ti = t[i]
# TODO remove this to implement callsite refinement of MustAlias
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
ti = widenconst(ti)
end
if isa(ti, Union)
for ty in uniontypes(ti::Union)
for ty in uniontypes(ti)
t[i] = ty
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
t[i] = origti
elseif has_extended_unionsplit(𝕃) && !isa(ti, Const) && !isvarargtype(ti) && isa(widenconst(ti), Union)
for ty in uniontypes(ti)
t[i] = ty
_switchtupleunion(t, i - 1, tunion, origt)
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
t[i] = origti
else
_switchtupleunion(t, i - 1, tunion, origt)
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
end
return tunion
Expand Down
14 changes: 7 additions & 7 deletions test/compiler/inference.jl
Expand Up @@ -2944,11 +2944,11 @@ end
# issue #28356
# unit test to make sure countunionsplit overflows gracefully
# we don't care what number is returned as long as it's large
@test Core.Compiler.unionsplitcost(Any[Union{Int32, Int64} for i=1:80]) > 100000
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}]) == 2
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int32, Int64} for i=1:80]) > 100000
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}]) == 2
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6

# make sure compiler doesn't hang in union splitting

Expand Down Expand Up @@ -3949,13 +3949,13 @@ end

# argtypes
let
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Core.Const(nothing)])
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Core.Const(nothing)])
@test length(tunion) == 2
@test Any[Int32, Core.Const(nothing)] in tunion
@test Any[Int64, Core.Const(nothing)] in tunion
end
let
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
@test length(tunion) == 4
@test Any[Int32, Float32, Core.Const(nothing)] in tunion
@test Any[Int32, Float64, Core.Const(nothing)] in tunion
Expand Down