Skip to content

Commit

Permalink
B006 and B008: Cover additional test cases (#239)
Browse files Browse the repository at this point in the history
* B006 and B008: Cover additional test cases

* Add change log entry
* Account for inconsistent ast between python versions
* Use ast.literal_eval to simplify infinity float detection
  • Loading branch information
jpy-git committed Mar 23, 2022
1 parent ea0bd48 commit f9e0f77
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 21.10b0
rev: 22.1.0
hooks:
- id: black
args:
Expand Down
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ MIT
Change Log
----------

<release-tbd>
~~~~~~~~~~

* B006 and B008: Detect function calls at any level of the default expression.

22.3.20
~~~~~~~~~~

Expand Down
143 changes: 84 additions & 59 deletions bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import builtins
import itertools
import logging
import math
import re
import sys
from collections import namedtuple
from contextlib import suppress
from functools import lru_cache, partial
Expand Down Expand Up @@ -354,13 +354,13 @@ def visit_Assert(self, node):

def visit_AsyncFunctionDef(self, node):
self.check_for_b902(node)
self.check_for_b006(node)
self.check_for_b006_and_b008(node)
self.generic_visit(node)

def visit_FunctionDef(self, node):
self.check_for_b901(node)
self.check_for_b902(node)
self.check_for_b006(node)
self.check_for_b006_and_b008(node)
self.check_for_b018(node)
self.check_for_b019(node)
self.check_for_b021(node)
Expand Down Expand Up @@ -390,23 +390,14 @@ def visit_With(self, node):
self.check_for_b022(node)
self.generic_visit(node)

def compose_call_path(self, node):
if isinstance(node, ast.Attribute):
yield from self.compose_call_path(node.value)
yield node.attr
elif isinstance(node, ast.Call):
yield from self.compose_call_path(node.func)
elif isinstance(node, ast.Name):
yield node.id

def check_for_b005(self, node):
if node.func.attr not in B005.methods:
return # method name doesn't match

if len(node.args) != 1 or not isinstance(node.args[0], ast.Str):
return # used arguments don't match the builtin strip

call_path = ".".join(self.compose_call_path(node.func.value))
call_path = ".".join(compose_call_path(node.func.value))
if call_path in B005.valid_paths:
return # path is exempt

Expand All @@ -419,48 +410,10 @@ def check_for_b005(self, node):

self.errors.append(B005(node.lineno, node.col_offset))

def check_for_b006(self, node):
for default in node.args.defaults + node.args.kw_defaults:
if isinstance(
default, (*B006.mutable_literals, *B006.mutable_comprehensions)
):
self.errors.append(B006(default.lineno, default.col_offset))
elif isinstance(default, ast.Call):
call_path = ".".join(self.compose_call_path(default.func))
if call_path in B006.mutable_calls:
self.errors.append(B006(default.lineno, default.col_offset))
elif (
call_path
not in B008.immutable_calls | self.b008_extend_immutable_calls
):
# Check if function call is actually a float infinity/NaN literal
if call_path == "float" and len(default.args) == 1:
float_arg = default.args[0]
if sys.version_info < (3, 8, 0):
# NOTE: pre-3.8, string literals are represented with ast.Str
if isinstance(float_arg, ast.Str):
str_val = float_arg.s
else:
str_val = ""
else:
# NOTE: post-3.8, string literals are represented with ast.Constant
if isinstance(float_arg, ast.Constant):
str_val = float_arg.value
if not isinstance(str_val, str):
str_val = ""
else:
str_val = ""

# NOTE: regex derived from documentation at:
# https://docs.python.org/3/library/functions.html#float
inf_nan_regex = r"^[+-]?(inf|infinity|nan)$"
re_result = re.search(inf_nan_regex, str_val.lower())
is_float_literal = re_result is not None
else:
is_float_literal = False

if not is_float_literal:
self.errors.append(B008(default.lineno, default.col_offset))
def check_for_b006_and_b008(self, node):
visitor = FuntionDefDefaultsVisitor(self.b008_extend_immutable_calls)
visitor.visit(node.args.defaults + node.args.kw_defaults)
self.errors.extend(visitor.errors)

def check_for_b007(self, node):
targets = NameFinder()
Expand Down Expand Up @@ -536,8 +489,7 @@ def check_for_b019(self, node):
# Preserve decorator order so we can get the lineno from the decorator node
# rather than the function node (this location definition changes in Python 3.8)
resolved_decorators = (
".".join(self.compose_call_path(decorator))
for decorator in node.decorator_list
".".join(compose_call_path(decorator)) for decorator in node.decorator_list
)
for idx, decorator in enumerate(resolved_decorators):
if decorator in {"classmethod", "staticmethod"}:
Expand Down Expand Up @@ -755,6 +707,16 @@ def check_for_b022(self, node):
self.errors.append(B022(node.lineno, node.col_offset))


def compose_call_path(node):
if isinstance(node, ast.Attribute):
yield from compose_call_path(node.value)
yield node.attr
elif isinstance(node, ast.Call):
yield from compose_call_path(node.func)
elif isinstance(node, ast.Name):
yield node.id


@attr.s
class NameFinder(ast.NodeVisitor):
"""Finds a name within a tree of nodes.
Expand All @@ -778,6 +740,69 @@ def visit(self, node):
return node


class FuntionDefDefaultsVisitor(ast.NodeVisitor):
def __init__(self, b008_extend_immutable_calls=None):
self.b008_extend_immutable_calls = b008_extend_immutable_calls or set()
for node in B006.mutable_literals + B006.mutable_comprehensions:
setattr(self, f"visit_{node}", self.visit_mutable_literal_or_comprehension)
self.errors = []
self.arg_depth = 0
super().__init__()

def visit_mutable_literal_or_comprehension(self, node):
# Flag B006 iff mutable literal/comprehension is not nested.
# We only flag these at the top level of the expression as we
# cannot easily guarantee that nested mutable structures are not
# made immutable by outer operations, so we prefer no false positives.
# e.g.
# >>> def this_is_fine(a=frozenset({"a", "b", "c"})): ...
#
# >>> def this_is_not_fine_but_hard_to_detect(a=(lambda x: x)([1, 2, 3]))
#
# We do still search for cases of B008 within mutable structures though.
if self.arg_depth == 1:
self.errors.append(B006(node.lineno, node.col_offset))
# Check for nested functions.
self.generic_visit(node)

def visit_Call(self, node):
call_path = ".".join(compose_call_path(node.func))
if call_path in B006.mutable_calls:
self.errors.append(B006(node.lineno, node.col_offset))
self.generic_visit(node)
return

if call_path in B008.immutable_calls | self.b008_extend_immutable_calls:
self.generic_visit(node)
return

# Check if function call is actually a float infinity/NaN literal
if call_path == "float" and len(node.args) == 1:
try:
value = float(ast.literal_eval(node.args[0]))
except Exception:
pass
else:
if math.isfinite(value):
self.errors.append(B008(node.lineno, node.col_offset))
else:
self.errors.append(B008(node.lineno, node.col_offset))

# Check for nested functions.
self.generic_visit(node)

def visit(self, node):
"""Like super-visit but supports iteration over lists."""
self.arg_depth += 1
if isinstance(node, list):
for elem in node:
if elem is not None:
super().visit(elem)
else:
super().visit(node)
self.arg_depth -= 1


class B020NameFinder(NameFinder):
"""Ignore names defined within the local scope of a comprehension."""

Expand Down Expand Up @@ -851,8 +876,8 @@ def visit_comprehension(self, node):
"between them."
)
)
B006.mutable_literals = (ast.Dict, ast.List, ast.Set)
B006.mutable_comprehensions = (ast.ListComp, ast.DictComp, ast.SetComp)
B006.mutable_literals = ("Dict", "List", "Set")
B006.mutable_comprehensions = ("ListComp", "DictComp", "SetComp")
B006.mutable_calls = {
"Counter",
"OrderedDict",
Expand Down
Loading

0 comments on commit f9e0f77

Please sign in to comment.