From 50ec6d3b0fdf8082702c65e039dec9f3b5c3906b Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 21 Feb 2023 13:10:31 -0500 Subject: [PATCH] Use LibCST to fix chained assertions (#3087) --- .../fixtures/flake8_pytest_style/PT018.py | 16 +- .../flake8_pytest_style/rules/assertion.rs | 177 +++++++++++++----- ...es__flake8_pytest_style__tests__PT018.snap | 165 +++++++++------- 3 files changed, 237 insertions(+), 121 deletions(-) diff --git a/crates/ruff/resources/test/fixtures/flake8_pytest_style/PT018.py b/crates/ruff/resources/test/fixtures/flake8_pytest_style/PT018.py index e94c5d1ec79be..02dc99b2a7410 100644 --- a/crates/ruff/resources/test/fixtures/flake8_pytest_style/PT018.py +++ b/crates/ruff/resources/test/fixtures/flake8_pytest_style/PT018.py @@ -9,6 +9,7 @@ def test_ok(): assert something, "something message" assert something or something_else and something_third, "another message" + def test_error(): assert something and something_else assert something and something_else and something_third @@ -17,13 +18,24 @@ def test_error(): assert not something and something_else assert not (something or something_else) assert not (something or something_else or something_third) + assert something and something_else == """error + message + """ # recursive case - assert not (a or not (b or c)) - assert not (a or not (b and c)) # note that we only reduce once here + assert not (a or not (b or c)) # note that we only reduce once here + assert not (a or not (b and c)) # detected, but no autofix for messages assert something and something_else, "error message" assert not (something or something_else and something_third), "with message" # detected, but no autofix for mixed conditions (e.g. `a or b and c`) assert not (something or something_else and something_third) + # detected, but no autofix for parenthesized conditions + assert ( + something + and something_else + == """error +message +""" + ) diff --git a/crates/ruff/src/rules/flake8_pytest_style/rules/assertion.rs b/crates/ruff/src/rules/flake8_pytest_style/rules/assertion.rs index 6e96447f934b7..daef6dc0c5052 100644 --- a/crates/ruff/src/rules/flake8_pytest_style/rules/assertion.rs +++ b/crates/ruff/src/rules/flake8_pytest_style/rules/assertion.rs @@ -1,17 +1,26 @@ +use anyhow::bail; +use anyhow::Result; +use libcst_native::{ + Assert, BooleanOp, Codegen, CodegenState, CompoundStatement, Expression, + ParenthesizableWhitespace, ParenthesizedNode, SimpleStatementLine, SimpleWhitespace, + SmallStatement, Statement, Suite, TrailingWhitespace, UnaryOp, UnaryOperation, +}; use rustpython_parser::ast::{ - Boolop, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Keyword, Stmt, StmtKind, Unaryop, + Boolop, Excepthandler, ExcepthandlerKind, Expr, ExprKind, Keyword, Location, Stmt, StmtKind, + Unaryop, }; use ruff_macros::{define_violation, derive_message_formats}; -use crate::ast::helpers::{create_expr, create_stmt, unparse_stmt}; +use crate::ast::helpers::unparse_stmt; use crate::ast::types::Range; -use crate::ast::visitor; use crate::ast::visitor::Visitor; +use crate::ast::{visitor, whitespace}; use crate::checkers::ast::Checker; +use crate::cst::matchers::match_module; use crate::fix::Fix; use crate::registry::Diagnostic; -use crate::source_code::Stylist; +use crate::source_code::{Locator, Stylist}; use crate::violation::{AutofixKind, Availability, Violation}; use super::helpers::is_falsy_constant; @@ -295,66 +304,130 @@ fn is_composite_condition(test: &Expr) -> CompositionKind { } /// Negate a condition, i.e., `a` => `not a` and `not a` => `a`. -fn negate(f: Expr) -> Expr { - match f.node { - ExprKind::UnaryOp { - op: Unaryop::Not, - operand, - } => *operand, - _ => create_expr(ExprKind::UnaryOp { - op: Unaryop::Not, - operand: Box::new(f), - }), +fn negate<'a>(expression: &Expression<'a>) -> Expression<'a> { + if let Expression::UnaryOperation(ref expression) = expression { + if matches!(expression.operator, UnaryOp::Not { .. }) { + return *expression.expression.clone(); + } } + Expression::UnaryOperation(Box::new(UnaryOperation { + operator: UnaryOp::Not { + whitespace_after: ParenthesizableWhitespace::SimpleWhitespace(SimpleWhitespace(" ")), + }, + expression: Box::new(expression.clone()), + lpar: vec![], + rpar: vec![], + })) } /// Replace composite condition `assert a == "hello" and b == "world"` with two statements /// `assert a == "hello"` and `assert b == "world"`. -fn fix_composite_condition(stylist: &Stylist, stmt: &Stmt, test: &Expr) -> Fix { - let mut conditions: Vec = vec![]; - match &test.node { - ExprKind::BoolOp { - op: Boolop::And, - values, - } => { - // Compound, so split. - conditions.extend(values.clone()); - } - ExprKind::UnaryOp { - op: Unaryop::Not, - operand, - } => { - match &operand.node { - ExprKind::BoolOp { - op: Boolop::Or, - values, - } => { - // Split via `not (a or b)` equals `not a and not b`. - conditions.extend(values.iter().map(|f| negate(f.clone()))); - } - _ => { - // Do not split. - conditions.push(*operand.clone()); +fn fix_composite_condition(stmt: &Stmt, locator: &Locator, stylist: &Stylist) -> Result { + // Infer the indentation of the outer block. + let Some(outer_indent) = whitespace::indentation(locator, stmt) else { + bail!("Unable to fix multiline statement"); + }; + + // Extract the module text. + let contents = locator.slice(&Range::new( + Location::new(stmt.location.row(), 0), + Location::new(stmt.end_location.unwrap().row() + 1, 0), + )); + + // "Embed" it in a function definition, to preserve indentation while retaining valid source + // code. (We'll strip the prefix later on.) + let module_text = format!("def f():{}{contents}", stylist.line_ending().as_str()); + + // Parse the CST. + let mut tree = match_module(&module_text)?; + + // Extract the assert statement. + let statements: &mut Vec = { + let [Statement::Compound(CompoundStatement::FunctionDef(embedding))] = &mut *tree.body else { + bail!("Expected statement to be embedded in a function definition") + }; + + let Suite::IndentedBlock(indented_block) = &mut embedding.body else { + bail!("Expected indented block") + }; + indented_block.indent = Some(outer_indent); + + &mut indented_block.body + }; + let [Statement::Simple(simple_statement_line)] = statements.as_mut_slice() else { + bail!("Expected one simple statement") + }; + let [SmallStatement::Assert(assert_statement)] = &mut *simple_statement_line.body else { + bail!("Expected simple statement to be an assert") + }; + + if !(assert_statement.test.lpar().is_empty() && assert_statement.test.rpar().is_empty()) { + bail!("Unable to split parenthesized condition"); + } + + // Extract the individual conditions. + let mut conditions: Vec = Vec::with_capacity(2); + match &assert_statement.test { + Expression::UnaryOperation(op) => { + if matches!(op.operator, UnaryOp::Not { .. }) { + if let Expression::BooleanOperation(op) = &*op.expression { + if matches!(op.operator, BooleanOp::Or { .. }) { + conditions.push(negate(&op.left)); + conditions.push(negate(&op.right)); + } else { + bail!("Expected assert statement to be a composite condition"); + } + } else { + bail!("Expected assert statement to be a composite condition"); } } } - _ => {} - }; + Expression::BooleanOperation(op) => { + if matches!(op.operator, BooleanOp::And { .. }) { + conditions.push(*op.left.clone()); + conditions.push(*op.right.clone()); + } else { + bail!("Expected assert statement to be a composite condition"); + } + } + _ => bail!("Expected assert statement to be a composite condition"), + } // For each condition, create an `assert condition` statement. - let mut content: Vec = Vec::with_capacity(conditions.len()); + statements.clear(); for condition in conditions { - content.push(unparse_stmt( - &create_stmt(StmtKind::Assert { - test: Box::new(condition.clone()), + statements.push(Statement::Simple(SimpleStatementLine { + body: vec![SmallStatement::Assert(Assert { + test: condition, msg: None, - }), - stylist, - )); + comma: None, + whitespace_after_assert: SimpleWhitespace(" "), + semicolon: None, + })], + leading_lines: Vec::default(), + trailing_whitespace: TrailingWhitespace::default(), + })); } - let content = content.join(stylist.line_ending().as_str()); - Fix::replacement(content, stmt.location, stmt.end_location.unwrap()) + let mut state = CodegenState { + default_newline: stylist.line_ending(), + default_indent: stylist.indentation(), + ..CodegenState::default() + }; + tree.codegen(&mut state); + + // Reconstruct and reformat the code. + let module_text = state.to_string(); + let contents = module_text + .strip_prefix(&format!("def f():{}", stylist.line_ending().as_str())) + .unwrap() + .to_string(); + + Ok(Fix::replacement( + contents, + Location::new(stmt.location.row(), 0), + Location::new(stmt.end_location.unwrap().row() + 1, 0), + )) } /// PT018 @@ -365,7 +438,9 @@ pub fn composite_condition(checker: &mut Checker, stmt: &Stmt, test: &Expr, msg: let mut diagnostic = Diagnostic::new(CompositeAssertion { fixable }, Range::from_located(stmt)); if fixable && checker.patch(diagnostic.kind.rule()) { - diagnostic.amend(fix_composite_condition(checker.stylist, stmt, test)); + if let Ok(fix) = fix_composite_condition(stmt, checker.locator, checker.stylist) { + diagnostic.amend(fix); + } } checker.diagnostics.push(diagnostic); } diff --git a/crates/ruff/src/rules/flake8_pytest_style/snapshots/ruff__rules__flake8_pytest_style__tests__PT018.snap b/crates/ruff/src/rules/flake8_pytest_style/snapshots/ruff__rules__flake8_pytest_style__tests__PT018.snap index 8ac810c35cfac..3cfde0ccf99c3 100644 --- a/crates/ruff/src/rules/flake8_pytest_style/snapshots/ruff__rules__flake8_pytest_style__tests__PT018.snap +++ b/crates/ruff/src/rules/flake8_pytest_style/snapshots/ruff__rules__flake8_pytest_style__tests__PT018.snap @@ -2,24 +2,6 @@ source: crates/ruff/src/rules/flake8_pytest_style/mod.rs expression: diagnostics --- -- kind: - CompositeAssertion: - fixable: true - location: - row: 13 - column: 4 - end_location: - row: 13 - column: 39 - fix: - content: "assert something\nassert something_else" - location: - row: 13 - column: 4 - end_location: - row: 13 - column: 39 - parent: ~ - kind: CompositeAssertion: fixable: true @@ -28,15 +10,15 @@ expression: diagnostics column: 4 end_location: row: 14 - column: 59 + column: 39 fix: - content: "assert something\nassert something_else\nassert something_third" + content: " assert something\n assert something_else\n" location: row: 14 - column: 4 + column: 0 end_location: - row: 14 - column: 59 + row: 15 + column: 0 parent: ~ - kind: CompositeAssertion: @@ -46,15 +28,15 @@ expression: diagnostics column: 4 end_location: row: 15 - column: 43 + column: 59 fix: - content: "assert something\nassert not something_else" + content: " assert something and something_else\n assert something_third\n" location: row: 15 - column: 4 + column: 0 end_location: - row: 15 - column: 43 + row: 16 + column: 0 parent: ~ - kind: CompositeAssertion: @@ -64,15 +46,15 @@ expression: diagnostics column: 4 end_location: row: 16 - column: 60 + column: 43 fix: - content: "assert something\nassert something_else or something_third" + content: " assert something\n assert not something_else\n" location: row: 16 - column: 4 + column: 0 end_location: - row: 16 - column: 60 + row: 17 + column: 0 parent: ~ - kind: CompositeAssertion: @@ -82,15 +64,15 @@ expression: diagnostics column: 4 end_location: row: 17 - column: 43 + column: 60 fix: - content: "assert not something\nassert something_else" + content: " assert something\n assert (something_else or something_third)\n" location: row: 17 - column: 4 + column: 0 end_location: - row: 17 - column: 43 + row: 18 + column: 0 parent: ~ - kind: CompositeAssertion: @@ -100,15 +82,15 @@ expression: diagnostics column: 4 end_location: row: 18 - column: 44 + column: 43 fix: - content: "assert not something\nassert not something_else" + content: " assert not something\n assert something_else\n" location: row: 18 - column: 4 + column: 0 end_location: - row: 18 - column: 44 + row: 19 + column: 0 parent: ~ - kind: CompositeAssertion: @@ -118,60 +100,96 @@ expression: diagnostics column: 4 end_location: row: 19 - column: 63 + column: 44 fix: - content: "assert not something\nassert not something_else\nassert not something_third" + content: " assert not something\n assert not something_else\n" location: row: 19 - column: 4 + column: 0 end_location: - row: 19 - column: 63 + row: 20 + column: 0 parent: ~ - kind: CompositeAssertion: fixable: true location: - row: 22 + row: 20 column: 4 end_location: - row: 22 - column: 34 + row: 20 + column: 63 fix: - content: "assert not a\nassert b or c" + content: " assert not something or something_else\n assert not something_third\n" location: - row: 22 - column: 4 + row: 20 + column: 0 end_location: - row: 22 - column: 34 + row: 21 + column: 0 parent: ~ - kind: CompositeAssertion: fixable: true location: - row: 23 + row: 21 column: 4 end_location: row: 23 - column: 35 + column: 7 fix: - content: "assert not a\nassert b and c" + content: " assert something\n assert something_else == \"\"\"error\n message\n \"\"\"\n" location: - row: 23 - column: 4 + row: 21 + column: 0 end_location: - row: 23 - column: 35 + row: 24 + column: 0 parent: ~ - kind: CompositeAssertion: - fixable: false + fixable: true location: row: 26 column: 4 end_location: row: 26 + column: 34 + fix: + content: " assert not a\n assert (b or c)\n" + location: + row: 26 + column: 0 + end_location: + row: 27 + column: 0 + parent: ~ +- kind: + CompositeAssertion: + fixable: true + location: + row: 27 + column: 4 + end_location: + row: 27 + column: 35 + fix: + content: " assert not a\n assert (b and c)\n" + location: + row: 27 + column: 0 + end_location: + row: 28 + column: 0 + parent: ~ +- kind: + CompositeAssertion: + fixable: false + location: + row: 30 + column: 4 + end_location: + row: 30 column: 56 fix: ~ parent: ~ @@ -179,10 +197,10 @@ expression: diagnostics CompositeAssertion: fixable: false location: - row: 27 + row: 31 column: 4 end_location: - row: 27 + row: 31 column: 80 fix: ~ parent: ~ @@ -190,11 +208,22 @@ expression: diagnostics CompositeAssertion: fixable: false location: - row: 29 + row: 33 column: 4 end_location: - row: 29 + row: 33 column: 64 fix: ~ parent: ~ +- kind: + CompositeAssertion: + fixable: true + location: + row: 35 + column: 4 + end_location: + row: 41 + column: 5 + fix: ~ + parent: ~