Skip to content

Commit

Permalink
minor refactoring on find_method_matches (#53741)
Browse files Browse the repository at this point in the history
So that it can be tested in isolation easier.
  • Loading branch information
aviatesk committed Mar 19, 2024
1 parent 8e67f99 commit 8f76c69
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 67 deletions.
135 changes: 70 additions & 65 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

argtypes = arginfo.argtypes
matches = find_matching_methods(𝕃ᵢ, argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
matches = find_method_matches(interp, argtypes, atype; max_methods)
if isa(matches, FailedMethodMatch)
add_remark!(interp, sv, matches.reason)
return CallMeta(Any, Any, Effects(), NoCallInfo())
Expand Down Expand Up @@ -255,73 +254,79 @@ struct UnionSplitMethodMatches
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)

function find_matching_methods(𝕃::AbstractLattice,
argtypes::Vector{Any}, @nospecialize(atype), method_table::MethodTableView,
max_union_splitting::Int, max_methods::Int)
# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
if 1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting
split_argtypes = switchtupleunion(𝕃, argtypes)
infos = MethodMatchInfo[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
mts = MethodTable[]
fullmatches = Bool[]
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::MethodTable
matches = findall(sig_n, method_table; limit = max_methods)
if matches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
push!(infos, MethodMatchInfo(matches))
for m in matches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatches[i] &= thisfullmatch
found = true
break
end
end
if !found
push!(mts, mt)
push!(fullmatches, thisfullmatch)
end
end
return UnionSplitMethodMatches(applicable,
applicable_argtypes,
UnionSplitInfo(infos),
valid_worlds,
mts,
fullmatches)
else
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
return FailedMethodMatch("Could not identify method table for call")
end
function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
max_methods::Int = InferenceParams(interp).max_methods)
if is_union_split_eligible(typeinf_lattice(interp), argtypes, max_union_splitting)
return find_union_split_method_matches(interp, argtypes, atype, max_methods)
end
return find_simple_method_matches(interp, atype, max_methods)
end

# NOTE this is valid as far as any "constant" lattice element doesn't represent `Union` type
is_union_split_eligible(𝕃::AbstractLattice, argtypes::Vector{Any}, max_union_splitting::Int) =
1 < unionsplitcost(𝕃, argtypes) <= max_union_splitting

function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any},
@nospecialize(atype), max_methods::Int)
split_argtypes = switchtupleunion(typeinf_lattice(interp), argtypes)
infos = MethodMatchInfo[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
mts = MethodTable[]
fullmatches = Bool[]
for i in 1:length(split_argtypes)
arg_n = split_argtypes[i]::Vector{Any}
sig_n = argtypes_to_type(arg_n)
mt = ccall(:jl_method_table_for, Any, (Any,), sig_n)
mt === nothing && return FailedMethodMatch("Could not identify method table for call")
mt = mt::MethodTable
matches = findall(atype, method_table; limit = max_methods)
matches = findall(sig_n, method_table(interp); limit = max_methods)
if matches === nothing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
push!(infos, MethodMatchInfo(matches))
for m in matches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatches[i] &= thisfullmatch
found = true
break
end
end
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
return MethodMatches(matches.matches,
MethodMatchInfo(matches),
matches.valid_worlds,
mt,
fullmatch)
if !found
push!(mts, mt)
push!(fullmatches, thisfullmatch)
end
end
info = UnionSplitInfo(infos)
return UnionSplitMethodMatches(
applicable, applicable_argtypes, info, valid_worlds, mts, fullmatches)
end

function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(atype), max_methods::Int)
mt = ccall(:jl_method_table_for, Any, (Any,), atype)
if mt === nothing
return FailedMethodMatch("Could not identify method table for call")
end
mt = mt::MethodTable
matches = findall(atype, method_table(interp); limit = max_methods)
if matches === nothing
# this means too many methods matched
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
info = MethodMatchInfo(matches)
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
return MethodMatches(
matches.matches, info, matches.valid_worlds, mt, fullmatch)
end

"""
Expand Down
3 changes: 1 addition & 2 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3019,8 +3019,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
isvarargtype(argtypes[2]) && return CallMeta(Bool, Any, EFFECTS_UNKNOWN, NoCallInfo())
argtypes = argtypes[2:end]
atype = argtypes_to_type(argtypes)
matches = find_matching_methods(typeinf_lattice(interp), argtypes, atype, method_table(interp),
InferenceParams(interp).max_union_splitting, max_methods)
matches = find_method_matches(interp, argtypes, atype; max_methods)
if isa(matches, FailedMethodMatch)
rt = Bool # too many matches to analyze
else
Expand Down

0 comments on commit 8f76c69

Please sign in to comment.