Skip to content

Commit

Permalink
convert fstring rewriter to a plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
asottile committed Apr 8, 2022
1 parent 087c7e6 commit 51d5d11
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 236 deletions.
201 changes: 4 additions & 197 deletions pyupgrade/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
import argparse
import ast
import re
import string
import sys
import tokenize
from typing import Match
from typing import Optional
from typing import Sequence
from typing import Tuple

from tokenize_rt import NON_CODING_TOKENS
from tokenize_rt import Offset
from tokenize_rt import parse_string_literal
from tokenize_rt import reversed_enumerate
from tokenize_rt import rfind_string_parts
Expand All @@ -22,68 +18,18 @@
from tokenize_rt import UNIMPORTANT_WS

from pyupgrade._ast_helpers import ast_parse
from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import contains_await
from pyupgrade._ast_helpers import has_starargs
from pyupgrade._data import FUNCS
from pyupgrade._data import Settings
from pyupgrade._data import Version
from pyupgrade._data import visit
from pyupgrade._string_helpers import curly_escape
from pyupgrade._string_helpers import DotFormatPart
from pyupgrade._string_helpers import is_codec
from pyupgrade._string_helpers import NAMED_UNICODE_RE
from pyupgrade._string_helpers import parse_format
from pyupgrade._string_helpers import unparse_parsed_string
from pyupgrade._token_helpers import CLOSING
from pyupgrade._token_helpers import OPENING
from pyupgrade._token_helpers import parse_call_args
from pyupgrade._token_helpers import remove_brace

DotFormatPart = Tuple[str, Optional[str], Optional[str], Optional[str]]

FUNC_TYPES = (ast.Lambda, ast.FunctionDef, ast.AsyncFunctionDef)

_stdlib_parse_format = string.Formatter().parse


def parse_format(s: str) -> tuple[DotFormatPart, ...]:
"""handle named escape sequences"""
ret: list[DotFormatPart] = []

for part in NAMED_UNICODE_RE.split(s):
if NAMED_UNICODE_RE.fullmatch(part):
if not ret:
ret.append((part, None, None, None))
else:
ret[-1] = (ret[-1][0] + part, None, None, None)
else:
first = True
for tup in _stdlib_parse_format(part):
if not first or not ret:
ret.append(tup)
else:
ret[-1] = (ret[-1][0] + tup[0], *tup[1:])
first = False

if not ret:
ret.append((s, None, None, None))

return tuple(ret)


def unparse_parsed_string(parsed: Sequence[DotFormatPart]) -> str:
def _convert_tup(tup: DotFormatPart) -> str:
ret, field_name, format_spec, conversion = tup
ret = curly_escape(ret)
if field_name is not None:
ret += '{' + field_name
if conversion:
ret += '!' + conversion
if format_spec:
ret += ':' + format_spec
ret += '}'
return ret

return ''.join(_convert_tup(tup) for tup in parsed)


def inty(s: str) -> bool:
try:
Expand Down Expand Up @@ -328,7 +274,7 @@ def _fix_format_literal(tokens: list[Token], end: int) -> None:
else:
return

parsed_parts.append(tuple(_remove_fmt(tup) for tup in parsed))
parsed_parts.append([_remove_fmt(tup) for tup in parsed])

for i, parsed in zip(parts, parsed_parts):
tokens[i] = tokens[i]._replace(src=unparse_parsed_string(parsed))
Expand Down Expand Up @@ -523,141 +469,6 @@ def _fix_tokens(contents_text: str, min_version: Version) -> str:
return tokens_to_src(tokens).lstrip()


def _format_params(call: ast.Call) -> set[str]:
params = {str(i) for i, arg in enumerate(call.args)}
for kwd in call.keywords:
# kwd.arg can't be None here because we exclude starargs
assert kwd.arg is not None
params.add(kwd.arg)
return params


class FindPy36Plus(ast.NodeVisitor):
def __init__(self, *, min_version: Version) -> None:
self.fstrings: dict[Offset, ast.Call] = {}
self.min_version = min_version

def _parse(self, node: ast.Call) -> tuple[DotFormatPart, ...] | None:
if not (
isinstance(node.func, ast.Attribute) and
isinstance(node.func.value, ast.Str) and
node.func.attr == 'format' and
not has_starargs(node)
):
return None

try:
return parse_format(node.func.value.s)
except ValueError:
return None

def visit_Call(self, node: ast.Call) -> None:
parsed = self._parse(node)
if parsed is not None:
params = _format_params(node)
seen: set[str] = set()
i = 0
for _, name, spec, _ in parsed:
# timid: difficult to rewrite correctly
if spec is not None and '{' in spec:
break
if name is not None:
candidate, _, _ = name.partition('.')
# timid: could make the f-string longer
if candidate and candidate in seen:
break
# timid: bracketed
elif '[' in name:
break
seen.add(candidate)

key = candidate or str(i)
# their .format() call is broken currently
if key not in params:
break
if not candidate:
i += 1
else:
if self.min_version >= (3, 7) or not contains_await(node):
self.fstrings[ast_to_offset(node)] = node

self.generic_visit(node)


def _skip_unimportant_ws(tokens: list[Token], i: int) -> int:
while tokens[i].name == 'UNIMPORTANT_WS':
i += 1
return i


