Skip to content

Commit

Permalink
Minor tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh committed Mar 21, 2024
1 parent 4a53b0f commit a4f8a30
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 60 deletions.
133 changes: 83 additions & 50 deletions crates/ruff_linter/src/rules/pylint/rules/nan_comparison.rs
@@ -1,7 +1,7 @@
use ruff_python_ast::{self as ast, Expr};

use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::{self as ast, Expr};
use ruff_python_semantic::SemanticModel;
use ruff_text_size::Ranged;

use crate::checkers::ast::Checker;
Expand All @@ -10,7 +10,14 @@ use crate::checkers::ast::Checker;
/// Checks for comparisons against NaN values.
///
/// ## Why is this bad?
/// Comparing against a NaN value will always return False even if both values are NaN.
/// Comparing against a NaN value can lead to unexpected results. For example,
/// `float("NaN") == float("NaN")` will return `False` and, in general,
/// `x == float("NaN")` will always return `False`, even if `x` is `NaN`.
///
/// To determine whether a value is `NaN`, use `math.isnan` or `np.isnan`
/// instead of comparing against `NaN` directly.
///
/// To
///
/// ## Example
/// ```python
Expand All @@ -28,77 +35,103 @@ use crate::checkers::ast::Checker;
///
#[violation]
pub struct NanComparison {
using_numpy: bool,
nan: Nan,
}

impl Violation for NanComparison {
#[derive_message_formats]
fn message(&self) -> String {
let NanComparison { using_numpy } = self;
if *using_numpy {
format!("Comparing against a NaN value, consider using `np.isnan`")
} else {
format!("Comparing against a NaN value, consider using `math.isnan`")
let NanComparison { nan } = self;
match nan {
Nan::Math => format!("Comparing against a NaN value; use `math.isnan` instead"),
Nan::NumPy => format!("Comparing against a NaN value; use `np.isnan` instead"),
}
}
}

/// PLW0117
pub(crate) fn nan_comparison(checker: &mut Checker, left: &Expr, comparators: &[Expr]) {
for expr in std::iter::once(left).chain(comparators.iter()) {
if let Some(qualified_name) = checker.semantic().resolve_qualified_name(expr) {
match qualified_name.segments() {
["numpy", "nan" | "NAN" | "NaN"] => {
checker.diagnostics.push(Diagnostic::new(
NanComparison { nan: Nan::NumPy },
expr.range(),
));
}
["math", "nan"] => {
checker.diagnostics.push(Diagnostic::new(
NanComparison { nan: Nan::Math },
expr.range(),
));
}
_ => continue,
}
}

if is_nan_float(expr, checker.semantic()) {
checker.diagnostics.push(Diagnostic::new(
NanComparison { nan: Nan::Math },
expr.range(),
));
}
}
}

fn is_nan_float(expr: &Expr) -> bool {
#[derive(Debug, PartialEq, Eq)]
enum Nan {
/// `math.isnan`
Math,
/// `np.isnan`
NumPy,
}

impl std::fmt::Display for Nan {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Nan::Math => fmt.write_str("math"),
Nan::NumPy => fmt.write_str("numpy"),
}
}
}

