Skip to content

Commit

Permalink
Partial typing of imports.py (pylint-dev#6982)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNoord committed Jul 13, 2022
1 parent 460a0c7 commit 65543fd
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 36 deletions.
4 changes: 2 additions & 2 deletions pylint/checkers/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def deprecated_classes(self, module: str) -> Iterable[str]:
# pylint: disable=unused-argument
return ()

def check_deprecated_module(self, node: nodes.Import, mod_path: str) -> None:
def check_deprecated_module(self, node: nodes.Import, mod_path: str | None) -> None:
"""Checks if the module is deprecated."""
for mod_name in self.deprecated_modules():
if mod_path == mod_name or mod_path.startswith(mod_name + "."):
if mod_path == mod_name or mod_path and mod_path.startswith(mod_name + "."):
self.add_message("deprecated-module", node=node, args=mod_path)

def check_deprecated_method(self, node: nodes.Call, inferred: nodes.NodeNG) -> None:
Expand Down
108 changes: 75 additions & 33 deletions pylint/checkers/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import copy
import os
import sys
from collections import defaultdict
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import astroid
from astroid import nodes
from astroid.nodes._base_nodes import ImportNode

from pylint.checkers import BaseChecker, DeprecatedMixin
from pylint.checkers.utils import (
Expand All @@ -28,6 +31,7 @@
from pylint.reporters.ureports.nodes import Paragraph, Section, VerbatimText
from pylint.typing import MessageDefinitionTuple
from pylint.utils import IsortDriver
from pylint.utils.linterstats import LinterStats

if TYPE_CHECKING:
from pylint.lint import PyLinter
Expand Down Expand Up @@ -69,19 +73,26 @@
}


def _qualified_names(modname):
def _qualified_names(modname: str | None) -> list[str]:
"""Split the names of the given module into subparts.
For example,
_qualified_names('pylint.checkers.ImportsChecker')
returns
['pylint', 'pylint.checkers', 'pylint.checkers.ImportsChecker']
"""
names = modname.split(".")
names = modname.split(".") if modname is not None else ""
return [".".join(names[0 : i + 1]) for i in range(len(names))]


def _get_first_import(node, context, name, base, level, alias):
def _get_first_import(
node: ImportNode,
context: nodes.LocalsDictNodeNG,
name: str,
base: str | None,
level: int | None,
alias: str | None,
) -> nodes.Import | nodes.ImportFrom | None:
"""Return the node where [base.]<name> is imported or None if not found."""
fullname = f"{base}.{name}" if base else name

Expand Down Expand Up @@ -116,7 +127,11 @@ def _get_first_import(node, context, name, base, level, alias):
return None


def _ignore_import_failure(node, modname, ignored_modules):
def _ignore_import_failure(
node: ImportNode,
modname: str | None,
ignored_modules: Sequence[str],
) -> bool:
for submodule in _qualified_names(modname):
if submodule in ignored_modules:
return True
Expand Down Expand Up @@ -186,7 +201,7 @@ def _dependencies_graph(filename: str, dep_info: dict[str, set[str]]) -> str:

def _make_graph(
filename: str, dep_info: dict[str, set[str]], sect: Section, gtype: str
):
) -> None:
"""Generate a dependencies graph and add some information about it in the
report's section.
"""
Expand Down Expand Up @@ -403,7 +418,7 @@ class ImportsChecker(DeprecatedMixin, BaseChecker):

def __init__(self, linter: PyLinter) -> None:
BaseChecker.__init__(self, linter)
self.import_graph: collections.defaultdict = collections.defaultdict(set)
self.import_graph: defaultdict[str, set[str]] = defaultdict(set)
self._imports_stack: list[tuple[Any, Any]] = []
self._first_non_import_node = None
self._module_pkg: dict[
Expand All @@ -415,14 +430,14 @@ def __init__(self, linter: PyLinter) -> None:
("RP0402", "Modules dependencies graph", self._report_dependencies_graph),
)

def open(self):
def open(self) -> None:
"""Called before visiting project (i.e set of modules)."""
self.linter.stats.dependencies = {}
self.linter.stats = self.linter.stats
self.import_graph = collections.defaultdict(set)
self.import_graph = defaultdict(set)
self._module_pkg = {} # mapping of modules to the pkg they belong in
self._excluded_edges = collections.defaultdict(set)
self._ignored_modules = self.linter.config.ignored_modules
self._excluded_edges: defaultdict[str, set[str]] = defaultdict(set)
self._ignored_modules: Sequence[str] = self.linter.config.ignored_modules
# Build a mapping {'module': 'preferred-module'}
self.preferred_modules = dict(
module.split(":")
Expand All @@ -431,13 +446,13 @@ def open(self):
)
self._allow_any_import_level = set(self.linter.config.allow_any_import_level)

