Skip to content

Commit

Permalink
Allow external lattice elements to properly union split (JuliaLang#49030
Browse files Browse the repository at this point in the history
)

Currently `MustAlias` is the only lattice element that is allowed
to widen to union types. However, there are others in external
packages. Expand the support we have for this in order to allow
union splitting of lattice elements.

Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com>
  • Loading branch information
2 people authored and Xnartharax committed Apr 13, 2023
1 parent a4af449 commit a078649
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 27 deletions.
13 changes: 7 additions & 6 deletions base/compiler/abstractinterpretation.jl
Expand Up @@ -59,7 +59,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# as we may want to concrete-evaluate this frame in cases when there are
# no overlayed calls, try an additional effort now to check if this call
# isn't overlayed rather than just handling it conservatively
matches = find_matching_methods(arginfo.argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), arginfo.argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if !isa(matches, FailedMethodMatch)
nonoverlayed = matches.nonoverlayed
Expand All @@ -75,7 +75,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

argtypes = arginfo.argtypes
matches = find_matching_methods(argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
Expand Down Expand Up @@ -273,11 +273,12 @@ struct UnionSplitMethodMatches
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)

function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
function find_matching_methods(𝕃::AbstractLattice,
argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
max_union_splitting::Int, max_methods::Int)
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
if 1 < unionsplitcost(argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(argtypes)
if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(𝕃, argtypes)
infos = MethodMatchInfo[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
Expand Down Expand Up @@ -1496,7 +1497,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
end
res = Union{}
nargs = length(aargtypes)
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).max_apply_union_enum
splitunions = 1 < unionsplitcost(typeinf_lattice(interp), aargtypes) <= InferenceParams(interp).max_apply_union_enum
ctypes = [Any[aft]]
infos = Vector{MaybeAbstractIterationInfo}[MaybeAbstractIterationInfo[]]
effects = EFFECTS_TOTAL
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/abstractlattice.jl
Expand Up @@ -293,6 +293,10 @@ has_mustalias(𝕃::AbstractLattice) = has_mustalias(widenlattice(𝕃))
has_mustalias(::AnyMustAliasesLattice) = true
has_mustalias(::JLTypeLattice) = false

has_extended_unionsplit(𝕃::AbstractLattice) = has_extended_unionsplit(widenlattice(𝕃))
has_extended_unionsplit(::AnyMustAliasesLattice) = true
has_extended_unionsplit(::JLTypeLattice) = false

# Curried versions
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
(lattice::AbstractLattice) = (@nospecialize(a), @nospecialize(b)) -> (lattice, a, b)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/tfuncs.jl
Expand Up @@ -2542,7 +2542,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
isvarargtype(argtypes[2]) && return CallMeta(Bool, EFFECTS_UNKNOWN, NoCallInfo())
argtypes = argtypes[2:end]
atype = argtypes_to_type(argtypes)
matches = find_matching_methods(argtypes, atype, method_table(interp),
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
if isa(matches, FailedMethodMatch)
rt = Bool # too many matches to analyze
Expand Down
2 changes: 2 additions & 0 deletions base/compiler/typelattice.jl
Expand Up @@ -120,6 +120,8 @@ end
MustAlias(var::SlotNumber, @nospecialize(vartyp), fldidx::Int, @nospecialize(fldtyp)) =
MustAlias(slot_id(var), vartyp, fldidx, fldtyp)

_uniontypes(x::MustAlias, ts) = _uniontypes(widenconst(x), ts)

"""
alias::InterMustAlias
Expand Down
28 changes: 15 additions & 13 deletions base/compiler/typeutils.jl
Expand Up @@ -165,7 +165,7 @@ function typesubtract(@nospecialize(a), @nospecialize(b), max_union_splitting::I
if ub isa DataType
if a.name === ub.name === Tuple.name &&
length(a.parameters) == length(ub.parameters)
if 1 < unionsplitcost(a.parameters) <= max_union_splitting
if 1 < unionsplitcost(JLTypeLattice(), a.parameters) <= max_union_splitting
ta = switchtupleunion(a)
return typesubtract(Union{ta...}, b, 0)
elseif b isa DataType
Expand Down Expand Up @@ -227,12 +227,11 @@ end
# or outside of the Tuple/Union nesting, though somewhat more expensive to be
# outside than inside because the representation is larger (because and it
# informs the callee whether any splitting is possible).
function unionsplitcost(argtypes::Union{SimpleVector,Vector{Any}})
function unionsplitcost(𝕃::AbstractLattice, argtypes::Union{SimpleVector,Vector{Any}})
nu = 1
max = 2
for ti in argtypes
# TODO remove this to implement callsite refinement of MustAlias
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
if has_extended_unionsplit(𝕃) && !isvarargtype(ti)
ti = widenconst(ti)
end
if isa(ti, Union)
Expand All @@ -252,12 +251,12 @@ end
# and `Union{return...} == ty`
function switchtupleunion(@nospecialize(ty))
tparams = (unwrap_unionall(ty)::DataType).parameters
return _switchtupleunion(Any[tparams...], length(tparams), [], ty)
return _switchtupleunion(JLTypeLattice(), Any[tparams...], length(tparams), [], ty)
end

switchtupleunion(argtypes::Vector{Any}) = _switchtupleunion(argtypes, length(argtypes), [], nothing)
switchtupleunion(𝕃::AbstractLattice, argtypes::Vector{Any}) = _switchtupleunion(𝕃, argtypes, length(argtypes), [], nothing)

function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
function _switchtupleunion(𝕃::AbstractLattice, t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospecialize(origt))
if i == 0
if origt === nothing
push!(tunion, copy(t))
Expand All @@ -268,17 +267,20 @@ function _switchtupleunion(t::Vector{Any}, i::Int, tunion::Vector{Any}, @nospeci
else
origti = ti = t[i]
# TODO remove this to implement callsite refinement of MustAlias
if isa(ti, MustAlias) && isa(widenconst(ti), Union)
ti = widenconst(ti)
end
if isa(ti, Union)
for ty in uniontypes(ti::Union)
for ty in uniontypes(ti)
t[i] = ty
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
t[i] = origti
elseif has_extended_unionsplit(𝕃) && !isa(ti, Const) && !isvarargtype(ti) && isa(widenconst(ti), Union)
for ty in uniontypes(ti)
t[i] = ty
_switchtupleunion(t, i - 1, tunion, origt)
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
t[i] = origti
else
_switchtupleunion(t, i - 1, tunion, origt)
_switchtupleunion(𝕃, t, i - 1, tunion, origt)
end
end
return tunion
Expand Down
14 changes: 7 additions & 7 deletions test/compiler/inference.jl
Expand Up @@ -2944,11 +2944,11 @@ end
# issue #28356
# unit test to make sure countunionsplit overflows gracefully
# we don't care what number is returned as long as it's large
@test Core.Compiler.unionsplitcost(Any[Union{Int32, Int64} for i=1:80]) > 100000
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}]) == 2
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int32, Int64} for i=1:80]) > 100000
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}]) == 2
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
@test Core.Compiler.unionsplitcost(Core.Compiler.JLTypeLattice(), Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6

# make sure compiler doesn't hang in union splitting

Expand Down Expand Up @@ -3949,13 +3949,13 @@ end

# argtypes
let
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Core.Const(nothing)])
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Core.Const(nothing)])
@test length(tunion) == 2
@test Any[Int32, Core.Const(nothing)] in tunion
@test Any[Int64, Core.Const(nothing)] in tunion
end
let
tunion = Core.Compiler.switchtupleunion(Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
tunion = Core.Compiler.switchtupleunion(Core.Compiler.ConstsLattice(), Any[Union{Int32,Int64}, Union{Float32,Float64}, Core.Const(nothing)])
@test length(tunion) == 4
@test Any[Int32, Float32, Core.Const(nothing)] in tunion
@test Any[Int32, Float64, Core.Const(nothing)] in tunion
Expand Down

0 comments on commit a078649

Please sign in to comment.