diff --git a/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_1.py b/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_1.py index 006b7909f20344..af245b8d47de3a 100644 --- a/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_1.py +++ b/crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_1.py @@ -9,4 +9,8 @@ def main(): quit(1) -sys.exit(2) +def main(): + sys = 1 + + exit(1) + quit(1) diff --git a/crates/ruff/src/autofix/helpers.rs b/crates/ruff/src/autofix/helpers.rs index a46553d4d218b6..119d25824e7f63 100644 --- a/crates/ruff/src/autofix/helpers.rs +++ b/crates/ruff/src/autofix/helpers.rs @@ -7,6 +7,7 @@ use rustpython_parser::ast::{ExcepthandlerKind, Expr, Keyword, Location, Stmt, S use rustpython_parser::{lexer, Mode, Tok}; use ruff_diagnostics::Edit; +use ruff_python_ast::context::Context; use ruff_python_ast::helpers; use ruff_python_ast::helpers::to_absolute; use ruff_python_ast::newlines::NewlineWithTrailingNewline; @@ -14,6 +15,8 @@ use ruff_python_ast::source_code::{Indexer, Locator, Stylist}; use crate::cst::helpers::compose_module_path; use crate::cst::matchers::match_module; +use crate::imports::importer::Importer; +use crate::imports::{AnyImport, Import}; /// Determine if a body contains only a single statement, taking into account /// deleted. @@ -444,6 +447,66 @@ pub fn remove_argument( } } +/// Generate an [`Edit`] to reference the given symbol. Returns the [`Edit`] necessary to make the +/// symbol available in the current scope along with the bound name of the symbol. +/// +/// For example, assuming `module` is `"functools"` and `member` is `"lru_cache"`, this function +/// could return an [`Edit`] to add `import functools` to the top of the file, alongside with the +/// name on which the `lru_cache` symbol would be made available (`"functools.lru_cache"`). +/// +/// Attempts to reuse existing imports when possible. +pub fn get_or_import_symbol( + module: &str, + member: &str, + context: &Context, + importer: &Importer, + locator: &Locator, +) -> Result<(Edit, String)> { + if let Some((source, binding)) = context.resolve_binding(module, member) { + // If the symbol is already available in the current scope, use it, and add a no-nop edit to + // force conflicts with any other fixes that might try to remove the import. + let import_edit = Edit::replacement( + locator.slice(source).to_string(), + source.location, + source.end_location.unwrap(), + ); + Ok((import_edit, binding)) + } else { + if let Some(stmt) = importer.get_import_from(module) { + // Case 1: `from functools import lru_cache` is in scope, and we're trying to reference + // `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the + // bound name. + if context + .find_binding(member) + .map_or(true, |binding| binding.kind.is_builtin()) + { + let import_edit = importer.add_member(stmt, member)?; + Ok((import_edit, member.to_string())) + } else { + bail!( + "Unable to insert `{}` into scope due to name conflict", + member + ) + } + } else { + // Case 2: No `functools` import is in scope; thus, we add `import functools`, and + // return `"functools.lru_cache"` as the bound name. + if context + .find_binding(module) + .map_or(true, |binding| binding.kind.is_builtin()) + { + let import_edit = importer.add_import(&AnyImport::Import(Import::module(module))); + Ok((import_edit, format!("{module}.{member}"))) + } else { + bail!( + "Unable to insert `{}` into scope due to name conflict", + module + ) + } + } + } +} + #[cfg(test)] mod tests { use anyhow::Result; diff --git a/crates/ruff/src/checkers/ast/mod.rs b/crates/ruff/src/checkers/ast/mod.rs index cc46c4a798004a..d9c035d58d6b4d 100644 --- a/crates/ruff/src/checkers/ast/mod.rs +++ b/crates/ruff/src/checkers/ast/mod.rs @@ -38,6 +38,7 @@ use crate::docstrings::definition::{ transition_scope, Definition, DefinitionKind, Docstring, Documentable, }; use crate::fs::relativize_path; +use crate::imports::importer::Importer; use crate::registry::{AsRule, Rule}; use crate::rules::{ flake8_2020, flake8_annotations, flake8_bandit, flake8_blind_except, flake8_boolean_trap, @@ -69,6 +70,7 @@ pub struct Checker<'a> { pub locator: &'a Locator<'a>, pub stylist: &'a Stylist<'a>, pub indexer: &'a Indexer, + pub importer: Importer<'a>, // Stateful fields. pub ctx: Context<'a>, pub deferred: Deferred<'a>, @@ -91,6 +93,7 @@ impl<'a> Checker<'a> { locator: &'a Locator, style: &'a Stylist, indexer: &'a Indexer, + importer: Importer<'a>, ) -> Checker<'a> { Checker { settings, @@ -104,6 +107,7 @@ impl<'a> Checker<'a> { locator, stylist: style, indexer, + importer, ctx: Context::new(&settings.typing_modules, path, module_path), deferred: Deferred::default(), diagnostics: vec![], @@ -188,6 +192,20 @@ where } } } + + // Track each top-level import, to guide import insertions. + if matches!( + &stmt.node, + StmtKind::Import { .. } | StmtKind::ImportFrom { .. } + ) { + let scope_index = self.ctx.scope_id(); + if scope_index.is_global() { + if self.ctx.current_stmt_parent().is_none() { + self.importer.visit_import(stmt); + } + } + } + // Pre-visit. match &stmt.node { StmtKind::Global { names } => { @@ -5387,6 +5405,7 @@ pub fn check_ast( locator, stylist, indexer, + Importer::new(python_ast, locator, stylist), ); checker.bind_builtins(); diff --git a/crates/ruff/src/cst/matchers.rs b/crates/ruff/src/cst/matchers.rs index 9a5ea83a38b9cc..fbebdca0a82c85 100644 --- a/crates/ruff/src/cst/matchers.rs +++ b/crates/ruff/src/cst/matchers.rs @@ -1,7 +1,7 @@ use anyhow::{bail, Result}; use libcst_native::{ - Attribute, Call, Comparison, Dict, Expr, Expression, Import, ImportFrom, Module, SimpleString, - SmallStatement, Statement, + Attribute, Call, Comparison, Dict, Expr, Expression, Import, ImportAlias, ImportFrom, + ImportNames, Module, SimpleString, SmallStatement, Statement, }; pub fn match_module(module_text: &str) -> Result { @@ -54,6 +54,16 @@ pub fn match_import_from<'a, 'b>(module: &'a mut Module<'b>) -> Result<&'a mut I } } +pub fn match_aliases<'a, 'b>( + import_from: &'a mut ImportFrom<'b>, +) -> Result<&'a mut Vec>> { + if let ImportNames::Aliases(aliases) = &mut import_from.names { + Ok(aliases) + } else { + bail!("Expected ImportNames::Aliases") + } +} + pub fn match_call<'a, 'b>(expression: &'a mut Expression<'b>) -> Result<&'a mut Call<'b>> { if let Expression::Call(call) = expression { Ok(call) diff --git a/crates/ruff/src/imports/importer.rs b/crates/ruff/src/imports/importer.rs new file mode 100644 index 00000000000000..04cef1474ce856 --- /dev/null +++ b/crates/ruff/src/imports/importer.rs @@ -0,0 +1,217 @@ +//! Add and modify import statements to make symbols available during code generation. + +use anyhow::Result; +use libcst_native::{Codegen, CodegenState, ImportAlias, Name, NameOrAttribute}; +use rustc_hash::FxHashMap; +use rustpython_parser::ast::{Location, Stmt, StmtKind, Suite}; +use rustpython_parser::{lexer, Mode, Tok}; + +use ruff_diagnostics::Edit; +use ruff_python_ast::helpers::is_docstring_stmt; +use ruff_python_ast::source_code::{Locator, Stylist}; + +use crate::cst::matchers::{match_aliases, match_import_from, match_module}; +use crate::imports::AnyImport; + +pub struct Importer<'a> { + python_ast: &'a Suite, + locator: &'a Locator<'a>, + stylist: &'a Stylist<'a>, + /// A map from module name to top-level `StmtKind::ImportFrom` statements. + import_from_map: FxHashMap<&'a str, &'a Stmt>, + /// The last top-level import statement. + trailing_import: Option<&'a Stmt>, +} + +impl<'a> Importer<'a> { + pub fn new(python_ast: &'a Suite, locator: &'a Locator<'a>, stylist: &'a Stylist<'a>) -> Self { + Self { + python_ast, + locator, + stylist, + import_from_map: FxHashMap::default(), + trailing_import: None, + } + } + + /// Visit a top-level import statement. + pub fn visit_import(&mut self, import: &'a Stmt) { + // Store a reference to the import statement in the appropriate map. + match &import.node { + StmtKind::Import { .. } => { + // Nothing to do here, we don't extend top-level `import` statements at all. + } + StmtKind::ImportFrom { module, level, .. } => { + // Store a reverse-map from module name to `import ... from` statement. + if level.map_or(true, |level| level == 0) { + if let Some(module) = module { + self.import_from_map.insert(module.as_str(), import); + } + } + } + _ => { + unreachable!("Expected StmtKind::Import | StmtKind::ImportFrom"); + } + } + + // Store a reference to the last top-level import statement. + self.trailing_import = Some(import); + } + + /// Add an import statement to import the given module. + /// + /// If there are no existing imports, the new import will be added at the top + /// of the file. Otherwise, it will be added after the most recent top-level + /// import statement. + pub fn add_import(&self, import: &AnyImport) -> Edit { + let required_import = import.to_string(); + if let Some(stmt) = self.trailing_import { + // Insert after the last top-level import. + let (prefix, location, suffix) = + end_of_statement_insertion(stmt, self.locator, self.stylist); + let content = format!("{prefix}{required_import}{suffix}"); + Edit::insertion(content, location) + } else { + // Insert at the top of the file. + let (prefix, location, suffix) = + top_of_file_insertion(self.python_ast, self.locator, self.stylist); + let content = format!("{prefix}{required_import}{suffix}"); + Edit::insertion(content, location) + } + } + + /// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom`. + /// if it exists. + pub fn get_import_from(&self, module: &str) -> Option<&Stmt> { + self.import_from_map.get(module).copied() + } + + /// Add the given member to an existing `StmtKind::ImportFrom` statement. + pub fn add_member(&self, stmt: &Stmt, member: &str) -> Result { + let mut tree = match_module(self.locator.slice(stmt))?; + let import_from = match_import_from(&mut tree)?; + let aliases = match_aliases(import_from)?; + aliases.push(ImportAlias { + name: NameOrAttribute::N(Box::new(Name { + value: member, + lpar: vec![], + rpar: vec![], + })), + asname: None, + comma: aliases.last().and_then(|alias| alias.comma.clone()), + }); + let mut state = CodegenState { + default_newline: self.stylist.line_ending(), + default_indent: self.stylist.indentation(), + ..CodegenState::default() + }; + tree.codegen(&mut state); + Ok(Edit::replacement( + state.to_string(), + stmt.location, + stmt.end_location.unwrap(), + )) + } +} + +/// Find the end of the last docstring. +fn match_docstring_end(body: &[Stmt]) -> Option { + let mut iter = body.iter(); + let Some(mut stmt) = iter.next() else { + return None; + }; + if !is_docstring_stmt(stmt) { + return None; + } + for next in iter { + if !is_docstring_stmt(next) { + break; + } + stmt = next; + } + Some(stmt.end_location.unwrap()) +} + +/// Find the location at which a "top-of-file" import should be inserted, +/// along with a prefix and suffix to use for the insertion. +/// +/// For example, given the following code: +/// +/// ```python +/// """Hello, world!""" +/// +/// import os +/// ``` +/// +/// The location returned will be the start of the `import os` statement, +/// along with a trailing newline suffix. +pub fn end_of_statement_insertion( + stmt: &Stmt, + locator: &Locator, + stylist: &Stylist, +) -> (&'static str, Location, &'static str) { + let location = stmt.end_location.unwrap(); + let mut tokens = lexer::lex_located(locator.skip(location), Mode::Module, location).flatten(); + if let Some((.., Tok::Semi, end)) = tokens.next() { + // If the first token after the docstring is a semicolon, insert after the semicolon as an + // inline statement; + (" ", end, ";") + } else { + // Otherwise, insert on the next line. + ( + "", + Location::new(location.row() + 1, 0), + stylist.line_ending().as_str(), + ) + } +} + +/// Find the location at which a "top-of-file" import should be inserted, +/// along with a prefix and suffix to use for the insertion. +/// +/// For example, given the following code: +/// +/// ```python +/// """Hello, world!""" +/// +/// import os +/// ``` +/// +/// The location returned will be the start of the `import os` statement, +/// along with a trailing newline suffix. +pub fn top_of_file_insertion( + body: &[Stmt], + locator: &Locator, + stylist: &Stylist, +) -> (&'static str, Location, &'static str) { + // Skip over any docstrings. + let mut location = if let Some(location) = match_docstring_end(body) { + let mut tokens = lexer::lex_located(locator.skip(location), Mode::Module, location) + .flatten() + .peekable(); + + // If the first token after the docstring is a semicolon, insert after the semicolon as an + // inline statement; + if let Some((.., Tok::Semi, end)) = tokens.peek() { + return (" ", *end, ";"); + } + + // Otherwise, advance to the next row. + Location::new(location.row() + 1, 0) + } else { + Location::default() + }; + + // Skip over any comments and empty lines. + for (.., tok, end) in + lexer::lex_located(locator.skip(location), Mode::Module, location).flatten() + { + if matches!(tok, Tok::Comment(..) | Tok::Newline) { + location = Location::new(end.row() + 1, 0); + } else { + break; + } + } + + return ("", location, stylist.line_ending().as_str()); +} diff --git a/crates/ruff/src/imports/mod.rs b/crates/ruff/src/imports/mod.rs new file mode 100644 index 00000000000000..3fd6b93d59fe4a --- /dev/null +++ b/crates/ruff/src/imports/mod.rs @@ -0,0 +1,71 @@ +use std::fmt; + +pub mod importer; + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] +struct Alias<'a> { + name: &'a str, + as_name: Option<&'a str>, +} + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] +pub struct ImportFrom<'a> { + module: Option<&'a str>, + name: Alias<'a>, + level: Option<&'a usize>, +} + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] +pub struct Import<'a> { + name: Alias<'a>, +} + +impl<'a> Import<'a> { + pub fn module(name: &'a str) -> Self { + Self { + name: Alias { + name, + as_name: None, + }, + } + } +} + +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] +pub enum AnyImport<'a> { + Import(Import<'a>), + ImportFrom(ImportFrom<'a>), +} + +impl fmt::Display for ImportFrom<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "from ")?; + if let Some(level) = self.level { + write!(f, "{}", ".".repeat(*level))?; + } + if let Some(module) = self.module { + write!(f, "{module}")?; + } + write!(f, " import {}", self.name.name)?; + Ok(()) + } +} + +impl fmt::Display for Import<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "import {}", self.name.name)?; + if let Some(as_name) = self.name.as_name { + write!(f, " as {as_name}")?; + } + Ok(()) + } +} + +impl fmt::Display for AnyImport<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AnyImport::Import(import) => write!(f, "{import}"), + AnyImport::ImportFrom(import_from) => write!(f, "{import_from}"), + } + } +} diff --git a/crates/ruff/src/lib.rs b/crates/ruff/src/lib.rs index 7736b7a7ebd2e5..7052ccbb5125d1 100644 --- a/crates/ruff/src/lib.rs +++ b/crates/ruff/src/lib.rs @@ -19,6 +19,7 @@ mod doc_lines; mod docstrings; pub mod flake8_to_ruff; pub mod fs; +mod imports; pub mod jupyter; mod lex; pub mod linter; diff --git a/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs b/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs index ae024c737521d5..b0db6067a4fe84 100644 --- a/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs +++ b/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs @@ -1,9 +1,10 @@ use rustpython_parser::ast::{Expr, ExprKind}; -use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Violation}; +use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::types::Range; +use crate::autofix::helpers::get_or_import_symbol; use crate::checkers::ast::Checker; use crate::registry::AsRule; @@ -45,13 +46,18 @@ pub fn sys_exit_alias(checker: &mut Checker, func: &Expr) { Range::from(func), ); if checker.patch(diagnostic.kind.rule()) { - if let Some(binding) = checker.ctx.resolve_binding("sys", "exit") { - diagnostic.set_fix(Edit::replacement( - binding, - func.location, - func.end_location.unwrap(), - )); - } + diagnostic.try_set_fix(|| { + let (import_edit, binding) = get_or_import_symbol( + "sys", + "exit", + &checker.ctx, + &checker.importer, + checker.locator, + )?; + let reference_edit = + Edit::replacement(binding, func.location, func.end_location.unwrap()); + Ok(Fix::from_iter([import_edit, reference_edit])) + }); } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_0.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_0.py.snap index 1ac50fdd49fd5a..c0ef38362472d9 100644 --- a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_0.py.snap +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_0.py.snap @@ -14,7 +14,21 @@ expression: diagnostics row: 1 column: 4 fix: - edits: [] + edits: + - content: "import sys\n" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 0 + - content: sys.exit + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 4 parent: ~ - kind: name: SysExitAlias @@ -28,7 +42,21 @@ expression: diagnostics row: 2 column: 4 fix: - edits: [] + edits: + - content: "import sys\n" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 0 + - content: sys.exit + location: + row: 2 + column: 0 + end_location: + row: 2 + column: 4 parent: ~ - kind: name: SysExitAlias @@ -42,7 +70,21 @@ expression: diagnostics row: 6 column: 8 fix: - edits: [] + edits: + - content: "import sys\n" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 0 + - content: sys.exit + location: + row: 6 + column: 4 + end_location: + row: 6 + column: 8 parent: ~ - kind: name: SysExitAlias @@ -56,6 +98,20 @@ expression: diagnostics row: 7 column: 8 fix: - edits: [] + edits: + - content: "import sys\n" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 0 + - content: sys.exit + location: + row: 7 + column: 4 + end_location: + row: 7 + column: 8 parent: ~ diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_1.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_1.py.snap index 9194663f6cb258..6bc29cbca87b5a 100644 --- a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_1.py.snap +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_1.py.snap @@ -15,6 +15,13 @@ expression: diagnostics column: 4 fix: edits: + - content: import sys + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 10 - content: sys.exit location: row: 3 @@ -36,6 +43,13 @@ expression: diagnostics column: 4 fix: edits: + - content: import sys + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 10 - content: sys.exit location: row: 4 @@ -57,6 +71,13 @@ expression: diagnostics column: 8 fix: edits: + - content: import sys + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 10 - content: sys.exit location: row: 8 @@ -78,6 +99,13 @@ expression: diagnostics column: 8 fix: edits: + - content: import sys + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 10 - content: sys.exit location: row: 9 @@ -86,4 +114,32 @@ expression: diagnostics row: 9 column: 8 parent: ~ +- kind: + name: SysExitAlias + body: "Use `sys.exit()` instead of `exit`" + suggestion: "Replace `exit` with `sys.exit()`" + fixable: true + location: + row: 15 + column: 4 + end_location: + row: 15 + column: 8 + fix: + edits: [] + parent: ~ +- kind: + name: SysExitAlias + body: "Use `sys.exit()` instead of `quit`" + suggestion: "Replace `quit` with `sys.exit()`" + fixable: true + location: + row: 16 + column: 4 + end_location: + row: 16 + column: 8 + fix: + edits: [] + parent: ~ diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_2.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_2.py.snap index aaae64b9832658..9de290bf6f9406 100644 --- a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_2.py.snap +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_2.py.snap @@ -15,6 +15,13 @@ expression: diagnostics column: 4 fix: edits: + - content: import sys as sys2 + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 18 - content: sys2.exit location: row: 3 @@ -36,6 +43,13 @@ expression: diagnostics column: 4 fix: edits: + - content: import sys as sys2 + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 18 - content: sys2.exit location: row: 4 @@ -57,6 +71,13 @@ expression: diagnostics column: 8 fix: edits: + - content: import sys as sys2 + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 18 - content: sys2.exit location: row: 8 @@ -78,6 +99,13 @@ expression: diagnostics column: 8 fix: edits: + - content: import sys as sys2 + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 18 - content: sys2.exit location: row: 9 diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_3.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_3.py.snap index 52d2f6d5a4f9c7..801a625462e482 100644 --- a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_3.py.snap +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_3.py.snap @@ -15,6 +15,13 @@ expression: diagnostics column: 4 fix: edits: + - content: from sys import exit + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 20 - content: exit location: row: 4 @@ -36,6 +43,13 @@ expression: diagnostics column: 8 fix: edits: + - content: from sys import exit + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 20 - content: exit location: row: 9 diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_4.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_4.py.snap index 73ad483c429046..3c39d1915ed355 100644 --- a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_4.py.snap +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_4.py.snap @@ -15,6 +15,13 @@ expression: diagnostics column: 4 fix: edits: + - content: from sys import exit as exit2 + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 29 - content: exit2 location: row: 3 @@ -36,6 +43,13 @@ expression: diagnostics column: 4 fix: edits: + - content: from sys import exit as exit2 + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 29 - content: exit2 location: row: 4 @@ -57,6 +71,13 @@ expression: diagnostics column: 8 fix: edits: + - content: from sys import exit as exit2 + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 29 - content: exit2 location: row: 8 @@ -78,6 +99,13 @@ expression: diagnostics column: 8 fix: edits: + - content: from sys import exit as exit2 + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 29 - content: exit2 location: row: 9 diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_6.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_6.py.snap index 1f55fd9ccf21c7..4d1e19362f690f 100644 --- a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_6.py.snap +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLR1722_sys_exit_alias_6.py.snap @@ -14,7 +14,21 @@ expression: diagnostics row: 1 column: 4 fix: - edits: [] + edits: + - content: "import sys\n" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 0 + - content: sys.exit + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 4 parent: ~ - kind: name: SysExitAlias @@ -28,6 +42,20 @@ expression: diagnostics row: 2 column: 4 fix: - edits: [] + edits: + - content: "import sys\n" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 0 + - content: sys.exit + location: + row: 2 + column: 0 + end_location: + row: 2 + column: 4 parent: ~ diff --git a/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs b/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs index db81ad64012593..db512c3c53f595 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs @@ -1,9 +1,10 @@ use rustpython_parser::ast::{Constant, Expr, ExprKind, KeywordData}; -use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit}; +use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::types::Range; +use crate::autofix::helpers::get_or_import_symbol; use crate::checkers::ast::Checker; use crate::registry::AsRule; @@ -57,13 +58,18 @@ pub fn lru_cache_with_maxsize_none(checker: &mut Checker, decorator_list: &[Expr Range::new(func.end_location.unwrap(), expr.end_location.unwrap()), ); if checker.patch(diagnostic.kind.rule()) { - if let Some(binding) = checker.ctx.resolve_binding("functools", "cache") { - diagnostic.set_fix(Edit::replacement( - binding, - expr.location, - expr.end_location.unwrap(), - )); - } + diagnostic.try_set_fix(|| { + let (import_edit, binding) = get_or_import_symbol( + "functools", + "cache", + &checker.ctx, + &checker.importer, + checker.locator, + )?; + let reference_edit = + Edit::replacement(binding, expr.location, expr.end_location.unwrap()); + Ok(Fix::from_iter([import_edit, reference_edit])) + }); } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP033_0.py.snap b/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP033_0.py.snap index 6e0425c16e5df1..97069eddb00bab 100644 --- a/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP033_0.py.snap +++ b/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP033_0.py.snap @@ -15,6 +15,13 @@ expression: diagnostics column: 34 fix: edits: + - content: import functools + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 16 - content: functools.cache location: row: 4 @@ -36,6 +43,13 @@ expression: diagnostics column: 34 fix: edits: + - content: import functools + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 16 - content: functools.cache location: row: 10 @@ -57,6 +71,13 @@ expression: diagnostics column: 34 fix: edits: + - content: import functools + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 16 - content: functools.cache location: row: 15 diff --git a/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP033_1.py.snap b/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP033_1.py.snap index 41b265ae328ea9..1eee466d682d81 100644 --- a/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP033_1.py.snap +++ b/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP033_1.py.snap @@ -14,7 +14,21 @@ expression: diagnostics row: 4 column: 24 fix: - edits: [] + edits: + - content: "from functools import lru_cache, cache" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 31 + - content: cache + location: + row: 4 + column: 1 + end_location: + row: 4 + column: 24 parent: ~ - kind: name: LRUCacheWithMaxsizeNone @@ -28,7 +42,21 @@ expression: diagnostics row: 10 column: 24 fix: - edits: [] + edits: + - content: "from functools import lru_cache, cache" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 31 + - content: cache + location: + row: 10 + column: 1 + end_location: + row: 10 + column: 24 parent: ~ - kind: name: LRUCacheWithMaxsizeNone @@ -42,6 +70,20 @@ expression: diagnostics row: 15 column: 24 fix: - edits: [] + edits: + - content: "from functools import lru_cache, cache" + location: + row: 1 + column: 0 + end_location: + row: 1 + column: 31 + - content: cache + location: + row: 15 + column: 1 + end_location: + row: 15 + column: 24 parent: ~ diff --git a/crates/ruff_python_ast/src/context.rs b/crates/ruff_python_ast/src/context.rs index daee5269866865..c49a18026908cc 100644 --- a/crates/ruff_python_ast/src/context.rs +++ b/crates/ruff_python_ast/src/context.rs @@ -217,10 +217,11 @@ impl<'a> Context<'a> { /// ``` /// /// ...then `resolve_binding("sys", "version_info")` will return `Some("python_version")`. - pub fn resolve_binding(&self, module: &str, member: &str) -> Option { + pub fn resolve_binding(&self, module: &str, member: &str) -> Option<(&Stmt, String)> { self.scopes().enumerate().find_map(|(scope_index, scope)| { scope.binding_ids().find_map(|binding_index| { - match &self.bindings[*binding_index].kind { + let binding = &self.bindings[*binding_index]; + match &binding.kind { // Ex) Given `module="sys"` and `object="exit"`: // `import sys` -> `sys.exit` // `import sys as sys2` -> `sys2.exit` @@ -232,7 +233,10 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some(format!("{name}.{member}")); + return Some(( + binding.source.as_ref().unwrap().into(), + format!("{name}.{member}"), + )); } } } @@ -248,7 +252,10 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some((*name).to_string()); + return Some(( + binding.source.as_ref().unwrap().into(), + (*name).to_string(), + )); } } } @@ -263,7 +270,10 @@ impl<'a> Context<'a> { .take(scope_index) .all(|scope| scope.get(name).is_none()) { - return Some(format!("{name}.{member}")); + return Some(( + binding.source.as_ref().unwrap().into(), + format!("{name}.{member}"), + )); } } }