From 039c9df3ef3a65e0bc8a52207d571c74cc2c80f4 Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Thu, 8 Jul 2021 20:45:24 +0100 Subject: [PATCH] dont rewrite fstring if await in py36 --- pyupgrade/_main.py | 22 +++++++++++++++++----- tests/features/fstrings_test.py | 13 +++++++++++-- tests/features/typing_named_tuple_test.py | 4 ++-- tests/features/typing_typed_dict_test.py | 4 ++-- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/pyupgrade/_main.py b/pyupgrade/_main.py index 6e8f7b5c..24bfbb67 100644 --- a/pyupgrade/_main.py +++ b/pyupgrade/_main.py @@ -520,13 +520,22 @@ def _format_params(call: ast.Call) -> Set[str]: return params +def _contains_await(node: ast.AST) -> bool: + for node_ in ast.walk(node): + if isinstance(node_, ast.Await): + return True + else: + return False + + class FindPy36Plus(ast.NodeVisitor): - def __init__(self) -> None: + def __init__(self, *, min_version: Version) -> None: self.fstrings: Dict[Offset, ast.Call] = {} self.named_tuples: Dict[Offset, ast.Call] = {} self.dict_typed_dicts: Dict[Offset, ast.Call] = {} self.kw_typed_dicts: Dict[Offset, ast.Call] = {} self._from_imports: Dict[str, Set[str]] = collections.defaultdict(set) + self.min_version = min_version def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if node.level == 0 and node.module in {'typing', 'typing_extensions'}: @@ -591,7 +600,8 @@ def visit_Call(self, node: ast.Call) -> None: if not candidate: i += 1 else: - self.fstrings[ast_to_offset(node)] = node + if self.min_version >= (3, 7) or not _contains_await(node): + self.fstrings[ast_to_offset(node)] = node self.generic_visit(node) @@ -758,13 +768,13 @@ def _typed_class_replacement( return end, attrs -def _fix_py36_plus(contents_text: str) -> str: +def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str: try: ast_obj = ast_parse(contents_text) except SyntaxError: return contents_text - visitor = FindPy36Plus() + visitor = FindPy36Plus(min_version=min_version) visitor.visit(ast_obj) if not any(( @@ -871,7 +881,9 @@ def _fix_file(filename: str, args: argparse.Namespace) -> int: ) contents_text = _fix_tokens(contents_text, min_version=args.min_version) if args.min_version >= (3, 6): - contents_text = _fix_py36_plus(contents_text) + contents_text = _fix_py36_plus( + contents_text, min_version=args.min_version, + ) if filename == '-': print(contents_text, end='') diff --git a/tests/features/fstrings_test.py b/tests/features/fstrings_test.py index 791c3a22..3fd54c37 100644 --- a/tests/features/fstrings_test.py +++ b/tests/features/fstrings_test.py @@ -34,10 +34,13 @@ r'''"{}".format(a['\\'])''', '"{}".format(a["b"])', "'{}'.format(a['b'])", + # await only becomes keyword in Python 3.7+ + "async def c(): return '{}'.format(await 3)", + "async def c(): return '{}'.format(1 + await 3)", ), ) def test_fix_fstrings_noop(s): - assert _fix_py36_plus(s) == s + assert _fix_py36_plus(s, min_version=(3, 6)) == s @pytest.mark.parametrize( @@ -60,4 +63,10 @@ def test_fix_fstrings_noop(s): ), ) def test_fix_fstrings(s, expected): - assert _fix_py36_plus(s) == expected + assert _fix_py36_plus(s, min_version=(3, 6)) == expected + + +def test_fix_fstrings_await_py37(): + s = "async def c(): return '{}'.format(await 1+foo())" + expected = "async def c(): return f'{await 1+foo()}'" + assert _fix_py36_plus(s, min_version=(3, 7)) == expected diff --git a/tests/features/typing_named_tuple_test.py b/tests/features/typing_named_tuple_test.py index 1b94a4a3..ac6fccfe 100644 --- a/tests/features/typing_named_tuple_test.py +++ b/tests/features/typing_named_tuple_test.py @@ -57,7 +57,7 @@ ), ) def test_typing_named_tuple_noop(s): - assert _fix_py36_plus(s) == s + assert _fix_py36_plus(s, min_version=(3, 6)) == s @pytest.mark.parametrize( @@ -171,4 +171,4 @@ def test_typing_named_tuple_noop(s): ), ) def test_fix_typing_named_tuple(s, expected): - assert _fix_py36_plus(s) == expected + assert _fix_py36_plus(s, min_version=(3, 6)) == expected diff --git a/tests/features/typing_typed_dict_test.py b/tests/features/typing_typed_dict_test.py index f3a67335..0ead5842 100644 --- a/tests/features/typing_typed_dict_test.py +++ b/tests/features/typing_typed_dict_test.py @@ -46,7 +46,7 @@ ), ) def test_typing_typed_dict_noop(s): - assert _fix_py36_plus(s) == s + assert _fix_py36_plus(s, min_version=(3, 6)) == s @pytest.mark.parametrize( @@ -137,4 +137,4 @@ def test_typing_typed_dict_noop(s): ), ) def test_typing_typed_dict(s, expected): - assert _fix_py36_plus(s) == expected + assert _fix_py36_plus(s, min_version=(3, 6)) == expected