/
_data.py
124 lines (96 loc) · 3.21 KB
/
_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import ast
import collections
import pkgutil
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import NamedTuple
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from tokenize_rt import Offset
from tokenize_rt import Token
from pyupgrade import _plugins
if TYPE_CHECKING:
from typing import Protocol
else:
Protocol = object
Version = Tuple[int, ...]
class Settings(NamedTuple):
min_version: Version = (2, 7)
keep_percent_format: bool = False
keep_mock: bool = False
keep_runtime_typing: bool = False
class State(NamedTuple):
settings: Settings
from_imports: Dict[str, Set[str]]
in_annotation: bool = False
AST_T = TypeVar('AST_T', bound=ast.AST)
TokenFunc = Callable[[int, List[Token]], None]
ASTFunc = Callable[[State, AST_T, ast.AST], Iterable[Tuple[Offset, TokenFunc]]]
RECORD_FROM_IMPORTS = frozenset((
'__future__',
'functools',
'mmap',
'select',
'six',
'socket',
'sys',
'typing',
))
FUNCS = collections.defaultdict(list)
def register(tp: Type[AST_T]) -> Callable[[ASTFunc[AST_T]], ASTFunc[AST_T]]:
def register_decorator(func: ASTFunc[AST_T]) -> ASTFunc[AST_T]:
FUNCS[tp].append(func)
return func
return register_decorator
class ASTCallbackMapping(Protocol):
def __getitem__(self, tp: Type[AST_T]) -> List[ASTFunc[AST_T]]: ...
def visit(
funcs: ASTCallbackMapping,
tree: ast.Module,
settings: Settings,
) -> Dict[Offset, List[TokenFunc]]:
initial_state = State(
settings=settings,
from_imports=collections.defaultdict(set),
)
nodes: List[Tuple[State, ast.AST, ast.AST]] = [(initial_state, tree, tree)]
ret = collections.defaultdict(list)
while nodes:
state, node, parent = nodes.pop()
tp = type(node)
for ast_func in funcs[tp]:
for offset, token_func in ast_func(state, node, parent):
ret[offset].append(token_func)
if (
isinstance(node, ast.ImportFrom) and
not node.level and
node.module in RECORD_FROM_IMPORTS
):
state.from_imports[node.module].update(
name.name for name in node.names if not name.asname
)
for name in reversed(node._fields):
value = getattr(node, name)
if name in {'annotation', 'returns'}:
next_state = state._replace(in_annotation=True)
else:
next_state = state
if isinstance(value, ast.AST):
nodes.append((next_state, value, node))
elif isinstance(value, list):
for value in reversed(value):
if isinstance(value, ast.AST):
nodes.append((next_state, value, node))
return ret
def _import_plugins() -> None:
# https://github.com/python/mypy/issues/1422
plugins_path: str = _plugins.__path__ # type: ignore
mod_infos = pkgutil.walk_packages(plugins_path, f'{_plugins.__name__}.')
for _, name, _ in mod_infos:
__import__(name, fromlist=['_trash'])
_import_plugins()