-
Notifications
You must be signed in to change notification settings - Fork 176
/
gather.py
160 lines (119 loc) · 4.6 KB
/
gather.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from collections import defaultdict
from collections.abc import Sequence as ABCSequence
from dataclasses import dataclass, fields, replace
from typing import Dict, Iterator, List, Mapping, Sequence, Set, Type, Union
import libcst as cst
def _get_bases() -> Iterator[Type[cst.CSTNode]]:
"""
Get all base classes that are subclasses of CSTNode but not an actual
node itself. This allows us to keep our types sane by refering to the
base classes themselves.
"""
for name in dir(cst):
if not name.startswith("Base"):
continue
yield getattr(cst, name)
typeclasses: Sequence[Type[cst.CSTNode]] = sorted(
_get_bases(), key=lambda base: base.__name__
)
def _get_nodes() -> Iterator[Type[cst.CSTNode]]:
"""
Grab all CSTNodes that are not a superclass. Basically, anything that a
person might use to generate a tree.
"""
for name in dir(cst):
if name.startswith("__") and name.endswith("__"):
continue
if name == "CSTNode":
continue
node = getattr(cst, name)
try:
if issubclass(node, cst.CSTNode):
yield node
except TypeError:
# This isn't a class, so we don't care about it.
pass
all_libcst_nodes: Sequence[Type[cst.CSTNode]] = sorted(
_get_nodes(), key=lambda node: node.__name__
)
node_to_bases: Dict[Type[cst.CSTNode], List[Type[cst.CSTNode]]] = {}
for node in all_libcst_nodes:
# Map the base classes for this node
node_to_bases[node] = list(
reversed([b for b in inspect.getmro(node) if issubclass(b, cst.CSTNode)])
)
def _get_most_generic_base_for_node(node: Type[cst.CSTNode]) -> Type[cst.CSTNode]:
# Ignore non-exported bases, a user couldn't specify these types
# in type hints.
exportable_bases = [b for b in node_to_bases[node] if b in node_to_bases]
return exportable_bases[0]
nodebases: Dict[Type[cst.CSTNode], Type[cst.CSTNode]] = {}
for node in all_libcst_nodes:
# Find the most generic version of this node that isn't CSTNode.
nodebases[node] = _get_most_generic_base_for_node(node)
@dataclass(frozen=True)
class Usage:
maybe: bool = False
optional: bool = False
sequence: bool = False
nodeuses: Dict[Type[cst.CSTNode], Usage] = {node: Usage() for node in all_libcst_nodes}
def _is_maybe(typeobj: object) -> bool:
try:
# pyre-ignore We wrap this in a TypeError check so this is safe
return issubclass(typeobj, cst.MaybeSentinel)
except TypeError:
return False
def _get_origin(typeobj: object) -> object:
try:
# pyre-ignore We wrap this in a AttributeError check so this is safe
return typeobj.__origin__
except AttributeError:
# Don't care, not a union or sequence
return None
def _get_args(typeobj: object) -> List[object]:
try:
# pyre-ignore We wrap this in a AttributeError check so this is safe
return typeobj.__args__
except AttributeError:
# Don't care, not a union or sequence
return []
def _is_sequence(typeobj: object) -> bool:
origin = _get_origin(typeobj)
return origin is Sequence or origin is ABCSequence
def _is_union(typeobj: object) -> bool:
return _get_origin(typeobj) is Union
def _calc_node_usage(typeobj: object) -> None:
if _is_union(typeobj):
has_maybe = any(_is_maybe(n) for n in _get_args(typeobj))
has_none = any(isinstance(n, type(None)) for n in _get_args(typeobj))
for node in _get_args(typeobj):
if node in all_libcst_nodes:
nodeuses[node] = replace(
nodeuses[node],
maybe=nodeuses[node].maybe or has_maybe,
optional=nodeuses[node].optional or has_none,
)
else:
_calc_node_usage(node)
if _is_sequence(typeobj):
for node in _get_args(typeobj):
if node in all_libcst_nodes:
nodeuses[node] = replace(nodeuses[node], sequence=True)
else:
_calc_node_usage(node)
for node in all_libcst_nodes:
for field in fields(node) or []:
if field.name == "_metadata":
continue
_calc_node_usage(field.type)
imports: Mapping[str, Set[str]] = defaultdict(set)
for node, base in nodebases.items():
if node.__name__.startswith("Base"):
continue
for x in (node, base):
imports[x.__module__].add(x.__name__)