diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py b/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py new file mode 100644 index 0000000000000..b7d58a6f94302 --- /dev/null +++ b/crates/ruff_linter/resources/test/fixtures/flake8_annotations/auto_return_type.py @@ -0,0 +1,32 @@ +def func(): + return 1 + + +def func(): + return 1.5 + + +def func(x: int): + if x > 0: + return 1 + else: + return 1.5 + + +def func(): + return True + + +def func(x: int): + if x > 0: + return None + else: + return + + +def func(x: int): + return 1 or 2.5 if x > 0 else 1.5 or "str" + + +def func(x: int): + return 1 + 2.5 if x > 0 else 1.5 or "str" diff --git a/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs b/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs index 42d13735645dc..318ebf053fac9 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs +++ b/crates/ruff_linter/src/rules/flake8_annotations/helpers.rs @@ -1,5 +1,14 @@ +use itertools::Itertools; + +use ruff_python_ast::helpers::{pep_604_union, ReturnStatementVisitor}; +use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::{self as ast, Expr, ExprContext}; +use ruff_python_semantic::analyze::type_inference::{NumberLike, PythonType, ResolvedPythonType}; use ruff_python_semantic::analyze::visibility; use ruff_python_semantic::{Definition, SemanticModel}; +use ruff_text_size::TextRange; + +use crate::settings::types::PythonVersion; /// Return the name of the function, if it's overloaded. pub(crate) fn overloaded_name(definition: &Definition, semantic: &SemanticModel) -> Option { @@ -27,3 +36,81 @@ pub(crate) fn is_overload_impl( function.name.as_str() == overloaded_name } } + +/// Given a function, guess its return type. +pub(crate) fn auto_return_type( + function: &ast::StmtFunctionDef, + target_version: PythonVersion, +) -> Option { + // Collect all the `return` statements. + let returns = { + let mut visitor = ReturnStatementVisitor::default(); + visitor.visit_body(&function.body); + if visitor.is_generator { + return None; + } + visitor.returns + }; + + // Determine the return type of the first `return` statement. + let (return_statement, returns) = returns.split_first()?; + let mut return_type = return_statement.value.as_deref().map_or( + ResolvedPythonType::Atom(PythonType::None), + ResolvedPythonType::from, + ); + + // Merge the return types of the remaining `return` statements. + for return_statement in returns { + return_type = return_type.union(return_statement.value.as_deref().map_or( + ResolvedPythonType::Atom(PythonType::None), + ResolvedPythonType::from, + )); + } + + match return_type { + ResolvedPythonType::Atom(python_type) => type_expr(python_type), + ResolvedPythonType::Union(python_types) if target_version >= PythonVersion::Py310 => { + // Aggregate all the individual types (e.g., `int`, `float`). + let names = python_types + .iter() + .sorted_unstable() + .filter_map(|python_type| type_expr(*python_type)) + .collect::>(); + + // Wrap in a bitwise union (e.g., `int | float`). + Some(pep_604_union(&names)) + } + ResolvedPythonType::Union(_) => None, + ResolvedPythonType::Unknown => None, + ResolvedPythonType::TypeError => None, + } +} + +/// Given a [`PythonType`], return an [`Expr`] that resolves to that type. +fn type_expr(python_type: PythonType) -> Option { + fn name(name: &str) -> Expr { + Expr::Name(ast::ExprName { + id: name.into(), + range: TextRange::default(), + ctx: ExprContext::Load, + }) + } + + match python_type { + PythonType::String => Some(name("str")), + PythonType::Bytes => Some(name("bytes")), + PythonType::Number(number) => match number { + NumberLike::Integer => Some(name("int")), + NumberLike::Float => Some(name("float")), + NumberLike::Complex => Some(name("complex")), + NumberLike::Bool => Some(name("bool")), + }, + PythonType::None => Some(name("None")), + PythonType::Ellipsis => None, + PythonType::Dict => None, + PythonType::List => None, + PythonType::Set => None, + PythonType::Tuple => None, + PythonType::Generator => None, + } +} diff --git a/crates/ruff_linter/src/rules/flake8_annotations/mod.rs b/crates/ruff_linter/src/rules/flake8_annotations/mod.rs index 19f948c37991d..b7dbf012035aa 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/mod.rs +++ b/crates/ruff_linter/src/rules/flake8_annotations/mod.rs @@ -110,6 +110,24 @@ mod tests { Ok(()) } + #[test] + fn auto_return_type() -> Result<()> { + let diagnostics = test_path( + Path::new("flake8_annotations/auto_return_type.py"), + &LinterSettings { + ..LinterSettings::for_rules(vec![ + Rule::MissingReturnTypeUndocumentedPublicFunction, + Rule::MissingReturnTypePrivateFunction, + Rule::MissingReturnTypeSpecialMethod, + Rule::MissingReturnTypeStaticMethod, + Rule::MissingReturnTypeClassMethod, + ]) + }, + )?; + assert_messages!(diagnostics); + Ok(()) + } + #[test] fn suppress_none_returning() -> Result<()> { let diagnostics = test_path( diff --git a/crates/ruff_linter/src/rules/flake8_annotations/rules/definition.rs b/crates/ruff_linter/src/rules/flake8_annotations/rules/definition.rs index 6f8321a055b5e..3fe862c01991d 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/rules/definition.rs +++ b/crates/ruff_linter/src/rules/flake8_annotations/rules/definition.rs @@ -1,8 +1,8 @@ -use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix, Violation}; +use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::ReturnStatementVisitor; use ruff_python_ast::identifier::Identifier; -use ruff_python_ast::statement_visitor::StatementVisitor; +use ruff_python_ast::visitor::Visitor; use ruff_python_ast::{self as ast, Expr, ParameterWithDefault, Stmt}; use ruff_python_parser::typing::parse_type_annotation; use ruff_python_semantic::analyze::visibility; @@ -12,6 +12,7 @@ use ruff_text_size::Ranged; use crate::checkers::ast::Checker; use crate::registry::Rule; +use crate::rules::flake8_annotations::helpers::auto_return_type; use crate::rules::ruff::typing::type_hint_resolves_to_any; /// ## What it does @@ -41,7 +42,7 @@ pub struct MissingTypeFunctionArgument { impl Violation for MissingTypeFunctionArgument { #[derive_message_formats] fn message(&self) -> String { - let MissingTypeFunctionArgument { name } = self; + let Self { name } = self; format!("Missing type annotation for function argument `{name}`") } } @@ -73,7 +74,7 @@ pub struct MissingTypeArgs { impl Violation for MissingTypeArgs { #[derive_message_formats] fn message(&self) -> String { - let MissingTypeArgs { name } = self; + let Self { name } = self; format!("Missing type annotation for `*{name}`") } } @@ -105,7 +106,7 @@ pub struct MissingTypeKwargs { impl Violation for MissingTypeKwargs { #[derive_message_formats] fn message(&self) -> String { - let MissingTypeKwargs { name } = self; + let Self { name } = self; format!("Missing type annotation for `**{name}`") } } @@ -142,7 +143,7 @@ pub struct MissingTypeSelf { impl Violation for MissingTypeSelf { #[derive_message_formats] fn message(&self) -> String { - let MissingTypeSelf { name } = self; + let Self { name } = self; format!("Missing type annotation for `{name}` in method") } } @@ -181,7 +182,7 @@ pub struct MissingTypeCls { impl Violation for MissingTypeCls { #[derive_message_formats] fn message(&self) -> String { - let MissingTypeCls { name } = self; + let Self { name } = self; format!("Missing type annotation for `{name}` in classmethod") } } @@ -208,14 +209,26 @@ impl Violation for MissingTypeCls { #[violation] pub struct MissingReturnTypeUndocumentedPublicFunction { name: String, + annotation: Option, } impl Violation for MissingReturnTypeUndocumentedPublicFunction { + const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes; + #[derive_message_formats] fn message(&self) -> String { - let MissingReturnTypeUndocumentedPublicFunction { name } = self; + let Self { name, .. } = self; format!("Missing return type annotation for public function `{name}`") } + + fn fix_title(&self) -> Option { + let Self { annotation, .. } = self; + if let Some(annotation) = annotation { + Some(format!("Add return type annotation: `{annotation}`")) + } else { + Some(format!("Add return type annotation")) + } + } } /// ## What it does @@ -240,14 +253,26 @@ impl Violation for MissingReturnTypeUndocumentedPublicFunction { #[violation] pub struct MissingReturnTypePrivateFunction { name: String, + annotation: Option, } impl Violation for MissingReturnTypePrivateFunction { + const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes; + #[derive_message_formats] fn message(&self) -> String { - let MissingReturnTypePrivateFunction { name } = self; + let Self { name, .. } = self; format!("Missing return type annotation for private function `{name}`") } + + fn fix_title(&self) -> Option { + let Self { annotation, .. } = self; + if let Some(annotation) = annotation { + Some(format!("Add return type annotation: `{annotation}`")) + } else { + Some(format!("Add return type annotation")) + } + } } /// ## What it does @@ -285,17 +310,25 @@ impl Violation for MissingReturnTypePrivateFunction { #[violation] pub struct MissingReturnTypeSpecialMethod { name: String, + annotation: Option, } -impl AlwaysFixableViolation for MissingReturnTypeSpecialMethod { +impl Violation for MissingReturnTypeSpecialMethod { + const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes; + #[derive_message_formats] fn message(&self) -> String { - let MissingReturnTypeSpecialMethod { name } = self; + let Self { name, .. } = self; format!("Missing return type annotation for special method `{name}`") } - fn fix_title(&self) -> String { - "Add `None` return type".to_string() + fn fix_title(&self) -> Option { + let Self { annotation, .. } = self; + if let Some(annotation) = annotation { + Some(format!("Add return type annotation: `{annotation}`")) + } else { + Some(format!("Add return type annotation")) + } } } @@ -325,14 +358,26 @@ impl AlwaysFixableViolation for MissingReturnTypeSpecialMethod { #[violation] pub struct MissingReturnTypeStaticMethod { name: String, + annotation: Option, } impl Violation for MissingReturnTypeStaticMethod { + const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes; + #[derive_message_formats] fn message(&self) -> String { - let MissingReturnTypeStaticMethod { name } = self; + let Self { name, .. } = self; format!("Missing return type annotation for staticmethod `{name}`") } + + fn fix_title(&self) -> Option { + let Self { annotation, .. } = self; + if let Some(annotation) = annotation { + Some(format!("Add return type annotation: `{annotation}`")) + } else { + Some(format!("Add return type annotation")) + } + } } /// ## What it does @@ -361,14 +406,26 @@ impl Violation for MissingReturnTypeStaticMethod { #[violation] pub struct MissingReturnTypeClassMethod { name: String, + annotation: Option, } impl Violation for MissingReturnTypeClassMethod { + const FIX_AVAILABILITY: FixAvailability = FixAvailability::Sometimes; + #[derive_message_formats] fn message(&self) -> String { - let MissingReturnTypeClassMethod { name } = self; + let Self { name, .. } = self; format!("Missing return type annotation for classmethod `{name}`") } + + fn fix_title(&self) -> Option { + let Self { annotation, .. } = self; + if let Some(annotation) = annotation { + Some(format!("Add return type annotation: `{annotation}`")) + } else { + Some(format!("Add return type annotation")) + } + } } /// ## What it does @@ -421,7 +478,7 @@ pub struct AnyType { impl Violation for AnyType { #[derive_message_formats] fn message(&self) -> String { - let AnyType { name } = self; + let Self { name } = self; format!("Dynamically typed expressions (typing.Any) are disallowed in `{name}`") } } @@ -673,21 +730,41 @@ pub(crate) fn definition( ) { if is_method && visibility::is_classmethod(decorator_list, checker.semantic()) { if checker.enabled(Rule::MissingReturnTypeClassMethod) { - diagnostics.push(Diagnostic::new( + let return_type = auto_return_type(function, checker.settings.target_version) + .map(|return_type| checker.generator().expr(&return_type)); + let mut diagnostic = Diagnostic::new( MissingReturnTypeClassMethod { name: name.to_string(), + annotation: return_type.clone(), }, function.identifier(), - )); + ); + if let Some(return_type) = return_type { + diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion( + format!(" -> {return_type}"), + function.parameters.range().end(), + ))); + } + diagnostics.push(diagnostic); } } else if is_method && visibility::is_staticmethod(decorator_list, checker.semantic()) { if checker.enabled(Rule::MissingReturnTypeStaticMethod) { - diagnostics.push(Diagnostic::new( + let return_type = auto_return_type(function, checker.settings.target_version) + .map(|return_type| checker.generator().expr(&return_type)); + let mut diagnostic = Diagnostic::new( MissingReturnTypeStaticMethod { name: name.to_string(), + annotation: return_type.clone(), }, function.identifier(), - )); + ); + if let Some(return_type) = return_type { + diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion( + format!(" -> {return_type}"), + function.parameters.range().end(), + ))); + } + diagnostics.push(diagnostic); } } else if is_method && visibility::is_init(name) { // Allow omission of return annotation in `__init__` functions, as long as at @@ -697,6 +774,7 @@ pub(crate) fn definition( let mut diagnostic = Diagnostic::new( MissingReturnTypeSpecialMethod { name: name.to_string(), + annotation: Some("None".to_string()), }, function.identifier(), ); @@ -709,13 +787,15 @@ pub(crate) fn definition( } } else if is_method && visibility::is_magic(name) { if checker.enabled(Rule::MissingReturnTypeSpecialMethod) { + let return_type = simple_magic_return_type(name); let mut diagnostic = Diagnostic::new( MissingReturnTypeSpecialMethod { name: name.to_string(), + annotation: return_type.map(ToString::to_string), }, function.identifier(), ); - if let Some(return_type) = simple_magic_return_type(name) { + if let Some(return_type) = return_type { diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion( format!(" -> {return_type}"), function.parameters.range().end(), @@ -727,22 +807,44 @@ pub(crate) fn definition( match visibility { visibility::Visibility::Public => { if checker.enabled(Rule::MissingReturnTypeUndocumentedPublicFunction) { - diagnostics.push(Diagnostic::new( + let return_type = + auto_return_type(function, checker.settings.target_version) + .map(|return_type| checker.generator().expr(&return_type)); + let mut diagnostic = Diagnostic::new( MissingReturnTypeUndocumentedPublicFunction { name: name.to_string(), + annotation: return_type.clone(), }, function.identifier(), - )); + ); + if let Some(return_type) = return_type { + diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion( + format!(" -> {return_type}"), + function.parameters.range().end(), + ))); + } + diagnostics.push(diagnostic); } } visibility::Visibility::Private => { if checker.enabled(Rule::MissingReturnTypePrivateFunction) { - diagnostics.push(Diagnostic::new( + let return_type = + auto_return_type(function, checker.settings.target_version) + .map(|return_type| checker.generator().expr(&return_type)); + let mut diagnostic = Diagnostic::new( MissingReturnTypePrivateFunction { name: name.to_string(), + annotation: return_type.clone(), }, function.identifier(), - )); + ); + if let Some(return_type) = return_type { + diagnostic.set_fix(Fix::unsafe_edit(Edit::insertion( + format!(" -> {return_type}"), + function.parameters.range().end(), + ))); + } + diagnostics.push(diagnostic); } } } diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__allow_overload.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__allow_overload.snap index b120a666f616b..8db1454b2141f 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__allow_overload.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__allow_overload.snap @@ -8,5 +8,6 @@ allow_overload.py:29:9: ANN201 Missing return type annotation for public functio | ^^^ ANN201 30 | return i | + = help: Add return type annotation diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap new file mode 100644 index 0000000000000..108a1483004c7 --- /dev/null +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__auto_return_type.snap @@ -0,0 +1,127 @@ +--- +source: crates/ruff_linter/src/rules/flake8_annotations/mod.rs +--- +auto_return_type.py:1:5: ANN201 [*] Missing return type annotation for public function `func` + | +1 | def func(): + | ^^^^ ANN201 +2 | return 1 + | + = help: Add return type annotation: `int` + +ℹ Unsafe fix +1 |-def func(): + 1 |+def func() -> int: +2 2 | return 1 +3 3 | +4 4 | + +auto_return_type.py:5:5: ANN201 [*] Missing return type annotation for public function `func` + | +5 | def func(): + | ^^^^ ANN201 +6 | return 1.5 + | + = help: Add return type annotation: `float` + +ℹ Unsafe fix +2 2 | return 1 +3 3 | +4 4 | +5 |-def func(): + 5 |+def func() -> float: +6 6 | return 1.5 +7 7 | +8 8 | + +auto_return_type.py:9:5: ANN201 [*] Missing return type annotation for public function `func` + | + 9 | def func(x: int): + | ^^^^ ANN201 +10 | if x > 0: +11 | return 1 + | + = help: Add return type annotation: `float` + +ℹ Unsafe fix +6 6 | return 1.5 +7 7 | +8 8 | +9 |-def func(x: int): + 9 |+def func(x: int) -> float: +10 10 | if x > 0: +11 11 | return 1 +12 12 | else: + +auto_return_type.py:16:5: ANN201 [*] Missing return type annotation for public function `func` + | +16 | def func(): + | ^^^^ ANN201 +17 | return True + | + = help: Add return type annotation: `bool` + +ℹ Unsafe fix +13 13 | return 1.5 +14 14 | +15 15 | +16 |-def func(): + 16 |+def func() -> bool: +17 17 | return True +18 18 | +19 19 | + +auto_return_type.py:20:5: ANN201 [*] Missing return type annotation for public function `func` + | +20 | def func(x: int): + | ^^^^ ANN201 +21 | if x > 0: +22 | return None + | + = help: Add return type annotation: `None` + +ℹ Unsafe fix +17 17 | return True +18 18 | +19 19 | +20 |-def func(x: int): + 20 |+def func(x: int) -> None: +21 21 | if x > 0: +22 22 | return None +23 23 | else: + +auto_return_type.py:27:5: ANN201 [*] Missing return type annotation for public function `func` + | +27 | def func(x: int): + | ^^^^ ANN201 +28 | return 1 or 2.5 if x > 0 else 1.5 or "str" + | + = help: Add return type annotation: `str | float` + +ℹ Unsafe fix +24 24 | return +25 25 | +26 26 | +27 |-def func(x: int): + 27 |+def func(x: int) -> str | float: +28 28 | return 1 or 2.5 if x > 0 else 1.5 or "str" +29 29 | +30 30 | + +auto_return_type.py:31:5: ANN201 [*] Missing return type annotation for public function `func` + | +31 | def func(x: int): + | ^^^^ ANN201 +32 | return 1 + 2.5 if x > 0 else 1.5 or "str" + | + = help: Add return type annotation: `str | float` + +ℹ Unsafe fix +28 28 | return 1 or 2.5 if x > 0 else 1.5 or "str" +29 29 | +30 30 | +31 |-def func(x: int): + 31 |+def func(x: int) -> str | float: +32 32 | return 1 + 2.5 if x > 0 else 1.5 or "str" + + diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__defaults.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__defaults.snap index e54d5ec02a2af..8665fb92ce192 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__defaults.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__defaults.snap @@ -8,6 +8,7 @@ annotation_presence.py:5:5: ANN201 Missing return type annotation for public fun | ^^^ ANN201 6 | pass | + = help: Add return type annotation annotation_presence.py:5:9: ANN001 Missing type annotation for function argument `a` | @@ -32,6 +33,7 @@ annotation_presence.py:10:5: ANN201 Missing return type annotation for public fu | ^^^ ANN201 11 | pass | + = help: Add return type annotation annotation_presence.py:10:17: ANN001 Missing type annotation for function argument `b` | @@ -56,6 +58,7 @@ annotation_presence.py:20:5: ANN201 Missing return type annotation for public fu | ^^^ ANN201 21 | pass | + = help: Add return type annotation annotation_presence.py:25:5: ANN201 Missing return type annotation for public function `foo` | @@ -64,6 +67,7 @@ annotation_presence.py:25:5: ANN201 Missing return type annotation for public fu | ^^^ ANN201 26 | pass | + = help: Add return type annotation annotation_presence.py:45:12: ANN401 Dynamically typed expressions (typing.Any) are disallowed in `a` | @@ -250,7 +254,7 @@ annotation_presence.py:159:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^ ANN204 160 | ... | - = help: Add `None` return type + = help: Add return type annotation: `None` ℹ Unsafe fix 156 156 | @@ -270,7 +274,7 @@ annotation_presence.py:165:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^ ANN204 166 | print(f"{self.attr=}") | - = help: Add `None` return type + = help: Add return type annotation: `None` ℹ Unsafe fix 162 162 | diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__ignore_fully_untyped.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__ignore_fully_untyped.snap index 70d0e78d8d196..a93908081cec1 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__ignore_fully_untyped.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__ignore_fully_untyped.snap @@ -7,6 +7,7 @@ ignore_fully_untyped.py:24:5: ANN201 Missing return type annotation for public f | ^^^^^^^^^^^^^^^^^^^^^^^ ANN201 25 | pass | + = help: Add return type annotation ignore_fully_untyped.py:24:37: ANN001 Missing type annotation for function argument `b` | @@ -28,6 +29,7 @@ ignore_fully_untyped.py:32:5: ANN201 Missing return type annotation for public f | ^^^^^^^^^^^^^^^^^^^^^^^ ANN201 33 | pass | + = help: Add return type annotation ignore_fully_untyped.py:43:9: ANN201 Missing return type annotation for public function `error_typed_self` | @@ -37,5 +39,6 @@ ignore_fully_untyped.py:43:9: ANN201 Missing return type annotation for public f | ^^^^^^^^^^^^^^^^ ANN201 44 | pass | + = help: Add return type annotation diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__mypy_init_return.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__mypy_init_return.snap index 0f3a069545876..2ce0c94524eb4 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__mypy_init_return.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__mypy_init_return.snap @@ -9,7 +9,7 @@ mypy_init_return.py:5:9: ANN204 [*] Missing return type annotation for special m | ^^^^^^^^ ANN204 6 | ... | - = help: Add `None` return type + = help: Add return type annotation: `None` ℹ Unsafe fix 2 2 | @@ -29,7 +29,7 @@ mypy_init_return.py:11:9: ANN204 [*] Missing return type annotation for special | ^^^^^^^^ ANN204 12 | ... | - = help: Add `None` return type + = help: Add return type annotation: `None` ℹ Unsafe fix 8 8 | @@ -48,6 +48,7 @@ mypy_init_return.py:40:5: ANN202 Missing return type annotation for private func | ^^^^^^^^ ANN202 41 | ... | + = help: Add return type annotation mypy_init_return.py:47:9: ANN204 [*] Missing return type annotation for special method `__init__` | @@ -57,7 +58,7 @@ mypy_init_return.py:47:9: ANN204 [*] Missing return type annotation for special | ^^^^^^^^ ANN204 48 | ... | - = help: Add `None` return type + = help: Add return type annotation: `None` ℹ Unsafe fix 44 44 | # Error – used to be ok for a moment since the mere presence diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__simple_magic_methods.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__simple_magic_methods.snap index a2fca17824024..29f85e2af8423 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__simple_magic_methods.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__simple_magic_methods.snap @@ -8,7 +8,7 @@ simple_magic_methods.py:2:9: ANN204 [*] Missing return type annotation for speci | ^^^^^^^ ANN204 3 | ... | - = help: Add `None` return type + = help: Add return type annotation: `str` ℹ Unsafe fix 1 1 | class Foo: @@ -26,7 +26,7 @@ simple_magic_methods.py:5:9: ANN204 [*] Missing return type annotation for speci | ^^^^^^^^ ANN204 6 | ... | - = help: Add `None` return type + = help: Add return type annotation: `str` ℹ Unsafe fix 2 2 | def __str__(self): @@ -46,7 +46,7 @@ simple_magic_methods.py:8:9: ANN204 [*] Missing return type annotation for speci | ^^^^^^^ ANN204 9 | ... | - = help: Add `None` return type + = help: Add return type annotation: `int` ℹ Unsafe fix 5 5 | def __repr__(self): @@ -66,7 +66,7 @@ simple_magic_methods.py:11:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^^^^^^^^ ANN204 12 | ... | - = help: Add `None` return type + = help: Add return type annotation: `int` ℹ Unsafe fix 8 8 | def __len__(self): @@ -86,7 +86,7 @@ simple_magic_methods.py:14:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^ ANN204 15 | ... | - = help: Add `None` return type + = help: Add return type annotation: `None` ℹ Unsafe fix 11 11 | def __length_hint__(self): @@ -106,7 +106,7 @@ simple_magic_methods.py:17:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^ ANN204 18 | ... | - = help: Add `None` return type + = help: Add return type annotation: `None` ℹ Unsafe fix 14 14 | def __init__(self): @@ -126,7 +126,7 @@ simple_magic_methods.py:20:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^ ANN204 21 | ... | - = help: Add `None` return type + = help: Add return type annotation: `bool` ℹ Unsafe fix 17 17 | def __del__(self): @@ -146,7 +146,7 @@ simple_magic_methods.py:23:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^^ ANN204 24 | ... | - = help: Add `None` return type + = help: Add return type annotation: `bytes` ℹ Unsafe fix 20 20 | def __bool__(self): @@ -166,7 +166,7 @@ simple_magic_methods.py:26:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^^^ ANN204 27 | ... | - = help: Add `None` return type + = help: Add return type annotation: `str` ℹ Unsafe fix 23 23 | def __bytes__(self): @@ -186,7 +186,7 @@ simple_magic_methods.py:29:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^^^^^ ANN204 30 | ... | - = help: Add `None` return type + = help: Add return type annotation: `bool` ℹ Unsafe fix 26 26 | def __format__(self, format_spec): @@ -206,7 +206,7 @@ simple_magic_methods.py:32:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^^^^ ANN204 33 | ... | - = help: Add `None` return type + = help: Add return type annotation: `complex` ℹ Unsafe fix 29 29 | def __contains__(self, item): @@ -226,7 +226,7 @@ simple_magic_methods.py:35:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^ ANN204 36 | ... | - = help: Add `None` return type + = help: Add return type annotation: `int` ℹ Unsafe fix 32 32 | def __complex__(self): @@ -246,7 +246,7 @@ simple_magic_methods.py:38:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^^ ANN204 39 | ... | - = help: Add `None` return type + = help: Add return type annotation: `float` ℹ Unsafe fix 35 35 | def __int__(self): @@ -266,7 +266,7 @@ simple_magic_methods.py:41:9: ANN204 [*] Missing return type annotation for spec | ^^^^^^^^^ ANN204 42 | ... | - = help: Add `None` return type + = help: Add return type annotation: `int` ℹ Unsafe fix 38 38 | def __float__(self): diff --git a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__suppress_none_returning.snap b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__suppress_none_returning.snap index 2b5ddaa81e2eb..07d875aa50649 100644 --- a/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__suppress_none_returning.snap +++ b/crates/ruff_linter/src/rules/flake8_annotations/snapshots/ruff_linter__rules__flake8_annotations__tests__suppress_none_returning.snap @@ -1,15 +1,26 @@ --- source: crates/ruff_linter/src/rules/flake8_annotations/mod.rs --- -suppress_none_returning.py:45:5: ANN201 Missing return type annotation for public function `foo` +suppress_none_returning.py:45:5: ANN201 [*] Missing return type annotation for public function `foo` | 44 | # Error 45 | def foo(): | ^^^ ANN201 46 | return True | + = help: Add return type annotation: `bool` -suppress_none_returning.py:50:5: ANN201 Missing return type annotation for public function `foo` +ℹ Unsafe fix +42 42 | +43 43 | +44 44 | # Error +45 |-def foo(): + 45 |+def foo() -> bool: +46 46 | return True +47 47 | +48 48 | + +suppress_none_returning.py:50:5: ANN201 [*] Missing return type annotation for public function `foo` | 49 | # Error 50 | def foo(): @@ -17,6 +28,17 @@ suppress_none_returning.py:50:5: ANN201 Missing return type annotation for publi 51 | a = 2 + 2 52 | if a == 4: | + = help: Add return type annotation: `bool | None` + +ℹ Unsafe fix +47 47 | +48 48 | +49 49 | # Error +50 |-def foo(): + 50 |+def foo() -> bool | None: +51 51 | a = 2 + 2 +52 52 | if a == 4: +53 53 | return True suppress_none_returning.py:59:9: ANN001 Missing type annotation for function argument `a` | diff --git a/crates/ruff_linter/src/rules/pylint/rules/invalid_str_return.rs b/crates/ruff_linter/src/rules/pylint/rules/invalid_str_return.rs index 60b79c01cdd93..00764c4de338b 100644 --- a/crates/ruff_linter/src/rules/pylint/rules/invalid_str_return.rs +++ b/crates/ruff_linter/src/rules/pylint/rules/invalid_str_return.rs @@ -1,8 +1,8 @@ -use ruff_python_ast::Stmt; - use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::{helpers::ReturnStatementVisitor, statement_visitor::StatementVisitor}; +use ruff_python_ast::helpers::ReturnStatementVisitor; +use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::Stmt; use ruff_python_semantic::analyze::type_inference::{PythonType, ResolvedPythonType}; use ruff_text_size::Ranged; diff --git a/crates/ruff_linter/src/rules/pylint/rules/too_many_return_statements.rs b/crates/ruff_linter/src/rules/pylint/rules/too_many_return_statements.rs index 8fc7fe16e6134..dc84eac2f406f 100644 --- a/crates/ruff_linter/src/rules/pylint/rules/too_many_return_statements.rs +++ b/crates/ruff_linter/src/rules/pylint/rules/too_many_return_statements.rs @@ -4,7 +4,7 @@ use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::ReturnStatementVisitor; use ruff_python_ast::identifier::Identifier; -use ruff_python_ast::statement_visitor::StatementVisitor; +use ruff_python_ast::visitor::Visitor; /// ## What it does /// Checks for functions or methods with too many return statements. diff --git a/crates/ruff_linter/src/rules/pylint/rules/useless_return.rs b/crates/ruff_linter/src/rules/pylint/rules/useless_return.rs index f9ba04041ccfa..6d51bbd4e0d5a 100644 --- a/crates/ruff_linter/src/rules/pylint/rules/useless_return.rs +++ b/crates/ruff_linter/src/rules/pylint/rules/useless_return.rs @@ -1,9 +1,8 @@ -use ruff_python_ast::{self as ast, Expr, Stmt}; - use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Fix}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::helpers::ReturnStatementVisitor; -use ruff_python_ast::statement_visitor::StatementVisitor; +use ruff_python_ast::visitor::Visitor; +use ruff_python_ast::{self as ast, Expr, Stmt}; use ruff_text_size::Ranged; use crate::checkers::ast::Checker; @@ -60,9 +59,11 @@ pub(crate) fn useless_return( // Skip empty functions. return; }; - if !last_stmt.is_return_stmt() { + + // Verify that the last statement is a return statement. + let Stmt::Return(ast::StmtReturn { value, range: _ }) = &last_stmt else { return; - } + }; // Skip functions that consist of a single return statement. if body.len() == 1 { @@ -78,11 +79,6 @@ pub(crate) fn useless_return( } } - // Verify that the last statement is a return statement. - let Stmt::Return(ast::StmtReturn { value, range: _ }) = &last_stmt else { - return; - }; - // Verify that the return statement is either bare or returns `None`. if !value .as_ref() diff --git a/crates/ruff_linter/src/rules/pyupgrade/rules/use_pep604_annotation.rs b/crates/ruff_linter/src/rules/pyupgrade/rules/use_pep604_annotation.rs index 1823e116d7fec..cb8a9600dccfa 100644 --- a/crates/ruff_linter/src/rules/pyupgrade/rules/use_pep604_annotation.rs +++ b/crates/ruff_linter/src/rules/pyupgrade/rules/use_pep604_annotation.rs @@ -1,8 +1,9 @@ use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::{self as ast, Expr, ExprContext, Operator}; +use ruff_python_ast::helpers::{pep_604_optional, pep_604_union}; +use ruff_python_ast::{self as ast, Expr}; use ruff_python_semantic::analyze::typing::Pep604Operator; -use ruff_text_size::{Ranged, TextRange}; +use ruff_text_size::Ranged; use crate::checkers::ast::Checker; use crate::fix::edits::pad; @@ -80,7 +81,7 @@ pub(crate) fn use_pep604_annotation( _ => { diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement( pad( - checker.generator().expr(&optional(slice)), + checker.generator().expr(&pep_604_optional(slice)), expr.range(), checker.locator(), ), @@ -101,7 +102,7 @@ pub(crate) fn use_pep604_annotation( Expr::Tuple(ast::ExprTuple { elts, .. }) => { diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement( pad( - checker.generator().expr(&union(elts)), + checker.generator().expr(&pep_604_union(elts)), expr.range(), checker.locator(), ), @@ -126,36 +127,6 @@ pub(crate) fn use_pep604_annotation( } } -/// Format the expression as a PEP 604-style optional. -fn optional(expr: &Expr) -> Expr { - ast::ExprBinOp { - left: Box::new(expr.clone()), - op: Operator::BitOr, - right: Box::new(Expr::NoneLiteral(ast::ExprNoneLiteral::default())), - range: TextRange::default(), - } - .into() -} - -/// Format the expressions as a PEP 604-style union. -fn union(elts: &[Expr]) -> Expr { - match elts { - [] => Expr::Tuple(ast::ExprTuple { - elts: vec![], - ctx: ExprContext::Load, - range: TextRange::default(), - }), - [Expr::Tuple(ast::ExprTuple { elts, .. })] => union(elts), - [elt] => elt.clone(), - [rest @ .., elt] => Expr::BinOp(ast::ExprBinOp { - left: Box::new(union(rest)), - op: Operator::BitOr, - right: Box::new(union(&[elt.clone()])), - range: TextRange::default(), - }), - } -} - /// Returns `true` if the expression is valid for use in a bitwise union (e.g., `X | Y`). Returns /// `false` for lambdas, yield expressions, and other expressions that are invalid in such a /// context. diff --git a/crates/ruff_linter/src/rules/pyupgrade/rules/use_pep604_isinstance.rs b/crates/ruff_linter/src/rules/pyupgrade/rules/use_pep604_isinstance.rs index 550987f839932..eb8f3dd3090ff 100644 --- a/crates/ruff_linter/src/rules/pyupgrade/rules/use_pep604_isinstance.rs +++ b/crates/ruff_linter/src/rules/pyupgrade/rules/use_pep604_isinstance.rs @@ -2,8 +2,9 @@ use std::fmt; use ruff_diagnostics::{AlwaysFixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_ast::{self as ast, Expr, Operator}; -use ruff_text_size::{Ranged, TextRange}; +use ruff_python_ast::helpers::pep_604_union; +use ruff_python_ast::{self as ast, Expr}; +use ruff_text_size::Ranged; use crate::checkers::ast::Checker; @@ -79,19 +80,6 @@ impl AlwaysFixableViolation for NonPEP604Isinstance { } } -fn union(elts: &[Expr]) -> Expr { - if elts.len() == 1 { - elts[0].clone() - } else { - Expr::BinOp(ast::ExprBinOp { - left: Box::new(union(&elts[..elts.len() - 1])), - op: Operator::BitOr, - right: Box::new(elts[elts.len() - 1].clone()), - range: TextRange::default(), - }) - } -} - /// UP038 pub(crate) fn use_pep604_isinstance( checker: &mut Checker, @@ -120,7 +108,7 @@ pub(crate) fn use_pep604_isinstance( let mut diagnostic = Diagnostic::new(NonPEP604Isinstance { kind }, expr.range()); diagnostic.set_fix(Fix::unsafe_edit(Edit::range_replacement( - checker.generator().expr(&union(elts)), + checker.generator().expr(&pep_604_union(elts)), types.range(), ))); checker.diagnostics.push(diagnostic); diff --git a/crates/ruff_python_ast/src/helpers.rs b/crates/ruff_python_ast/src/helpers.rs index 2e9a13451c9f8..24b63669c1e21 100644 --- a/crates/ruff_python_ast/src/helpers.rs +++ b/crates/ruff_python_ast/src/helpers.rs @@ -1,20 +1,21 @@ use std::borrow::Cow; use std::path::Path; -use ruff_python_trivia::CommentRanges; -use ruff_source_file::Locator; use smallvec::SmallVec; +use ruff_python_trivia::CommentRanges; +use ruff_source_file::Locator; use ruff_text_size::{Ranged, TextRange}; use crate::call_path::CallPath; use crate::parenthesize::parenthesized_range; -use crate::statement_visitor::{walk_body, walk_stmt, StatementVisitor}; +use crate::statement_visitor::StatementVisitor; use crate::visitor::Visitor; -use crate::AnyNodeRef; use crate::{ - self as ast, Arguments, CmpOp, ExceptHandler, Expr, MatchCase, Pattern, Stmt, TypeParam, + self as ast, Arguments, CmpOp, ExceptHandler, Expr, MatchCase, Operator, Pattern, Stmt, + TypeParam, }; +use crate::{AnyNodeRef, ExprContext}; /// Return `true` if the `Stmt` is a compound statement (as opposed to a simple statement). pub const fn is_compound_statement(stmt: &Stmt) -> bool { @@ -882,9 +883,10 @@ pub fn resolve_imported_module_path<'a>( #[derive(Default)] pub struct ReturnStatementVisitor<'a> { pub returns: Vec<&'a ast::StmtReturn>, + pub is_generator: bool, } -impl<'a, 'b> StatementVisitor<'b> for ReturnStatementVisitor<'a> +impl<'a, 'b> Visitor<'b> for ReturnStatementVisitor<'a> where 'b: 'a, { @@ -894,7 +896,15 @@ where // Don't recurse. } Stmt::Return(stmt) => self.returns.push(stmt), - _ => walk_stmt(self, stmt), + _ => crate::visitor::walk_stmt(self, stmt), + } + } + + fn visit_expr(&mut self, expr: &'b Expr) { + if let Expr::Yield(_) | Expr::YieldFrom(_) = expr { + self.is_generator = true; + } else { + crate::visitor::walk_expr(self, expr); } } } @@ -925,7 +935,7 @@ where elif_else_clauses, .. }) => { - walk_body(self, body); + crate::statement_visitor::walk_body(self, body); for clause in elif_else_clauses { self.visit_elif_else_clause(clause); } @@ -933,11 +943,11 @@ where Stmt::While(ast::StmtWhile { body, .. }) | Stmt::With(ast::StmtWith { body, .. }) | Stmt::For(ast::StmtFor { body, .. }) => { - walk_body(self, body); + crate::statement_visitor::walk_body(self, body); } Stmt::Match(ast::StmtMatch { cases, .. }) => { for case in cases { - walk_body(self, &case.body); + crate::statement_visitor::walk_body(self, &case.body); } } _ => {} @@ -1248,6 +1258,36 @@ pub fn generate_comparison( contents } +/// Format the expression as a PEP 604-style optional. +pub fn pep_604_optional(expr: &Expr) -> Expr { + ast::ExprBinOp { + left: Box::new(expr.clone()), + op: Operator::BitOr, + right: Box::new(Expr::NoneLiteral(ast::ExprNoneLiteral::default())), + range: TextRange::default(), + } + .into() +} + +/// Format the expressions as a PEP 604-style union. +pub fn pep_604_union(elts: &[Expr]) -> Expr { + match elts { + [] => Expr::Tuple(ast::ExprTuple { + elts: vec![], + ctx: ExprContext::Load, + range: TextRange::default(), + }), + [Expr::Tuple(ast::ExprTuple { elts, .. })] => pep_604_union(elts), + [elt] => elt.clone(), + [rest @ .., elt] => Expr::BinOp(ast::ExprBinOp { + left: Box::new(pep_604_union(rest)), + op: Operator::BitOr, + right: Box::new(pep_604_union(&[elt.clone()])), + range: TextRange::default(), + }), + } +} + #[cfg(test)] mod tests { use std::borrow::Cow; diff --git a/crates/ruff_python_semantic/src/analyze/type_inference.rs b/crates/ruff_python_semantic/src/analyze/type_inference.rs index 1ef88f57c087d..f5261cd683664 100644 --- a/crates/ruff_python_semantic/src/analyze/type_inference.rs +++ b/crates/ruff_python_semantic/src/analyze/type_inference.rs @@ -24,22 +24,41 @@ impl ResolvedPythonType { (Self::TypeError, _) | (_, Self::TypeError) => Self::TypeError, (Self::Unknown, _) | (_, Self::Unknown) => Self::Unknown, (Self::Atom(a), Self::Atom(b)) => { - if a == b { + if a.is_subtype_of(b) { + Self::Atom(b) + } else if b.is_subtype_of(a) { Self::Atom(a) } else { Self::Union(FxHashSet::from_iter([a, b])) } } (Self::Atom(a), Self::Union(mut b)) => { - b.insert(a); + // If `a` is a subtype of any of the types in `b`, then `a` is + // redundant. + if !b.iter().any(|b_element| a.is_subtype_of(*b_element)) { + b.insert(a); + } Self::Union(b) } (Self::Union(mut a), Self::Atom(b)) => { - a.insert(b); + // If `b` is a subtype of any of the types in `a`, then `b` is + // redundant. + if !a.iter().any(|a_element| b.is_subtype_of(*a_element)) { + a.insert(b); + } Self::Union(a) } (Self::Union(mut a), Self::Union(b)) => { - a.extend(b); + for b_element in b { + // If `b_element` is a subtype of any of the types in `a`, then + // `b_element` is redundant. + if !a + .iter() + .any(|a_element| b_element.is_subtype_of(*a_element)) + { + a.insert(b_element); + } + } Self::Union(a) } } @@ -321,7 +340,7 @@ impl From<&Expr> for ResolvedPythonType { /// such as strings, integers, floats, and containers. It cannot infer the /// types of variables or expressions that are not statically known from /// individual AST nodes alone. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum PythonType { /// A string literal, such as `"hello"`. String, @@ -345,8 +364,48 @@ pub enum PythonType { Generator, } +impl PythonType { + /// Returns `true` if `self` is a subtype of `other`. + fn is_subtype_of(self, other: Self) -> bool { + match (self, other) { + (PythonType::String, PythonType::String) => true, + (PythonType::Bytes, PythonType::Bytes) => true, + (PythonType::None, PythonType::None) => true, + (PythonType::Ellipsis, PythonType::Ellipsis) => true, + // The Numeric Tower (https://peps.python.org/pep-3141/) + (PythonType::Number(NumberLike::Bool), PythonType::Number(NumberLike::Bool)) => true, + (PythonType::Number(NumberLike::Integer), PythonType::Number(NumberLike::Integer)) => { + true + } + (PythonType::Number(NumberLike::Float), PythonType::Number(NumberLike::Float)) => true, + (PythonType::Number(NumberLike::Complex), PythonType::Number(NumberLike::Complex)) => { + true + } + (PythonType::Number(NumberLike::Bool), PythonType::Number(NumberLike::Integer)) => true, + (PythonType::Number(NumberLike::Bool), PythonType::Number(NumberLike::Float)) => true, + (PythonType::Number(NumberLike::Bool), PythonType::Number(NumberLike::Complex)) => true, + (PythonType::Number(NumberLike::Integer), PythonType::Number(NumberLike::Float)) => { + true + } + (PythonType::Number(NumberLike::Integer), PythonType::Number(NumberLike::Complex)) => { + true + } + (PythonType::Number(NumberLike::Float), PythonType::Number(NumberLike::Complex)) => { + true + } + // This simple type hierarchy doesn't support generics. + (PythonType::Dict, PythonType::Dict) => true, + (PythonType::List, PythonType::List) => true, + (PythonType::Set, PythonType::Set) => true, + (PythonType::Tuple, PythonType::Tuple) => true, + (PythonType::Generator, PythonType::Generator) => true, + _ => false, + } + } +} + /// A numeric type, or a type that can be trivially coerced to a numeric type. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum NumberLike { /// An integer literal, such as `1` or `0x1`. Integer, @@ -372,8 +431,6 @@ impl NumberLike { #[cfg(test)] mod tests { - use rustc_hash::FxHashSet; - use ruff_python_ast::Expr; use ruff_python_parser::parse_expression; @@ -410,10 +467,7 @@ mod tests { ); assert_eq!( ResolvedPythonType::from(&parse("1 and True")), - ResolvedPythonType::Union(FxHashSet::from_iter([ - PythonType::Number(NumberLike::Integer), - PythonType::Number(NumberLike::Bool) - ])) + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)) ); // Binary operators. @@ -475,17 +529,11 @@ mod tests { ); assert_eq!( ResolvedPythonType::from(&parse("1 if True else 2.0")), - ResolvedPythonType::Union(FxHashSet::from_iter([ - PythonType::Number(NumberLike::Integer), - PythonType::Number(NumberLike::Float) - ])) + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float)) ); assert_eq!( ResolvedPythonType::from(&parse("1 if True else False")), - ResolvedPythonType::Union(FxHashSet::from_iter([ - PythonType::Number(NumberLike::Integer), - PythonType::Number(NumberLike::Bool) - ])) + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)) ); } }