diff --git a/crates/ruff_linter/resources/test/fixtures/perflint/PERF101.py b/crates/ruff_linter/resources/test/fixtures/perflint/PERF101.py index e6ae0b8f25d75..e624930ff2e04 100644 --- a/crates/ruff_linter/resources/test/fixtures/perflint/PERF101.py +++ b/crates/ruff_linter/resources/test/fixtures/perflint/PERF101.py @@ -36,35 +36,47 @@ ): # PERF101 pass -for i in list(foo_dict): # Ok +for i in list(foo_dict): # OK pass -for i in list(1): # Ok +for i in list(1): # OK pass -for i in list(foo_int): # Ok +for i in list(foo_int): # OK pass import itertools -for i in itertools.product(foo_int): # Ok +for i in itertools.product(foo_int): # OK pass -for i in list(foo_list): # Ok +for i in list(foo_list): # OK foo_list.append(i + 1) for i in list(foo_list): # PERF101 # Make sure we match the correct list other_list.append(i + 1) -for i in list(foo_tuple): # Ok +for i in list(foo_tuple): # OK foo_tuple.append(i + 1) -for i in list(foo_set): # Ok +for i in list(foo_set): # OK foo_set.append(i + 1) x, y, nested_tuple = (1, 2, (3, 4, 5)) for i in list(nested_tuple): # PERF101 pass + +for i in list(foo_list): # OK + if True: + foo_list.append(i + 1) + +for i in list(foo_list): # OK + if True: + foo_list[i] = i + 1 + +for i in list(foo_list): # OK + if True: + del foo_list[i + 1] diff --git a/crates/ruff_linter/src/rules/perflint/rules/unnecessary_list_cast.rs b/crates/ruff_linter/src/rules/perflint/rules/unnecessary_list_cast.rs index 7ff1d544b392b..d6676a0798bcf 100644 --- a/crates/ruff_linter/src/rules/perflint/rules/unnecessary_list_cast.rs +++ b/crates/ruff_linter/src/rules/perflint/rules/unnecessary_list_cast.rs @@ -1,5 +1,6 @@ use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; +use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor}; use ruff_python_ast::{self as ast, Arguments, Expr, Stmt}; use ruff_python_semantic::analyze::typing::find_assigned_value; use ruff_text_size::TextRange; @@ -98,22 +99,25 @@ pub(crate) fn unnecessary_list_cast(checker: &mut Checker, iter: &Expr, body: &[ range: iterable_range, .. }) => { - // If the variable is being appended to, don't suggest removing the cast: - // - // ```python - // items = ["foo", "bar"] - // for item in list(items): - // items.append("baz") - // ``` - // - // Here, removing the `list()` cast would change the behavior of the code. - if body.iter().any(|stmt| match_append(stmt, id)) { - return; - } let Some(value) = find_assigned_value(id, checker.semantic()) else { return; }; if matches!(value, Expr::Tuple(_) | Expr::List(_) | Expr::Set(_)) { + // If the variable is being modified to, don't suggest removing the cast: + // + // ```python + // items = ["foo", "bar"] + // for item in list(items): + // items.append("baz") + // ``` + // + // Here, removing the `list()` cast would change the behavior of the code. + let mut visitor = MutationVisitor::new(id); + visitor.visit_body(body); + if visitor.is_mutated { + return; + } + let mut diagnostic = Diagnostic::new(UnnecessaryListCast, *list_range); diagnostic.set_fix(remove_cast(*list_range, *iterable_range)); checker.diagnostics.push(diagnostic); @@ -123,28 +127,6 @@ pub(crate) fn unnecessary_list_cast(checker: &mut Checker, iter: &Expr, body: &[ } } -/// Check if a statement is an `append` call to a given identifier. -/// -/// For example, `foo.append(bar)` would return `true` if `id` is `foo`. -fn match_append(stmt: &Stmt, id: &str) -> bool { - let Some(ast::StmtExpr { value, .. }) = stmt.as_expr_stmt() else { - return false; - }; - let Some(ast::ExprCall { func, .. }) = value.as_call_expr() else { - return false; - }; - let Some(ast::ExprAttribute { value, attr, .. }) = func.as_attribute_expr() else { - return false; - }; - if attr != "append" { - return false; - } - let Some(ast::ExprName { id: target_id, .. }) = value.as_name_expr() else { - return false; - }; - target_id == id -} - /// Generate a [`Fix`] to remove a `list` cast from an expression. fn remove_cast(list_range: TextRange, iterable_range: TextRange) -> Fix { Fix::safe_edits( @@ -152,3 +134,95 @@ fn remove_cast(list_range: TextRange, iterable_range: TextRange) -> Fix { [Edit::deletion(iterable_range.end(), list_range.end())], ) } + +/// A [`StatementVisitor`] that (conservatively) identifies mutations to a variable. +#[derive(Default)] +pub(crate) struct MutationVisitor<'a> { + pub(crate) target: &'a str, + pub(crate) is_mutated: bool, +} + +impl<'a> MutationVisitor<'a> { + pub(crate) fn new(target: &'a str) -> Self { + Self { + target, + is_mutated: false, + } + } +} + +impl<'a, 'b> StatementVisitor<'b> for MutationVisitor<'a> +where + 'b: 'a, +{ + fn visit_stmt(&mut self, stmt: &'b Stmt) { + if match_mutation(stmt, self.target) { + self.is_mutated = true; + } else { + walk_stmt(self, stmt); + } + } +} + +/// Check if a statement is (probably) a modification to the list assigned to the given identifier. +/// +/// For example, `foo.append(bar)` would return `true` if `id` is `foo`. +fn match_mutation(stmt: &Stmt, id: &str) -> bool { + match stmt { + // Ex) `foo.append(bar)` + Stmt::Expr(ast::StmtExpr { value, .. }) => { + let Some(ast::ExprCall { func, .. }) = value.as_call_expr() else { + return false; + }; + let Some(ast::ExprAttribute { value, attr, .. }) = func.as_attribute_expr() else { + return false; + }; + if !matches!( + attr.as_str(), + "append" | "insert" | "extend" | "remove" | "pop" | "clear" | "reverse" | "sort" + ) { + return false; + } + let Some(ast::ExprName { id: target_id, .. }) = value.as_name_expr() else { + return false; + }; + target_id == id + } + // Ex) `foo[0] = bar` + Stmt::Assign(ast::StmtAssign { targets, .. }) => targets.iter().any(|target| { + if let Some(ast::ExprSubscript { value: target, .. }) = target.as_subscript_expr() { + if let Some(ast::ExprName { id: target_id, .. }) = target.as_name_expr() { + return target_id == id; + } + } + false + }), + // Ex) `foo += bar` + Stmt::AugAssign(ast::StmtAugAssign { target, .. }) => { + if let Some(ast::ExprName { id: target_id, .. }) = target.as_name_expr() { + target_id == id + } else { + false + } + } + // Ex) `foo[0]: int = bar` + Stmt::AnnAssign(ast::StmtAnnAssign { target, .. }) => { + if let Some(ast::ExprSubscript { value: target, .. }) = target.as_subscript_expr() { + if let Some(ast::ExprName { id: target_id, .. }) = target.as_name_expr() { + return target_id == id; + } + } + false + } + // Ex) `del foo[0]` + Stmt::Delete(ast::StmtDelete { targets, .. }) => targets.iter().any(|target| { + if let Some(ast::ExprSubscript { value: target, .. }) = target.as_subscript_expr() { + if let Some(ast::ExprName { id: target_id, .. }) = target.as_name_expr() { + return target_id == id; + } + } + false + }), + _ => false, + } +} diff --git a/crates/ruff_linter/src/rules/perflint/snapshots/ruff_linter__rules__perflint__tests__PERF101_PERF101.py.snap b/crates/ruff_linter/src/rules/perflint/snapshots/ruff_linter__rules__perflint__tests__PERF101_PERF101.py.snap index 11dafc4dd2565..d41a00b33eac7 100644 --- a/crates/ruff_linter/src/rules/perflint/snapshots/ruff_linter__rules__perflint__tests__PERF101_PERF101.py.snap +++ b/crates/ruff_linter/src/rules/perflint/snapshots/ruff_linter__rules__perflint__tests__PERF101_PERF101.py.snap @@ -178,7 +178,7 @@ PERF101.py:34:10: PERF101 [*] Do not cast an iterable to `list` before iterating 34 |+for i in {1, 2, 3}: # PERF101 37 35 | pass 38 36 | -39 37 | for i in list(foo_dict): # Ok +39 37 | for i in list(foo_dict): # OK PERF101.py:57:10: PERF101 [*] Do not cast an iterable to `list` before iterating over it | @@ -192,7 +192,7 @@ PERF101.py:57:10: PERF101 [*] Do not cast an iterable to `list` before iterating = help: Remove `list()` cast ℹ Safe fix -54 54 | for i in list(foo_list): # Ok +54 54 | for i in list(foo_list): # OK 55 55 | foo_list.append(i + 1) 56 56 | 57 |-for i in list(foo_list): # PERF101 @@ -218,5 +218,7 @@ PERF101.py:69:10: PERF101 [*] Do not cast an iterable to `list` before iterating 69 |-for i in list(nested_tuple): # PERF101 69 |+for i in nested_tuple: # PERF101 70 70 | pass +71 71 | +72 72 | for i in list(foo_list): # OK diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index 579a29ffbc7dd..154be660d37ae 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -935,7 +935,7 @@ where } } -/// A [`StatementVisitor`] that collects all `return` statements in a function or method. +/// A [`Visitor`] that collects all `return` statements in a function or method. #[derive(Default)] pub struct ReturnStatementVisitor<'a> { pub returns: Vec<&'a ast::StmtReturn>,