/
_data.py
125 lines (96 loc) · 3.16 KB
/
_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import annotations
import ast
import collections
import pkgutil
from typing import Callable
from typing import Iterable
from typing import List
from typing import NamedTuple
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from tokenize_rt import Offset
from tokenize_rt import Token
from pyupgrade import _plugins
if TYPE_CHECKING:
from typing import Protocol
else:
Protocol = object
Version = Tuple[int, ...]
class Settings(NamedTuple):
min_version: Version = (3,)
keep_percent_format: bool = False
keep_mock: bool = False
keep_runtime_typing: bool = False
class State(NamedTuple):
settings: Settings
from_imports: dict[str, set[str]]
in_annotation: bool = False
AST_T = TypeVar('AST_T', bound=ast.AST)
TokenFunc = Callable[[int, List[Token]], None]
ASTFunc = Callable[[State, AST_T, ast.AST], Iterable[Tuple[Offset, TokenFunc]]]
RECORD_FROM_IMPORTS = frozenset((
'__future__',
'functools',
'mmap',
'select',
'six',
'six.moves',
'socket',
'subprocess',
'sys',
'typing',
'typing_extensions',
))
FUNCS = collections.defaultdict(list)
def register(tp: type[AST_T]) -> Callable[[ASTFunc[AST_T]], ASTFunc[AST_T]]:
def register_decorator(func: ASTFunc[AST_T]) -> ASTFunc[AST_T]:
FUNCS[tp].append(func)
return func
return register_decorator
class ASTCallbackMapping(Protocol):
def __getitem__(self, tp: type[AST_T]) -> list[ASTFunc[AST_T]]: ...
def visit(
funcs: ASTCallbackMapping,
tree: ast.Module,
settings: Settings,
) -> dict[Offset, list[TokenFunc]]:
initial_state = State(
settings=settings,
from_imports=collections.defaultdict(set),
)
nodes: list[tuple[State, ast.AST, ast.AST]] = [(initial_state, tree, tree)]
ret = collections.defaultdict(list)
while nodes:
state, node, parent = nodes.pop()
tp = type(node)
for ast_func in funcs[tp]:
for offset, token_func in ast_func(state, node, parent):
ret[offset].append(token_func)
if (
isinstance(node, ast.ImportFrom) and
not node.level and
node.module in RECORD_FROM_IMPORTS
):
state.from_imports[node.module].update(
name.name for name in node.names if not name.asname
)
for name in reversed(node._fields):
value = getattr(node, name)
if name in {'annotation', 'returns'}:
next_state = state._replace(in_annotation=True)
else:
next_state = state
if isinstance(value, ast.AST):
nodes.append((next_state, value, node))
elif isinstance(value, list):
for value in reversed(value):
if isinstance(value, ast.AST):
nodes.append((next_state, value, node))
return ret
def _import_plugins() -> None:
plugins_path = _plugins.__path__
mod_infos = pkgutil.walk_packages(plugins_path, f'{_plugins.__name__}.')
for _, name, _ in mod_infos:
__import__(name, fromlist=['_trash'])
_import_plugins()