Skip to content

Commit

Permalink
Implementing a way to force certain modules/packaged to be fully impo…
Browse files Browse the repository at this point in the history
…rted to avoid naming collisions (such as datetime.date)
  • Loading branch information
KaylaHood committed Nov 13, 2022
1 parent 5161681 commit 5c7ffc2
Showing 1 changed file with 46 additions and 21 deletions.
67 changes: 46 additions & 21 deletions generate_stub.py
Expand Up @@ -2,23 +2,30 @@
import inspect
import pathlib
import re
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
from faker.config import AVAILABLE_LOCALES, PROVIDERS
from faker import Factory

import faker.proxy

BUILTIN_MODULES_TO_IGNORE = ["builtins"]
GENERIC_MANGLE_TYPES_TO_IGNORE = ["builtin_function_or_method", "mappingproxy"]
MODULES_TO_FULLY_QUALIFY = ["datetime"]


def get_module_and_member(cls, locale = None) -> Tuple[str, str]:
imports: Dict[str, Optional[Set[str]]] = defaultdict(lambda: None)
imports["typing"] = {"TypeVar"}
imports["enum"] = {"Enum"}


def get_module_and_member_to_import(cls, locale = None) -> Tuple[str, str]:
cls_name = getattr(cls, '__name__', getattr(cls, '_name', str(cls)))
module, member = cls.__module__, cls_name
if cls_name is None:
qualified_type = re.findall(r'([a-zA-Z_0-9]+)\.([a-zA-Z_0-9]+)', str(cls))
if len(qualified_type) > 0:
if qualified_type[0][1] not in imports[qualified_type[0][0]]:
if imports[qualified_type[0][0]] is None \
or qualified_type[0][1] not in imports[qualified_type[0][0]]:
module, member = qualified_type[0]
else:
unqualified_type = re.findall(r'[^\.a-zA-Z0-9_]([A-Z][a-zA-Z0-9_]+)[^\.a-zA-Z0-9_]', ' ' + str(cls) + ' ')
Expand All @@ -28,8 +35,11 @@ def get_module_and_member(cls, locale = None) -> Tuple[str, str]:
cls_str = cls_str.split("'")[1]
if locale is not None:
cls_str = cls_str.replace('.'+locale, '')
if unqualified_type[0] not in imports[cls_str]:

if imports[cls_str] is None or unqualified_type[0] not in imports[cls_str]:
module, member = cls_str, unqualified_type[0]
if module in MODULES_TO_FULLY_QUALIFY:
member = None
return module, member


Expand Down Expand Up @@ -90,10 +100,6 @@ def get_member_functions_and_variables(cls: object, include_mangled: bool = Fals
for var_name_to_remove in overlapping_var_names:
mbr_funcs_and_vars.vars.pop(var_name_to_remove, None)

imports = defaultdict(set)
imports["typing"] = {"TypeVar"}
imports["enum"] = {"Enum"}

# list of tuples. First elem of tuple is the signature string,
# second is the comment string,
# third is a boolean which is True if the comment precedes the signature
Expand All @@ -103,11 +109,14 @@ def get_member_functions_and_variables(cls: object, include_mangled: bool = Fals
for func_name, func_value in mbr_funcs_and_vars.funcs.items():
sig = inspect.signature(func_value)
ret_annot_module = getattr(sig.return_annotation, "__module__", None)
if (not sig.return_annotation in [None, inspect.Signature.empty, prov_cls.__name__]
and not ret_annot_module in [None, *BUILTIN_MODULES_TO_IGNORE]):
module, member = get_module_and_member(sig.return_annotation, locale)
if module is not None and member is not None:
imports[module].add(member)
if (sig.return_annotation not in [None, inspect.Signature.empty, prov_cls.__name__]
and ret_annot_module not in [None, *BUILTIN_MODULES_TO_IGNORE]):
module, member = get_module_and_member_to_import(sig.return_annotation, locale)
if module is not None:
if imports[module] is None:
imports[module] = set() if member is None else {member}
elif member is not None:
imports[module].add(member)

new_parms = []
for key, parm_val in sig.parameters.items():
Expand All @@ -116,29 +125,39 @@ def get_member_functions_and_variables(cls: object, include_mangled: bool = Fals
new_parm = parm_val.replace(default=...)
if (new_parm.annotation is not inspect.Parameter.empty
and not new_parm.annotation.__module__ in BUILTIN_MODULES_TO_IGNORE):
module, member = get_module_and_member(new_parm.annotation, locale)
if module is not None and member is not None:
imports[module].add(member)
module, member = get_module_and_member_to_import(new_parm.annotation, locale)
if module is not None:
if imports[module] is None:
imports[module] = set() if member is None else {member}
elif member is not None:
imports[module].add(member)
new_parms.append(new_parm)

sig = sig.replace(parameters=new_parms)
sig_str = str(sig).replace("Ellipsis", "...").replace("NoneType", "None").replace("~", "")
for module in imports.keys():
if module in MODULES_TO_FULLY_QUALIFY:
continue
sig_str = sig_str.replace(f"{module}.", "")

comment = inspect.getdoc(func_value)
signatures_with_comments.append((f"def {func_name}{sig_str}: ...", None if comment == "" else comment, False))
for var_name, var_value in mbr_funcs_and_vars.vars.items():
new_modules = []
type_module = getattr(type(var_value), "__module__", None)
if type_module is not None and not type_module in BUILTIN_MODULES_TO_IGNORE:
module, member = get_module_and_member(type(var_value), locale)
if module is not None and member is not None:
imports[module].add(member)
if type_module is not None and type_module not in BUILTIN_MODULES_TO_IGNORE:
module, member = get_module_and_member_to_import(type(var_value), locale)
if module is not None:
if imports[module] is None:
imports[module] = set() if member is None else {member}
elif member is not None:
imports[module].add(member)
new_modules.append(module)

type_str = type(var_value).__name__.replace("Ellipsis", "...").replace("NoneType", "None").replace("~", "")
for module in new_modules:
if module in MODULES_TO_FULLY_QUALIFY:
continue
type_str = type_str.replace(f"{module}.", "")

comment = inspect.getcomments(var_value)
Expand All @@ -155,7 +174,13 @@ def get_member_functions_and_variables(cls: object, include_mangled: bool = Fals
else:
signatures_with_comments_as_str.append(sig)

imports_block = "\n".join([f"from {module} import {', '.join(names)}" for module, names in imports.items()])
def get_import_str(module: str, members: Optional[Set[str]]) -> str:
if members is None or len(members) == 0:
return f"import {module}"
else:
return f"from {module} import {', '.join(members)}"

imports_block = "\n".join([get_import_str(module, names) for module, names in imports.items()])
member_signatures_block = " " + "\n ".join([sig.replace("\n", "\n ") for sig in signatures_with_comments_as_str])

body = \
Expand Down

0 comments on commit 5c7ffc2

Please sign in to comment.