Skip to content

Commit 79d5a3a

Browse files
committed
Improved builtin inference for tuple, set, frozenset, list and dict
We were properly inferring these callables *only* if they had consts as values, but that is not the case most of the time. Instead we try to infer the values that their arguments can be and use them instead of assuming Const nodes all the time. Close pylint-dev/pylint#2841
1 parent 44bbb98 commit 79d5a3a

File tree

4 files changed

+57
-18
lines changed

4 files changed

+57
-18
lines changed

ChangeLog

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ What's New in astroid 2.3.0?
66
============================
77
Release Date: TBA
88

9+
* Improved builtin inference for ``tuple``, ``set``, ``frozenset``, ``list`` and ``dict``
10+
11+
We were properly inferring these callables *only* if they had consts as
12+
values, but that is not the case most of the time. Instead we try to infer
13+
the values that their arguments can be and use them instead of assuming
14+
Const nodes all the time.
15+
16+
Close PyCQA/pylint#2841
17+
918
* The last except handler wins when inferring variables bound in an except handler.
1019

1120
Close PyCQA/pylint#2777

astroid/brain/brain_builtin_inference.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _transform_wrapper(node, context=None):
147147
)
148148

149149

150-
def _generic_inference(node, context, node_type, transform):
150+
def _container_generic_inference(node, context, node_type, transform):
151151
args = node.args
152152
if not args:
153153
return node_type()
@@ -169,14 +169,17 @@ def _generic_inference(node, context, node_type, transform):
169169
return transformed
170170

171171

172-
def _generic_transform(arg, klass, iterables, build_elts):
172+
def _container_generic_transform(arg, klass, iterables, build_elts):
173173
if isinstance(arg, klass):
174174
return arg
175175
elif isinstance(arg, iterables):
176-
if not all(isinstance(elt, nodes.Const) for elt in arg.elts):
177-
raise UseInferenceDefault()
178-
elts = [elt.value for elt in arg.elts]
176+
if all(isinstance(elt, nodes.Const) for elt in arg.elts):
177+
elts = [elt.value for elt in arg.elts]
178+
else:
179+
# TODO: Does not handle deduplication for sets.
180+
elts = filter(None, map(helpers.safe_infer, arg.elts))
179181
elif isinstance(arg, nodes.Dict):
182+
# Dicts need to have consts as strings already.
180183
if not all(isinstance(elt[0], nodes.Const) for elt in arg.items):
181184
raise UseInferenceDefault()
182185
elts = [item[0].value for item in arg.items]
@@ -186,20 +189,25 @@ def _generic_transform(arg, klass, iterables, build_elts):
186189
elts = arg.value
187190
else:
188191
return
189-
return klass.from_constants(elts=build_elts(elts))
192+
return klass.from_elements(elts=build_elts(elts))
190193

191194

192-
def _infer_builtin(node, context, klass=None, iterables=None, build_elts=None):
195+
def _infer_builtin_container(
196+
node, context, klass=None, iterables=None, build_elts=None
197+
):
193198
transform_func = partial(
194-
_generic_transform, klass=klass, iterables=iterables, build_elts=build_elts
199+
_container_generic_transform,
200+
klass=klass,
201+
iterables=iterables,
202+
build_elts=build_elts,
195203
)
196204

197-
return _generic_inference(node, context, klass, transform_func)
205+
return _container_generic_inference(node, context, klass, transform_func)
198206

199207