/// Returns `true` if the expression is a call to `float("NaN")`.
fn is_nan_float(expr: &Expr, semantic: &SemanticModel) -> bool {
let Expr::Call(call) = expr else {
return false;
};

let Expr::Name(ast::ExprName { id, .. }) = ast::helpers::map_subscript(call.func.as_ref())
else {
let Expr::Name(ast::ExprName { id, .. }) = call.func.as_ref() else {
return false;
};

if id.as_str() != "float" {
return false;
}

let Some(arg) = call.arguments.find_positional(0) else {
if !call.arguments.keywords.is_empty() {
return false;
};

if let Expr::StringLiteral(ast::ExprStringLiteral { value, .. }) = arg {
return value.to_str().to_lowercase() == "nan";
}

false
}
let [arg] = call.arguments.args.as_ref() else {
return false;
};

/// PLW0117
pub(crate) fn nan_comparison(checker: &mut Checker, left: &Expr, comparators: &[Expr]) {
for comparison_expr in std::iter::once(left).chain(comparators.iter()) {
if let Some(qualified_name) = checker.semantic().resolve_qualified_name(comparison_expr) {
let segments = qualified_name.segments();
match segments[0] {
"numpy" => {
if segments[1].to_lowercase() == "nan" {
checker.diagnostics.push(Diagnostic::new(
NanComparison { using_numpy: true },
comparison_expr.range(),
));
}
}
"math" => {
if segments[1] == "nan" {
checker.diagnostics.push(Diagnostic::new(
NanComparison { using_numpy: false },
comparison_expr.range(),
));
}
}
_ => continue,
}
}
let Expr::StringLiteral(ast::ExprStringLiteral { value, .. }) = arg else {
return false;
};

if is_nan_float(comparison_expr) {
checker.diagnostics.push(Diagnostic::new(
NanComparison { using_numpy: false },
comparison_expr.range(),
));
}
if !matches!(
value.to_str(),
"nan" | "NaN" | "NAN" | "Nan" | "nAn" | "naN" | "nAN" | "NAn"
) {
return false;
}

if !semantic.is_builtin("float") {
return false;
}

true
}
@@ -1,79 +1,79 @@
---
source: crates/ruff_linter/src/rules/pylint/mod.rs
---
nan_comparison.py:11:9: PLW0117 Comparing against a NaN value, consider using `math.isnan`
nan_comparison.py:11:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead
|
10 | # PLW0117
11 | if x == float('nan'):
| ^^^^^^^^^^^^ PLW0117
12 | pass
|

nan_comparison.py:15:9: PLW0117 Comparing against a NaN value, consider using `math.isnan`
nan_comparison.py:15:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead
|
14 | # PLW0117
15 | if x == float('NaN'):
| ^^^^^^^^^^^^ PLW0117
16 | pass
|

nan_comparison.py:19:9: PLW0117 Comparing against a NaN value, consider using `math.isnan`
nan_comparison.py:19:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead
|
18 | # PLW0117
19 | if x == float('NAN'):
| ^^^^^^^^^^^^ PLW0117
20 | pass
|

nan_comparison.py:23:9: PLW0117 Comparing against a NaN value, consider using `math.isnan`
nan_comparison.py:23:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead
|
22 | # PLW0117
23 | if x == float('Nan'):
| ^^^^^^^^^^^^ PLW0117
24 | pass
|

nan_comparison.py:27:9: PLW0117 Comparing against a NaN value, consider using `math.isnan`
nan_comparison.py:27:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead
|
26 | # PLW0117
27 | if x == math.nan:
| ^^^^^^^^ PLW0117
28 | pass
|

nan_comparison.py:31:9: PLW0117 Comparing against a NaN value, consider using `math.isnan`
nan_comparison.py:31:9: PLW0117 Comparing against a NaN value; use `math.isnan` instead
|
30 | # PLW0117
31 | if x == bad_val:
| ^^^^^^^ PLW0117
32 | pass
|

nan_comparison.py:35:9: PLW0117 Comparing against a NaN value, consider using `np.isnan`
nan_comparison.py:35:9: PLW0117 Comparing against a NaN value; use `np.isnan` instead
|
34 | # PLW0117
35 | if y == np.NaN:
| ^^^^^^ PLW0117
36 | pass
|

nan_comparison.py:39:9: PLW0117 Comparing against a NaN value, consider using `np.isnan`
nan_comparison.py:39:9: PLW0117 Comparing against a NaN value; use `np.isnan` instead
|
38 | # PLW0117
39 | if y == np.NAN:
| ^^^^^^ PLW0117
40 | pass
|

nan_comparison.py:43:9: PLW0117 Comparing against a NaN value, consider using `np.isnan`
nan_comparison.py:43:9: PLW0117 Comparing against a NaN value; use `np.isnan` instead
|
42 | # PLW0117
43 | if y == np.nan:
| ^^^^^^ PLW0117
44 | pass
|

nan_comparison.py:47:9: PLW0117 Comparing against a NaN value, consider using `np.isnan`
nan_comparison.py:47:9: PLW0117 Comparing against a NaN value; use `np.isnan` instead
|
46 | # PLW0117
47 | if y == npy_nan:
Expand Down

0 comments on commit a4f8a30

Please sign in to comment.