Skip to content

Commit 9650510

Browse files
authored
inference: model partially initialized structs with PartialStruct (#55297)
There is still room for improvement in the accuracy of `getfield` and `isdefined` for structs with uninitialized fields. This commit aims to enhance the accuracy of struct field defined-ness by propagating such struct as `PartialStruct` in cases where fields that might be uninitialized are confirmed to be defined. Specifically, the improvements are made in the following situations: 1. when a `:new` expression receives arguments greater than the minimum number of initialized fields. 2. when new information about the initialized fields of `x` can be obtained in the `then` branch of `if isdefined(x, :f)`. Combined with the existing optimizations, these improvements enable DCE in scenarios such as: ```julia julia> @noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing); julia> @allocated broadcast_noescape1(Ref("x")) 16 # master 0 # this PR ``` One important point to note is that, as revealed in #48999, fields and globals can revert to `undef` during precompilation. This commit does not affect globals. Furthermore, even for fields, the refinements made by 1. and 2. are propagated along with data-flow, and field defined-ness information is only used when fields are confirmed to be initialized. Therefore, the same issues as #48999 will not occur by this commit.
1 parent d4bd540 commit 9650510

File tree

8 files changed

+393
-126
lines changed

8 files changed

+393
-126
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,33 +2006,64 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
20062006
return Conditional(aty.slot, thentype, elsetype)
20072007
end
20082008
elseif f === isdefined
2009-
uty = argtypes[2]
20102009
a = ssa_def_slot(fargs[2], sv)
2011-
if isa(uty, Union) && isa(a, SlotNumber)
2012-
fld = argtypes[3]
2013-
thentype = Bottom
2014-
elsetype = Bottom
2015-
for ty in uniontypes(uty)
2016-
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
2017-
if isa(cnd, Const)
2018-
if cnd.val::Bool
2019-
thentype = thentype ty
2010+
if isa(a, SlotNumber)
2011+
argtype2 = argtypes[2]
2012+
if isa(argtype2, Union)
2013+
fld = argtypes[3]
2014+
thentype = Bottom
2015+
elsetype = Bottom
2016+
for ty in uniontypes(argtype2)
2017+
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
2018+
if isa(cnd, Const)
2019+
if cnd.val::Bool
2020+
thentype = thentype ty
2021+
else
2022+
elsetype = elsetype ty
2023+
end
20202024
else
2025+
thentype = thentype ty
20212026
elsetype = elsetype ty
20222027
end
2023-
else
2024-
thentype = thentype ty
2025-
elsetype = elsetype ty
2028+
end
2029+
return Conditional(a, thentype, elsetype)
2030+
else
2031+
thentype = form_partially_defined_struct(argtype2, argtypes[3])
2032+
if thentype !== nothing
2033+
elsetype = argtype2
2034+
if rt === Const(false)
2035+
thentype = Bottom
2036+
elseif rt === Const(true)
2037+
elsetype = Bottom
2038+
end
2039+
return Conditional(a, thentype, elsetype)
20262040
end
20272041
end
2028-
return Conditional(a, thentype, elsetype)
20292042
end
20302043
end
20312044
end
20322045
@assert !isa(rt, TypeVar) "unhandled TypeVar"
20332046
return rt
20342047
end
20352048

2049+
function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name))
2050+
obj isa Const && return nothing # nothing to refine
2051+
name isa Const || return nothing
2052+
objt0 = widenconst(obj)
2053+
objt = unwrap_unionall(objt0)
2054+
objt isa DataType || return nothing
2055+
isabstracttype(objt) && return nothing
2056+
fldidx = try_compute_fieldidx(objt, name.val)
2057+
fldidx === nothing && return nothing
2058+
nminfld = datatype_min_ninitialized(objt)
2059+
if ismutabletype(objt)
2060+
fldidx == nminfld+1 || return nothing
2061+
else
2062+
fldidx > nminfld || return nothing
2063+
end
2064+
return PartialStruct(objt0, Any[fieldtype(objt0, i) for i = 1:fldidx])
2065+
end
2066+
20362067
function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta)
20372068
na = length(argtypes)
20382069
if isvarargtype(argtypes[end])
@@ -2573,20 +2604,18 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
25732604
end
25742605
ats[i] = at
25752606
end
2576-
# For now, don't allow:
2577-
# - Const/PartialStruct of mutables (but still allow PartialStruct of mutables
2578-
# with `const` fields if anything refined)
2579-
# - partially initialized Const/PartialStruct
2580-
if fcount == nargs
2581-
if consistent === ALWAYS_TRUE && allconst
2582-
argvals = Vector{Any}(undef, nargs)
2583-
for j in 1:nargs
2584-
argvals[j] = (ats[j]::Const).val
2585-
end
2586-
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
2587-
elseif anyrefine
2588-
rt = PartialStruct(rt, ats)
2607+
if fcount == nargs && consistent === ALWAYS_TRUE && allconst
2608+
argvals = Vector{Any}(undef, nargs)
2609+
for j in 1:nargs
2610+
argvals[j] = (ats[j]::Const).val
25892611
end
2612+
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
2613+
elseif anyrefine || nargs > datatype_min_ninitialized(rt)
2614+
# propagate partially initialized struct as `PartialStruct` when:
2615+
# - any refinement information is available (`anyrefine`), or when
2616+
# - `nargs` is greater than `n_initialized` derived from the struct type
2617+
# information alone
2618+
rt = PartialStruct(rt, ats)
25902619
end
25912620
else
25922621
rt = refine_partial_type(rt)
@@ -3094,7 +3123,8 @@ end
30943123
@nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo)
30953124
if isa(rt, PartialStruct)
30963125
fields = copy(rt.fields)
3097-
local anyrefine = false
3126+
anyrefine = !isvarargtype(rt.fields[end]) &&
3127+
length(rt.fields) > datatype_min_ninitialized(unwrap_unionall(rt.typ))
30983128
𝕃 = typeinf_lattice(info.interp)
30993129
= strictpartialorder(𝕃)
31003130
for i in 1:length(fields)