def _import_graph_without_ignored_edges(self):
def _import_graph_without_ignored_edges(self) -> defaultdict[str, set[str]]:
filtered_graph = copy.deepcopy(self.import_graph)
for node in filtered_graph:
filtered_graph[node].difference_update(self._excluded_edges[node])
return filtered_graph

def close(self):
def close(self) -> None:
"""Called before visiting project (i.e set of modules)."""
if self.linter.is_message_enabled("cyclic-import"):
graph = self._import_graph_without_ignored_edges()
Expand Down Expand Up @@ -536,7 +551,17 @@ def leave_module(self, node: nodes.Module) -> None:
self._imports_stack = []
self._first_non_import_node = None

def compute_first_non_import_node(self, node):
def compute_first_non_import_node(
self,
node: nodes.If
| nodes.Expr
| nodes.Comprehension
| nodes.IfExp
| nodes.Assign
| nodes.AssignAttr
| nodes.TryExcept
| nodes.TryFinally,
) -> None:
# if the node does not contain an import instruction, and if it is the
# first node of the module, keep a track of it (all the import positions
# of the module will be compared to the position of this first
Expand Down Expand Up @@ -576,7 +601,9 @@ def compute_first_non_import_node(self, node):
visit_ifexp
) = visit_comprehension = visit_expr = visit_if = compute_first_non_import_node

def visit_functiondef(self, node: nodes.FunctionDef) -> None:
def visit_functiondef(
self, node: nodes.FunctionDef | nodes.While | nodes.For | nodes.ClassDef
) -> None:
# If it is the first non import instruction of the module, record it.
if self._first_non_import_node:
return
Expand All @@ -598,7 +625,7 @@ def visit_functiondef(self, node: nodes.FunctionDef) -> None:

visit_classdef = visit_for = visit_while = visit_functiondef

def _check_misplaced_future(self, node):
def _check_misplaced_future(self, node: nodes.ImportFrom) -> None:
basename = node.modname
if basename == "__future__":
# check if this is the first non-docstring statement in the module
Expand All @@ -611,15 +638,15 @@ def _check_misplaced_future(self, node):
self.add_message("misplaced-future", node=node)
return

def _check_same_line_imports(self, node):
def _check_same_line_imports(self, node: nodes.ImportFrom) -> None:
# Detect duplicate imports on the same line.
names = (name for name, _ in node.names)
counter = collections.Counter(names)
for name, count in counter.items():
if count > 1:
self.add_message("reimported", node=node, args=(name, node.fromlineno))

