Skip to content

Commit

Permalink
Removing variables that overlap with functions of same name (such as …
Browse files Browse the repository at this point in the history
…country_code) so that the function signature is included in the stub
  • Loading branch information
KaylaHood committed Nov 13, 2022
1 parent b63e226 commit d555b36
Showing 1 changed file with 59 additions and 19 deletions.
78 changes: 59 additions & 19 deletions generate_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,40 @@ def get_module_and_member(cls, locale = None) -> Tuple[str, str]:
return module, member


def get_members(cls: object, include_mangled: bool = False) -> List[Tuple[str, Tuple[type, Any]]]:
members = [(name, (cls, value)) for (name, value) in inspect.getmembers(cls)
seen_funcs = set()
seen_vars = set()


class UniqueMemberFunctionsAndVariables:
def __init__(self, cls: type, funcs: Dict[str, Any], vars: Dict[str, Any]):
global seen_funcs, seen_vars
self.cls = cls
self.funcs = funcs
for func_name in seen_funcs:
self.funcs.pop(func_name, None)
seen_funcs = seen_funcs.union(self.funcs.keys())

self.vars = vars
for var_name in seen_vars:
self.vars.pop(var_name, None)
seen_vars = seen_vars.union(self.vars.keys())


def get_member_functions_and_variables(cls: object, include_mangled: bool = False) \
-> UniqueMemberFunctionsAndVariables:
members = [(name, value) for (name, value) in inspect.getmembers(cls)
if ((include_mangled and name.startswith("__")) or not name.startswith("_"))]
return members
funcs: Dict[str, Any] = {}
vars: Dict[str, Any] = {}
for (name, value) in members:
attr = getattr(cls, name, None)
if attr is not None and (inspect.isfunction(attr) or inspect.ismethod(attr)):
funcs[name] = value
else:
vars[name] = value

return UniqueMemberFunctionsAndVariables(cls, funcs, vars)


classes_and_locales_to_use_for_stub: List[Tuple[object, str]] = []
for locale in AVAILABLE_LOCALES:
Expand All @@ -43,30 +73,35 @@ def get_members(cls: object, include_mangled: bool = False) -> List[Tuple[str, T
prov_cls, _, _ = Factory._find_provider_class(provider, locale)
classes_and_locales_to_use_for_stub.append((prov_cls, locale))

unique_members = {mbr[0]: (*mbr[1], locale)
for cls, locale in classes_and_locales_to_use_for_stub
for mbr in get_members(cls)}
all_members: List[Tuple[UniqueMemberFunctionsAndVariables, str]] = \
[(get_member_functions_and_variables(cls), locale) for cls, locale in classes_and_locales_to_use_for_stub] \
+ [(get_member_functions_and_variables(faker.Faker, include_mangled=True), None)]

unique_members.update({mbr[0]: (*mbr[1], None) for mbr in get_members(faker.Faker, include_mangled=True)})
# Use the accumulated seen_funcs and seen_vars to remove all variables that have the same name as a function somewhere
overlapping_var_names = seen_vars.intersection(seen_funcs)
for mbr_funcs_and_vars, _ in all_members:
for var_name_to_remove in overlapping_var_names:
mbr_funcs_and_vars.vars.pop(var_name_to_remove, None)

imports = defaultdict(set)
imports["typing"] = {"TypeVar"}
imports["enum"] = {"Enum"}

# list of tuples. First elem of tuple is the signature string,
# second is the comment string,
# third is a boolean which is True if the comment precedes the signature
signatures_with_comments: List[Tuple[str, str, bool]] = []
for name, (prov_cls, value, locale) in unique_members.items():
attr = getattr(prov_cls, name, None)
if attr is not None and (inspect.isfunction(attr) or inspect.ismethod(attr)):
sig = inspect.signature(value)
comment = inspect.getdoc(value)

for mbr_funcs_and_vars, locale in all_members:
for func_name, func_value in mbr_funcs_and_vars.funcs.items():
sig = inspect.signature(func_value)
ret_annot_module = getattr(sig.return_annotation, "__module__", None)
if (not sig.return_annotation in [None, inspect.Signature.empty, prov_cls.__name__]
and not ret_annot_module in [None, "builtins"]):
module, member = get_module_and_member(sig.return_annotation, locale)
if module is not None and member is not None:
imports[module].add(member)

new_parms = []
for key, parm_val in sig.parameters.items():
new_parm = parm_val
Expand All @@ -78,24 +113,29 @@ def get_members(cls: object, include_mangled: bool = False) -> List[Tuple[str, T
if module is not None and member is not None:
imports[module].add(member)
new_parms.append(new_parm)

sig = sig.replace(parameters=new_parms)
sig_str = str(sig).replace("Ellipsis", "...").replace("NoneType", "None").replace("~", "")
for module in imports.keys():
sig_str = sig_str.replace(f"{module}.", "")
signatures_with_comments.append((f"def {name}{sig_str}: ...", None if comment == "" else comment, False))
else:

comment = inspect.getdoc(func_value)
signatures_with_comments.append((f"def {func_name}{sig_str}: ...", None if comment == "" else comment, False))
for var_name, var_value in mbr_funcs_and_vars.vars.items():
new_modules = []
type_module = getattr(type(value), "__module__", None)
comment = inspect.getcomments(value)
type_module = getattr(type(var_value), "__module__", None)
if type_module is not None and type_module != "builtins":
module, member = get_module_and_member(type(value), locale)
module, member = get_module_and_member(type(var_value), locale)
if module is not None and member is not None:
imports[module].add(member)
new_modules.append(module)
type_str = type(value).__name__.replace("Ellipsis", "...").replace("NoneType", "None").replace("~", "")

type_str = type(var_value).__name__.replace("Ellipsis", "...").replace("NoneType", "None").replace("~", "")
for module in new_modules:
type_str = type_str.replace(f"{module}.", "")
signatures_with_comments.append((f"{name}: {type_str}", comment, True))

comment = inspect.getcomments(var_value)
signatures_with_comments.append((f"{var_name}: {type_str}", comment, True))

signatures_with_comments_as_str = []
for sig, comment, is_preceding_comment in signatures_with_comments:
Expand Down

0 comments on commit d555b36

Please sign in to comment.