base/compiler/ssair/passes.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1166,7 +1166,12 @@ struct IntermediaryCollector <: WalkerCallback
11661166
intermediaries::SPCSet
11671167
end
11681168
function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
1169-
isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id)
1169+
if !(def isa Expr)
1170+
push!(walker_callback.intermediaries, defssa.id)
1171+
if def isa PiNode
1172+
return LiftedValue(def.val)
1173+
end
1174+
end
11701175
return nothing
11711176
end
11721177

base/compiler/tfuncs.jl

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -419,23 +419,29 @@ end
419419
else
420420
return Bottom
421421
end
422-
if 1 <= idx <= datatype_min_ninitialized(a1)
422+
if 1 idx datatype_min_ninitialized(a1)
423423
return Const(true)
424424
elseif a1.name === _NAMEDTUPLE_NAME
425425
if isconcretetype(a1)
426426
return Const(false)
427427
else
428428
ns = a1.parameters[1]
429429
if isa(ns, Tuple)
430-
return Const(1 <= idx <= length(ns))
430+
return Const(1 idx length(ns))
431431
end
432432
end
433-
elseif idx <= 0 || (!isvatuple(a1) && idx > fieldcount(a1))
433+
elseif idx 0 || (!isvatuple(a1) && idx > fieldcount(a1))
434434
return Const(false)
435435
elseif isa(arg1, Const)
436436
if !ismutabletype(a1) || isconst(a1, idx)
437437
return Const(isdefined(arg1.val, idx))
438438
end
439+
elseif isa(arg1, PartialStruct)
440+
if !isvarargtype(arg1.fields[end])
441+
if 1 idx length(arg1.fields)
442+
return Const(true)
443+
end
444+
end
439445
elseif !isvatuple(a1)
440446
fieldT = fieldtype(a1, idx)
441447
if isa(fieldT, DataType) && isbitstype(fieldT)
@@ -989,27 +995,39 @@ end
989995
= partialorder(𝕃)
990996

991997
# If we have s00 being a const, we can potentially refine our type-based analysis above
992-
if isa(s00, Const) || isconstType(s00)
993-
if !isa(s00, Const)
994-
sv = (s00::DataType).parameters[1]
995-
else
998+
if isa(s00, Const) || isconstType(s00) || isa(s00, PartialStruct)
999+
if isa(s00, Const)
9961000
sv = s00.val
1001+
sty = typeof(sv)
1002+
nflds = nfields(sv)
1003+
ismod = sv isa Module
1004+
elseif isa(s00, PartialStruct)
1005+
sty = unwrap_unionall(s00.typ)
1006+
nflds = fieldcount_noerror(sty)
1007+
ismod = false
1008+
else
1009+
sv = (s00::DataType).parameters[1]
1010+
sty = typeof(sv)
1011+
nflds = nfields(sv)
1012+
ismod = sv isa Module
9971013
end
9981014
if isa(name, Const)
9991015
nval = name.val
10001016
if !isa(nval, Symbol)
1001-
isa(sv, Module) && return false
1017+
ismod && return false
10021018
isa(nval, Int) || return false
10031019
end
10041020
return isdefined_tfunc(𝕃, s00, name) === Const(true)
10051021
end
1006-
boundscheck && return false
1022+
10071023
# If bounds checking is disabled and all fields are assigned,
10081024
# we may assume that we don't throw
1009-
isa(sv, Module) && return false
1025+
@assert !boundscheck
1026+
ismod && return false
10101027
name Int || name Symbol || return false
1011-
typeof(sv).name.n_uninitialized == 0 && return true
1012-
for i = (datatype_min_ninitialized(typeof(sv)) + 1):nfields(sv)
1028+
sty.name.n_uninitialized == 0 && return true
1029+
nflds === nothing && return false
1030+
for i = (datatype_min_ninitialized(sty)+1):nflds
10131031
isdefined_tfunc(𝕃, s00, Const(i)) === Const(true) || return false
10141032
end
10151033
return true

base/compiler/typelattice.jl

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,42 @@
66