def _check_position(self, node):
def _check_position(self, node: ImportNode) -> None:
"""Check `node` import or importfrom node position is correct.
Send a message if `node` comes before another instruction
Expand All @@ -638,7 +665,11 @@ def _check_position(self, node):
"wrong-import-position", node.fromlineno, node
)

def _record_import(self, node, importedmodnode):
def _record_import(
self,
node: ImportNode,
importedmodnode: nodes.Module | None,
) -> None:
"""Record the package `node` imports from."""
if isinstance(node, nodes.ImportFrom):
importedname = node.modname
Expand Down Expand Up @@ -759,7 +790,9 @@ def _check_imports_order(self, _module_node):
)
return std_imports, external_imports, local_imports

def _get_imported_module(self, importnode, modname):
def _get_imported_module(
self, importnode: ImportNode, modname: str | None
) -> nodes.Module | None:
try:
return importnode.do_import_module(modname)
except astroid.TooManyLevelsError:
Expand Down Expand Up @@ -789,9 +822,7 @@ def _get_imported_module(self, importnode, modname):
raise astroid.AstroidError from e
return None

def _add_imported_module(
self, node: nodes.Import | nodes.ImportFrom, importedmodname: str
) -> None:
def _add_imported_module(self, node: ImportNode, importedmodname: str) -> None:
"""Notify an imported module, used to analyze dependencies."""
module_file = node.root().file
context_name = node.root().name
Expand Down Expand Up @@ -841,7 +872,7 @@ def _check_preferred_module(self, node, mod_path):
args=(self.preferred_modules[mod_path], mod_path),
)

def _check_import_as_rename(self, node: nodes.Import | nodes.ImportFrom) -> None:
def _check_import_as_rename(self, node: ImportNode) -> None:
names = node.names
for name in names:
if not all(name):
Expand All @@ -862,7 +893,12 @@ def _check_import_as_rename(self, node: nodes.Import | nodes.ImportFrom) -> None
args=(splitted_packages[0], import_name),
)

def _check_reimport(self, node, basename=None, level=None):
def _check_reimport(
self,
node: ImportNode,
basename: str | None = None,
level: int | None = None,
) -> None:
"""Check if the import is necessary (i.e. not already done)."""
if not self.linter.is_message_enabled("reimported"):
return
Expand All @@ -883,15 +919,19 @@ def _check_reimport(self, node, basename=None, level=None):
"reimported", node=node, args=(name, first.fromlineno)
)

def _report_external_dependencies(self, sect, _, _dummy):
def _report_external_dependencies(
self, sect: Section, _: LinterStats, _dummy: LinterStats | None
) -> None:
"""Return a verbatim layout for displaying dependencies."""
dep_info = _make_tree_defs(self._external_dependencies_info().items())
if not dep_info:
raise EmptyReportError()
tree_str = _repr_tree_defs(dep_info)
sect.append(VerbatimText(tree_str))

def _report_dependencies_graph(self, sect, _, _dummy):
def _report_dependencies_graph(
self, sect: Section, _: LinterStats, _dummy: LinterStats | None
) -> None:
"""Write dependencies as a dot (graphviz) file."""
dep_info = self.linter.stats.dependencies
if not dep_info or not (
Expand All @@ -910,9 +950,9 @@ def _report_dependencies_graph(self, sect, _, _dummy):
if filename:
_make_graph(filename, self._internal_dependencies_info(), sect, "internal ")

def _filter_dependencies_graph(self, internal):
def _filter_dependencies_graph(self, internal: bool) -> defaultdict[str, set[str]]:
"""Build the internal or the external dependency graph."""
graph = collections.defaultdict(set)
graph: defaultdict[str, set[str]] = defaultdict(set)
for importee, importers in self.linter.stats.dependencies.items():
for importer in importers:
package = self._module_pkg.get(importer, importer)
Expand All @@ -922,20 +962,22 @@ def _filter_dependencies_graph(self, internal):
return graph

@astroid.decorators.cached
def _external_dependencies_info(self):
def _external_dependencies_info(self) -> defaultdict[str, set[str]]:
"""Return cached external dependencies information or build and
cache them.
"""
return self._filter_dependencies_graph(internal=False)

@astroid.decorators.cached
def _internal_dependencies_info(self):
def _internal_dependencies_info(self) -> defaultdict[str, set[str]]:
"""Return cached internal dependencies information or build and
cache them.
"""
return self._filter_dependencies_graph(internal=True)

def _check_wildcard_imports(self, node, imported_module):
def _check_wildcard_imports(
self, node: nodes.ImportFrom, imported_module: nodes.Module | None
) -> None:
if node.root().package:
# Skip the check if in __init__.py issue #2026
return
Expand All @@ -945,14 +987,14 @@ def _check_wildcard_imports(self, node, imported_module):
if name == "*" and not wildcard_import_is_allowed:
self.add_message("wildcard-import", args=node.modname, node=node)

def _wildcard_import_is_allowed(self, imported_module):
def _wildcard_import_is_allowed(self, imported_module: nodes.Module | None) -> bool:
return (
self.linter.config.allow_wildcard_with_all
and imported_module is not None
and "__all__" in imported_module.locals
)

def _check_toplevel(self, node):
def _check_toplevel(self, node: ImportNode) -> None:
"""Check whether the import is made outside the module toplevel."""
# If the scope of the import is a module, then obviously it is
# not outside the module toplevel.
Expand Down
3 changes: 2 additions & 1 deletion pylint/checkers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from astroid import TooManyLevelsError, nodes
from astroid.context import InferenceContext
from astroid.exceptions import AstroidError
from astroid.nodes._base_nodes import ImportNode

if TYPE_CHECKING:
from pylint.checkers import BaseChecker
Expand Down Expand Up @@ -1651,7 +1652,7 @@ def get_subscript_const_value(node: nodes.Subscript) -> nodes.Const:
return inferred


def get_import_name(importnode: nodes.Import | nodes.ImportFrom, modname: str) -> str:
def get_import_name(importnode: ImportNode, modname: str | None) -> str | None:
"""Get a prepared module name from the given import node.
In the case of relative imports, this will return the
Expand Down

0 comments on commit 65543fd

Please sign in to comment.