Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[isort]: support submodules in known_(first|third)_party config options #3768

Merged
merged 4 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import sys
import baz
from foo import bar, baz
from foo.bar import blah, blub
from foo.bar.baz import something
import foo
import foo.bar
import foo.bar.baz
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,15 @@ pub fn typing_only_runtime_import(
// Extract the module base and level from the full name.
// Ex) `foo.bar.baz` -> `foo`, `0`
// Ex) `.foo.bar.baz` -> `foo`, `1`
let module_base = full_name.split('.').next().unwrap();
let level = full_name.chars().take_while(|c| *c == '.').count();

// Categorize the import.
match categorize(
module_base,
full_name,
Some(&level),
&settings.src,
package,
&settings.isort.known_first_party,
&settings.isort.known_third_party,
&settings.isort.known_local_folder,
&settings.isort.extra_standard_library,
&settings.isort.known_modules,
settings.target_version,
) {
ImportType::LocalFolder | ImportType::FirstParty => Some(Diagnostic::new(
Expand Down
141 changes: 102 additions & 39 deletions crates/ruff/src/rules/isort/categorize.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::{BTreeMap, BTreeSet};
use std::fs;
use std::path::{Path, PathBuf};
use std::{fs, iter};

use log::debug;
use schemars::JsonSchema;
Expand Down Expand Up @@ -52,29 +52,21 @@ enum Reason<'a> {

#[allow(clippy::too_many_arguments)]
pub fn categorize(
module_base: &str,
module_name: &str,
level: Option<&usize>,
src: &[PathBuf],
package: Option<&Path>,
known_first_party: &BTreeSet<String>,
known_third_party: &BTreeSet<String>,
known_local_folder: &BTreeSet<String>,
extra_standard_library: &BTreeSet<String>,
known_modules: &KnownModules,
target_version: PythonVersion,
) -> ImportType {
let module_base = module_name.split('.').next().unwrap();
let (import_type, reason) = {
if level.map_or(false, |level| *level > 0) {
(ImportType::LocalFolder, Reason::NonZeroLevel)
} else if known_first_party.contains(module_base) {
(ImportType::FirstParty, Reason::KnownFirstParty)
} else if known_third_party.contains(module_base) {
(ImportType::ThirdParty, Reason::KnownThirdParty)
} else if known_local_folder.contains(module_base) {
(ImportType::LocalFolder, Reason::KnownLocalFolder)
} else if extra_standard_library.contains(module_base) {
(ImportType::StandardLibrary, Reason::ExtraStandardLibrary)
} else if module_base == "__future__" {
(ImportType::Future, Reason::Future)
} else if let Some((import_type, reason)) = known_modules.categorize(module_name) {
(import_type, reason)
} else if KNOWN_STANDARD_LIBRARY
.get(&target_version.as_tuple())
.unwrap()
Expand All @@ -91,7 +83,7 @@ pub fn categorize(
};
debug!(
"Categorized '{}' as {:?} ({:?})",
module_base, import_type, reason
module_name, import_type, reason
);
import_type
}
Expand Down Expand Up @@ -121,24 +113,18 @@ pub fn categorize_imports<'a>(
block: ImportBlock<'a>,
src: &[PathBuf],
package: Option<&Path>,
known_first_party: &BTreeSet<String>,
known_third_party: &BTreeSet<String>,
known_local_folder: &BTreeSet<String>,
extra_standard_library: &BTreeSet<String>,
known_modules: &KnownModules,
target_version: PythonVersion,
) -> BTreeMap<ImportType, ImportBlock<'a>> {
let mut block_by_type: BTreeMap<ImportType, ImportBlock> = BTreeMap::default();
// Categorize `StmtKind::Import`.
for (alias, comments) in block.import {
let import_type = categorize(
&alias.module_base(),
&alias.module_name(),
None,
src,
package,
known_first_party,
known_third_party,
known_local_folder,
extra_standard_library,
known_modules,
target_version,
);
block_by_type
Expand All @@ -150,14 +136,11 @@ pub fn categorize_imports<'a>(
// Categorize `StmtKind::ImportFrom` (without re-export).
for (import_from, aliases) in block.import_from {
let classification = categorize(
&import_from.module_base(),
&import_from.module_name(),
import_from.level,
src,
package,
known_first_party,
known_third_party,
known_local_folder,
extra_standard_library,
known_modules,
target_version,
);
block_by_type
Expand All @@ -169,14 +152,11 @@ pub fn categorize_imports<'a>(
// Categorize `StmtKind::ImportFrom` (with re-export).
for ((import_from, alias), aliases) in block.import_from_as {
let classification = categorize(
&import_from.module_base(),
&import_from.module_name(),
import_from.level,
src,
package,
known_first_party,
known_third_party,
known_local_folder,
extra_standard_library,
known_modules,
target_version,
);
block_by_type
Expand All @@ -188,14 +168,11 @@ pub fn categorize_imports<'a>(
// Categorize `StmtKind::ImportFrom` (with star).
for (import_from, comments) in block.import_from_star {
let classification = categorize(
&import_from.module_base(),
&import_from.module_name(),
import_from.level,
src,
package,
known_first_party,
known_third_party,
known_local_folder,
extra_standard_library,
known_modules,
target_version,
);
block_by_type
Expand All @@ -206,3 +183,89 @@ pub fn categorize_imports<'a>(
}
block_by_type
}

#[derive(Debug, Default, CacheKey)]
pub struct KnownModules {
/// A set of user-provided first-party modules.
pub first_party: BTreeSet<String>,
/// A set of user-provided third-party modules.
pub third_party: BTreeSet<String>,
/// A set of user-provided local folder modules.
pub local_folder: BTreeSet<String>,
/// A set of user-provided standard library modules.
pub standard_library: BTreeSet<String>,
/// Whether any of the known modules are submodules (e.g., `foo.bar`, as opposed to `foo`).
has_submodules: bool,
}

impl KnownModules {
pub fn new(
first_party: Vec<String>,
third_party: Vec<String>,
local_folder: Vec<String>,
standard_library: Vec<String>,
) -> Self {
let first_party = BTreeSet::from_iter(first_party);
let third_party = BTreeSet::from_iter(third_party);
let local_folder = BTreeSet::from_iter(local_folder);
let standard_library = BTreeSet::from_iter(standard_library);
let has_submodules = first_party
.iter()
.chain(third_party.iter())
.chain(local_folder.iter())
.chain(standard_library.iter())
.any(|m| m.contains('.'));
Self {
first_party,
third_party,
local_folder,
standard_library,
has_submodules,
}
}

/// Return the [`ImportType`] for a given module, if it's been categorized as a known module
/// by the user.
fn categorize(&self, module_name: &str) -> Option<(ImportType, Reason)> {
if self.has_submodules {
// Check all module prefixes from the longest to the shortest (e.g., given
// `foo.bar.baz`, check `foo.bar.baz`, then `foo.bar`, then `foo`, taking the first,
// most precise match).
for i in module_name
.match_indices('.')
.map(|(i, _)| i)
.chain(iter::once(module_name.len()))
.rev()
{
let submodule = &module_name[0..i];
if self.first_party.contains(submodule) {
return Some((ImportType::FirstParty, Reason::KnownFirstParty));
}
if self.third_party.contains(submodule) {
return Some((ImportType::ThirdParty, Reason::KnownThirdParty));
}
if self.local_folder.contains(submodule) {
return Some((ImportType::LocalFolder, Reason::KnownLocalFolder));
}
if self.standard_library.contains(submodule) {
return Some((ImportType::StandardLibrary, Reason::ExtraStandardLibrary));
}
}
None
} else {
// Happy path: no submodules, so we can check the module base and be done.
let module_base = module_name.split('.').next().unwrap();
if self.first_party.contains(module_base) {
Some((ImportType::FirstParty, Reason::KnownFirstParty))
} else if self.third_party.contains(module_base) {
Some((ImportType::ThirdParty, Reason::KnownThirdParty))
} else if self.local_folder.contains(module_base) {
Some((ImportType::LocalFolder, Reason::KnownLocalFolder))
} else if self.standard_library.contains(module_base) {
Some((ImportType::StandardLibrary, Reason::ExtraStandardLibrary))
} else {
None
}
}
}
}