def _to_fstring(
src: str, tokens: list[Token], args: list[tuple[int, int]],
) -> str:
params = {}
i = 0
for start, end in args:
start = _skip_unimportant_ws(tokens, start)
if tokens[start].name == 'NAME':
after = _skip_unimportant_ws(tokens, start + 1)
if tokens[after].src == '=': # keyword argument
params[tokens[start].src] = tokens_to_src(
tokens[after + 1:end],
).strip()
continue
params[str(i)] = tokens_to_src(tokens[start:end]).strip()
i += 1

parts = []
i = 0
for s, name, spec, conv in parse_format('f' + src):
if name is not None:
k, dot, rest = name.partition('.')
name = ''.join((params[k or str(i)], dot, rest))
if not k: # named and auto params can be in different orders
i += 1
parts.append((s, name, spec, conv))
return unparse_parsed_string(parts)


def _fix_py36_plus(contents_text: str, *, min_version: Version) -> str:
try:
ast_obj = ast_parse(contents_text)
except SyntaxError:
return contents_text

visitor = FindPy36Plus(min_version=min_version)
visitor.visit(ast_obj)

if not visitor.fstrings:
return contents_text

try:
tokens = src_to_tokens(contents_text)
except tokenize.TokenError: # pragma: no cover (bpo-2180)
return contents_text
for i, token in reversed_enumerate(tokens):
if token.offset in visitor.fstrings:
paren = i + 3
if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(':
continue

args, end = parse_call_args(tokens, paren)
# if it spans more than one line, bail
if tokens[end - 1].line != token.line:
continue

args_src = tokens_to_src(tokens[paren:end])
if '\\' in args_src or '"' in args_src or "'" in args_src:
continue

tokens[i] = token._replace(
src=_to_fstring(token.src, tokens, args),
)
del tokens[i + 1:end]

return tokens_to_src(tokens)


def _fix_file(filename: str, args: argparse.Namespace) -> int:
if filename == '-':
contents_bytes = sys.stdin.buffer.read()
Expand All @@ -681,10 +492,6 @@ def _fix_file(filename: str, args: argparse.Namespace) -> int:
),
)
contents_text = _fix_tokens(contents_text, min_version=args.min_version)
if args.min_version >= (3, 6):
contents_text = _fix_py36_plus(
contents_text, min_version=args.min_version,
)

if filename == '-':
print(contents_text, end='')
Expand Down
133 changes: 133 additions & 0 deletions pyupgrade/_plugins/fstrings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import annotations

import ast
from typing import Iterable

from tokenize_rt import Offset
from tokenize_rt import Token
from tokenize_rt import tokens_to_src

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._ast_helpers import contains_await
from pyupgrade._ast_helpers import has_starargs
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._string_helpers import parse_format
from pyupgrade._string_helpers import unparse_parsed_string
from pyupgrade._token_helpers import parse_call_args


def _skip_unimportant_ws(tokens: list[Token], i: int) -> int:
while tokens[i].name == 'UNIMPORTANT_WS':
i += 1
return i


def _to_fstring(
src: str, tokens: list[Token], args: list[tuple[int, int]],
) -> str:
params = {}
i = 0
for start, end in args:
start = _skip_unimportant_ws(tokens, start)
if tokens[start].name == 'NAME':
after = _skip_unimportant_ws(tokens, start + 1)
if tokens[after].src == '=': # keyword argument
params[tokens[start].src] = tokens_to_src(
tokens[after + 1:end],
).strip()
continue
params[str(i)] = tokens_to_src(tokens[start:end]).strip()
i += 1

parts = []
i = 0
for s, name, spec, conv in parse_format('f' + src):
if name is not None:
k, dot, rest = name.partition('.')
name = ''.join((params[k or str(i)], dot, rest))
if not k: # named and auto params can be in different orders
i += 1
parts.append((s, name, spec, conv))
return unparse_parsed_string(parts)


def _fix_fstring(i: int, tokens: list[Token]) -> None:
token = tokens[i]

paren = i + 3
if tokens_to_src(tokens[i + 1:paren + 1]) != '.format(':
return

args, end = parse_call_args(tokens, paren)
# if it spans more than one line, bail
if tokens[end - 1].line != token.line:
return

args_src = tokens_to_src(tokens[paren:end])
if '\\' in args_src or '"' in args_src or "'" in args_src:
return

tokens[i] = token._replace(src=_to_fstring(token.src, tokens, args))
del tokens[i + 1:end]


def _format_params(call: ast.Call) -> set[str]:
params = {str(i) for i, arg in enumerate(call.args)}
for kwd in call.keywords:
# kwd.arg can't be None here because we exclude starargs
assert kwd.arg is not None
params.add(kwd.arg)
return params


@register(ast.Call)
def visit_Call(
state: State,
node: ast.Call,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if state.settings.min_version < (3, 6):
return

if (
isinstance(node.func, ast.Attribute) and
isinstance(node.func.value, ast.Str) and
node.func.attr == 'format' and
not has_starargs(node)
):
try:
parsed = parse_format(node.func.value.s)
except ValueError:
return

params = _format_params(node)
seen = set()
i = 0
for _, name, spec, _ in parsed:
# timid: difficult to rewrite correctly
if spec is not None and '{' in spec:
break
if name is not None:
candidate, _, _ = name.partition('.')
# timid: could make the f-string longer
if candidate and candidate in seen:
break
# timid: bracketed
elif '[' in name:
break
seen.add(candidate)

key = candidate or str(i)
# their .format() call is broken currently
if key not in params:
break
if not candidate:
i += 1
else:
if (
state.settings.min_version >= (3, 7) or
not contains_await(node)
):
yield ast_to_offset(node), _fix_fstring

0 comments on commit 51d5d11

Please sign in to comment.