77
# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used
88
# inside the global code cache.
9-
#
10-
# # The type of a value might be constant
11-
# struct Const
12-
# val
13-
# end
14-
#
15-
# struct PartialStruct
16-
# typ
17-
# fields::Vector{Any} # elements are other type lattice members
18-
# end
9+
1910
import Core: Const, PartialStruct
11+
12+
"""
13+
struct Const
14+
val
15+
end
16+
17+
The type representing a constant value.
18+
"""
19+
:(Const)
20+
21+
"""
22+
struct PartialStruct
23+
typ
24+
fields::Vector{Any} # elements are other type lattice members
25+
end
26+
27+
This extended lattice element is introduced when we have information about an object's
28+
fields beyond what can be obtained from the object type. E.g. it represents a tuple where
29+
some elements are known to be constants or a struct whose `Any`-typed field is initialized
30+
with `Int` values.
31+
32+
- `typ` indicates the type of the object
33+
- `fields` holds the lattice elements corresponding to each field of the object
34+
35+
If `typ` is a struct, `fields` represents the fields of the struct that are guaranteed to be
36+
initialized. For instance, if the length of `fields` of `PartialStruct` representing a
37+
struct with 4 fields is 3, the 4th field may not be initialized. If the length is 4, all
38+
fields are guaranteed to be initialized.
39+
40+
If `typ` is a tuple, the last element of `fields` may be `Vararg`. In this case, it is
41+
guaranteed that the number of elements in the tuple is at least `length(fields)-1`, but the
42+
exact number of elements is unknown.
43+
"""
44+
:(PartialStruct)
2045
function PartialStruct(@nospecialize(typ), fields::Vector{Any})
2146
for i = 1:length(fields)
2247
assert_nested_slotwrapper(fields[i])
@@ -57,23 +82,20 @@ end
5782
Conditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) =
5883
Conditional(slot_id(var), thentype, elsetype)
5984

85+
import Core: InterConditional
6086
"""
61-
cnd::InterConditional
87+
struct InterConditional
88+
slot::Int
89+
thentype
90+
elsetype
91+
end
6292
6393
Similar to `Conditional`, but conveys inter-procedural constraints imposed on call arguments.
6494
This is separate from `Conditional` to catch logic errors: the lattice element name is `InterConditional`
6595
while processing a call, then `Conditional` everywhere else. Thus `InterConditional` does not appear in
6696
`CompilerTypes`—these type's usages are disjoint—though we define the lattice for `InterConditional`.
6797
"""
6898
:(InterConditional)
69-
import Core: InterConditional
70-
# struct InterConditional
71-
# slot::Int
72-
# thentype
73-
# elsetype
74-
# InterConditional(slot::Int, @nospecialize(thentype), @nospecialize(elsetype)) =
75-
# new(slot, thentype, elsetype)
76-
# end
7799
InterConditional(var::SlotNumber, @nospecialize(thentype), @nospecialize(elsetype)) =
78100
InterConditional(slot_id(var), thentype, elsetype)
79101

@@ -447,8 +469,13 @@ end
447469
@nospecializeinfer function (lattice::PartialsLattice, @nospecialize(a), @nospecialize(b))
448470
if isa(a, PartialStruct)
449471
if isa(b, PartialStruct)
450-
if !(length(a.fields) == length(b.fields) && a.typ <: b.typ)
451-
return false
472+
a.typ <: b.typ || return false
473+
if length(a.fields) length(b.fields)
474+
if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end]))
475+
length(a.fields) length(b.fields) || return false
476+
else
477+
return false
478+
end
452479
end
453480
for i in 1:length(b.fields)
454481
af = a.fields[i]
@@ -471,19 +498,25 @@ end
471498
return isa(b, Type) && a.typ <: b
472499
elseif isa(b, PartialStruct)
473500
if isa(a, Const)
474-
nf = nfields(a.val)
475-
nf == length(b.fields) || return false
476501
widea = widenconst(a)::DataType
477502
wideb = widenconst(b)
478503
wideb′ = unwrap_unionall(wideb)::DataType
479504
widea.name === wideb′.name || return false
480-
# We can skip the subtype check if b is a Tuple, since in that
481-
# case, the ⊑ of the elements is sufficient.
482-
if wideb′.name !== Tuple.name && !(widea <: wideb)
483-
return false
505+
if wideb′.name === Tuple.name
506+
# We can skip the subtype check if b is a Tuple, since in that
507+
# case, the ⊑ of the elements is sufficient.
508+
# But for tuple comparisons, we need their lengths to be the same for now.
509+
# TODO improve accuracy for cases when `b` contains vararg element
510+
nfields(a.val) == length(b.fields) || return false
511+
else
512+
widea <: wideb || return false
513+
# for structs we need to check that `a` has more information than `b` that may be partially initialized
514+
n_initialized(a) length(b.fields) || return false
484515
end
516+
nf = nfields(a.val)
485517
for i in 1:nf
486518
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
519+
i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct
487520
bfᵢ = b.fields[i]
488521
if i == nf
489522
bfᵢ = unwrapva(bfᵢ)

0 commit comments

Comments
 (0)