200208
# pylint: disable=invalid-name
201209
infer_tuple = partial(
202-
_infer_builtin,
210+
_infer_builtin_container,
203211
klass=nodes.Tuple,
204212
iterables=(
205213
nodes.List,
@@ -213,7 +221,7 @@ def _infer_builtin(node, context, klass=None, iterables=None, build_elts=None):
213221
)
214222

215223
infer_list = partial(
216-
_infer_builtin,
224+
_infer_builtin_container,
217225
klass=nodes.List,
218226
iterables=(
219227
nodes.Tuple,
@@ -227,14 +235,14 @@ def _infer_builtin(node, context, klass=None, iterables=None, build_elts=None):
227235
)
228236

229237
infer_set = partial(
230-
_infer_builtin,
238+
_infer_builtin_container,
231239
klass=nodes.Set,
232240
iterables=(nodes.List, nodes.Tuple, objects.FrozenSet, objects.DictKeys),
233241
build_elts=set,
234242
)
235243

236244
infer_frozenset = partial(
237-
_infer_builtin,
245+
_infer_builtin_container,
238246
klass=objects.FrozenSet,
239247
iterables=(nodes.List, nodes.Tuple, nodes.Set, objects.FrozenSet, objects.DictKeys),
240248
build_elts=frozenset,

astroid/node_classes.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
MANAGER = manager.AstroidManager()
4949

5050

51+
def _is_const(value):
52+
return isinstance(value, tuple(CONST_CLS))
53+
54+
5155
@decorators.raise_if_nothing_inferred
5256
def unpack_infer(stmt, context=None):
5357
"""recursively generate nodes inferred by the given statement.
@@ -1017,7 +1021,7 @@ def postinit(self, elts):
10171021
self.elts = elts
10181022

10191023
@classmethod
1020-
def from_constants(cls, elts=None):
1024+
def from_elements(cls, elts=None):
10211025
"""Create a node of this type from the given list of elements.
10221026
10231027
:param elts: The list of elements that the node should contain.
@@ -1030,7 +1034,7 @@ def from_constants(cls, elts=None):
10301034
if elts is None:
10311035
node.elts = []
10321036
else:
1033-
node.elts = [const_factory(e) for e in elts]
1037+
node.elts = [const_factory(e) if _is_const(e) else e for e in elts]
10341038
return node
10351039

10361040
def itered(self):
@@ -2728,7 +2732,7 @@ def postinit(self, items):
27282732
self.items = items
27292733

27302734
@classmethod
2731-
def from_constants(cls, items=None):
2735+
def from_elements(cls, items=None):
27322736
"""Create a :class:`Dict` of constants from a live dictionary.
27332737
27342738
:param items: The items to store in the node.
@@ -2742,7 +2746,10 @@ def from_constants(cls, items=None):
27422746
node.items = []
27432747
else:
27442748
node.items = [
2745-
(const_factory(k), const_factory(v)) for k, v in items.items()
2749+
(const_factory(k), const_factory(v) if _is_const(v) else v)
2750+
for k, v in items.items()
2751+
# The keys need to be constants
2752+
if _is_const(k)
27462753
]
27472754
return node
27482755

astroid/tests/unittest_inference.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ def test_binary_op_int_sub(self):
10401040

10411041
def test_binary_op_float_div(self):
10421042
ast = builder.string_build("a = 1 / 2.", __name__, __file__)
1043-
self._test_const_inferred(ast["a"], 1 / 2.)
1043+
self._test_const_inferred(ast["a"], 1 / 2.0)
10441044

10451045
def test_binary_op_str_mul(self):
10461046
ast = builder.string_build('a = "*" * 40', __name__, __file__)
@@ -5149,5 +5149,20 @@ def test_exception_lookup_name_bound_in_except_handler():
51495149
assert inferred_exc.value == 2
51505150

51515151

5152+
def test_builtin_inference_list_of_exceptions():
5153+
node = extract_node(
5154+
"""
5155+
tuple([ValueError, TypeError])
5156+
"""
5157+
)
5158+
inferred = next(node.infer())
5159+
assert isinstance(inferred, nodes.Tuple)
5160+
assert len(inferred.elts) == 2
5161+
assert isinstance(inferred.elts[0], nodes.ClassDef)
5162+
assert inferred.elts[0].name == "ValueError"
5163+
assert isinstance(inferred.elts[1], nodes.ClassDef)
5164+
assert inferred.elts[1].name == "TypeError"
5165+
5166+
51525167
if __name__ == "__main__":
51535168
unittest.main()

0 commit comments

Comments
 (0)