Skip to content

Commit

Permalink
fix: Fixed a bug where imports would not check reexports for shortest…
Browse files Browse the repository at this point in the history
… path (#112)

Closes #82 

### Summary of Changes
Fixed the bug where imports would not use the shortest path from
existing reexports.

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
Masara and megalinter-bot committed May 4, 2024
1 parent cb061ab commit 48c5367
Show file tree
Hide file tree
Showing 35 changed files with 473 additions and 197 deletions.
1 change: 0 additions & 1 deletion src/safeds_stubgen/api_analyzer/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self, distribution: str, package: str, version: str) -> None:
self.enum_instances: dict[str, EnumInstance] = {}
self.attributes_: dict[str, Attribute] = {}
self.parameters_: dict[str, Parameter] = {}

self.reexport_map: dict[str, set[Module]] = defaultdict(set)

def add_module(self, module: Module) -> None:
Expand Down
56 changes: 32 additions & 24 deletions src/safeds_stubgen/api_analyzer/_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def leave_moduledef(self, _: mp_nodes.MypyFile) -> None:
self.api.add_module(module)

def enter_classdef(self, node: mp_nodes.ClassDef) -> None:
name = node.name
id_ = self._create_id_from_stack(name)
id_ = self._create_id_from_stack(node.name)

# Get docstring
docstring = self.docstring_parser.get_class_documentation(node)
Expand Down Expand Up @@ -202,7 +201,9 @@ def enter_classdef(self, node: mp_nodes.ClassDef) -> None:
superclasses.append(superclass_qname)

# Get reexported data
reexported_by = self._get_reexported_by(name)
reexported_by = self._get_reexported_by(node.fullname)
# Sort for snapshot tests
reexported_by.sort(key=lambda x: x.id)

# Get constructor docstring
definitions = get_classdef_definitions(node)
Expand All @@ -215,9 +216,9 @@ def enter_classdef(self, node: mp_nodes.ClassDef) -> None:
# Remember class, so we can later add methods
class_ = Class(
id=id_,
name=name,
name=node.name,
superclasses=superclasses,
is_public=self._is_public(node.name, name),
is_public=self._is_public(node.name, node.fullname),
docstring=docstring,
reexported_by=reexported_by,
constructor_fulldocstring=constructor_fulldocstring,
Expand Down Expand Up @@ -266,7 +267,9 @@ def enter_funcdef(self, node: mp_nodes.FuncDef) -> None:
results = self._parse_results(node, function_id, result_docstrings)

# Get reexported data
reexported_by = self._get_reexported_by(name)
reexported_by = self._get_reexported_by(node.fullname)
# Sort for snapshot tests
reexported_by.sort(key=lambda x: x.id)

# Create and add Function to stack
function = Function(
Expand Down Expand Up @@ -845,25 +848,25 @@ def _get_parameter_type_and_default_value(

# #### Reexport utilities

def _get_reexported_by(self, name: str) -> list[Module]:
# Get the uppermost module and the path to the current node
parents = []
parent = None
i = 1
while not isinstance(parent, Module):
parent = self.__declaration_stack[-i]
if isinstance(parent, list): # pragma: no cover
continue
parents.append(parent.name)
i += 1
path = [*list(reversed(parents)), name]
def _get_reexported_by(self, qname: str) -> list[Module]:
path = qname.split(".")

# Check if there is a reexport entry for each item in the path to the current module
reexported_by = set()
for i in range(len(path)):
reexport_name = ".".join(path[: i + 1])
if reexport_name in self.api.reexport_map:
for mod in self.api.reexport_map[reexport_name]:
reexport_name_forward = ".".join(path[: i + 1])
if reexport_name_forward in self.api.reexport_map:
for mod in self.api.reexport_map[reexport_name_forward]:
reexported_by.add(mod)

reexport_name_backward = ".".join(path[-i - 1 :])
if reexport_name_backward in self.api.reexport_map:
for mod in self.api.reexport_map[reexport_name_backward]:
reexported_by.add(mod)

reexport_name_backward_whitelist = f"{'.'.join(path[-2 - i:-1])}.*"
if reexport_name_backward_whitelist in self.api.reexport_map:
for mod in self.api.reexport_map[reexport_name_backward_whitelist]:
reexported_by.add(mod)

return list(reexported_by)
Expand All @@ -875,7 +878,7 @@ def _add_reexports(self, module: Module) -> None:

for wildcard_import in module.wildcard_imports:
name = wildcard_import.module_name
self.api.reexport_map[name].add(module)
self.api.reexport_map[f"{name}.*"].add(module)

# #### Misc. utilities
def mypy_type_to_abstract_type(
Expand Down Expand Up @@ -1130,7 +1133,12 @@ def _check_publicity_in_reexports(self, name: str, qname: str, parent: Module |
package_id = "/".join(module_qname.split(".")[:-1])

for reexported_key in self.api.reexport_map:
module_is_reexported = reexported_key in {module_name, module_qname}
module_is_reexported = reexported_key in {
module_name,
module_qname,
f"{module_name}.*",
f"{module_qname}.*",
}

# Check if the function/class/module is reexported
if reexported_key.endswith(name) or module_is_reexported:
Expand All @@ -1140,7 +1148,7 @@ def _check_publicity_in_reexports(self, name: str, qname: str, parent: Module |

# We have to check if it's the correct reexport with the ID
is_from_same_package = reexport_source.id == package_id
is_from_another_package = reexported_key in {qname, module_qname}
is_from_another_package = reexported_key.rstrip(".*") in {qname, module_qname}
if not is_from_same_package and not is_from_another_package:
continue

Expand Down

0 comments on commit 48c5367

Please sign in to comment.