From ee47848782a47d8261d773f446e2b7021e7a7fb5 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 13 Apr 2025 14:20:20 +0900 Subject: [PATCH 001/105] [red-knot] infer function's return type --- .../resources/mdtest/function/return_type.md | 37 ++++++++++++ crates/ty_python_semantic/src/types.rs | 29 +++++++++- .../ty_python_semantic/src/types/call/bind.rs | 7 ++- crates/ty_python_semantic/src/types/infer.rs | 58 +++++++++++++++---- 4 files changed, 117 insertions(+), 14 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index c3589b2c626d8..b315f1f40bbbc 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -182,6 +182,43 @@ def f(cond: bool) -> int: return 2 ``` +## Inferred return type + +If a function's return type is not annotated, it is inferred. The inferred type is the union of all +possible return types. + +```py +def f(): + return 1 + +reveal_type(f()) # revealed: Literal[1] + +def g(cond: bool): + if cond: + return 1 + else: + return "a" + +reveal_type(g(True)) # revealed: Literal[1, "a"] + +# This function implicitly returns `None`. +def h(x: int, y: str): + if x > 10: + return x + elif x > 5: + return y + +reveal_type(h(1, "a")) # revealed: int | str | None + +def generator(): + yield 1 + yield 2 + return None + +# TODO: Should be `Generator[Literal[1, 2], Any, None]` +reveal_type(generator()) # revealed: None +``` + ## Invalid return type diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9dc4b4cc6d366..8d0d86f630788 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4208,6 +4208,14 @@ impl<'db> Type<'db> { } } + /// Returns the inferred return type of `self` if it is a function literal. + fn inferred_return_type(self, db: &'db dyn Db) -> Option> { + match self { + Type::FunctionLiteral(function_type) => Some(function_type.inferred_return_type(db)), + _ => None, + } + } + /// Calls `self`. Returns a [`CallError`] if `self` is (always or possibly) not callable, or if /// the arguments are not compatible with the formal parameters. /// @@ -4220,7 +4228,9 @@ impl<'db> Type<'db> { argument_types: &CallArgumentTypes<'_, 'db>, ) -> Result, CallError<'db>> { let signatures = self.signatures(db); - Bindings::match_parameters(signatures, argument_types).check_types(db, argument_types) + let inferred_return_ty = || self.inferred_return_type(db).unwrap_or(Type::unknown()); + Bindings::match_parameters(signatures, inferred_return_ty, argument_types) + .check_types(db, argument_types) } /// Look up a dunder method on the meta-type of `self` and call it. @@ -4258,8 +4268,14 @@ impl<'db> Type<'db> { { Symbol::Type(dunder_callable, boundness) => { let signatures = dunder_callable.signatures(db); - let bindings = Bindings::match_parameters(signatures, argument_types) - .check_types(db, argument_types)?; + let inferred_return_ty = || { + dunder_callable + .inferred_return_type(db) + .unwrap_or(Type::unknown()) + }; + let bindings = + Bindings::match_parameters(signatures, inferred_return_ty, argument_types) + .check_types(db, argument_types)?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); } @@ -6785,6 +6801,13 @@ impl<'db> FunctionType<'db> { signature } + /// Infers this function scope's types and returns the inferred return type. + fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> { + let scope = self.body_scope(db); + let inference = infer_scope_types(db, scope); + inference.inferred_return_type(db) + } + pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool { self.known(db) == Some(known_function) } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index d20c5e2faaf2d..99916eb58c6fe 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -57,6 +57,7 @@ impl<'db> Bindings<'db> { /// verify that each argument type is assignable to the corresponding parameter type. pub(crate) fn match_parameters( signatures: Signatures<'db>, + inferred_return_ty: impl Fn() -> Type<'db> + Copy, arguments: &CallArguments<'_>, ) -> Self { let mut argument_forms = vec![None; arguments.len()]; @@ -69,6 +70,7 @@ impl<'db> Bindings<'db> { arguments, &mut argument_forms, &mut conflicting_forms, + inferred_return_ty, ) }) .collect(); @@ -946,6 +948,7 @@ impl<'db> CallableBinding<'db> { arguments: &CallArguments<'_>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], + inferred_return_ty: impl Fn() -> Type<'db> + Copy, ) -> Self { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. @@ -965,6 +968,7 @@ impl<'db> CallableBinding<'db> { arguments.as_ref(), argument_forms, conflicting_forms, + inferred_return_ty, ) }) .collect(); @@ -1148,6 +1152,7 @@ impl<'db> Binding<'db> { arguments: &CallArguments<'_>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], + inferred_return_ty: impl Fn() -> Type<'db>, ) -> Self { let parameters = signature.parameters(); // The parameter that each argument is matched with. @@ -1264,7 +1269,7 @@ impl<'db> Binding<'db> { } Self { - return_ty: signature.return_ty.unwrap_or(Type::unknown()), + return_ty: signature.return_ty.unwrap_or_else(inferred_return_ty), specialization: None, inherited_specialization: None, argument_parameters: argument_parameters.into_boxed_slice(), diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 43c384c84f43a..073d4271d4b9a 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -58,7 +58,7 @@ use crate::semantic_index::narrowing_constraints::ConstraintKey; use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind, ScopedSymbolId, }; -use crate::semantic_index::{semantic_index, EagerSnapshotResult, SemanticIndex}; +use crate::semantic_index::{semantic_index, use_def_map, EagerSnapshotResult, SemanticIndex}; use crate::symbol::{ builtins_module_scope, builtins_symbol, explicit_global_symbol, module_type_implicit_global_symbol, symbol, symbol_from_bindings, symbol_from_declarations, @@ -394,6 +394,11 @@ pub(crate) struct TypeInference<'db> { /// The scope this region is part of. scope: ScopeId<'db>, + /// The returned types of this region (if this is a function body). + /// + /// These are stored in `Vec` to delay the creation of the union type as long as possible. + return_types: Vec>, + /// The fallback type for missing expressions/bindings/declarations. /// /// This is used only when constructing a cycle-recovery `TypeInference`. @@ -409,6 +414,7 @@ impl<'db> TypeInference<'db> { deferred: FxHashSet::default(), diagnostics: TypeCheckDiagnostics::default(), scope, + return_types: vec![], cycle_fallback_type: None, } } @@ -421,6 +427,7 @@ impl<'db> TypeInference<'db> { deferred: FxHashSet::default(), diagnostics: TypeCheckDiagnostics::default(), scope, + return_types: vec![], cycle_fallback_type: Some(cycle_fallback_type), } } @@ -468,12 +475,27 @@ impl<'db> TypeInference<'db> { &self.diagnostics } + /// Returns the inferred return type of this function body (union of all possible return types), + /// or `None` if the region is not a function body. + pub(crate) fn inferred_return_type(&self, db: &'db dyn Db) -> Type<'db> { + let mut union = UnionBuilder::new(db); + for ty in &self.return_types { + union = union.add(*ty); + } + let use_def = use_def_map(db, self.scope); + if use_def.can_implicit_return(db) { + union = union.add(Type::none(db)); + } + union.build() + } + fn shrink_to_fit(&mut self) { self.expressions.shrink_to_fit(); self.bindings.shrink_to_fit(); self.declarations.shrink_to_fit(); self.diagnostics.shrink_to_fit(); self.deferred.shrink_to_fit(); + self.return_types.shrink_to_fit(); } } @@ -4761,7 +4783,12 @@ impl<'db> TypeInferenceBuilder<'db> { } let signatures = callable_type.signatures(self.db()); - let bindings = Bindings::match_parameters(signatures, &call_arguments); + let inferred_return_ty = || { + callable_type + .inferred_return_type(self.db()) + .unwrap_or(Type::unknown()) + }; + let bindings = Bindings::match_parameters(signatures, inferred_return_ty, &call_arguments); let call_argument_types = self.infer_argument_types(arguments, call_arguments, &bindings.argument_forms); @@ -6940,15 +6967,21 @@ impl<'db> TypeInferenceBuilder<'db> { value_ty, generic_context.signature(self.db()), )); - let bindings = match Bindings::match_parameters(signatures, &call_argument_types) - .check_types(self.db(), &call_argument_types) - { - Ok(bindings) => bindings, - Err(CallError(_, bindings)) => { - bindings.report_diagnostics(&self.context, subscript.into()); - return Type::unknown(); - } + let inferred_return_ty = || { + value_ty + .inferred_return_type(self.db()) + .unwrap_or(Type::unknown()) }; + let bindings = + match Bindings::match_parameters(signatures, inferred_return_ty, &call_argument_types) + .check_types(self.db(), &call_argument_types) + { + Ok(bindings) => bindings, + Err(CallError(_, bindings)) => { + bindings.report_diagnostics(&self.context, subscript.into()); + return Type::unknown(); + } + }; let callable = bindings .into_iter() .next() @@ -7371,6 +7404,11 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_region(); self.types.diagnostics = self.context.finish(); self.types.shrink_to_fit(); + self.types.return_types = self + .return_types_and_ranges + .into_iter() + .map(|ty_range| ty_range.ty) + .collect(); self.types } } From dafae0068e556256d55264fe4ab230a8c5a87e68 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 9 May 2025 02:20:17 +0900 Subject: [PATCH 002/105] Do not infer the return type of a recursive function --- crates/ty_python_semantic/src/types.rs | 58 +++++++++++++---- .../ty_python_semantic/src/types/call/bind.rs | 14 +++- crates/ty_python_semantic/src/types/class.rs | 2 +- crates/ty_python_semantic/src/types/infer.rs | 64 ++++++++++++++----- 4 files changed, 104 insertions(+), 34 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 8d0d86f630788..517c9e3e73201 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -2659,7 +2659,11 @@ impl<'db> Type<'db> { if let Symbol::Type(descr_get, descr_get_boundness) = descr_get { let return_ty = descr_get - .try_call(db, &CallArgumentTypes::positional([self, instance, owner])) + .try_call( + db, + &CallArgumentTypes::positional([self, instance, owner]), + None, + ) .map(|bindings| { if descr_get_boundness == Boundness::Bound { bindings.return_type(db) @@ -3077,6 +3081,7 @@ impl<'db> Type<'db> { CallArgumentTypes::positional([Type::StringLiteral( StringLiteralType::new(db, Box::from(name.as_str())), )]), + None, ) .map(|outcome| Symbol::bound(outcome.return_type(db))) // TODO: Handle call errors here. @@ -3195,7 +3200,7 @@ impl<'db> Type<'db> { // runtime there is a fallback to `__len__`, since `__bool__` takes precedence // and a subclass could add a `__bool__` method. - match self.try_call_dunder(db, "__bool__", CallArgumentTypes::none()) { + match self.try_call_dunder(db, "__bool__", CallArgumentTypes::none(), None) { Ok(outcome) => { let return_type = outcome.return_type(db); if !return_type.is_assignable_to(db, KnownClass::Bool.to_instance(db)) { @@ -3386,7 +3391,7 @@ impl<'db> Type<'db> { return usize_len.try_into().ok().map(Type::IntLiteral); } - let return_ty = match self.try_call_dunder(db, "__len__", CallArgumentTypes::none()) { + let return_ty = match self.try_call_dunder(db, "__len__", CallArgumentTypes::none(), None) { Ok(bindings) => bindings.return_type(db), Err(CallDunderError::PossiblyUnbound(bindings)) => bindings.return_type(db), @@ -4209,9 +4214,20 @@ impl<'db> Type<'db> { } /// Returns the inferred return type of `self` if it is a function literal. - fn inferred_return_type(self, db: &'db dyn Db) -> Option> { + fn inferred_return_type( + self, + db: &'db dyn Db, + call_scope: Option, + ) -> Option> { match self { - Type::FunctionLiteral(function_type) => Some(function_type.inferred_return_type(db)), + Type::FunctionLiteral(function_type) + if call_scope.is_some_and(|call_scope| { + !function_type.file(db).is_stub(db) + && call_scope != function_type.body_scope(db) + }) => + { + Some(function_type.inferred_return_type(db)) + } _ => None, } } @@ -4226,9 +4242,13 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>, + call_scope: Option, ) -> Result, CallError<'db>> { let signatures = self.signatures(db); - let inferred_return_ty = || self.inferred_return_type(db).unwrap_or(Type::unknown()); + let inferred_return_ty = || { + self.inferred_return_type(db, call_scope) + .unwrap_or(Type::unknown()) + }; Bindings::match_parameters(signatures, inferred_return_ty, argument_types) .check_types(db, argument_types) } @@ -4242,12 +4262,14 @@ impl<'db> Type<'db> { db: &'db dyn Db, name: &str, mut argument_types: CallArgumentTypes<'_, 'db>, + call_scope: Option, ) -> Result, CallDunderError<'db>> { self.try_call_dunder_with_policy( db, name, &mut argument_types, MemberLookupPolicy::NO_INSTANCE_FALLBACK, + call_scope, ) } @@ -4261,6 +4283,7 @@ impl<'db> Type<'db> { name: &str, argument_types: &mut CallArgumentTypes<'_, 'db>, policy: MemberLookupPolicy, + call_scope: Option, ) -> Result, CallDunderError<'db>> { match self .member_lookup_with_policy(db, name.into(), policy) @@ -4270,7 +4293,7 @@ impl<'db> Type<'db> { let signatures = dunder_callable.signatures(db); let inferred_return_ty = || { dunder_callable - .inferred_return_type(db) + .inferred_return_type(db, call_scope) .unwrap_or(Type::unknown()) }; let bindings = @@ -4318,18 +4341,19 @@ impl<'db> Type<'db> { db, "__getitem__", CallArgumentTypes::positional([KnownClass::Int.to_instance(db)]), + None, ) .map(|dunder_getitem_outcome| dunder_getitem_outcome.return_type(db)) }; let try_call_dunder_next_on_iterator = |iterator: Type<'db>| { iterator - .try_call_dunder(db, "__next__", CallArgumentTypes::none()) + .try_call_dunder(db, "__next__", CallArgumentTypes::none(), None) .map(|dunder_next_outcome| dunder_next_outcome.return_type(db)) }; let dunder_iter_result = self - .try_call_dunder(db, "__iter__", CallArgumentTypes::none()) + .try_call_dunder(db, "__iter__", CallArgumentTypes::none(), None) .map(|dunder_iter_outcome| dunder_iter_outcome.return_type(db)); match dunder_iter_result { @@ -4413,11 +4437,12 @@ impl<'db> Type<'db> { /// pass /// ``` fn try_enter(self, db: &'db dyn Db) -> Result, ContextManagerError<'db>> { - let enter = self.try_call_dunder(db, "__enter__", CallArgumentTypes::none()); + let enter = self.try_call_dunder(db, "__enter__", CallArgumentTypes::none(), None); let exit = self.try_call_dunder( db, "__exit__", CallArgumentTypes::positional([Type::none(db), Type::none(db), Type::none(db)]), + None, ); // TODO: Make use of Protocols when we support it (the manager be assignable to `contextlib.AbstractContextManager`). @@ -4453,6 +4478,7 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, argument_types: CallArgumentTypes<'_, 'db>, + call_scope: Option, ) -> Result, ConstructorCallError<'db>> { debug_assert!(matches!( self, @@ -4515,8 +4541,11 @@ impl<'db> Type<'db> { let new_call_outcome = new_method.and_then(|new_method| { match new_method.symbol.try_call_dunder_get(db, self_type) { Symbol::Type(new_method, boundness) => { - let result = - new_method.try_call(db, argument_types.with_self(Some(self_type)).as_ref()); + let result = new_method.try_call( + db, + argument_types.with_self(Some(self_type)).as_ref(), + call_scope, + ); if boundness == Boundness::PossiblyUnbound { Some(Err(DunderNewCallError::PossiblyUnbound(result.err()))) } else { @@ -4544,7 +4573,7 @@ impl<'db> Type<'db> { .symbol .is_unbound() { - Some(init_ty.try_call_dunder(db, "__init__", argument_types)) + Some(init_ty.try_call_dunder(db, "__init__", argument_types, call_scope)) } else { None }; @@ -5911,7 +5940,7 @@ impl<'db> IterationError<'db> { Self::IterCallError(_, dunder_iter_bindings) => dunder_iter_bindings .return_type(db) - .try_call_dunder(db, "__next__", CallArgumentTypes::none()) + .try_call_dunder(db, "__next__", CallArgumentTypes::none(), None) .map(|dunder_next_outcome| Some(dunder_next_outcome.return_type(db))) .unwrap_or_else(|dunder_next_call_error| dunder_next_call_error.return_type(db)), @@ -6802,6 +6831,7 @@ impl<'db> FunctionType<'db> { } /// Infers this function scope's types and returns the inferred return type. + /// Do not call this method within a function scope itself, i.e. check that the function is not recursive. fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.body_scope(db); let inference = infer_scope_types(db, scope); diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 99916eb58c6fe..4e76333ac2aa2 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -362,7 +362,11 @@ impl<'db> Bindings<'db> { [Some(Type::PropertyInstance(property)), Some(instance), ..] => { if let Some(getter) = property.getter(db) { if let Ok(return_ty) = getter - .try_call(db, &CallArgumentTypes::positional([*instance])) + .try_call( + db, + &CallArgumentTypes::positional([*instance]), + None, + ) .map(|binding| binding.return_type(db)) { overload.set_return_type(return_ty); @@ -391,7 +395,11 @@ impl<'db> Bindings<'db> { [Some(instance), ..] => { if let Some(getter) = property.getter(db) { if let Ok(return_ty) = getter - .try_call(db, &CallArgumentTypes::positional([*instance])) + .try_call( + db, + &CallArgumentTypes::positional([*instance]), + None, + ) .map(|binding| binding.return_type(db)) { overload.set_return_type(return_ty); @@ -420,6 +428,7 @@ impl<'db> Bindings<'db> { if let Err(_call_error) = setter.try_call( db, &CallArgumentTypes::positional([*instance, *value]), + None, ) { overload.errors.push(BindingError::InternalCallError( "calling the setter failed", @@ -439,6 +448,7 @@ impl<'db> Bindings<'db> { if let Err(_call_error) = setter.try_call( db, &CallArgumentTypes::positional([*instance, *value]), + None, ) { overload.errors.push(BindingError::InternalCallError( "calling the setter failed", diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index e99844a9d8a9b..b176d4ddc8ff9 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -878,7 +878,7 @@ impl<'db> ClassLiteral<'db> { // TODO: Other keyword arguments? let arguments = CallArgumentTypes::positional([name, bases, namespace]); - let return_ty_result = match metaclass.try_call(db, &arguments) { + let return_ty_result = match metaclass.try_call(db, &arguments, None) { Ok(bindings) => Ok(bindings.return_type(db)), Err(CallError(CallErrorKind::NotCallable, bindings)) => Err(MetaclassError { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 073d4271d4b9a..6750595be83bc 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -1964,7 +1964,11 @@ impl<'db> TypeInferenceBuilder<'db> { for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() { inferred_ty = match decorator_ty - .try_call(self.db(), &CallArgumentTypes::positional([inferred_ty])) + .try_call( + self.db(), + &CallArgumentTypes::positional([inferred_ty]), + Some(self.scope()), + ) .map(|bindings| bindings.return_type(self.db())) { Ok(return_ty) => return_ty, @@ -3001,6 +3005,7 @@ impl<'db> TypeInferenceBuilder<'db> { object_ty, value_ty, ]), + Some(self.scope()), ) .is_ok(); @@ -3084,6 +3089,7 @@ impl<'db> TypeInferenceBuilder<'db> { ]), MemberLookupPolicy::NO_INSTANCE_FALLBACK | MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, + Some(self.scope()), ); match result { @@ -3142,6 +3148,7 @@ impl<'db> TypeInferenceBuilder<'db> { object_ty, value_ty, ]), + Some(self.scope()), ) .is_ok(); @@ -3544,15 +3551,6 @@ impl<'db> TypeInferenceBuilder<'db> { )); }; - // Fall back to non-augmented binary operator inference. - let mut binary_return_ty = || { - self.infer_binary_expression_type(assignment.into(), false, target_type, value_type, op) - .unwrap_or_else(|| { - report_unsupported_augmented_op(&mut self.context); - Type::unknown() - }) - }; - match target_type { Type::Union(union) => union.map(db, |&elem_type| { self.infer_augmented_op(assignment, elem_type, value_type) @@ -3562,8 +3560,24 @@ impl<'db> TypeInferenceBuilder<'db> { db, op.in_place_dunder(), CallArgumentTypes::positional([value_type]), + Some(self.scope()), ); + // Fall back to non-augmented binary operator inference. + let mut binary_return_ty = || { + self.infer_binary_expression_type( + assignment.into(), + false, + target_type, + value_type, + op, + ) + .unwrap_or_else(|| { + report_unsupported_augmented_op(&mut self.context); + Type::unknown() + }) + }; + match call { Ok(outcome) => outcome.return_type(db), Err(CallDunderError::MethodNotAvailable) => binary_return_ty(), @@ -4774,7 +4788,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_argument_types(arguments, call_arguments, &argument_forms); return callable_type - .try_call_constructor(self.db(), call_argument_types) + .try_call_constructor(self.db(), call_argument_types, Some(self.scope())) .unwrap_or_else(|err| { err.report_diagnostic(&self.context, callable_type, call_expression.into()); err.return_type() @@ -4785,7 +4799,7 @@ impl<'db> TypeInferenceBuilder<'db> { let signatures = callable_type.signatures(self.db()); let inferred_return_ty = || { callable_type - .inferred_return_type(self.db()) + .inferred_return_type(self.db(), Some(self.scope())) .unwrap_or(Type::unknown()) }; let bindings = Bindings::match_parameters(signatures, inferred_return_ty, &call_arguments); @@ -5724,6 +5738,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), unary_dunder_method, CallArgumentTypes::none(), + Some(self.scope()), ) { Ok(outcome) => outcome.return_type(self.db()), Err(e) => { @@ -6071,6 +6086,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), reflected_dunder, CallArgumentTypes::positional([left_ty]), + Some(self.scope()), ) .map(|outcome| outcome.return_type(self.db())) .or_else(|_| { @@ -6079,6 +6095,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), op.dunder(), CallArgumentTypes::positional([right_ty]), + Some(self.scope()), ) .map(|outcome| outcome.return_type(self.db())) }) @@ -6091,6 +6108,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), op.dunder(), CallArgumentTypes::positional([right_ty]), + Some(self.scope()), ) .map(|outcome| outcome.return_type(self.db())) .ok(); @@ -6104,6 +6122,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), op.reflected_dunder(), CallArgumentTypes::positional([left_ty]), + Some(self.scope()), ) .map(|outcome| outcome.return_type(self.db())) .ok() @@ -6743,9 +6762,14 @@ impl<'db> TypeInferenceBuilder<'db> { // The following resource has details about the rich comparison algorithm: // https://snarky.ca/unravelling-rich-comparison-operators/ let call_dunder = |op: RichCompareOperator, left: Type<'db>, right: Type<'db>| { - left.try_call_dunder(db, op.dunder(), CallArgumentTypes::positional([right])) - .map(|outcome| outcome.return_type(db)) - .ok() + left.try_call_dunder( + db, + op.dunder(), + CallArgumentTypes::positional([right]), + Some(self.scope()), + ) + .map(|outcome| outcome.return_type(db)) + .ok() }; // The reflected dunder has priority if the right-hand side is a strict subclass of the left-hand side. @@ -6789,7 +6813,11 @@ impl<'db> TypeInferenceBuilder<'db> { Symbol::Type(contains_dunder, Boundness::Bound) => { // If `__contains__` is available, it is used directly for the membership test. contains_dunder - .try_call(db, &CallArgumentTypes::positional([right, left])) + .try_call( + db, + &CallArgumentTypes::positional([right, left]), + Some(self.scope()), + ) .map(|bindings| bindings.return_type(db)) .ok() } @@ -6969,7 +6997,7 @@ impl<'db> TypeInferenceBuilder<'db> { )); let inferred_return_ty = || { value_ty - .inferred_return_type(self.db()) + .inferred_return_type(self.db(), Some(self.scope())) .unwrap_or(Type::unknown()) }; let bindings = @@ -7177,6 +7205,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), "__getitem__", CallArgumentTypes::positional([slice_ty]), + Some(self.scope()), ) { Ok(outcome) => return outcome.return_type(self.db()), Err(err @ CallDunderError::PossiblyUnbound { .. }) => { @@ -7243,6 +7272,7 @@ impl<'db> TypeInferenceBuilder<'db> { match ty.try_call( self.db(), &CallArgumentTypes::positional([value_ty, slice_ty]), + Some(self.scope()), ) { Ok(bindings) => return bindings.return_type(self.db()), Err(CallError(_, bindings)) => { From a8a35d5db320a4286629d588f052d5bafa85511a Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 9 May 2025 02:50:20 +0900 Subject: [PATCH 003/105] The return type of an ambiguously typed method is `Unknown | ` --- .../resources/mdtest/function/return_type.md | 106 ++++++++++++++++++ crates/ty_python_semantic/src/types.rs | 103 ++++++++++++++++- crates/ty_python_semantic/src/types/infer.rs | 17 ++- .../src/types/signatures.rs | 22 +++- 4 files changed, 240 insertions(+), 8 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index b315f1f40bbbc..3e4e388fd8079 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -184,6 +184,8 @@ def f(cond: bool) -> int: ## Inferred return type +### Free function + If a function's return type is not annotated, it is inferred. The inferred type is the union of all possible return types. @@ -219,6 +221,110 @@ def generator(): reveal_type(generator()) # revealed: None ``` +### Class method + +If a method's return type is not annotated, it is also inferred, but the inferred type is a union of +all possible return types and `Unknown`. This is because a method of a class may be overridden by +its subtypes. For example, if the return type of a method is inferred to be `int`, the type the +coder really intended might be `int | None`, in which case it would be impossible for the overridden +method to return `None`. + +```py +class C: + def f(self): + return 1 + +class D(C): + def f(self): + return None + +reveal_type(C().f()) # revealed: Literal[1] | Unknown +reveal_type(D().f()) # revealed: None | Literal[1] | Unknown +``` + +However, in the following cases, `Unknown` is not included in the inferred return type because there +is no ambiguity in the subclass. + +- The class or the method is marked as `final`. + +```py +from typing import final + +@final +class C: + def f(self): + return 1 + +class D: + @final + def f(self): + return "a" + +reveal_type(C().f()) # revealed: Literal[1] +reveal_type(D().f()) # revealed: Literal["a"] +``` + +- The method overrides the methods of the base classes, and the return types of the base class + methods are known (In this case, the return type of the method is the intersection of the return + types of the methods in the base classes). + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Literal + +class C: + def f(self) -> int: + return 1 + + def g[T](self, x: T) -> T: + return x + +class D(C): + def f(self): + return 2 + + def g(self, x: int): + return 2 + +class E(D): + def f(self): + return 3 + +reveal_type(C().f()) # revealed: int +reveal_type(D().f()) # revealed: int +reveal_type(E().f()) # revealed: int +reveal_type(C().g(1)) # revealed: Literal[1] +reveal_type(D().g(1)) # revealed: int + +class F: + def f(self) -> Literal[1, 2]: + return 2 + +class G: + def f(self) -> Literal[2, 3]: + return 2 + +class H(F, G): + def f(self): + raise NotImplementedError + +reveal_type(H().f()) # revealed: Literal[2] + +class C2[T]: + def f(self, x: T) -> T: + return x + +class D2(C2[int]): + def f(self, x: int): + return x + +reveal_type(D2().f(1)) # revealed: int +``` + ## Invalid return type diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 517c9e3e73201..41f133194f040 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4221,13 +4221,20 @@ impl<'db> Type<'db> { ) -> Option> { match self { Type::FunctionLiteral(function_type) - if call_scope.is_some_and(|call_scope| { - !function_type.file(db).is_stub(db) - && call_scope != function_type.body_scope(db) - }) => + if !function_type.file(db).is_stub(db) + && call_scope + .is_some_and(|call_scope| call_scope != function_type.body_scope(db)) => { Some(function_type.inferred_return_type(db)) } + Type::BoundMethod(method_type) + if !method_type.function(db).file(db).is_stub(db) + && call_scope.is_some_and(|call_scope| { + call_scope != method_type.function(db).body_scope(db) + }) => + { + Some(method_type.inferred_return_type(db)) + } _ => None, } } @@ -6835,7 +6842,7 @@ impl<'db> FunctionType<'db> { fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.body_scope(db); let inference = infer_scope_types(db, scope); - inference.inferred_return_type(db) + inference.inferred_return_type(db, None) } pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool { @@ -7169,6 +7176,92 @@ impl<'db> BoundMethodType<'db> { .map(signatures::Signature::bind_self), )) } + + /// Infers this method scope's types and returns the inferred return type. + pub(crate) fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> { + let scope = self.function(db).body_scope(db); + let inference = infer_scope_types(db, scope); + inference.inferred_return_type(db, Some(self)) + } + + pub(crate) fn is_final(self, db: &'db dyn Db) -> bool { + if self + .function(db) + .has_known_decorator(db, FunctionDecorators::FINAL) + { + return true; + } + let definition_scope = self.function(db).definition(db).scope(db); + let index = semantic_index(db, definition_scope.file(db)); + let class_definition = + index.expect_single_definition(definition_scope.node(db).expect_class()); + let class_ty = binding_type(db, class_definition).expect_class_literal(); + class_ty + .known_function_decorators(db) + .any(|deco| deco == KnownFunction::Final) + } + + /// Returns the compatible return type for this method -- the intersection of the return types of the base class methods. + pub(crate) fn compatible_return_type( + self, + db: &'db dyn Db, + call_scope: Option, + ) -> Option> { + let definition_scope = self.function(db).definition(db).scope(db); + let index = semantic_index(db, definition_scope.file(db)); + let class_definition = + index.expect_single_definition(definition_scope.node(db).expect_class()); + let class = binding_type(db, class_definition).expect_class_type(db); + let (class_lit, _) = class.class_literal(db); + let name = self.function(db).name(db); + + let mut found = false; + let mut intersection = IntersectionBuilder::new(db); + for base in class_lit.explicit_bases(db) { + if let Some(base_function_ty) = base + .member(db, name) + .into_lookup_result() + .ok() + .and_then(|ty| ty.inner_type().into_function_literal()) + { + if let FunctionSignature::Single(base_signature) = base_function_ty.signature(db) { + if let Some(return_ty) = base_signature.return_ty.or_else(|| { + let base_method_ty = + base_function_ty.into_bound_method_type(db, Type::instance(db, class)); + base_method_ty.inferred_return_type(db, call_scope) + }) { + if let Type::TypeVar(return_typevar) = return_ty { + if let FunctionSignature::Single(signature) = + self.function(db).signature(db) + { + if let Ok(specialization) = + base_signature.specialize_with(db, signature) + { + if let Some(return_ty) = + specialization.type_mapping().get(db, return_typevar) + { + found = true; + intersection = intersection.add_positive(return_ty); + } + } + } + } else { + found = true; + intersection = intersection.add_positive(return_ty); + } + } + } else { + // TODO: overloaded methods + } + } + } + + if found { + Some(intersection.build()) + } else { + None + } + } } /// This type represents the set of all callable objects with a certain, possibly overloaded, diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 6750595be83bc..a8e14da594b79 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -113,7 +113,7 @@ use super::string_annotation::{ parse_string_annotation, BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, }; use super::subclass_of::SubclassOfInner; -use super::{BoundSuperError, BoundSuperType, ClassBase}; +use super::{BoundMethodType, BoundSuperError, BoundSuperType, ClassBase}; /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the @@ -477,7 +477,13 @@ impl<'db> TypeInference<'db> { /// Returns the inferred return type of this function body (union of all possible return types), /// or `None` if the region is not a function body. - pub(crate) fn inferred_return_type(&self, db: &'db dyn Db) -> Type<'db> { + /// In the case of methods, the return type of the superclass method is further unioned. + /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. + pub(crate) fn inferred_return_type( + &self, + db: &'db dyn Db, + method_ty: Option>, + ) -> Type<'db> { let mut union = UnionBuilder::new(db); for ty in &self.return_types { union = union.add(*ty); @@ -486,6 +492,13 @@ impl<'db> TypeInference<'db> { if use_def.can_implicit_return(db) { union = union.add(Type::none(db)); } + if let Some(method_ty) = method_ty { + if let Some(return_ty) = method_ty.compatible_return_type(db, Some(self.scope)) { + union = union.add(return_ty); + } else if !method_ty.is_final(db) { + union = union.add(Type::unknown()); + } + } union.build() } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 7cbead0b751c7..a6c4ae375a3ed 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -17,7 +17,9 @@ use smallvec::{smallvec, SmallVec}; use super::{definition_expression_type, DynamicType, Type}; use crate::semantic_index::definition::Definition; -use crate::types::generics::{GenericContext, Specialization, TypeMapping}; +use crate::types::generics::{ + GenericContext, Specialization, SpecializationBuilder, SpecializationError, TypeMapping, +}; use crate::types::{todo_type, ClassLiteral, TypeVarInstance}; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -332,6 +334,24 @@ impl<'db> Signature<'db> { } } + pub(crate) fn specialize_with( + &self, + db: &'db dyn Db, + signature: &Signature<'db>, + ) -> Result, SpecializationError<'db>> { + debug_assert!( + self.generic_context.is_some(), + "Cannot specialize a signature without a generic context" + ); + let mut specialization = SpecializationBuilder::new(db); + for (self_param, param) in self.parameters().iter().zip(signature.parameters().iter()) { + if let Some((self_ty, ty)) = self_param.annotated_type().zip(param.annotated_type()) { + specialization.infer(self_ty, ty)?; + } + } + Ok(specialization.build(self.generic_context.unwrap())) + } + pub(crate) fn find_legacy_typevars( &self, db: &'db dyn Db, From 391ee913b1b8c1a9809cff1970cbdb415848a835 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 11 May 2025 21:29:21 +0900 Subject: [PATCH 004/105] infer recursive function's return type --- .../resources/mdtest/function/return_type.md | 29 +++++ crates/ty_python_semantic/src/types.rs | 107 +++++++++--------- .../ty_python_semantic/src/types/call/bind.rs | 14 +-- crates/ty_python_semantic/src/types/class.rs | 2 +- crates/ty_python_semantic/src/types/infer.rs | 66 +++-------- 5 files changed, 103 insertions(+), 115 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 3e4e388fd8079..75eb883a45d39 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -221,6 +221,35 @@ def generator(): reveal_type(generator()) # revealed: None ``` +The return type of a recursive function is also inferred. + +```py +def fibonacci(n: int): + if n == 0: + return 0 + elif n == 1: + return 1 + else: + return fibonacci(n - 1) + fibonacci(n - 2) + +reveal_type(fibonacci(5)) # revealed: int + +def even(n: int): + if n == 0: + return True + else: + return odd(n - 1) + +def odd(n: int): + if n == 0: + return False + else: + return even(n - 1) + +reveal_type(even(1)) # revealed: bool +reveal_type(odd(1)) # revealed: bool +``` + ### Class method If a method's return type is not annotated, it is also inferred, but the inferred type is a union of diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 41f133194f040..9e0c3c77cfb0b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -81,6 +81,38 @@ mod definition; #[cfg(test)] mod property_tests; +fn function_return_type_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: FunctionType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn function_return_type_cycle_initial<'db>( + _db: &'db dyn Db, + _self: FunctionType<'db>, +) -> Type<'db> { + Type::Never +} + +fn method_return_type_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: BoundMethodType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn method_return_type_cycle_initial<'db>( + _db: &'db dyn Db, + _self: BoundMethodType<'db>, +) -> Type<'db> { + Type::Never +} + #[salsa::tracked(returns(ref))] pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics { let _span = tracing::trace_span!("check_types", ?file).entered(); @@ -2659,11 +2691,7 @@ impl<'db> Type<'db> { if let Symbol::Type(descr_get, descr_get_boundness) = descr_get { let return_ty = descr_get - .try_call( - db, - &CallArgumentTypes::positional([self, instance, owner]), - None, - ) + .try_call(db, &CallArgumentTypes::positional([self, instance, owner])) .map(|bindings| { if descr_get_boundness == Boundness::Bound { bindings.return_type(db) @@ -3081,7 +3109,6 @@ impl<'db> Type<'db> { CallArgumentTypes::positional([Type::StringLiteral( StringLiteralType::new(db, Box::from(name.as_str())), )]), - None, ) .map(|outcome| Symbol::bound(outcome.return_type(db))) // TODO: Handle call errors here. @@ -3200,7 +3227,7 @@ impl<'db> Type<'db> { // runtime there is a fallback to `__len__`, since `__bool__` takes precedence // and a subclass could add a `__bool__` method. - match self.try_call_dunder(db, "__bool__", CallArgumentTypes::none(), None) { + match self.try_call_dunder(db, "__bool__", CallArgumentTypes::none()) { Ok(outcome) => { let return_type = outcome.return_type(db); if !return_type.is_assignable_to(db, KnownClass::Bool.to_instance(db)) { @@ -3391,7 +3418,7 @@ impl<'db> Type<'db> { return usize_len.try_into().ok().map(Type::IntLiteral); } - let return_ty = match self.try_call_dunder(db, "__len__", CallArgumentTypes::none(), None) { + let return_ty = match self.try_call_dunder(db, "__len__", CallArgumentTypes::none()) { Ok(bindings) => bindings.return_type(db), Err(CallDunderError::PossiblyUnbound(bindings)) => bindings.return_type(db), @@ -4213,26 +4240,13 @@ impl<'db> Type<'db> { } } - /// Returns the inferred return type of `self` if it is a function literal. - fn inferred_return_type( - self, - db: &'db dyn Db, - call_scope: Option, - ) -> Option> { + /// Returns the inferred return type of `self` if it is a function literal / bounded method. + fn inferred_return_type(self, db: &'db dyn Db) -> Option> { match self { - Type::FunctionLiteral(function_type) - if !function_type.file(db).is_stub(db) - && call_scope - .is_some_and(|call_scope| call_scope != function_type.body_scope(db)) => - { + Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db) => { Some(function_type.inferred_return_type(db)) } - Type::BoundMethod(method_type) - if !method_type.function(db).file(db).is_stub(db) - && call_scope.is_some_and(|call_scope| { - call_scope != method_type.function(db).body_scope(db) - }) => - { + Type::BoundMethod(method_type) if !method_type.function(db).file(db).is_stub(db) => { Some(method_type.inferred_return_type(db)) } _ => None, @@ -4249,13 +4263,9 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>, - call_scope: Option, ) -> Result, CallError<'db>> { let signatures = self.signatures(db); - let inferred_return_ty = || { - self.inferred_return_type(db, call_scope) - .unwrap_or(Type::unknown()) - }; + let inferred_return_ty = || self.inferred_return_type(db).unwrap_or(Type::unknown()); Bindings::match_parameters(signatures, inferred_return_ty, argument_types) .check_types(db, argument_types) } @@ -4269,14 +4279,12 @@ impl<'db> Type<'db> { db: &'db dyn Db, name: &str, mut argument_types: CallArgumentTypes<'_, 'db>, - call_scope: Option, ) -> Result, CallDunderError<'db>> { self.try_call_dunder_with_policy( db, name, &mut argument_types, MemberLookupPolicy::NO_INSTANCE_FALLBACK, - call_scope, ) } @@ -4290,7 +4298,6 @@ impl<'db> Type<'db> { name: &str, argument_types: &mut CallArgumentTypes<'_, 'db>, policy: MemberLookupPolicy, - call_scope: Option, ) -> Result, CallDunderError<'db>> { match self .member_lookup_with_policy(db, name.into(), policy) @@ -4300,7 +4307,7 @@ impl<'db> Type<'db> { let signatures = dunder_callable.signatures(db); let inferred_return_ty = || { dunder_callable - .inferred_return_type(db, call_scope) + .inferred_return_type(db) .unwrap_or(Type::unknown()) }; let bindings = @@ -4348,19 +4355,18 @@ impl<'db> Type<'db> { db, "__getitem__", CallArgumentTypes::positional([KnownClass::Int.to_instance(db)]), - None, ) .map(|dunder_getitem_outcome| dunder_getitem_outcome.return_type(db)) }; let try_call_dunder_next_on_iterator = |iterator: Type<'db>| { iterator - .try_call_dunder(db, "__next__", CallArgumentTypes::none(), None) + .try_call_dunder(db, "__next__", CallArgumentTypes::none()) .map(|dunder_next_outcome| dunder_next_outcome.return_type(db)) }; let dunder_iter_result = self - .try_call_dunder(db, "__iter__", CallArgumentTypes::none(), None) + .try_call_dunder(db, "__iter__", CallArgumentTypes::none()) .map(|dunder_iter_outcome| dunder_iter_outcome.return_type(db)); match dunder_iter_result { @@ -4444,12 +4450,11 @@ impl<'db> Type<'db> { /// pass /// ``` fn try_enter(self, db: &'db dyn Db) -> Result, ContextManagerError<'db>> { - let enter = self.try_call_dunder(db, "__enter__", CallArgumentTypes::none(), None); + let enter = self.try_call_dunder(db, "__enter__", CallArgumentTypes::none()); let exit = self.try_call_dunder( db, "__exit__", CallArgumentTypes::positional([Type::none(db), Type::none(db), Type::none(db)]), - None, ); // TODO: Make use of Protocols when we support it (the manager be assignable to `contextlib.AbstractContextManager`). @@ -4485,7 +4490,6 @@ impl<'db> Type<'db> { self, db: &'db dyn Db, argument_types: CallArgumentTypes<'_, 'db>, - call_scope: Option, ) -> Result, ConstructorCallError<'db>> { debug_assert!(matches!( self, @@ -4548,11 +4552,8 @@ impl<'db> Type<'db> { let new_call_outcome = new_method.and_then(|new_method| { match new_method.symbol.try_call_dunder_get(db, self_type) { Symbol::Type(new_method, boundness) => { - let result = new_method.try_call( - db, - argument_types.with_self(Some(self_type)).as_ref(), - call_scope, - ); + let result = + new_method.try_call(db, argument_types.with_self(Some(self_type)).as_ref()); if boundness == Boundness::PossiblyUnbound { Some(Err(DunderNewCallError::PossiblyUnbound(result.err()))) } else { @@ -4580,7 +4581,7 @@ impl<'db> Type<'db> { .symbol .is_unbound() { - Some(init_ty.try_call_dunder(db, "__init__", argument_types, call_scope)) + Some(init_ty.try_call_dunder(db, "__init__", argument_types)) } else { None }; @@ -5947,7 +5948,7 @@ impl<'db> IterationError<'db> { Self::IterCallError(_, dunder_iter_bindings) => dunder_iter_bindings .return_type(db) - .try_call_dunder(db, "__next__", CallArgumentTypes::none(), None) + .try_call_dunder(db, "__next__", CallArgumentTypes::none()) .map(|dunder_next_outcome| Some(dunder_next_outcome.return_type(db))) .unwrap_or_else(|dunder_next_call_error| dunder_next_call_error.return_type(db)), @@ -6838,7 +6839,7 @@ impl<'db> FunctionType<'db> { } /// Infers this function scope's types and returns the inferred return type. - /// Do not call this method within a function scope itself, i.e. check that the function is not recursive. + #[salsa::tracked(cycle_fn=function_return_type_cycle_recover, cycle_initial=function_return_type_cycle_initial)] fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.body_scope(db); let inference = infer_scope_types(db, scope); @@ -7166,6 +7167,7 @@ pub struct BoundMethodType<'db> { self_instance: Type<'db>, } +#[salsa::tracked] impl<'db> BoundMethodType<'db> { pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> { Type::Callable(CallableType::from_overloads( @@ -7178,6 +7180,7 @@ impl<'db> BoundMethodType<'db> { } /// Infers this method scope's types and returns the inferred return type. + #[salsa::tracked(cycle_fn=method_return_type_cycle_recover, cycle_initial=method_return_type_cycle_initial)] pub(crate) fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.function(db).body_scope(db); let inference = infer_scope_types(db, scope); @@ -7202,11 +7205,7 @@ impl<'db> BoundMethodType<'db> { } /// Returns the compatible return type for this method -- the intersection of the return types of the base class methods. - pub(crate) fn compatible_return_type( - self, - db: &'db dyn Db, - call_scope: Option, - ) -> Option> { + pub(crate) fn compatible_return_type(self, db: &'db dyn Db) -> Option> { let definition_scope = self.function(db).definition(db).scope(db); let index = semantic_index(db, definition_scope.file(db)); let class_definition = @@ -7228,7 +7227,7 @@ impl<'db> BoundMethodType<'db> { if let Some(return_ty) = base_signature.return_ty.or_else(|| { let base_method_ty = base_function_ty.into_bound_method_type(db, Type::instance(db, class)); - base_method_ty.inferred_return_type(db, call_scope) + base_method_ty.inferred_return_type(db) }) { if let Type::TypeVar(return_typevar) = return_ty { if let FunctionSignature::Single(signature) = diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 4e76333ac2aa2..99916eb58c6fe 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -362,11 +362,7 @@ impl<'db> Bindings<'db> { [Some(Type::PropertyInstance(property)), Some(instance), ..] => { if let Some(getter) = property.getter(db) { if let Ok(return_ty) = getter - .try_call( - db, - &CallArgumentTypes::positional([*instance]), - None, - ) + .try_call(db, &CallArgumentTypes::positional([*instance])) .map(|binding| binding.return_type(db)) { overload.set_return_type(return_ty); @@ -395,11 +391,7 @@ impl<'db> Bindings<'db> { [Some(instance), ..] => { if let Some(getter) = property.getter(db) { if let Ok(return_ty) = getter - .try_call( - db, - &CallArgumentTypes::positional([*instance]), - None, - ) + .try_call(db, &CallArgumentTypes::positional([*instance])) .map(|binding| binding.return_type(db)) { overload.set_return_type(return_ty); @@ -428,7 +420,6 @@ impl<'db> Bindings<'db> { if let Err(_call_error) = setter.try_call( db, &CallArgumentTypes::positional([*instance, *value]), - None, ) { overload.errors.push(BindingError::InternalCallError( "calling the setter failed", @@ -448,7 +439,6 @@ impl<'db> Bindings<'db> { if let Err(_call_error) = setter.try_call( db, &CallArgumentTypes::positional([*instance, *value]), - None, ) { overload.errors.push(BindingError::InternalCallError( "calling the setter failed", diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index b176d4ddc8ff9..e99844a9d8a9b 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -878,7 +878,7 @@ impl<'db> ClassLiteral<'db> { // TODO: Other keyword arguments? let arguments = CallArgumentTypes::positional([name, bases, namespace]); - let return_ty_result = match metaclass.try_call(db, &arguments, None) { + let return_ty_result = match metaclass.try_call(db, &arguments) { Ok(bindings) => Ok(bindings.return_type(db)), Err(CallError(CallErrorKind::NotCallable, bindings)) => Err(MetaclassError { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index a8e14da594b79..9604a481ca5b4 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -493,7 +493,7 @@ impl<'db> TypeInference<'db> { union = union.add(Type::none(db)); } if let Some(method_ty) = method_ty { - if let Some(return_ty) = method_ty.compatible_return_type(db, Some(self.scope)) { + if let Some(return_ty) = method_ty.compatible_return_type(db) { union = union.add(return_ty); } else if !method_ty.is_final(db) { union = union.add(Type::unknown()); @@ -1977,11 +1977,7 @@ impl<'db> TypeInferenceBuilder<'db> { for (decorator_ty, decorator_node) in decorator_types_and_nodes.iter().rev() { inferred_ty = match decorator_ty - .try_call( - self.db(), - &CallArgumentTypes::positional([inferred_ty]), - Some(self.scope()), - ) + .try_call(self.db(), &CallArgumentTypes::positional([inferred_ty])) .map(|bindings| bindings.return_type(self.db())) { Ok(return_ty) => return_ty, @@ -3018,7 +3014,6 @@ impl<'db> TypeInferenceBuilder<'db> { object_ty, value_ty, ]), - Some(self.scope()), ) .is_ok(); @@ -3102,7 +3097,6 @@ impl<'db> TypeInferenceBuilder<'db> { ]), MemberLookupPolicy::NO_INSTANCE_FALLBACK | MemberLookupPolicy::MRO_NO_OBJECT_FALLBACK, - Some(self.scope()), ); match result { @@ -3161,7 +3155,6 @@ impl<'db> TypeInferenceBuilder<'db> { object_ty, value_ty, ]), - Some(self.scope()), ) .is_ok(); @@ -3564,6 +3557,15 @@ impl<'db> TypeInferenceBuilder<'db> { )); }; + // Fall back to non-augmented binary operator inference. + let mut binary_return_ty = || { + self.infer_binary_expression_type(assignment.into(), false, target_type, value_type, op) + .unwrap_or_else(|| { + report_unsupported_augmented_op(&mut self.context); + Type::unknown() + }) + }; + match target_type { Type::Union(union) => union.map(db, |&elem_type| { self.infer_augmented_op(assignment, elem_type, value_type) @@ -3573,24 +3575,8 @@ impl<'db> TypeInferenceBuilder<'db> { db, op.in_place_dunder(), CallArgumentTypes::positional([value_type]), - Some(self.scope()), ); - // Fall back to non-augmented binary operator inference. - let mut binary_return_ty = || { - self.infer_binary_expression_type( - assignment.into(), - false, - target_type, - value_type, - op, - ) - .unwrap_or_else(|| { - report_unsupported_augmented_op(&mut self.context); - Type::unknown() - }) - }; - match call { Ok(outcome) => outcome.return_type(db), Err(CallDunderError::MethodNotAvailable) => binary_return_ty(), @@ -4801,7 +4787,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_argument_types(arguments, call_arguments, &argument_forms); return callable_type - .try_call_constructor(self.db(), call_argument_types, Some(self.scope())) + .try_call_constructor(self.db(), call_argument_types) .unwrap_or_else(|err| { err.report_diagnostic(&self.context, callable_type, call_expression.into()); err.return_type() @@ -4812,7 +4798,7 @@ impl<'db> TypeInferenceBuilder<'db> { let signatures = callable_type.signatures(self.db()); let inferred_return_ty = || { callable_type - .inferred_return_type(self.db(), Some(self.scope())) + .inferred_return_type(self.db()) .unwrap_or(Type::unknown()) }; let bindings = Bindings::match_parameters(signatures, inferred_return_ty, &call_arguments); @@ -5751,7 +5737,6 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), unary_dunder_method, CallArgumentTypes::none(), - Some(self.scope()), ) { Ok(outcome) => outcome.return_type(self.db()), Err(e) => { @@ -6099,7 +6084,6 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), reflected_dunder, CallArgumentTypes::positional([left_ty]), - Some(self.scope()), ) .map(|outcome| outcome.return_type(self.db())) .or_else(|_| { @@ -6108,7 +6092,6 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), op.dunder(), CallArgumentTypes::positional([right_ty]), - Some(self.scope()), ) .map(|outcome| outcome.return_type(self.db())) }) @@ -6121,7 +6104,6 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), op.dunder(), CallArgumentTypes::positional([right_ty]), - Some(self.scope()), ) .map(|outcome| outcome.return_type(self.db())) .ok(); @@ -6135,7 +6117,6 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), op.reflected_dunder(), CallArgumentTypes::positional([left_ty]), - Some(self.scope()), ) .map(|outcome| outcome.return_type(self.db())) .ok() @@ -6775,14 +6756,9 @@ impl<'db> TypeInferenceBuilder<'db> { // The following resource has details about the rich comparison algorithm: // https://snarky.ca/unravelling-rich-comparison-operators/ let call_dunder = |op: RichCompareOperator, left: Type<'db>, right: Type<'db>| { - left.try_call_dunder( - db, - op.dunder(), - CallArgumentTypes::positional([right]), - Some(self.scope()), - ) - .map(|outcome| outcome.return_type(db)) - .ok() + left.try_call_dunder(db, op.dunder(), CallArgumentTypes::positional([right])) + .map(|outcome| outcome.return_type(db)) + .ok() }; // The reflected dunder has priority if the right-hand side is a strict subclass of the left-hand side. @@ -6826,11 +6802,7 @@ impl<'db> TypeInferenceBuilder<'db> { Symbol::Type(contains_dunder, Boundness::Bound) => { // If `__contains__` is available, it is used directly for the membership test. contains_dunder - .try_call( - db, - &CallArgumentTypes::positional([right, left]), - Some(self.scope()), - ) + .try_call(db, &CallArgumentTypes::positional([right, left])) .map(|bindings| bindings.return_type(db)) .ok() } @@ -7010,7 +6982,7 @@ impl<'db> TypeInferenceBuilder<'db> { )); let inferred_return_ty = || { value_ty - .inferred_return_type(self.db(), Some(self.scope())) + .inferred_return_type(self.db()) .unwrap_or(Type::unknown()) }; let bindings = @@ -7218,7 +7190,6 @@ impl<'db> TypeInferenceBuilder<'db> { self.db(), "__getitem__", CallArgumentTypes::positional([slice_ty]), - Some(self.scope()), ) { Ok(outcome) => return outcome.return_type(self.db()), Err(err @ CallDunderError::PossiblyUnbound { .. }) => { @@ -7285,7 +7256,6 @@ impl<'db> TypeInferenceBuilder<'db> { match ty.try_call( self.db(), &CallArgumentTypes::positional([value_ty, slice_ty]), - Some(self.scope()), ) { Ok(bindings) => return bindings.return_type(self.db()), Err(CallError(_, bindings)) => { From 8745f3ac8252f3b98b8fd8273298b8ae0c3cb091 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Thu, 29 May 2025 00:08:24 +0900 Subject: [PATCH 005/105] Apply suggestions from code review Co-authored-by: Carl Meyer --- .../resources/mdtest/function/return_type.md | 3 +++ crates/ty_python_semantic/src/types.rs | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 51cb7fc7c318a..d667b9352485e 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -341,6 +341,9 @@ class H(F, G): def f(self): raise NotImplementedError +# We can only reveal `Literal[2]` here, not `Never`, because of potential +# subclasses of `H`, which are bound by the annotated return types of +# `F.f` and `G.f`, but are not bound by our inference on `H.f`. reveal_type(H().f()) # revealed: Literal[2] class C2[T]: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 015fc27ec7ffa..d930989ba7a19 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4425,7 +4425,7 @@ impl<'db> Type<'db> { } } - /// Returns the inferred return type of `self` if it is a function literal / bounded method. + /// Returns the inferred return type of `self` if it is a function literal / bound method. fn inferred_return_type(self, db: &'db dyn Db) -> Option> { match self { Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db) => { From 18836cb356acd5cc51283235cbdc964e8b0b0250 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 29 May 2025 15:03:39 +0900 Subject: [PATCH 006/105] Update types.rs --- crates/ty_python_semantic/src/types.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index d930989ba7a19..4d30867b766a7 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4428,10 +4428,14 @@ impl<'db> Type<'db> { /// Returns the inferred return type of `self` if it is a function literal / bound method. fn inferred_return_type(self, db: &'db dyn Db) -> Option> { match self { - Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db) => { + Type::FunctionLiteral(function_type) + if !function_type.file(db).is_stub(db.upcast()) => + { Some(function_type.inferred_return_type(db)) } - Type::BoundMethod(method_type) if !method_type.function(db).file(db).is_stub(db) => { + Type::BoundMethod(method_type) + if !method_type.function(db).file(db).is_stub(db.upcast()) => + { Some(method_type.inferred_return_type(db)) } _ => None, From 2c0cbb46d5782f924dac31c052d984d8529687a9 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 9 Jun 2025 01:52:01 +0900 Subject: [PATCH 007/105] fix according to the review comments --- .../resources/mdtest/function/return_type.md | 10 +++++++++- crates/ty_python_semantic/src/types/infer.rs | 10 +++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index ebd716d9c2d56..0227d01480ee6 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -315,7 +315,8 @@ class C: class D(C): def f(self): return 2 - + # TODO: This should be an invalid-override error. + # If the override is invalid, the type of the method should be that of the base class method. def g(self, x: int): return 2 @@ -327,6 +328,7 @@ reveal_type(C().f()) # revealed: int reveal_type(D().f()) # revealed: int reveal_type(E().f()) # revealed: int reveal_type(C().g(1)) # revealed: Literal[1] +# TODO: should be `Literal[1]` reveal_type(D().g(1)) # revealed: int class F: @@ -341,10 +343,16 @@ class H(F, G): def f(self): raise NotImplementedError +class I(F, G): + @final + def f(self): + raise NotImplementedError + # We can only reveal `Literal[2]` here, not `Never`, because of potential # subclasses of `H`, which are bound by the annotated return types of # `F.f` and `G.f`, but are not bound by our inference on `H.f`. reveal_type(H().f()) # revealed: Literal[2] +reveal_type(I().f()) # revealed: Never class C2[T]: def f(self, x: T) -> T: diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 10faacb84b814..ee05fcafd2b10 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -535,10 +535,14 @@ impl<'db> TypeInference<'db> { union = union.add(Type::none(db)); } if let Some(method_ty) = method_ty { - if let Some(return_ty) = method_ty.compatible_return_type(db) { + // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. + // If any method in a base class does not have an annotated return type, `compatible_return_type` will include `Unknown`. + // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. + if !method_ty.is_final(db) { + let return_ty = method_ty + .compatible_return_type(db) + .unwrap_or(Type::unknown()); union = union.add(return_ty); - } else if !method_ty.is_final(db) { - union = union.add(Type::unknown()); } } union.build() From 4d611d9926e5df26e5097c00c7880a7a32124161 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 21 Jun 2025 15:35:45 +0900 Subject: [PATCH 008/105] refactor: `inferred_return_ty` -> `infer_return_type` --- crates/ty_python_semantic/src/types.rs | 22 +++++++++---------- .../ty_python_semantic/src/types/call/bind.rs | 12 +++++----- .../ty_python_semantic/src/types/function.rs | 4 ++-- crates/ty_python_semantic/src/types/infer.rs | 14 ++++++------ 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 5b0dd501ef925..b23e637e1ce03 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4383,17 +4383,17 @@ impl<'db> Type<'db> { } /// Returns the inferred return type of `self` if it is a function literal / bound method. - fn inferred_return_type(self, db: &'db dyn Db) -> Option> { + fn infer_return_type(self, db: &'db dyn Db) -> Option> { match self { Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db.upcast()) => { - Some(function_type.inferred_return_type(db)) + Some(function_type.infer_return_type(db)) } Type::BoundMethod(method_type) if !method_type.function(db).file(db).is_stub(db.upcast()) => { - Some(method_type.inferred_return_type(db)) + Some(method_type.infer_return_type(db)) } _ => None, } @@ -4410,10 +4410,10 @@ impl<'db> Type<'db> { db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>, ) -> Result, CallError<'db>> { - let inferred_return_ty = || self.inferred_return_type(db).unwrap_or(Type::unknown()); + let infer_return_type = || self.infer_return_type(db).unwrap_or(Type::unknown()); self.bindings(db) - .match_parameters(argument_types, inferred_return_ty) + .match_parameters(argument_types, infer_return_type) .check_types(db, argument_types) } @@ -4460,14 +4460,14 @@ impl<'db> Type<'db> { .place { Place::Type(dunder_callable, boundness) => { - let inferred_return_ty = || { + let infer_return_type = || { dunder_callable - .inferred_return_type(db) + .infer_return_type(db) .unwrap_or(Type::unknown()) }; let bindings = dunder_callable .bindings(db) - .match_parameters(argument_types, inferred_return_ty) + .match_parameters(argument_types, infer_return_type) .check_types(db, argument_types)?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); @@ -7161,14 +7161,14 @@ impl<'db> BoundMethodType<'db> { /// Infers this method scope's types and returns the inferred return type. #[salsa::tracked(cycle_fn=method_return_type_cycle_recover, cycle_initial=method_return_type_cycle_initial)] - pub(crate) fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> { + pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self .function(db) .literal(db) .last_definition(db) .body_scope(db); let inference = infer_scope_types(db, scope); - inference.inferred_return_type(db, Some(self)) + inference.infer_return_type(db, Some(self)) } pub(crate) fn is_final(self, db: &'db dyn Db) -> bool { @@ -7213,7 +7213,7 @@ impl<'db> BoundMethodType<'db> { if let Some(return_ty) = base_signature.return_ty.or_else(|| { let base_method_ty = base_function_ty.into_bound_method_type(db, Type::instance(db, class)); - base_method_ty.inferred_return_type(db) + base_method_ty.infer_return_type(db) }) { if let Type::TypeVar(return_typevar) = return_ty { if let [signature] = diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 321f8665fd4f2..4f0eb0a0b1736 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -103,7 +103,7 @@ impl<'db> Bindings<'db> { pub(crate) fn match_parameters( mut self, arguments: &CallArguments<'_>, - inferred_return_ty: impl Fn() -> Type<'db>, + infer_return_type: impl Fn() -> Type<'db>, ) -> Self { let mut argument_forms = vec![None; arguments.len()]; let mut conflicting_forms = vec![false; arguments.len()]; @@ -112,7 +112,7 @@ impl<'db> Bindings<'db> { arguments, &mut argument_forms, &mut conflicting_forms, - &inferred_return_ty, + &infer_return_type, ); } self.argument_forms = argument_forms.into(); @@ -1181,7 +1181,7 @@ impl<'db> CallableBinding<'db> { arguments: &CallArguments<'_>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], - inferred_return_ty: impl Fn() -> Type<'db>, + infer_return_type: impl Fn() -> Type<'db>, ) { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. @@ -1192,7 +1192,7 @@ impl<'db> CallableBinding<'db> { arguments.as_ref(), argument_forms, conflicting_forms, - &inferred_return_ty, + &infer_return_type, ); } } @@ -1825,7 +1825,7 @@ impl<'db> Binding<'db> { arguments: &CallArguments<'_>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], - inferred_return_ty: impl Fn() -> Type<'db>, + infer_return_type: impl Fn() -> Type<'db>, ) { let parameters = self.signature.parameters(); // The parameter that each argument is matched with. @@ -1934,7 +1934,7 @@ impl<'db> Binding<'db> { }); } - self.return_ty = self.signature.return_ty.unwrap_or_else(inferred_return_ty); + self.return_ty = self.signature.return_ty.unwrap_or_else(infer_return_type); self.argument_parameters = argument_parameters.into_boxed_slice(); self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 14758474fc97c..9de2c6177e1e8 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -812,10 +812,10 @@ impl<'db> FunctionType<'db> { /// Infers this function scope's types and returns the inferred return type. #[salsa::tracked(cycle_fn=function_return_type_cycle_recover, cycle_initial=function_return_type_cycle_initial)] - pub(crate) fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> { + pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.literal(db).last_definition(db).body_scope(db); let inference = infer_scope_types(db, scope); - inference.inferred_return_type(db, None) + inference.infer_return_type(db, None) } } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 08f5c9d829027..2dbe0e359193c 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -529,7 +529,7 @@ impl<'db> TypeInference<'db> { /// or `None` if the region is not a function body. /// In the case of methods, the return type of the superclass method is further unioned. /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. - pub(crate) fn inferred_return_type( + pub(crate) fn infer_return_type( &self, db: &'db dyn Db, method_ty: Option>, @@ -5417,14 +5417,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let inferred_return_ty = || { + let infer_return_type = || { callable_type - .inferred_return_type(self.db()) + .infer_return_type(self.db()) .unwrap_or(Type::unknown()) }; let bindings = callable_type .bindings(self.db()) - .match_parameters(&call_arguments, inferred_return_ty); + .match_parameters(&call_arguments, infer_return_type); let call_argument_types = self.infer_argument_types(arguments, call_arguments, &bindings.argument_forms); @@ -8146,14 +8146,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { _ => CallArgumentTypes::positional([self.infer_type_expression(slice_node)]), }; - let inferred_return_ty = || { + let infer_return_type = || { value_ty - .inferred_return_type(self.db()) + .infer_return_type(self.db()) .unwrap_or(Type::unknown()) }; let binding = Binding::single(value_ty, generic_context.signature(self.db())); let bindings = match Bindings::from(binding) - .match_parameters(&call_argument_types, inferred_return_ty) + .match_parameters(&call_argument_types, infer_return_type) .check_types(self.db(), &call_argument_types) { Ok(bindings) => bindings, From 94201b627874c06c268c03923dbfd81576b3868f Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 21 Jun 2025 16:49:31 +0900 Subject: [PATCH 009/105] Infer return types of (non-final) methods according to MRO --- .../resources/mdtest/function/return_type.md | 10 ++-- crates/ty_python_semantic/src/types.rs | 60 +++++-------------- crates/ty_python_semantic/src/types/infer.rs | 6 +- .../src/types/signatures.rs | 22 +------ 4 files changed, 24 insertions(+), 74 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 0227d01480ee6..aef84f65263a5 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -329,7 +329,7 @@ reveal_type(D().f()) # revealed: int reveal_type(E().f()) # revealed: int reveal_type(C().g(1)) # revealed: Literal[1] # TODO: should be `Literal[1]` -reveal_type(D().g(1)) # revealed: int +reveal_type(D().g(1)) # revealed: Literal[2] | T class F: def f(self) -> Literal[1, 2]: @@ -340,18 +340,18 @@ class G: return 2 class H(F, G): + # TODO: should be an invalid-override error def f(self): raise NotImplementedError class I(F, G): + # TODO: should be an invalid-override error @final def f(self): raise NotImplementedError -# We can only reveal `Literal[2]` here, not `Never`, because of potential -# subclasses of `H`, which are bound by the annotated return types of -# `F.f` and `G.f`, but are not bound by our inference on `H.f`. -reveal_type(H().f()) # revealed: Literal[2] +# We use a return type of `F.f` according to the MRO. +reveal_type(H().f()) # revealed: Literal[1, 2] reveal_type(I().f()) # revealed: Never class C2[T]: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index b23e637e1ce03..a436361bcd335 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -7189,59 +7189,31 @@ impl<'db> BoundMethodType<'db> { .any(|deco| deco == KnownFunction::Final) } - /// Returns the compatible return type for this method -- the intersection of the return types of the base class methods. - pub(crate) fn compatible_return_type(self, db: &'db dyn Db) -> Option> { + pub(crate) fn base_return_type(self, db: &'db dyn Db) -> Option> { let definition_scope = self.function(db).definition(db).scope(db); let index = semantic_index(db, definition_scope.file(db)); let module = parsed_module(db.upcast(), definition_scope.file(db)).load(db.upcast()); let class_definition = index.expect_single_definition(definition_scope.node(db).expect_class(&module)); let class = binding_type(db, class_definition).expect_class_type(db); - let (class_lit, _) = class.class_literal(db); let name = self.function(db).name(db); - let mut found = false; - let mut intersection = IntersectionBuilder::new(db); - for base in class_lit.explicit_bases(db) { - if let Some(base_function_ty) = base - .member(db, name) - .into_lookup_result() - .ok() - .and_then(|ty| ty.inner_type().into_function_literal()) - { - if let [base_signature] = base_function_ty.signature(db).overloads.as_slice() { - if let Some(return_ty) = base_signature.return_ty.or_else(|| { - let base_method_ty = - base_function_ty.into_bound_method_type(db, Type::instance(db, class)); - base_method_ty.infer_return_type(db) - }) { - if let Type::TypeVar(return_typevar) = return_ty { - if let [signature] = - self.function(db).signature(db).overloads.as_slice() - { - if let Ok(specialization) = - base_signature.specialize_with(db, signature) - { - if let Some(return_ty) = specialization.get(db, return_typevar) - { - found = true; - intersection = intersection.add_positive(return_ty); - } - } - } - } else { - found = true; - intersection = intersection.add_positive(return_ty); - } - } - } else { - // TODO: overloaded methods - } + let base = class + .iter_mro(db) + .nth(1) + .and_then(class_base::ClassBase::into_class)?; + let base_member = base.class_member(db, name, MemberLookupPolicy::default()); + if let Place::Type(Type::FunctionLiteral(base_func), _) = base_member.place { + if let [signature] = base_func.signature(db).overloads.as_slice() { + signature.return_ty.or_else(|| { + let base_method_ty = + base_func.into_bound_method_type(db, Type::instance(db, class)); + base_method_ty.infer_return_type(db) + }) + } else { + // TODO: Handle overloaded base methods. + None } - } - - if found { - Some(intersection.build()) } else { None } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 2dbe0e359193c..6c763a4667c49 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -544,12 +544,10 @@ impl<'db> TypeInference<'db> { } if let Some(method_ty) = method_ty { // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. - // If any method in a base class does not have an annotated return type, `compatible_return_type` will include `Unknown`. + // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. if !method_ty.is_final(db) { - let return_ty = method_ty - .compatible_return_type(db) - .unwrap_or(Type::unknown()); + let return_ty = method_ty.base_return_type(db).unwrap_or(Type::unknown()); union = union.add(return_ty); } } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 0746cdbeaf233..bcaeb6be1204c 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -17,9 +17,7 @@ use smallvec::{SmallVec, smallvec}; use super::{DynamicType, Type, TypeVarVariance, definition_expression_type}; use crate::semantic_index::definition::Definition; -use crate::types::generics::{ - GenericContext, Specialization, SpecializationBuilder, SpecializationError, -}; +use crate::types::generics::GenericContext; use crate::types::{ClassLiteral, TypeMapping, TypeRelation, TypeVarInstance, todo_type}; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -405,24 +403,6 @@ impl<'db> Signature<'db> { } } - pub(crate) fn specialize_with( - &self, - db: &'db dyn Db, - signature: &Signature<'db>, - ) -> Result, SpecializationError<'db>> { - debug_assert!( - self.generic_context.is_some(), - "Cannot specialize a signature without a generic context" - ); - let mut specialization = SpecializationBuilder::new(db); - for (self_param, param) in self.parameters().iter().zip(signature.parameters().iter()) { - if let Some((self_ty, ty)) = self_param.annotated_type().zip(param.annotated_type()) { - specialization.infer(self_ty, ty)?; - } - } - Ok(specialization.build(self.generic_context.unwrap())) - } - pub(crate) fn find_legacy_typevars( &self, db: &'db dyn Db, From 7f7c2304b521b4866eef75861bd0abad5492bc35 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 21 Jun 2025 20:18:10 +0900 Subject: [PATCH 010/105] Use `callable_type` to get the return type instead of passing a closure --- .../resources/mdtest/call/union.md | 6 ++--- .../src/semantic_index/use_def.rs | 1 - crates/ty_python_semantic/src/types.rs | 11 ++------ .../ty_python_semantic/src/types/call/bind.rs | 26 +++++++------------ crates/ty_python_semantic/src/types/infer.rs | 14 ++-------- 5 files changed, 17 insertions(+), 41 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/call/union.md b/crates/ty_python_semantic/resources/mdtest/call/union.md index 5edbdb29819b6..3f1b2d1a03446 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/union.md +++ b/crates/ty_python_semantic/resources/mdtest/call/union.md @@ -111,7 +111,7 @@ def _(flag: bool): # error: [call-non-callable] "Object of type `Literal["This is a string literal"]` is not callable" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None | Unknown ``` ## Union of binding errors @@ -128,7 +128,7 @@ def _(flag: bool): # error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1" # error: [too-many-positional-arguments] "Too many positional arguments to function `f2`: expected 0, got 1" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None ``` ## One not-callable, one wrong argument @@ -146,7 +146,7 @@ def _(flag: bool): # error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1" # error: [call-non-callable] "Object of type `C` is not callable" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None | Unknown ``` ## Union including a special-cased function diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 8ac7f91811984..1e01bed62771a 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -455,7 +455,6 @@ impl<'db> UseDefMap<'db> { .map(|place_id| (place_id, self.public_bindings(place_id))) } - /// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`. pub(crate) fn can_implicitly_return_none(&self, db: &dyn crate::Db) -> bool { !self .reachability_constraints diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index a436361bcd335..87e8eb4226c82 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -4410,10 +4410,8 @@ impl<'db> Type<'db> { db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>, ) -> Result, CallError<'db>> { - let infer_return_type = || self.infer_return_type(db).unwrap_or(Type::unknown()); - self.bindings(db) - .match_parameters(argument_types, infer_return_type) + .match_parameters(db, argument_types) .check_types(db, argument_types) } @@ -4460,14 +4458,9 @@ impl<'db> Type<'db> { .place { Place::Type(dunder_callable, boundness) => { - let infer_return_type = || { - dunder_callable - .infer_return_type(db) - .unwrap_or(Type::unknown()) - }; let bindings = dunder_callable .bindings(db) - .match_parameters(argument_types, infer_return_type) + .match_parameters(db, argument_types) .check_types(db, argument_types)?; if boundness == Boundness::PossiblyUnbound { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 4f0eb0a0b1736..6e83fee6ca6b6 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -102,18 +102,13 @@ impl<'db> Bindings<'db> { /// verify that each argument type is assignable to the corresponding parameter type. pub(crate) fn match_parameters( mut self, + db: &'db dyn Db, arguments: &CallArguments<'_>, - infer_return_type: impl Fn() -> Type<'db>, ) -> Self { let mut argument_forms = vec![None; arguments.len()]; let mut conflicting_forms = vec![false; arguments.len()]; for binding in &mut self.elements { - binding.match_parameters( - arguments, - &mut argument_forms, - &mut conflicting_forms, - &infer_return_type, - ); + binding.match_parameters(db, arguments, &mut argument_forms, &mut conflicting_forms); } self.argument_forms = argument_forms.into(); self.conflicting_forms = conflicting_forms.into(); @@ -1178,22 +1173,17 @@ impl<'db> CallableBinding<'db> { fn match_parameters( &mut self, + db: &'db dyn Db, arguments: &CallArguments<'_>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], - infer_return_type: impl Fn() -> Type<'db>, ) { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. let arguments = arguments.with_self(self.bound_type); for overload in &mut self.overloads { - overload.match_parameters( - arguments.as_ref(), - argument_forms, - conflicting_forms, - &infer_return_type, - ); + overload.match_parameters(db, arguments.as_ref(), argument_forms, conflicting_forms); } } @@ -1822,10 +1812,10 @@ impl<'db> Binding<'db> { fn match_parameters( &mut self, + db: &'db dyn Db, arguments: &CallArguments<'_>, argument_forms: &mut [Option], conflicting_forms: &mut [bool], - infer_return_type: impl Fn() -> Type<'db>, ) { let parameters = self.signature.parameters(); // The parameter that each argument is matched with. @@ -1934,7 +1924,11 @@ impl<'db> Binding<'db> { }); } - self.return_ty = self.signature.return_ty.unwrap_or_else(infer_return_type); + self.return_ty = self.signature.return_ty.unwrap_or_else(|| { + self.callable_type + .infer_return_type(db) + .unwrap_or(Type::unknown()) + }); self.argument_parameters = argument_parameters.into_boxed_slice(); self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 6c763a4667c49..6f43ee9a62506 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -5415,14 +5415,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let infer_return_type = || { - callable_type - .infer_return_type(self.db()) - .unwrap_or(Type::unknown()) - }; let bindings = callable_type .bindings(self.db()) - .match_parameters(&call_arguments, infer_return_type); + .match_parameters(self.db(), &call_arguments); let call_argument_types = self.infer_argument_types(arguments, call_arguments, &bindings.argument_forms); @@ -8144,14 +8139,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { _ => CallArgumentTypes::positional([self.infer_type_expression(slice_node)]), }; - let infer_return_type = || { - value_ty - .infer_return_type(self.db()) - .unwrap_or(Type::unknown()) - }; let binding = Binding::single(value_ty, generic_context.signature(self.db())); let bindings = match Bindings::from(binding) - .match_parameters(&call_argument_types, infer_return_type) + .match_parameters(self.db(), &call_argument_types) .check_types(self.db(), &call_argument_types) { Ok(bindings) => bindings, From 2c534c6875054d008544b3638468483051c22697 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 22 Jun 2025 12:29:13 +0900 Subject: [PATCH 011/105] Update types.rs --- crates/ty_python_semantic/src/types.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 87e8eb4226c82..48b3589d37ad7 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -7174,9 +7174,13 @@ impl<'db> BoundMethodType<'db> { let definition_scope = self.function(db).definition(db).scope(db); let index = semantic_index(db, definition_scope.file(db)); let module = parsed_module(db.upcast(), definition_scope.file(db)).load(db.upcast()); - let class_definition = - index.expect_single_definition(definition_scope.node(db).expect_class(&module)); - let class_ty = binding_type(db, class_definition).expect_class_literal(); + let Some(class_node) = definition_scope.node(db).as_class(&module) else { + return false; + }; + let class_definition = index.expect_single_definition(class_node); + let Some(class_ty) = binding_type(db, class_definition).into_class_literal() else { + return false; + }; class_ty .known_function_decorators(db) .any(|deco| deco == KnownFunction::Final) @@ -7187,8 +7191,8 @@ impl<'db> BoundMethodType<'db> { let index = semantic_index(db, definition_scope.file(db)); let module = parsed_module(db.upcast(), definition_scope.file(db)).load(db.upcast()); let class_definition = - index.expect_single_definition(definition_scope.node(db).expect_class(&module)); - let class = binding_type(db, class_definition).expect_class_type(db); + index.expect_single_definition(definition_scope.node(db).as_class(&module)?); + let class = binding_type(db, class_definition).to_class_type(db)?; let name = self.function(db).name(db); let base = class From 60f1b0f856e7b2e24906bbda57e3794b5edaf913 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 5 Jul 2025 12:31:46 +0900 Subject: [PATCH 012/105] Refactor: TypeAndRange -> Returnee TypeAndRange::ty: Type -> Returnee::expression: Option --- crates/ty_python_semantic/src/types/infer.rs | 96 +++++++++++--------- 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 6f43ee9a62506..f62dfecf5af5d 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -416,8 +416,8 @@ impl<'db> InferenceRegion<'db> { } #[derive(Debug, Clone, Copy, Eq, PartialEq)] -struct TypeAndRange<'db> { - ty: Type<'db>, +struct Returnee { + expression: Option, range: TextRange, } @@ -442,10 +442,10 @@ pub(crate) struct TypeInference<'db> { /// The scope this region is part of. scope: ScopeId<'db>, - /// The returned types of this region (if this is a function body). + /// The returnees of this region (if this is a function body). /// /// These are stored in `Vec` to delay the creation of the union type as long as possible. - return_types: Vec>, + returnees: Vec>, /// The fallback type for missing expressions/bindings/declarations. /// @@ -462,7 +462,7 @@ impl<'db> TypeInference<'db> { deferred: FxHashSet::default(), diagnostics: TypeCheckDiagnostics::default(), scope, - return_types: vec![], + returnees: vec![], cycle_fallback_type: None, } } @@ -475,7 +475,7 @@ impl<'db> TypeInference<'db> { deferred: FxHashSet::default(), diagnostics: TypeCheckDiagnostics::default(), scope, - return_types: vec![], + returnees: vec![], cycle_fallback_type: Some(cycle_fallback_type), } } @@ -535,8 +535,11 @@ impl<'db> TypeInference<'db> { method_ty: Option>, ) -> Type<'db> { let mut union = UnionBuilder::new(db); - for ty in &self.return_types { - union = union.add(*ty); + for returnee in &self.returnees { + let ty = returnee.map_or(Type::none(db), |expression| { + self.expression_type(expression) + }); + union = union.add(ty); } let use_def = use_def_map(db, self.scope); if use_def.can_implicitly_return_none(db) { @@ -560,7 +563,7 @@ impl<'db> TypeInference<'db> { self.declarations.shrink_to_fit(); self.diagnostics.shrink_to_fit(); self.deferred.shrink_to_fit(); - self.return_types.shrink_to_fit(); + self.returnees.shrink_to_fit(); } } @@ -637,8 +640,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// The type inference results types: TypeInference<'db>, - /// The returned types and their corresponding ranges of the region, if it is a function body. - return_types_and_ranges: Vec>, + /// The returnees and their corresponding ranges of the region, if it is a function body. + returnees: Vec, /// A set of functions that have been defined **and** called in this region. /// @@ -697,7 +700,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { context: InferContext::new(db, scope, module), index, region, - return_types_and_ranges: vec![], + returnees: vec![], called_functions: FxHashSet::default(), deferred_state: DeferredExpressionState::None, types: TypeInference::empty(scope), @@ -1870,9 +1873,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); } - fn record_return_type(&mut self, ty: Type<'db>, range: TextRange) { - self.return_types_and_ranges - .push(TypeAndRange { ty, range }); + fn record_returnee(&mut self, expression: Option, range: TextRange) { + self.returnees.push(Returnee { expression, range }); } fn infer_module(&mut self, module: &ast::ModModule) { @@ -2032,8 +2034,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let has_empty_body = - self.return_types_and_ranges.is_empty() && is_stub_suite(&function.body); + let has_empty_body = self.returnees.is_empty() && is_stub_suite(&function.body); let mut enclosing_class_context = None; @@ -2086,35 +2087,41 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return; } - for invalid in self - .return_types_and_ranges + for (invalid_ty, range) in self + .returnees .iter() .copied() - .filter_map(|ty_range| match ty_range.ty { - // We skip `is_assignable_to` checks for `NotImplemented`, - // so we remove it beforehand. - Type::Union(union) => Some(TypeAndRange { - ty: union.filter(self.db(), |ty| !ty.is_notimplemented(self.db())), - range: ty_range.range, - }), - ty if ty.is_notimplemented(self.db()) => None, - _ => Some(ty_range), + .filter_map(|returnee| { + match returnee + .expression + .map_or(Type::none(self.db()), |expression| { + self.types.expression_type(expression) + }) { + // We skip `is_assignable_to` checks for `NotImplemented`, + // so we remove it beforehand. + Type::Union(union) => Some(( + union.filter(self.db(), |ty| !ty.is_notimplemented(self.db())), + returnee.range, + )), + ty if ty.is_notimplemented(self.db()) => None, + ty => Some((ty, returnee.range)), + } }) - .filter(|ty_range| !ty_range.ty.is_assignable_to(self.db(), expected_ty)) + .filter(|(ty, _)| !ty.is_assignable_to(self.db(), expected_ty)) { report_invalid_return_type( &self.context, - invalid.range, + range, returns.range(), declared_ty, - invalid.ty, + invalid_ty, ); } let use_def = self.index.use_def_map(scope_id); if use_def.can_implicitly_return_none(self.db()) && !Type::none(self.db()).is_assignable_to(self.db(), expected_ty) { - let no_return = self.return_types_and_ranges.is_empty(); + let no_return = self.returnees.is_empty(); report_implicit_return_type( &self.context, returns.range(), @@ -4531,15 +4538,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { - if let Some(ty) = self.infer_optional_expression(ret.value.as_deref()) { - let range = ret - .value - .as_ref() - .map_or(ret.range(), |value| value.range()); - self.record_return_type(ty, range); - } else { - self.record_return_type(Type::none(self.db()), ret.range()); - } + self.infer_optional_expression(ret.value.as_deref()); + let range = ret + .value + .as_ref() + .map_or(ret.range(), |value| value.range()); + let expression = ret + .value + .as_ref() + .map(|expr| expr.scoped_expression_id(self.db(), self.scope())); + self.record_returnee(expression, range); } fn infer_delete_statement(&mut self, delete: &ast::StmtDelete) { @@ -8588,10 +8596,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_region(); self.types.diagnostics = self.context.finish(); self.types.shrink_to_fit(); - self.types.return_types = self - .return_types_and_ranges + self.types.returnees = self + .returnees .into_iter() - .map(|ty_range| ty_range.ty) + .map(|returnee| returnee.expression) .collect(); self.types } From c0f218e898ab0ff0cb1a184a6f9b8bcbff4e6408 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 7 Jul 2025 20:14:53 +0900 Subject: [PATCH 013/105] Add `CallStack` to prevent divergence in type inference for recursive functions --- crates/ruff_graph/src/db.rs | 10 +++- crates/ty_ide/src/db.rs | 8 ++- crates/ty_project/src/db.rs | 17 +++++- .../resources/mdtest/function/return_type.md | 44 +++++++++++++-- crates/ty_python_semantic/src/db.rs | 49 ++++++++++++++++- crates/ty_python_semantic/src/lib.rs | 2 +- crates/ty_python_semantic/src/types.rs | 54 ++++++++----------- .../ty_python_semantic/src/types/function.rs | 30 ++++------- crates/ty_python_semantic/src/types/infer.rs | 42 +++++++++++++++ crates/ty_python_semantic/src/types/narrow.rs | 35 ++++++------ crates/ty_python_semantic/tests/corpus.rs | 10 +++- crates/ty_test/src/db.rs | 8 ++- fuzz/fuzz_targets/ty_check_invalid_syntax.rs | 8 ++- 13 files changed, 231 insertions(+), 86 deletions(-) diff --git a/crates/ruff_graph/src/db.rs b/crates/ruff_graph/src/db.rs index 48ee62a348624..62b7afa6d6675 100644 --- a/crates/ruff_graph/src/db.rs +++ b/crates/ruff_graph/src/db.rs @@ -9,8 +9,9 @@ use ruff_db::vendored::{VendoredFileSystem, VendoredFileSystemBuilder}; use ruff_python_ast::PythonVersion; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; use ty_python_semantic::{ - Db, Program, ProgramSettings, PythonEnvironment, PythonPlatform, PythonVersionSource, - PythonVersionWithSource, SearchPathSettings, SysPrefixPathOrigin, default_lint_registry, + CallStack, Db, Program, ProgramSettings, PythonEnvironment, PythonPlatform, + PythonVersionSource, PythonVersionWithSource, SearchPathSettings, SysPrefixPathOrigin, + default_lint_registry, }; static EMPTY_VENDORED: std::sync::LazyLock = std::sync::LazyLock::new(|| { @@ -26,6 +27,7 @@ pub struct ModuleDb { files: Files, system: OsSystem, rule_selection: Arc, + call_stack: CallStack, } impl ModuleDb { @@ -98,6 +100,10 @@ impl Db for ModuleDb { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } + + fn call_stack(&self) -> &CallStack { + &self.call_stack + } } #[salsa::db] diff --git a/crates/ty_ide/src/db.rs b/crates/ty_ide/src/db.rs index 3d9a9c564eb66..68f49970557e2 100644 --- a/crates/ty_ide/src/db.rs +++ b/crates/ty_ide/src/db.rs @@ -13,7 +13,7 @@ pub(crate) mod tests { use ruff_db::system::{DbWithTestSystem, System, TestSystem}; use ruff_db::vendored::VendoredFileSystem; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; - use ty_python_semantic::{Db as SemanticDb, Program, default_lint_registry}; + use ty_python_semantic::{CallStack, Db as SemanticDb, Program, default_lint_registry}; type Events = Arc>>; @@ -26,6 +26,7 @@ pub(crate) mod tests { vendored: VendoredFileSystem, events: Events, rule_selection: Arc, + call_stack: CallStack, } #[expect(dead_code)] @@ -46,6 +47,7 @@ pub(crate) mod tests { events, files: Files::default(), rule_selection: Arc::new(RuleSelection::from_registry(default_lint_registry())), + call_stack: CallStack::default(), } } @@ -107,6 +109,10 @@ pub(crate) mod tests { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } + + fn call_stack(&self) -> &CallStack { + &self.call_stack + } } #[salsa::db] diff --git a/crates/ty_project/src/db.rs b/crates/ty_project/src/db.rs index cda81a9192904..45e8b2ebfc9b6 100644 --- a/crates/ty_project/src/db.rs +++ b/crates/ty_project/src/db.rs @@ -15,7 +15,7 @@ use salsa::Event; use salsa::plumbing::ZalsaDatabase; use ty_ide::Db as IdeDb; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; -use ty_python_semantic::{Db as SemanticDb, Program}; +use ty_python_semantic::{CallStack, Db as SemanticDb, Program}; mod changes; @@ -39,6 +39,8 @@ pub struct ProjectDatabase { // However, for this to work it's important that the `storage` is dropped AFTER any `Arc` that // we try to mutably borrow using `Arc::get_mut` (like `system`). storage: salsa::Storage, + + call_stack: CallStack, } impl ProjectDatabase { @@ -63,6 +65,7 @@ impl ProjectDatabase { }), files: Files::default(), system: Arc::new(system), + call_stack: CallStack::default(), }; // TODO: Use the `program_settings` to compute the key for the database's persistent @@ -404,6 +407,10 @@ impl SemanticDb for ProjectDatabase { fn lint_registry(&self) -> &LintRegistry { &DEFAULT_LINT_REGISTRY } + + fn call_stack(&self) -> &CallStack { + &self.call_stack + } } #[salsa::db] @@ -462,7 +469,7 @@ pub(crate) mod tests { use ty_python_semantic::lint::{LintRegistry, RuleSelection}; use crate::DEFAULT_LINT_REGISTRY; - use crate::db::Db; + use crate::db::{CallStack, Db}; use crate::{Project, ProjectMetadata}; type Events = Arc>>; @@ -476,6 +483,7 @@ pub(crate) mod tests { system: TestSystem, vendored: VendoredFileSystem, project: Option, + call_stack: CallStack, } impl TestDb { @@ -494,6 +502,7 @@ pub(crate) mod tests { files: Files::default(), events, project: None, + call_stack: CallStack::default(), }; let project = Project::from_metadata(&db, project).unwrap(); @@ -553,6 +562,10 @@ pub(crate) mod tests { fn lint_registry(&self) -> &LintRegistry { &DEFAULT_LINT_REGISTRY } + + fn call_stack(&self) -> &CallStack { + &self.call_stack + } } #[salsa::db] diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index aef84f65263a5..df9c178dc1312 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -221,7 +221,8 @@ def generator(): reveal_type(generator()) # revealed: None ``` -The return type of a recursive function is also inferred. +The return type of a recursive function is also inferred. When the return type inference would +diverge, it is truncated and replaced with the type `Unknown`. ```py def fibonacci(n: int): @@ -232,7 +233,8 @@ def fibonacci(n: int): else: return fibonacci(n - 1) + fibonacci(n - 2) -reveal_type(fibonacci(5)) # revealed: int +# TODO: it may be better to infer this as `int` if we can +reveal_type(fibonacci(5)) # revealed: Literal[0, 1] | Unknown def even(n: int): if n == 0: @@ -246,8 +248,42 @@ def odd(n: int): else: return even(n - 1) -reveal_type(even(1)) # revealed: bool -reveal_type(odd(1)) # revealed: bool +# TODO: it may be better to infer these as `bool` if we can +reveal_type(even(1)) # revealed: bool | Unknown +reveal_type(odd(1)) # revealed: bool | Unknown + +def repeat_a(n: int): + if n <= 0: + return "" + else: + return repeat_a(n - 1) + "a" + +# TODO: it may be better to infer this as `str` if we can +reveal_type(repeat_a(3)) # revealed: Literal[""] | Unknown + +def divergent(value): + if type(value) is tuple: + return (divergent(value[0]),) + else: + return None + +# tuple[tuple[tuple[...] | None] | None] | None => tuple[Unknown] | None +reveal_type(divergent((1,))) # revealed: tuple[Unknown] | None + +def nested_scope(): + def inner(): + return nested_scope() + return inner() + +reveal_type(nested_scope()) # revealed: Unknown + +def eager_nested_scope(): + class A: + x = eager_nested_scope() + + return A.x + +reveal_type(eager_nested_scope()) # revealed: Unknown ``` ### Class method diff --git a/crates/ty_python_semantic/src/db.rs b/crates/ty_python_semantic/src/db.rs index 815c37653cbcf..ea19b16202a79 100644 --- a/crates/ty_python_semantic/src/db.rs +++ b/crates/ty_python_semantic/src/db.rs @@ -1,6 +1,45 @@ +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; + use crate::lint::{LintRegistry, RuleSelection}; +use crate::semantic_index::place::FileScopeId; use ruff_db::Db as SourceDb; use ruff_db::files::File; +use rustc_hash::FxHasher; + +/// A stack of the currently inferred function scopes. +/// Used to monitor type inference for recursive functions to ensure they do not diverge. +/// This call stack is currently only used to infer functions with unspecified return types ​​and does not faithfully represent the actual call stack. +#[derive(Debug, Default, Clone)] +pub struct CallStack(Arc>>); + +impl CallStack { + pub fn new() -> Self { + CallStack(Arc::new(Mutex::new(Vec::new()))) + } + + pub fn push(&self, file: File, scope: FileScopeId) { + self.0.lock().unwrap().push((file, scope)); + } + + pub fn pop(&self) -> Option<(File, FileScopeId)> { + self.0.lock().unwrap().pop() + } + + pub fn contains(&self, file: File, scope: FileScopeId) -> bool { + self.0 + .lock() + .unwrap() + .iter() + .any(|(f, s)| f == &file && s == &scope) + } + + pub fn hash_value(&self) -> u64 { + let mut hasher = FxHasher::default(); + self.0.lock().unwrap().hash(&mut hasher); + hasher.finish() + } +} /// Database giving access to semantic information about a Python program. #[salsa::db] @@ -11,6 +50,8 @@ pub trait Db: SourceDb { fn rule_selection(&self, file: File) -> &RuleSelection; fn lint_registry(&self) -> &LintRegistry; + + fn call_stack(&self) -> &CallStack; } #[cfg(test)] @@ -23,7 +64,7 @@ pub(crate) mod tests { default_lint_registry, }; - use super::Db; + use super::{CallStack, Db}; use crate::lint::{LintRegistry, RuleSelection}; use anyhow::Context; use ruff_db::Db as SourceDb; @@ -45,6 +86,7 @@ pub(crate) mod tests { vendored: VendoredFileSystem, events: Events, rule_selection: Arc, + call_stack: CallStack, } impl TestDb { @@ -64,6 +106,7 @@ pub(crate) mod tests { events, files: Files::default(), rule_selection: Arc::new(RuleSelection::from_registry(default_lint_registry())), + call_stack: CallStack::default(), } } @@ -125,6 +168,10 @@ pub(crate) mod tests { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } + + fn call_stack(&self) -> &CallStack { + &self.call_stack + } } #[salsa::db] diff --git a/crates/ty_python_semantic/src/lib.rs b/crates/ty_python_semantic/src/lib.rs index fe8305e77ee15..64d262303a559 100644 --- a/crates/ty_python_semantic/src/lib.rs +++ b/crates/ty_python_semantic/src/lib.rs @@ -4,7 +4,7 @@ use rustc_hash::FxHasher; use crate::lint::{LintRegistry, LintRegistryBuilder}; use crate::suppression::{INVALID_IGNORE_COMMENT, UNKNOWN_RULE, UNUSED_IGNORE_COMMENT}; -pub use db::Db; +pub use db::{CallStack, Db}; pub use module_name::ModuleName; pub use module_resolver::{ KnownModule, Module, SearchPathValidationError, SearchPaths, resolve_module, diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 1714d6ff10cb7..ca90f2a9e96be 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -47,7 +47,7 @@ use crate::types::generics::{ walk_partial_specialization, walk_specialization, }; pub use crate::types::ide_support::all_members; -use crate::types::infer::infer_unpack_types; +use crate::types::infer::{infer_scope_types_with_call_stack, infer_unpack_types}; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature}; @@ -88,22 +88,6 @@ mod definition; #[cfg(test)] mod property_tests; -fn method_return_type_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &Type<'db>, - _count: u32, - _self: BoundMethodType<'db>, -) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate -} - -fn method_return_type_cycle_initial<'db>( - _db: &'db dyn Db, - _self: BoundMethodType<'db>, -) -> Type<'db> { - Type::Never -} - #[salsa::tracked(returns(ref), heap_size=get_size2::GetSize::get_heap_size)] pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics { let _span = tracing::trace_span!("check_types", ?file).entered(); @@ -7211,17 +7195,30 @@ impl<'db> BoundMethodType<'db> { } /// Infers this method scope's types and returns the inferred return type. - #[salsa::tracked(cycle_fn=method_return_type_cycle_recover, cycle_initial=method_return_type_cycle_initial)] pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self .function(db) .literal(db) .last_definition(db) .body_scope(db); - let inference = infer_scope_types(db, scope); + if db + .call_stack() + .contains(scope.file(db), scope.file_scope_id(db)) + { + return Type::unknown(); + } + let inference = infer_scope_types_with_call_stack(db, scope, db.call_stack().hash_value()); inference.infer_return_type(db, Some(self)) } + #[salsa::tracked] + fn class_definition(self, db: &'db dyn Db) -> Option> { + let definition_scope = self.function(db).definition(db).scope(db); + let index = semantic_index(db, definition_scope.file(db)); + let module = parsed_module(db, definition_scope.file(db)).load(db); + Some(index.expect_single_definition(definition_scope.node(db).as_class(&module)?)) + } + pub(crate) fn is_final(self, db: &'db dyn Db) -> bool { if self .function(db) @@ -7229,14 +7226,10 @@ impl<'db> BoundMethodType<'db> { { return true; } - let definition_scope = self.function(db).definition(db).scope(db); - let index = semantic_index(db, definition_scope.file(db)); - let module = parsed_module(db, definition_scope.file(db)).load(db); - let Some(class_node) = definition_scope.node(db).as_class(&module) else { - return false; - }; - let class_definition = index.expect_single_definition(class_node); - let Some(class_ty) = binding_type(db, class_definition).into_class_literal() else { + let Some(class_ty) = self + .class_definition(db) + .and_then(|class| binding_type(db, class).into_class_literal()) + else { return false; }; class_ty @@ -7245,12 +7238,7 @@ impl<'db> BoundMethodType<'db> { } pub(crate) fn base_return_type(self, db: &'db dyn Db) -> Option> { - let definition_scope = self.function(db).definition(db).scope(db); - let index = semantic_index(db, definition_scope.file(db)); - let module = parsed_module(db, definition_scope.file(db)).load(db); - let class_definition = - index.expect_single_definition(definition_scope.node(db).as_class(&module)?); - let class = binding_type(db, class_definition).to_class_type(db)?; + let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?; let name = self.function(db).name(db); let base = class diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 0396ecbb25c98..f0e4869a8b2e7 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -70,32 +70,17 @@ use crate::types::diagnostic::{ report_bad_argument_to_get_protocol_members, report_runtime_check_against_non_runtime_checkable_protocol, }; -use crate::types::generics::{GenericContext, walk_generic_context}; +use crate::types::generics::GenericContext; +use crate::types::infer::infer_scope_types_with_call_stack; use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; use crate::types::{ BoundMethodType, CallableType, DynamicType, KnownClass, Type, TypeMapping, TypeRelation, - TypeTransformer, TypeVarInstance, infer_scope_types, walk_type_mapping, + TypeTransformer, TypeVarInstance, walk_generic_context, walk_type_mapping, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; -fn function_return_type_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &Type<'db>, - _count: u32, - _self: FunctionType<'db>, -) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate -} - -fn function_return_type_cycle_initial<'db>( - _db: &'db dyn Db, - _self: FunctionType<'db>, -) -> Type<'db> { - Type::Never -} - /// A collection of useful spans for annotating functions. /// /// This can be retrieved via `FunctionType::spans` or @@ -875,10 +860,15 @@ impl<'db> FunctionType<'db> { } /// Infers this function scope's types and returns the inferred return type. - #[salsa::tracked(cycle_fn=function_return_type_cycle_recover, cycle_initial=function_return_type_cycle_initial)] pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.literal(db).last_definition(db).body_scope(db); - let inference = infer_scope_types(db, scope); + if db + .call_stack() + .contains(scope.file(db), scope.file_scope_id(db)) + { + return Type::unknown(); + } + let inference = infer_scope_types_with_call_stack(db, scope, db.call_stack().hash_value()); inference.infer_return_type(db, None) } } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index bcd662e1cac4d..ba5968c6153af 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -142,6 +142,30 @@ pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Ty TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish() } +/// Pushes a function scope onto the call stack and then performs scope analysis. It is used to infer the return value of a function. +/// Return values ​​of recursive functions are inferred by explicitly detecting recursion and returning `Unknown` instead of salsa's fixpoint iteration. +#[salsa::tracked(returns(ref), cycle_fn=scope_with_call_stack_cycle_recover, cycle_initial=scope_with_call_stack_cycle_initial)] +pub(crate) fn infer_scope_types_with_call_stack<'db>( + db: &'db dyn Db, + scope: ScopeId<'db>, + _call_stack_hash: u64, +) -> TypeInference<'db> { + let file = scope.file(db); + + let module = parsed_module(db, file).load(db); + + // Using the index here is fine because the code below depends on the AST anyway. + // The isolation of the query is by the return inferred types. + let index = semantic_index(db, file); + + db.call_stack().push(file, scope.file_scope_id(db)); + let inference = + TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish(); + db.call_stack().pop(); + + inference +} + fn scope_cycle_recover<'db>( _db: &'db dyn Db, _value: &TypeInference<'db>, @@ -155,6 +179,24 @@ fn scope_cycle_initial<'db>(_db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInfere TypeInference::cycle_fallback(scope, Type::Never) } +fn scope_with_call_stack_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &TypeInference<'db>, + _count: u32, + _scope: ScopeId<'db>, + _call_stack_hash: u64, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn scope_with_call_stack_cycle_initial<'db>( + _db: &'db dyn Db, + scope: ScopeId<'db>, + _call_stack_hash: u64, +) -> TypeInference<'db> { + TypeInference::cycle_fallback(scope, Type::Never) +} + /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a place use or public type of a place. #[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=get_size2::GetSize::get_heap_size)] diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 20a912c72d153..912a0e3b4bb75 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -19,6 +19,7 @@ use itertools::Itertools; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; +use std::cell::LazyCell; use std::collections::hash_map::Entry; use super::UnionType; @@ -691,31 +692,28 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // and that requires cross-symbol constraints, which we don't support yet. return None; } - - let inference = infer_expression_types(self.db, expression); + // Performance optimization: deferring type inference for an expression until it is actually needed. + let inference = LazyCell::new(|| infer_expression_types(self.db, expression)); let comparator_tuples = std::iter::once(&**left) .chain(comparators) .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); let mut constraints = NarrowingConstraints::default(); - let mut last_rhs_ty: Option = None; - for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) { - let lhs_ty = last_rhs_ty.unwrap_or_else(|| inference.expression_type(left)); - let rhs_ty = inference.expression_type(right); - last_rhs_ty = Some(rhs_ty); - match left { ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) | ast::Expr::Named(_) => { - if let Some(left) = place_expr(left) { + if let Some(left_place) = place_expr(left) { let op = if is_positive { *op } else { op.negate() }; + let lhs_ty = inference.expression_type(left); + let rhs_ty = inference.expression_type(right); + if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { - let place = self.expect_place(&left); + let place = self.expect_place(&left_place); constraints.insert(place, ty); } } @@ -732,14 +730,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { node_index: _, }, }) if keywords.is_empty() => { - let rhs_class = match rhs_ty { - Type::ClassLiteral(class) => class, - Type::GenericAlias(alias) => alias.origin(self.db), - _ => { - continue; - } - }; - let target = match &**args { [first] => match place_expr(first) { Some(target) => target, @@ -758,6 +748,15 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { continue; } + let rhs_ty = inference.expression_type(right); + let rhs_class = match rhs_ty { + Type::ClassLiteral(class) => class, + Type::GenericAlias(alias) => alias.origin(self.db), + _ => { + continue; + } + }; + let callable_type = inference.expression_type(&**callable); if callable_type diff --git a/crates/ty_python_semantic/tests/corpus.rs b/crates/ty_python_semantic/tests/corpus.rs index 799e76dd11ab4..4a127d73758b9 100644 --- a/crates/ty_python_semantic/tests/corpus.rs +++ b/crates/ty_python_semantic/tests/corpus.rs @@ -8,8 +8,8 @@ use ruff_python_ast::PythonVersion; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; use ty_python_semantic::pull_types::pull_types; use ty_python_semantic::{ - Program, ProgramSettings, PythonPlatform, PythonVersionSource, PythonVersionWithSource, - SearchPathSettings, default_lint_registry, + CallStack, Program, ProgramSettings, PythonPlatform, PythonVersionSource, + PythonVersionWithSource, SearchPathSettings, default_lint_registry, }; use test_case::test_case; @@ -182,6 +182,7 @@ pub struct CorpusDb { rule_selection: RuleSelection, system: TestSystem, vendored: VendoredFileSystem, + call_stack: CallStack, } impl CorpusDb { @@ -193,6 +194,7 @@ impl CorpusDb { vendored: ty_vendored::file_system().clone(), rule_selection: RuleSelection::from_registry(default_lint_registry()), files: Files::default(), + call_stack: CallStack::default(), }; Program::from_settings( @@ -255,6 +257,10 @@ impl ty_python_semantic::Db for CorpusDb { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } + + fn call_stack(&self) -> &CallStack { + &self.call_stack + } } #[salsa::db] diff --git a/crates/ty_test/src/db.rs b/crates/ty_test/src/db.rs index 76fff50e4eb65..e9bbd33988f1a 100644 --- a/crates/ty_test/src/db.rs +++ b/crates/ty_test/src/db.rs @@ -12,7 +12,7 @@ use std::borrow::Cow; use std::sync::Arc; use tempfile::TempDir; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; -use ty_python_semantic::{Db as SemanticDb, Program, default_lint_registry}; +use ty_python_semantic::{CallStack, Db as SemanticDb, Program, default_lint_registry}; #[salsa::db] #[derive(Clone)] @@ -22,6 +22,7 @@ pub(crate) struct Db { system: MdtestSystem, vendored: VendoredFileSystem, rule_selection: Arc, + call_stack: CallStack, } impl Db { @@ -38,6 +39,7 @@ impl Db { vendored: ty_vendored::file_system().clone(), files: Files::default(), rule_selection: Arc::new(rule_selection), + call_stack: CallStack::default(), } } @@ -88,6 +90,10 @@ impl SemanticDb for Db { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } + + fn call_stack(&self) -> &CallStack { + &self.call_stack + } } #[salsa::db] diff --git a/fuzz/fuzz_targets/ty_check_invalid_syntax.rs b/fuzz/fuzz_targets/ty_check_invalid_syntax.rs index 4dd62cd0c53bd..3b5870cac7544 100644 --- a/fuzz/fuzz_targets/ty_check_invalid_syntax.rs +++ b/fuzz/fuzz_targets/ty_check_invalid_syntax.rs @@ -18,7 +18,7 @@ use ruff_python_parser::{Mode, ParseOptions, parse_unchecked}; use ty_python_semantic::lint::LintRegistry; use ty_python_semantic::types::check_types; use ty_python_semantic::{ - Db as SemanticDb, Program, ProgramSettings, PythonPlatform, PythonVersionWithSource, + CallStack, Db as SemanticDb, Program, ProgramSettings, PythonPlatform, PythonVersionWithSource, SearchPathSettings, default_lint_registry, lint::RuleSelection, }; @@ -33,6 +33,7 @@ struct TestDb { system: TestSystem, vendored: VendoredFileSystem, rule_selection: Arc, + call_stack: CallStack, } impl TestDb { @@ -47,6 +48,7 @@ impl TestDb { vendored: ty_vendored::file_system().clone(), files: Files::default(), rule_selection: RuleSelection::from_registry(default_lint_registry()).into(), + call_stack: CallStack::default(), } } } @@ -93,6 +95,10 @@ impl SemanticDb for TestDb { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } + + fn call_stack(&self) -> &CallStack { + &self.call_stack + } } #[salsa::db] From 65bdbf9d507115263c72fbeaffff1ab3b15d9f12 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 7 Jul 2025 21:16:48 +0900 Subject: [PATCH 014/105] move `scipy`, `sympy` to `bad.txt` --- crates/ty_python_semantic/resources/primer/bad.txt | 2 ++ crates/ty_python_semantic/resources/primer/good.txt | 2 -- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/resources/primer/bad.txt b/crates/ty_python_semantic/resources/primer/bad.txt index 213e5c4dca52d..175e0c6a8ad7b 100644 --- a/crates/ty_python_semantic/resources/primer/bad.txt +++ b/crates/ty_python_semantic/resources/primer/bad.txt @@ -15,8 +15,10 @@ pip # vendors packaging, see above pylint # cycle panics (self-recursive type alias) pyodide # too many cycle iterations scikit-build-core # too many cycle iterations +scipy # probabilistic stack overflow (multi threaded) setuptools # vendors packaging, see above spack # slow, success, but mypy-primer hangs processing the output spark # too many iterations steam.py # hangs (single threaded) +sympy # stack overflow (multi threaded) xarray # too many iterations diff --git a/crates/ty_python_semantic/resources/primer/good.txt b/crates/ty_python_semantic/resources/primer/good.txt index 9a036f3c128df..b3575d2b6a6f7 100644 --- a/crates/ty_python_semantic/resources/primer/good.txt +++ b/crates/ty_python_semantic/resources/primer/good.txt @@ -98,7 +98,6 @@ rotki schema_salad schemathesis scikit-learn -scipy scrapy sockeye speedrun.com_global_scoreboard_webapp @@ -109,7 +108,6 @@ stone strawberry streamlit svcs -sympy tornado trio twine From f17a650361afffe7fcd9f54d64ecccc7ac1c8307 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 7 Jul 2025 23:07:51 +0900 Subject: [PATCH 015/105] make `TupleType::to_class_type` a tracked function --- crates/ty_python_semantic/src/types/tuple.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 58dcb0debbffb..7b0897b320935 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -110,6 +110,7 @@ impl<'db> Type<'db> { } } +#[salsa::tracked] impl<'db> TupleType<'db> { pub(crate) fn new(db: &'db dyn Db, tuple_key: T) -> Option where @@ -168,6 +169,7 @@ impl<'db> TupleType<'db> { Type::tuple(TupleType::new(db, TupleSpec::homogeneous(element))) } + #[salsa::tracked] pub(crate) fn to_class_type(self, db: &'db dyn Db) -> Option> { KnownClass::Tuple .try_to_class_literal(db) From 4f0ce11bf5c1f2d51cac5d5a622da16796f197fd Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 7 Jul 2025 23:43:48 +0900 Subject: [PATCH 016/105] increase `max_diagnostics` in `sympy` --- crates/ruff_benchmark/benches/ty_walltime.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index e1a5db9ebe96f..174cb0065ce4d 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -199,7 +199,7 @@ static SYMPY: std::sync::LazyLock> = std::sync::LazyLock::new max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 13000, + 75000, ) }); From 9bc45f81d5bcab395aa3ea5b6beea7e8a56cc0ad Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 2 Aug 2025 03:01:17 +0900 Subject: [PATCH 017/105] move `scipy` to `good.txt` and `manticore` to `bad.txt` --- crates/ty_python_semantic/resources/primer/bad.txt | 2 +- crates/ty_python_semantic/resources/primer/good.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/resources/primer/bad.txt b/crates/ty_python_semantic/resources/primer/bad.txt index 0b6106a8023b6..cc1b5d44ca4ed 100644 --- a/crates/ty_python_semantic/resources/primer/bad.txt +++ b/crates/ty_python_semantic/resources/primer/bad.txt @@ -7,6 +7,7 @@ cpython # too many cycle iterations hydpy # too many iterations ibis # too many iterations jax # too many iterations +manticore # too many iterations mypy # too many iterations (self-recursive type alias) packaging # too many iterations pandas # slow (9s) @@ -17,7 +18,6 @@ pip # vendors packaging, see above pylint # cycle panics (self-recursive type alias) pyodide # too many cycle iterations scikit-build-core # too many cycle iterations -scipy # probabilistic stack overflow (multi threaded) setuptools # vendors packaging, see above spack # slow, success, but mypy-primer hangs processing the output spark # too many iterations diff --git a/crates/ty_python_semantic/resources/primer/good.txt b/crates/ty_python_semantic/resources/primer/good.txt index 368345dc712f8..2e035f2a4e21c 100644 --- a/crates/ty_python_semantic/resources/primer/good.txt +++ b/crates/ty_python_semantic/resources/primer/good.txt @@ -53,7 +53,6 @@ jinja koda-validate kopf kornia -manticore materialize meson mitmproxy @@ -98,6 +97,7 @@ rotki schema_salad schemathesis scikit-learn +scipy scrapy sockeye speedrun.com_global_scoreboard_webapp From 1e1993f1bc7a0c9a4894c639b33d2b482d1bbad5 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 2 Aug 2025 03:40:58 +0900 Subject: [PATCH 018/105] move `ScopeInference::returnees` to `ScopeInferenceExtra` --- .../resources/primer/bad.txt | 1 - .../resources/primer/good.txt | 1 + .../src/semantic_index/scope.rs | 8 ++++ crates/ty_python_semantic/src/types/infer.rs | 44 +++++++++++-------- 4 files changed, 34 insertions(+), 20 deletions(-) diff --git a/crates/ty_python_semantic/resources/primer/bad.txt b/crates/ty_python_semantic/resources/primer/bad.txt index cc1b5d44ca4ed..9d839ff046a45 100644 --- a/crates/ty_python_semantic/resources/primer/bad.txt +++ b/crates/ty_python_semantic/resources/primer/bad.txt @@ -7,7 +7,6 @@ cpython # too many cycle iterations hydpy # too many iterations ibis # too many iterations jax # too many iterations -manticore # too many iterations mypy # too many iterations (self-recursive type alias) packaging # too many iterations pandas # slow (9s) diff --git a/crates/ty_python_semantic/resources/primer/good.txt b/crates/ty_python_semantic/resources/primer/good.txt index 2e035f2a4e21c..650040ce10a37 100644 --- a/crates/ty_python_semantic/resources/primer/good.txt +++ b/crates/ty_python_semantic/resources/primer/good.txt @@ -53,6 +53,7 @@ jinja koda-validate kopf kornia +manticore materialize meson mitmproxy diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index 89581d9194d92..0cb65f5001f80 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -29,6 +29,10 @@ impl<'db> ScopeId<'db> { self.node(db).scope_kind().is_function_like() } + pub(crate) fn is_function_or_lambda(self, db: &'db dyn Db) -> bool { + self.node(db).scope_kind().is_function_or_lambda() + } + pub(crate) fn is_type_parameter(self, db: &'db dyn Db) -> bool { self.node(db).scope_kind().is_type_parameter() } @@ -247,6 +251,10 @@ impl ScopeKind { ) } + pub(crate) const fn is_function_or_lambda(self) -> bool { + matches!(self, ScopeKind::Function | ScopeKind::Lambda) + } + pub(crate) const fn is_class(self) -> bool { matches!(self, ScopeKind::Class) } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 1672942fbf608..3b2d0125327a0 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -469,11 +469,6 @@ pub(crate) struct ScopeInference<'db> { /// The extra data that is only present for few inference regions. extra: Option>, - - /// The returnees of this region (if this is a function body). - /// - /// These are stored in `Vec` to delay the creation of the union type as long as possible. - returnees: Vec>, } #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] @@ -485,6 +480,11 @@ struct ScopeInferenceExtra { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, + + /// The returnees of this region (if this is a function body). + /// + /// These are stored in `Vec` to delay the creation of the union type as long as possible. + returnees: Vec>, } impl<'db> ScopeInference<'db> { @@ -495,7 +495,6 @@ impl<'db> ScopeInference<'db> { cycle_fallback: true, ..ScopeInferenceExtra::default() })), - returnees: Vec::new(), expressions: FxHashMap::default(), } } @@ -539,7 +538,12 @@ impl<'db> ScopeInference<'db> { method_ty: Option>, ) -> Type<'db> { let mut union = UnionBuilder::new(db); - for returnee in &self.returnees { + let Some(extra) = &self.extra else { + unreachable!( + "infer_return_type should only be called on a function body scope inference" + ); + }; + for returnee in &extra.returnees { let ty = returnee.map_or(Type::none(db), |expression| { self.expression_type(expression) }); @@ -9228,6 +9232,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { pub(super) fn finish_scope(mut self) -> ScopeInference<'db> { self.infer_region(); + let db = self.db(); let Self { context, @@ -9251,25 +9256,26 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let diagnostics = context.finish(); - let extra = (!diagnostics.is_empty() || cycle_fallback).then(|| { - Box::new(ScopeInferenceExtra { - cycle_fallback, - diagnostics, - }) - }); + let extra = (!diagnostics.is_empty() || cycle_fallback || scope.is_function_or_lambda(db)) + .then(|| { + let returnees = returnees + .into_iter() + .map(|returnee| returnee.expression) + .collect(); - expressions.shrink_to_fit(); + Box::new(ScopeInferenceExtra { + cycle_fallback, + diagnostics, + returnees, + }) + }); - let returnees = returnees - .into_iter() - .map(|returnee| returnee.expression) - .collect(); + expressions.shrink_to_fit(); ScopeInference { scope, expressions, extra, - returnees, } } } From e275162dde672e125b538c4a93b6b2ca38ccffcf Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 2 Aug 2025 23:04:02 +0900 Subject: [PATCH 019/105] Update type.md --- .../resources/mdtest/narrow/type.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type.md b/crates/ty_python_semantic/resources/mdtest/narrow/type.md index fc4d784ccdbb8..eef9bfd379e09 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type.md @@ -56,13 +56,20 @@ def _(x: A | B): ## No special narrowing for custom `type` callable +`stub.pyi`: + +```pyi +from ty_extensions import TypeOf + +def type(x: object) -> TypeOf[int]: ... +``` + ```py +from stub import type + class A: ... class B: ... -def type(x): - return int - def _(x: A | B): if type(x) is A: reveal_type(x) # revealed: Never From 46d21df55eedafcee4feb4b889efd36d6a5f20d1 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 5 Aug 2025 18:04:28 +0900 Subject: [PATCH 020/105] Avoid unspecialized type variables appearing in the return type --- .../resources/mdtest/function/return_type.md | 24 +++++++++++++++++-- crates/ty_python_semantic/src/types.rs | 18 +++++++++----- crates/ty_python_semantic/src/types/infer.rs | 15 ++++++++++-- 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index cc6b117a29a57..b7bd07c85024f 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -422,6 +422,12 @@ class C: def g[T](self, x: T) -> T: return x + def h[T: int](self, x: T) -> T: + return x + + def i[T: int](self, x: T) -> list[T]: + return [x] + class D(C): def f(self): return 2 @@ -429,6 +435,15 @@ class D(C): # If the override is invalid, the type of the method should be that of the base class method. def g(self, x: int): return 2 + # A strict application of the Liskov Substitution Principle would consider + # this an invalid override because it violates the guarantee that the method returns + # the same type as its input type (any type smaller than int), + # but neither mypy nor pyright will throw an error for this. + def h(self, x: int): + return 2 + + def i(self, x: int): + return [2] class E(D): def f(self): @@ -438,8 +453,13 @@ reveal_type(C().f()) # revealed: int reveal_type(D().f()) # revealed: int reveal_type(E().f()) # revealed: int reveal_type(C().g(1)) # revealed: Literal[1] -# TODO: should be `Literal[1]` -reveal_type(D().g(1)) # revealed: Literal[2] | T@g +reveal_type(D().g(1)) # revealed: Literal[2] | Unknown +reveal_type(C().h(1)) # revealed: Literal[1] +reveal_type(D().h(1)) # revealed: Literal[2] | Unknown +reveal_type(C().h(True)) # revealed: Literal[True] +reveal_type(D().h(True)) # revealed: Literal[2] | Unknown +reveal_type(C().i(1)) # revealed: list[Literal[1]] +reveal_type(D().i(1)) # revealed: list[Unknown] class F: def f(self) -> Literal[1, 2]: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 60bc28796abd2..cc1988a4c9ba4 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -7837,7 +7837,10 @@ impl<'db> BoundMethodType<'db> { .any(|deco| deco == KnownFunction::Final) } - pub(crate) fn base_return_type(self, db: &'db dyn Db) -> Option> { + pub(crate) fn base_signature_and_return_type( + self, + db: &'db dyn Db, + ) -> Option<(Signature<'db>, Option>)> { let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?; let name = self.function(db).name(db); @@ -7848,11 +7851,14 @@ impl<'db> BoundMethodType<'db> { let base_member = base.class_member(db, name, MemberLookupPolicy::default()); if let Place::Type(Type::FunctionLiteral(base_func), _) = base_member.place { if let [signature] = base_func.signature(db).overloads.as_slice() { - signature.return_ty.or_else(|| { - let base_method_ty = - base_func.into_bound_method_type(db, Type::instance(db, class)); - base_method_ty.infer_return_type(db) - }) + Some(( + signature.clone(), + signature.return_ty.or_else(|| { + let base_method_ty = + base_func.into_bound_method_type(db, Type::instance(db, class)); + base_method_ty.infer_return_type(db) + }), + )) } else { // TODO: Handle overloaded base methods. None diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 3b2d0125327a0..260465f0d0257 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -558,8 +558,19 @@ impl<'db> ScopeInference<'db> { // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. if !method_ty.is_final(db) { - let return_ty = method_ty.base_return_type(db).unwrap_or(Type::unknown()); - union = union.add(return_ty); + let (signature, return_ty) = method_ty + .base_signature_and_return_type(db) + .unwrap_or((Signature::unknown(), None)); + let return_ty = return_ty.unwrap_or(Type::unknown()); + if let Some(generic_context) = signature.generic_context.as_ref() { + // If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables. + union = union.add( + return_ty + .apply_specialization(db, generic_context.unknown_specialization(db)), + ); + } else { + union = union.add(return_ty); + } } } union.build() From 3701c730229fd7ddbafb8ba8839be9a62260edf6 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 6 Aug 2025 13:40:29 +0900 Subject: [PATCH 021/105] WIP: remove `CallStack` --- crates/ruff_graph/src/db.rs | 10 +--- crates/ty_project/src/db.rs | 17 +------ crates/ty_python_semantic/src/db.rs | 49 +------------------ crates/ty_python_semantic/src/lib.rs | 2 +- crates/ty_python_semantic/src/types.rs | 10 +--- .../ty_python_semantic/src/types/function.rs | 12 ++--- crates/ty_python_semantic/src/types/infer.rs | 42 ---------------- crates/ty_python_semantic/tests/corpus.rs | 10 +--- crates/ty_test/src/db.rs | 8 +-- 9 files changed, 14 insertions(+), 146 deletions(-) diff --git a/crates/ruff_graph/src/db.rs b/crates/ruff_graph/src/db.rs index 6615e51eabe3a..cbca766de6199 100644 --- a/crates/ruff_graph/src/db.rs +++ b/crates/ruff_graph/src/db.rs @@ -9,9 +9,8 @@ use ruff_db::vendored::{VendoredFileSystem, VendoredFileSystemBuilder}; use ruff_python_ast::PythonVersion; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; use ty_python_semantic::{ - CallStack, Db, Program, ProgramSettings, PythonEnvironment, PythonPlatform, - PythonVersionSource, PythonVersionWithSource, SearchPathSettings, SysPrefixPathOrigin, - default_lint_registry, + Db, Program, ProgramSettings, PythonEnvironment, PythonPlatform, PythonVersionSource, + PythonVersionWithSource, SearchPathSettings, SysPrefixPathOrigin, default_lint_registry, }; static EMPTY_VENDORED: std::sync::LazyLock = std::sync::LazyLock::new(|| { @@ -27,7 +26,6 @@ pub struct ModuleDb { files: Files, system: OsSystem, rule_selection: Arc, - call_stack: CallStack, } impl ModuleDb { @@ -100,10 +98,6 @@ impl Db for ModuleDb { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } - - fn call_stack(&self) -> &CallStack { - &self.call_stack - } } #[salsa::db] diff --git a/crates/ty_project/src/db.rs b/crates/ty_project/src/db.rs index 1deab46abea8d..d4d9a12e406e3 100644 --- a/crates/ty_project/src/db.rs +++ b/crates/ty_project/src/db.rs @@ -15,7 +15,7 @@ use ruff_db::vendored::VendoredFileSystem; use salsa::plumbing::ZalsaDatabase; use salsa::{Event, Setter}; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; -use ty_python_semantic::{CallStack, Db as SemanticDb, Program}; +use ty_python_semantic::{Db as SemanticDb, Program}; mod changes; @@ -39,8 +39,6 @@ pub struct ProjectDatabase { // However, for this to work it's important that the `storage` is dropped AFTER any `Arc` that // we try to mutably borrow using `Arc::get_mut` (like `system`). storage: salsa::Storage, - - call_stack: CallStack, } impl ProjectDatabase { @@ -65,7 +63,6 @@ impl ProjectDatabase { }), files: Files::default(), system: Arc::new(system), - call_stack: CallStack::default(), }; // TODO: Use the `program_settings` to compute the key for the database's persistent @@ -418,10 +415,6 @@ impl SemanticDb for ProjectDatabase { fn lint_registry(&self) -> &LintRegistry { &DEFAULT_LINT_REGISTRY } - - fn call_stack(&self) -> &CallStack { - &self.call_stack - } } #[salsa::db] @@ -480,7 +473,7 @@ pub(crate) mod tests { use ty_python_semantic::lint::{LintRegistry, RuleSelection}; use crate::DEFAULT_LINT_REGISTRY; - use crate::db::{CallStack, Db}; + use crate::db::Db; use crate::{Project, ProjectMetadata}; type Events = Arc>>; @@ -494,7 +487,6 @@ pub(crate) mod tests { system: TestSystem, vendored: VendoredFileSystem, project: Option, - call_stack: CallStack, } impl TestDb { @@ -513,7 +505,6 @@ pub(crate) mod tests { files: Files::default(), events, project: None, - call_stack: CallStack::default(), }; let project = Project::from_metadata(&db, project).unwrap(); @@ -573,10 +564,6 @@ pub(crate) mod tests { fn lint_registry(&self) -> &LintRegistry { &DEFAULT_LINT_REGISTRY } - - fn call_stack(&self) -> &CallStack { - &self.call_stack - } } #[salsa::db] diff --git a/crates/ty_python_semantic/src/db.rs b/crates/ty_python_semantic/src/db.rs index c511557a27805..645929235a7b3 100644 --- a/crates/ty_python_semantic/src/db.rs +++ b/crates/ty_python_semantic/src/db.rs @@ -1,45 +1,6 @@ -use std::hash::{Hash, Hasher}; -use std::sync::{Arc, Mutex}; - use crate::lint::{LintRegistry, RuleSelection}; -use crate::semantic_index::scope::FileScopeId; use ruff_db::Db as SourceDb; use ruff_db::files::File; -use rustc_hash::FxHasher; - -/// A stack of the currently inferred function scopes. -/// Used to monitor type inference for recursive functions to ensure they do not diverge. -/// This call stack is currently only used to infer functions with unspecified return types ​​and does not faithfully represent the actual call stack. -#[derive(Debug, Default, Clone)] -pub struct CallStack(Arc>>); - -impl CallStack { - pub fn new() -> Self { - CallStack(Arc::new(Mutex::new(Vec::new()))) - } - - pub fn push(&self, file: File, scope: FileScopeId) { - self.0.lock().unwrap().push((file, scope)); - } - - pub fn pop(&self) -> Option<(File, FileScopeId)> { - self.0.lock().unwrap().pop() - } - - pub fn contains(&self, file: File, scope: FileScopeId) -> bool { - self.0 - .lock() - .unwrap() - .iter() - .any(|(f, s)| f == &file && s == &scope) - } - - pub fn hash_value(&self) -> u64 { - let mut hasher = FxHasher::default(); - self.0.lock().unwrap().hash(&mut hasher); - hasher.finish() - } -} /// Database giving access to semantic information about a Python program. #[salsa::db] @@ -51,8 +12,6 @@ pub trait Db: SourceDb { fn rule_selection(&self, file: File) -> &RuleSelection; fn lint_registry(&self) -> &LintRegistry; - - fn call_stack(&self) -> &CallStack; } #[cfg(test)] @@ -65,7 +24,7 @@ pub(crate) mod tests { default_lint_registry, }; - use super::{CallStack, Db}; + use super::Db; use crate::lint::{LintRegistry, RuleSelection}; use anyhow::Context; use ruff_db::Db as SourceDb; @@ -87,7 +46,6 @@ pub(crate) mod tests { vendored: VendoredFileSystem, events: Events, rule_selection: Arc, - call_stack: CallStack, } impl TestDb { @@ -107,7 +65,6 @@ pub(crate) mod tests { events, files: Files::default(), rule_selection: Arc::new(RuleSelection::from_registry(default_lint_registry())), - call_stack: CallStack::default(), } } @@ -169,10 +126,6 @@ pub(crate) mod tests { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } - - fn call_stack(&self) -> &CallStack { - &self.call_stack - } } #[salsa::db] diff --git a/crates/ty_python_semantic/src/lib.rs b/crates/ty_python_semantic/src/lib.rs index ac8391e76ccc2..c185942247bd0 100644 --- a/crates/ty_python_semantic/src/lib.rs +++ b/crates/ty_python_semantic/src/lib.rs @@ -4,7 +4,7 @@ use rustc_hash::FxHasher; use crate::lint::{LintRegistry, LintRegistryBuilder}; use crate::suppression::{INVALID_IGNORE_COMMENT, UNKNOWN_RULE, UNUSED_IGNORE_COMMENT}; -pub use db::{CallStack, Db}; +pub use db::Db; pub use module_name::ModuleName; pub use module_resolver::{ Module, SearchPathValidationError, SearchPaths, resolve_module, resolve_real_module, diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index cc1988a4c9ba4..0212b4b00b806 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -54,7 +54,7 @@ pub use crate::types::ide_support::{ definitions_for_attribute, definitions_for_imported_symbol, definitions_for_keyword_argument, definitions_for_name, }; -use crate::types::infer::{infer_scope_types_with_call_stack, infer_unpack_types}; +use crate::types::infer::infer_unpack_types; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature}; @@ -7801,13 +7801,7 @@ impl<'db> BoundMethodType<'db> { .literal(db) .last_definition(db) .body_scope(db); - if db - .call_stack() - .contains(scope.file(db), scope.file_scope_id(db)) - { - return Type::unknown(); - } - let inference = infer_scope_types_with_call_stack(db, scope, db.call_stack().hash_value()); + let inference = infer_scope_types(db, scope); inference.infer_return_type(db, Some(self)) } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index b71020ddb7acd..01c82204d1d1c 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -72,14 +72,14 @@ use crate::types::diagnostic::{ report_runtime_check_against_non_runtime_checkable_protocol, }; use crate::types::generics::GenericContext; -use crate::types::infer::infer_scope_types_with_call_stack; use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; use crate::types::{ BoundMethodType, CallableType, ClassBase, ClassLiteral, ClassType, DeprecatedInstance, DynamicType, KnownClass, Truthiness, Type, TypeMapping, TypeRelation, TypeTransformer, - TypeVarInstance, UnionBuilder, all_members, walk_generic_context, walk_type_mapping, + TypeVarInstance, UnionBuilder, all_members, infer_scope_types, walk_generic_context, + walk_type_mapping, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; @@ -891,13 +891,7 @@ impl<'db> FunctionType<'db> { /// Infers this function scope's types and returns the inferred return type. pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.literal(db).last_definition(db).body_scope(db); - if db - .call_stack() - .contains(scope.file(db), scope.file_scope_id(db)) - { - return Type::unknown(); - } - let inference = infer_scope_types_with_call_stack(db, scope, db.call_stack().hash_value()); + let inference = infer_scope_types(db, scope); inference.infer_return_type(db, None) } } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 260465f0d0257..caacf1a9d88d3 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -146,30 +146,6 @@ pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Sc TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish_scope() } -/// Pushes a function scope onto the call stack and then performs scope analysis. It is used to infer the return value of a function. -/// Return values ​​of recursive functions are inferred by explicitly detecting recursion and returning `Unknown` instead of salsa's fixpoint iteration. -#[salsa::tracked(returns(ref), cycle_fn=scope_with_call_stack_cycle_recover, cycle_initial=scope_with_call_stack_cycle_initial)] -pub(crate) fn infer_scope_types_with_call_stack<'db>( - db: &'db dyn Db, - scope: ScopeId<'db>, - _call_stack_hash: u64, -) -> ScopeInference<'db> { - let file = scope.file(db); - - let module = parsed_module(db, file).load(db); - - // Using the index here is fine because the code below depends on the AST anyway. - // The isolation of the query is by the return inferred types. - let index = semantic_index(db, file); - - db.call_stack().push(file, scope.file_scope_id(db)); - let inference = - TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish_scope(); - db.call_stack().pop(); - - inference -} - fn scope_cycle_recover<'db>( _db: &'db dyn Db, _value: &ScopeInference<'db>, @@ -183,24 +159,6 @@ fn scope_cycle_initial<'db>(_db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInfer ScopeInference::cycle_fallback(scope) } -fn scope_with_call_stack_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &ScopeInference<'db>, - _count: u32, - _scope: ScopeId<'db>, - _call_stack_hash: u64, -) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate -} - -fn scope_with_call_stack_cycle_initial<'db>( - _db: &'db dyn Db, - scope: ScopeId<'db>, - _call_stack_hash: u64, -) -> ScopeInference<'db> { - ScopeInference::cycle_fallback(scope) -} - /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a place use or public type of a place. #[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=get_size2::heap_size)] diff --git a/crates/ty_python_semantic/tests/corpus.rs b/crates/ty_python_semantic/tests/corpus.rs index 090b4c9e98b42..83ad7ae1ffdae 100644 --- a/crates/ty_python_semantic/tests/corpus.rs +++ b/crates/ty_python_semantic/tests/corpus.rs @@ -8,8 +8,8 @@ use ruff_python_ast::PythonVersion; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; use ty_python_semantic::pull_types::pull_types; use ty_python_semantic::{ - CallStack, Program, ProgramSettings, PythonPlatform, PythonVersionSource, - PythonVersionWithSource, SearchPathSettings, default_lint_registry, + Program, ProgramSettings, PythonPlatform, PythonVersionSource, PythonVersionWithSource, + SearchPathSettings, default_lint_registry, }; use test_case::test_case; @@ -182,7 +182,6 @@ pub struct CorpusDb { rule_selection: RuleSelection, system: TestSystem, vendored: VendoredFileSystem, - call_stack: CallStack, } impl CorpusDb { @@ -194,7 +193,6 @@ impl CorpusDb { vendored: ty_vendored::file_system().clone(), rule_selection: RuleSelection::from_registry(default_lint_registry()), files: Files::default(), - call_stack: CallStack::default(), }; Program::from_settings( @@ -257,10 +255,6 @@ impl ty_python_semantic::Db for CorpusDb { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } - - fn call_stack(&self) -> &CallStack { - &self.call_stack - } } #[salsa::db] diff --git a/crates/ty_test/src/db.rs b/crates/ty_test/src/db.rs index 205d7d3f542d9..1a47745e1d3b2 100644 --- a/crates/ty_test/src/db.rs +++ b/crates/ty_test/src/db.rs @@ -12,7 +12,7 @@ use std::borrow::Cow; use std::sync::Arc; use tempfile::TempDir; use ty_python_semantic::lint::{LintRegistry, RuleSelection}; -use ty_python_semantic::{CallStack, Db as SemanticDb, Program, default_lint_registry}; +use ty_python_semantic::{Db as SemanticDb, Program, default_lint_registry}; #[salsa::db] #[derive(Clone)] @@ -22,7 +22,6 @@ pub(crate) struct Db { system: MdtestSystem, vendored: VendoredFileSystem, rule_selection: Arc, - call_stack: CallStack, } impl Db { @@ -39,7 +38,6 @@ impl Db { vendored: ty_vendored::file_system().clone(), files: Files::default(), rule_selection: Arc::new(rule_selection), - call_stack: CallStack::default(), } } @@ -90,10 +88,6 @@ impl SemanticDb for Db { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } - - fn call_stack(&self) -> &CallStack { - &self.call_stack - } } #[salsa::db] From e4246f53a1637d8866c39077c914dc82e537f141 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 6 Aug 2025 20:33:21 +0900 Subject: [PATCH 022/105] add `Divergent` type --- .../resources/mdtest/function/return_type.md | 21 +-- crates/ty_python_semantic/src/types.rs | 126 +++++++++++++++++- .../ty_python_semantic/src/types/builder.rs | 33 ++++- .../src/types/class_base.rs | 1 + .../ty_python_semantic/src/types/function.rs | 14 ++ crates/ty_python_semantic/src/types/infer.rs | 38 +++++- .../ty_python_semantic/src/types/instance.rs | 7 + .../src/types/protocol_class.rs | 2 +- .../src/types/type_ordering.rs | 3 + 9 files changed, 226 insertions(+), 19 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index b7bd07c85024f..3f3a433ae72ae 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -307,8 +307,7 @@ def fibonacci(n: int): else: return fibonacci(n - 1) + fibonacci(n - 2) -# TODO: it may be better to infer this as `int` if we can -reveal_type(fibonacci(5)) # revealed: Literal[0, 1] | Unknown +reveal_type(fibonacci(5)) # revealed: int def even(n: int): if n == 0: @@ -322,9 +321,8 @@ def odd(n: int): else: return even(n - 1) -# TODO: it may be better to infer these as `bool` if we can -reveal_type(even(1)) # revealed: bool | Unknown -reveal_type(odd(1)) # revealed: bool | Unknown +reveal_type(even(1)) # revealed: bool +reveal_type(odd(1)) # revealed: bool def repeat_a(n: int): if n <= 0: @@ -332,8 +330,7 @@ def repeat_a(n: int): else: return repeat_a(n - 1) + "a" -# TODO: it may be better to infer this as `str` if we can -reveal_type(repeat_a(3)) # revealed: Literal[""] | Unknown +reveal_type(repeat_a(3)) # revealed: str def divergent(value): if type(value) is tuple: @@ -342,14 +339,20 @@ def divergent(value): return None # tuple[tuple[tuple[...] | None] | None] | None => tuple[Unknown] | None -reveal_type(divergent((1,))) # revealed: tuple[Unknown] | None +reveal_type(divergent((1,))) # revealed: Divergent | None + +def call_divergent(x: int): + return (divergent((1, 2, 3)), x) + +# TODO: it would be better to reveal `tuple[Divergent | None, int]` +reveal_type(call_divergent(1)) # revealed: Divergent def nested_scope(): def inner(): return nested_scope() return inner() -reveal_type(nested_scope()) # revealed: Unknown +reveal_type(nested_scope()) # revealed: Never def eager_nested_scope(): class A: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 0212b4b00b806..053384b348ab7 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -97,6 +97,34 @@ mod definition; #[cfg(test)] mod property_tests; +fn return_type_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: BoundMethodType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn return_type_cycle_initial<'db>(_db: &'db dyn Db, _self: BoundMethodType<'db>) -> Type<'db> { + Type::Dynamic(DynamicType::Divergent) +} + +#[allow(clippy::trivially_copy_pass_by_ref)] +fn has_divergent_type_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &bool, + _count: u32, + _self: Type<'db>, + _unit: (), +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +fn has_divergent_type_cycle_initial<'db>(_db: &'db dyn Db, _slf: Type<'db>, _unit: ()) -> bool { + false +} + pub fn check_types(db: &dyn Db, file: File) -> Vec { let _span = tracing::trace_span!("check_types", ?file).entered(); @@ -6007,6 +6035,90 @@ impl<'db> Type<'db> { _ => None, } } + + pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + self._has_divergent_type(db, ()) + } + + #[allow(clippy::used_underscore_binding)] + #[salsa::tracked(cycle_fn=has_divergent_type_cycle_recover, cycle_initial=has_divergent_type_cycle_initial, heap_size=get_size2::heap_size)] + fn _has_divergent_type(self, db: &'db dyn Db, _unit: ()) -> bool { + match self { + Type::Dynamic(DynamicType::Divergent) => true, + Type::Union(union) => union.iter(db).any(|ty| ty.has_divergent_type(db)), + Type::Intersection(intersection) => { + intersection + .positive(db) + .iter() + .any(|ty| ty.has_divergent_type(db)) + || intersection + .negative(db) + .iter() + .any(|ty| ty.has_divergent_type(db)) + } + Type::Tuple(tuple) => tuple + .tuple(db) + .all_elements() + .any(|ty| ty.has_divergent_type(db)), + Type::GenericAlias(alias) => alias + .specialization(db) + .types(db) + .iter() + .any(|ty| ty.has_divergent_type(db)), + Type::NominalInstance(instance) => match instance.class { + ClassType::Generic(alias) => alias + .specialization(db) + .types(db) + .iter() + .any(|ty| ty.has_divergent_type(db)), + ClassType::NonGeneric(_) => false, + }, + Type::Callable(callable) => callable.signatures(db).iter().any(|sig| { + sig.parameters().iter().any(|param| { + param + .annotated_type() + .is_some_and(|ty| ty.has_divergent_type(db)) + }) || sig.return_ty.iter().any(|ty| ty.has_divergent_type(db)) + }), + Type::ProtocolInstance(protocol) => protocol.has_divergent_type(db), + Type::PropertyInstance(property) => { + property + .setter(db) + .is_some_and(|setter| setter.has_divergent_type(db)) + || property + .getter(db) + .is_some_and(|getter| getter.has_divergent_type(db)) + } + Type::TypeIs(type_is) => type_is.return_type(db).has_divergent_type(db), + Type::SubclassOf(subclass_of) => match subclass_of.subclass_of() { + SubclassOfInner::Dynamic(DynamicType::Divergent) => true, + SubclassOfInner::Dynamic(_) => false, + SubclassOfInner::Class(class) => class.metaclass(db).has_divergent_type(db), + }, + Type::Never + | Type::AlwaysTruthy + | Type::AlwaysFalsy + | Type::WrapperDescriptor(_) + | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) + | Type::ModuleLiteral(_) + | Type::ClassLiteral(_) + | Type::IntLiteral(_) + | Type::BooleanLiteral(_) + | Type::LiteralString + | Type::StringLiteral(_) + | Type::BytesLiteral(_) + | Type::EnumLiteral(_) + | Type::BoundSuper(_) + | Type::SpecialForm(_) + | Type::KnownInstance(_) + | Type::TypeVar(_) + | Type::FunctionLiteral(_) + | Type::BoundMethod(_) + | Type::Dynamic(_) => false, + } + } } impl<'db> From<&Type<'db>> for Type<'db> { @@ -6266,12 +6378,20 @@ pub enum DynamicType { /// A special Todo-variant for classes inheriting from `TypedDict`. /// A temporary variant to avoid false positives while we wait for full support. TodoTypedDict, + /// A type that is determined to be divergent during type inference for a recursive function. + /// This type must never be eliminated by reduction + /// (e.g. `Divergent` is assignable to `@Todo`, but `@Todo | Divergent` must not be reducted to `@Todo`). + /// Otherwise, type inference cannot converge properly. + Divergent, } impl DynamicType { - #[expect(clippy::unused_self)] fn normalized(self) -> Self { - Self::Any + if matches!(self, Self::Divergent) { + self + } else { + Self::Any + } } } @@ -6304,6 +6424,7 @@ impl std::fmt::Display for DynamicType { f.write_str("@Todo") } } + DynamicType::Divergent => f.write_str("Divergent"), } } } @@ -7795,6 +7916,7 @@ impl<'db> BoundMethodType<'db> { } /// Infers this method scope's types and returns the inferred return type. + #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self .function(db) diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index bf160570fb94c..020d4f13c9f69 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -39,7 +39,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::{ - BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type, + BytesLiteralType, DynamicType, IntersectionType, KnownClass, StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType, }; use crate::{Db, FxOrderSet}; @@ -202,7 +202,7 @@ enum ReduceResult<'db> { // TODO increase this once we extend `UnionElement` throughout all union/intersection // representations, so that we can make large unions of literals fast in all operations. -const MAX_UNION_LITERALS: usize = 200; +const MAX_UNION_LITERALS: usize = 199; pub(crate) struct UnionBuilder<'db> { elements: Vec>, @@ -418,6 +418,15 @@ impl<'db> UnionBuilder<'db> { ty if ty.is_object(self.db) => { self.collapse_to_object(); } + Type::Dynamic(DynamicType::Divergent) => { + if !self + .elements + .iter() + .any(|elem| matches!(elem, UnionElement::Type(elem) if ty == *elem)) + { + self.elements.push(UnionElement::Type(ty)); + } + } _ => { let bool_pair = if let Type::BooleanLiteral(b) = ty { Some(Type::BooleanLiteral(!b)) @@ -428,6 +437,16 @@ impl<'db> UnionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 2]>::new(); let ty_negated = ty.negate(self.db); + if ty.has_divergent_type(self.db) + && !self + .elements + .iter() + .any(|elem| matches!(elem, UnionElement::Type(elem) if ty == *elem)) + { + self.elements.push(UnionElement::Type(ty)); + return; + } + for (index, element) in self.elements.iter_mut().enumerate() { let element_type = match element.try_reduce(self.db, ty) { ReduceResult::KeepIf(keep) => { @@ -758,7 +777,9 @@ impl<'db> InnerIntersectionBuilder<'db> { self.add_positive(db, Type::LiteralString); self.add_negative(db, Type::string_literal(db, "")); } - + Type::Dynamic(DynamicType::Divergent) => { + self.positive.insert(new_positive); + } _ => { let known_instance = new_positive .into_nominal_instance() @@ -827,6 +848,9 @@ impl<'db> InnerIntersectionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_positive) in self.positive.iter().enumerate() { + if new_positive.has_divergent_type(db) { + break; + } // S & T = S if S <: T if existing_positive.is_subtype_of(db, new_positive) || existing_positive.is_equivalent_to(db, new_positive) @@ -850,6 +874,9 @@ impl<'db> InnerIntersectionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_negative) in self.negative.iter().enumerate() { + if new_positive.has_divergent_type(db) { + break; + } // S & ~T = Never if S <: T if new_positive.is_subtype_of(db, *existing_negative) { *self = Self::default(); diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 8a359e3a15ec4..8ec3d000d2272 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -54,6 +54,7 @@ impl<'db> ClassBase<'db> { | DynamicType::TodoTypeAlias | DynamicType::TodoTypedDict, ) => "@Todo", + ClassBase::Dynamic(DynamicType::Divergent) => "Divergent", ClassBase::Protocol => "Protocol", ClassBase::Generic => "Generic", } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 01c82204d1d1c..0f757b84368fd 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -83,6 +83,19 @@ use crate::types::{ }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; +fn return_type_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: FunctionType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn return_type_cycle_initial<'db>(_db: &'db dyn Db, _self: FunctionType<'db>) -> Type<'db> { + Type::Dynamic(DynamicType::Divergent) +} + /// A collection of useful spans for annotating functions. /// /// This can be retrieved via `FunctionType::spans` or @@ -889,6 +902,7 @@ impl<'db> FunctionType<'db> { } /// Infers this function scope's types and returns the inferred return type. + #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.literal(db).last_definition(db).body_scope(db); let inference = infer_scope_types(db, scope); diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index caacf1a9d88d3..c8e2016d2415f 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -81,7 +81,7 @@ use crate::semantic_index::definition::{ }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::narrowing_constraints::ConstraintKey; -use crate::semantic_index::place::{PlaceExpr, PlaceExprRef}; +use crate::semantic_index::place::{PlaceExpr, PlaceExprRef, ScopedPlaceId}; use crate::semantic_index::scope::{ FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind, }; @@ -159,6 +159,18 @@ fn scope_cycle_initial<'db>(_db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInfer ScopeInference::cycle_fallback(scope) } +#[salsa::tracked] +fn function_place<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Option { + if let NodeWithScopeKind::Function(func) = scope.node(db) { + let file = scope.file(db); + let index = semantic_index(db, file); + let module = parsed_module(db, file).load(db); + Some(index.expect_single_definition(func.node(&module)).place(db)) + } else { + None + } +} + /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a place use or public type of a place. #[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=get_size2::heap_size)] @@ -501,11 +513,23 @@ impl<'db> ScopeInference<'db> { "infer_return_type should only be called on a function body scope inference" ); }; + let div = Type::Dynamic(DynamicType::Divergent); for returnee in &extra.returnees { let ty = returnee.map_or(Type::none(db), |expression| { self.expression_type(expression) }); - union = union.add(ty); + // `Divergent` appearing in a union does not mean true divergence, so it can be removed. + if ty == div { + continue; + } else if ty.has_divergent_type(db) { + if let Type::Union(union_ty) = ty { + union = union.add(union_ty.filter(db, |ty| **ty != div)); + } else { + union = union.add(div); + } + } else { + union = union.add(ty); + } } let use_def = use_def_map(db, self.scope); if use_def.can_implicitly_return_none(db) { @@ -7317,6 +7341,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Non-todo Anys take precedence over Todos (as if we fix this `Todo` in the future, // the result would then become Any or Unknown, respectively). + (div @ Type::Dynamic(DynamicType::Divergent), _, _) + | (_, div @ Type::Dynamic(DynamicType::Divergent), _) => Some(div), (any @ Type::Dynamic(DynamicType::Any), _, _) | (_, any @ Type::Dynamic(DynamicType::Any), _) => Some(any), @@ -8208,7 +8234,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); match eq_result { - todo @ Type::Dynamic(DynamicType::Todo(_)) => return Ok(todo), + todo @ Type::Dynamic( + DynamicType::Todo(_) | DynamicType::Divergent, + ) => return Ok(todo), // It's okay to ignore errors here because Python doesn't call `__bool__` // for different union variants. Instead, this is just for us to // evaluate a possibly truthy value to `false` or `true`. @@ -8236,7 +8264,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); Ok(match eq_result { - todo @ Type::Dynamic(DynamicType::Todo(_)) => todo, + todo @ Type::Dynamic(DynamicType::Todo(_) | DynamicType::Divergent) => { + todo + } // It's okay to ignore errors here because Python doesn't call `__bool__` // for `is` and `is not` comparisons. This is an implementation detail // for how we determine the truthiness of a type. diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 3959b041e166f..cb4c8447aa7a2 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -347,6 +347,13 @@ impl<'db> ProtocolInstanceType<'db> { pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { self.inner.interface(db) } + + pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + self.inner + .interface(db) + .members(db) + .any(|member| member.ty().has_divergent_type(db)) + } } /// An enumeration of the two kinds of protocol types: those that originate from a class diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 6525be81819a5..49c61f41bea20 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -413,7 +413,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> { self.qualifiers } - fn ty(&self) -> Type<'db> { + pub(super) fn ty(&self) -> Type<'db> { match &self.kind { ProtocolMemberKind::Method(callable) => Type::Callable(*callable), ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property), diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 58b62af78947f..14dcf608ffce1 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -260,6 +260,9 @@ fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering (DynamicType::TodoTypedDict, _) => Ordering::Less, (_, DynamicType::TodoTypedDict) => Ordering::Greater, + + (DynamicType::Divergent, _) => Ordering::Less, + (_, DynamicType::Divergent) => Ordering::Greater, } } From 86f3c2d15f36cbc4b94cdfb1d0dc573d9094ba8c Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 7 Aug 2025 01:55:17 +0900 Subject: [PATCH 023/105] divergence check --- .../resources/corpus/divergent.py | 36 ++++++ crates/ty_python_semantic/src/types.rs | 50 ++++---- crates/ty_python_semantic/src/types/class.rs | 7 ++ .../ty_python_semantic/src/types/generics.rs | 7 ++ crates/ty_python_semantic/src/types/infer.rs | 113 +++++++++++++----- .../src/types/signatures.rs | 20 ++++ crates/ty_python_semantic/src/types/tuple.rs | 6 + 7 files changed, 181 insertions(+), 58 deletions(-) create mode 100644 crates/ty_python_semantic/resources/corpus/divergent.py diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py new file mode 100644 index 0000000000000..e120ae391d5f8 --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -0,0 +1,36 @@ +def f(cond: bool): + if cond: + result = () + result += (f(cond),) + return result + + return None + +reveal_type(f(True)) + +def f(cond: bool): + if cond: + result = () + result += (f(cond),) + return result + + return None + +def f(cond: bool): + result = None + if cond: + result = () + result += (f(cond),) + + return result + +reveal_type(f(True)) + +def f(cond: bool): + result = None + if cond: + result = [f(cond) for _ in range(1)] + + return result + +reveal_type(f(True)) \ No newline at end of file diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 053384b348ab7..836887dad20b2 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -54,7 +54,7 @@ pub use crate::types::ide_support::{ definitions_for_attribute, definitions_for_imported_symbol, definitions_for_keyword_argument, definitions_for_name, }; -use crate::types::infer::infer_unpack_types; +use crate::types::infer::{divergence_safe_todo, infer_unpack_types}; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature}; @@ -4742,8 +4742,10 @@ impl<'db> Type<'db> { match self { Type::Tuple(tuple_type) => return Ok(Cow::Borrowed(tuple_type.tuple(db))), Type::GenericAlias(alias) if alias.origin(db).is_tuple(db) => { - return Ok(Cow::Owned(TupleSpec::homogeneous(todo_type!( - "*tuple[] annotations" + return Ok(Cow::Owned(TupleSpec::homogeneous(divergence_safe_todo( + db, + "*tuple[] annotations", + [self], )))); } Type::StringLiteral(string_literal_ty) => { @@ -5177,7 +5179,11 @@ impl<'db> Type<'db> { typevar.kind(db), ))) } - Type::Intersection(_) => Some(todo_type!("Type::Intersection.to_instance")), + Type::Intersection(_) => Some(divergence_safe_todo( + db, + "Type::Intersection.to_instance", + [*self], + )), Type::BooleanLiteral(_) | Type::BytesLiteral(_) | Type::EnumLiteral(_) @@ -5561,10 +5567,10 @@ impl<'db> Type<'db> { .unwrap_or(SubclassOfInner::unknown()), ), }, - Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class_literal(db), Type::Dynamic(dynamic) => SubclassOfType::from(db, SubclassOfInner::Dynamic(*dynamic)), // TODO intersections + // TODO divergence safety Type::Intersection(_) => SubclassOfType::from( db, SubclassOfInner::try_from_type(db, todo_type!("Intersection meta-type")) @@ -6056,30 +6062,10 @@ impl<'db> Type<'db> { .iter() .any(|ty| ty.has_divergent_type(db)) } - Type::Tuple(tuple) => tuple - .tuple(db) - .all_elements() - .any(|ty| ty.has_divergent_type(db)), - Type::GenericAlias(alias) => alias - .specialization(db) - .types(db) - .iter() - .any(|ty| ty.has_divergent_type(db)), - Type::NominalInstance(instance) => match instance.class { - ClassType::Generic(alias) => alias - .specialization(db) - .types(db) - .iter() - .any(|ty| ty.has_divergent_type(db)), - ClassType::NonGeneric(_) => false, - }, - Type::Callable(callable) => callable.signatures(db).iter().any(|sig| { - sig.parameters().iter().any(|param| { - param - .annotated_type() - .is_some_and(|ty| ty.has_divergent_type(db)) - }) || sig.return_ty.iter().any(|ty| ty.has_divergent_type(db)) - }), + Type::Tuple(tuple) => tuple.has_divergent_type(db), + Type::GenericAlias(alias) => alias.specialization(db).has_divergent_type(db), + Type::NominalInstance(instance) => instance.class.has_divergent_type(db), + Type::Callable(callable) => callable.has_divergent_type(db), Type::ProtocolInstance(protocol) => protocol.has_divergent_type(db), Type::PropertyInstance(property) => { property @@ -6093,7 +6079,7 @@ impl<'db> Type<'db> { Type::SubclassOf(subclass_of) => match subclass_of.subclass_of() { SubclassOfInner::Dynamic(DynamicType::Divergent) => true, SubclassOfInner::Dynamic(_) => false, - SubclassOfInner::Class(class) => class.metaclass(db).has_divergent_type(db), + SubclassOfInner::Class(class) => class.has_divergent_type(db), }, Type::Never | Type::AlwaysTruthy @@ -8149,6 +8135,10 @@ impl<'db> CallableType<'db> { .signatures(db) .is_equivalent_to(db, other.signatures(db)) } + + fn has_divergent_type(self, db: &'db dyn Db) -> bool { + self.signatures(db).has_divergent_type(db) + } } /// Represents a specific instance of `types.MethodWrapperType` diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index f3b963943414c..b670d59926bc4 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1014,6 +1014,13 @@ impl<'db> ClassType<'db> { } } } + + pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + match self { + ClassType::NonGeneric(_) => false, + ClassType::Generic(generic) => generic.specialization(db).has_divergent_type(db), + } + } } impl<'db> From> for ClassType<'db> { diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 56f5b2e99bbaf..1c10eecbba8f0 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -633,6 +633,13 @@ impl<'db> Specialization<'db> { ty.find_legacy_typevars(db, typevars); } } + + pub(crate) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + self.types(db).iter().any(|ty| ty.has_divergent_type(db)) + || self + .tuple_inner(db) + .is_some_and(|tuple| tuple.has_divergent_type(db)) + } } /// A mapping between type variables and types. diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index c8e2016d2415f..b61aa98c7f56e 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -129,6 +129,34 @@ use crate::util::diagnostics::format_enumeration; use crate::util::subscript::{PyIndex, PySlice}; use crate::{Db, FxOrderSet, Program}; +pub(crate) fn divergence_safe_todo<'db>( + db: &'db dyn Db, + msg: &'static str, + types: impl IntoIterator>, +) -> Type<'db> { + let _ = msg; + let mut builder = IntersectionBuilder::new(db).add_positive(todo_type!(msg)); + for ty in types { + if ty.has_divergent_type(db) { + builder = builder.add_positive(ty); + } + } + builder.build() +} + +fn divergence_safe_unknown<'db>( + db: &'db dyn Db, + types: impl IntoIterator>, +) -> Type<'db> { + let mut builder = IntersectionBuilder::new(db).add_positive(Type::unknown()); + for ty in types { + if ty.has_divergent_type(db) { + builder = builder.add_positive(ty); + } + } + builder.build() +} + /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. @@ -523,7 +551,12 @@ impl<'db> ScopeInference<'db> { continue; } else if ty.has_divergent_type(db) { if let Type::Union(union_ty) = ty { - union = union.add(union_ty.filter(db, |ty| **ty != div)); + let union_ty = union_ty.filter(db, |ty| **ty != div); + if union_ty.has_divergent_type(db) { + union = union.add(div); + } else { + union = union.add(union_ty); + } } else { union = union.add(div); } @@ -531,6 +564,9 @@ impl<'db> ScopeInference<'db> { union = union.add(ty); } } + if self.is_cycle_callback() { + union = union.add(div); + } let use_def = use_def_map(db, self.scope); if use_def.can_implicitly_return_none(db) { union = union.add(Type::none(db)); @@ -3364,7 +3400,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(node) = node { report_invalid_exception_caught(&self.context, node, element); } - Type::unknown() + divergence_safe_unknown(self.db(), [element]) }, ); } @@ -3395,7 +3431,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(node) = node { report_invalid_exception_caught(&self.context, node, node_ty); } - Type::unknown() + divergence_safe_unknown(self.db(), [node_ty]) }; if is_star { @@ -4686,7 +4722,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_binary_expression_type(assignment.into(), false, target_type, value_type, op) .unwrap_or_else(|| { report_unsupported_augmented_op(&mut self.context); - Type::unknown() + divergence_safe_unknown(self.db(), [target_type, value_type]) }) }; @@ -7293,7 +7329,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { add_inferred_python_version_hint_to_diagnostic(db, &mut diag, "resolving types"); } } - Type::unknown() + divergence_safe_unknown(db, [left_ty, right_ty]) }) } @@ -7804,7 +7840,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | ast::CmpOp::Is | ast::CmpOp::IsNot => KnownClass::Bool.to_instance(builder.db()), // Other operators can return arbitrary types - _ => Type::unknown(), + _ => divergence_safe_unknown(builder.db(), [left_ty, right_ty]), } }); @@ -8438,7 +8474,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // TODO: Consider comparing the prefixes of the tuples, since that could give a comparison // result regardless of how long the variable-length tuple is. let (TupleSpec::Fixed(left), TupleSpec::Fixed(right)) = (left, right) else { - return Ok(Type::unknown()); + return Ok(divergence_safe_unknown( + self.db(), + left.all_elements().chain(right.all_elements()).copied(), + )); }; let left_iter = left.elements().copied(); @@ -8676,9 +8715,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // but we need to make sure we avoid emitting a diagnostic if one positive element has a `__getitem__` // method but another does not. This means `infer_subscript_expression_types` // needs to return a `Result` rather than eagerly emitting diagnostics. - (Type::Intersection(_), _) => { - Some(todo_type!("Subscript expressions on intersections")) - } + (Type::Intersection(_), _) => Some(divergence_safe_todo( + db, + "Subscript expressions on intersections", + [value_ty, slice_ty], + )), // Ex) Given `("a", "b", "c", "d")[1]`, return `"b"` (Type::Tuple(tuple_ty), Type::IntLiteral(i64_int)) => { @@ -8693,7 +8734,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { tuple.len().display_minimum(), i64_int, ); - Type::unknown() + divergence_safe_unknown(db, [value_ty]) }) }) } @@ -8704,14 +8745,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .slice_literal(db) .map(|SliceLiteral { start, stop, step }| { let TupleSpec::Fixed(tuple) = tuple_ty.tuple(db) else { - return todo_type!("slice into variable-length tuple"); + return divergence_safe_todo( + db, + "slice into variable-length tuple", + [value_ty, slice_ty], + ); }; if let Ok(new_elements) = tuple.py_slice(db, start, stop, step) { TupleType::from_elements(db, new_elements) } else { report_slice_step_size_zero(context, value_node.into()); - Type::unknown() + divergence_safe_unknown(self.db(), [value_ty, slice_ty]) } }) } @@ -8752,7 +8797,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::string_literal(db, &literal) } else { report_slice_step_size_zero(context, value_node.into()); - Type::unknown() + divergence_safe_unknown(self.db(), [slice_ty]) } }), @@ -8791,7 +8836,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::bytes_literal(db, &new_bytes) } else { report_slice_step_size_zero(context, value_node.into()); - Type::unknown() + divergence_safe_unknown(self.db(), [slice_ty]) } }), @@ -8817,9 +8862,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .map(|context| { Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(context)) }) - .unwrap_or_else(Type::unknown), + .unwrap_or_else(|| divergence_safe_unknown(self.db(), [slice_ty])), // TODO: emit a diagnostic - TupleSpec::Variable(_) => Type::unknown(), + TupleSpec::Variable(_) => divergence_safe_unknown(self.db(), [slice_ty]), }) } @@ -8830,12 +8875,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { LegacyGenericBase::Protocol, ) .map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(context))) - .unwrap_or_else(Type::unknown), + .unwrap_or_else(|| divergence_safe_unknown(self.db(), [slice_ty])), ), (Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(_)), _) => { // TODO: emit a diagnostic - Some(todo_type!("doubly-specialized typing.Protocol")) + Some(divergence_safe_todo( + db, + "doubly-specialized typing.Protocol", + [value_ty, slice_ty], + )) } (Type::SpecialForm(SpecialFormType::Generic), Type::Tuple(typevars)) => { @@ -8849,9 +8898,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .map(|context| { Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(context)) }) - .unwrap_or_else(Type::unknown), + .unwrap_or_else(|| divergence_safe_unknown(self.db(), [slice_ty])), // TODO: emit a diagnostic - TupleSpec::Variable(_) => Type::unknown(), + TupleSpec::Variable(_) => divergence_safe_unknown(self.db(), [slice_ty]), }) } @@ -8862,22 +8911,30 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { LegacyGenericBase::Generic, ) .map(|context| Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(context))) - .unwrap_or_else(Type::unknown), + .unwrap_or_else(|| divergence_safe_unknown(self.db(), [slice_ty])), ), (Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(_)), _) => { // TODO: emit a diagnostic - Some(todo_type!("doubly-specialized typing.Generic")) + Some(divergence_safe_todo( + db, + "doubly-specialized typing.Generic", + [value_ty, slice_ty], + )) } - (Type::SpecialForm(special_form), _) if special_form.class().is_special_form() => { - Some(todo_type!("Inference of subscript on special form")) - } + (Type::SpecialForm(special_form), _) if special_form.class().is_special_form() => Some( + divergence_safe_todo(db, "Inference of subscript on special form", [slice_ty]), + ), (Type::KnownInstance(known_instance), _) if known_instance.class().is_special_form() => { - Some(todo_type!("Inference of subscript on special form")) + Some(divergence_safe_todo( + db, + "Inference of subscript on special form", + [slice_ty], + )) } _ => None, @@ -9023,7 +9080,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - Type::unknown() + divergence_safe_unknown(self.db(), [value_ty, slice_ty]) } fn legacy_generic_class_context( diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index eef3ac6781f0c..4d264056ef130 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -201,6 +201,12 @@ impl<'db> CallableSignature<'db> { } } } + + pub(super) fn has_divergent_type(&self, db: &'db dyn Db) -> bool { + self.overloads + .iter() + .any(|signature| signature.has_divergent_type(db)) + } } impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> { @@ -928,6 +934,12 @@ impl<'db> Signature<'db> { pub(crate) fn with_definition(self, definition: Option>) -> Self { Self { definition, ..self } } + + fn has_divergent_type(&self, db: &'db dyn Db) -> bool { + self.return_ty + .is_some_and(|return_ty| return_ty.has_divergent_type(db)) + || self.parameters.has_divergent_type(db) + } } // Manual implementations of PartialEq, Eq, and Hash that exclude the definition field @@ -1226,6 +1238,14 @@ impl<'db> Parameters<'db> { .enumerate() .rfind(|(_, parameter)| parameter.is_keyword_variadic()) } + + fn has_divergent_type(&self, db: &'db dyn Db) -> bool { + self.iter().any(|parameter| { + parameter + .annotated_type() + .is_some_and(|ty| ty.has_divergent_type(db)) + }) + } } impl<'db, 'a> IntoIterator for &'a Parameters<'db> { diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 50c8254e06456..5a666dfd55499 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -289,6 +289,12 @@ impl<'db> TupleType<'db> { pub(crate) fn truthiness(self, db: &'db dyn Db) -> Truthiness { self.tuple(db).truthiness() } + + pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + self.tuple(db) + .all_elements() + .any(|ty| ty.has_divergent_type(db)) + } } /// A tuple spec describes the contents of a tuple type, which might be fixed- or variable-length. From 38fcc4154227462c48bb9ebefc7a40909dd33341 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 7 Aug 2025 03:29:01 +0900 Subject: [PATCH 024/105] Update ty_check_invalid_syntax.rs --- fuzz/fuzz_targets/ty_check_invalid_syntax.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/fuzz/fuzz_targets/ty_check_invalid_syntax.rs b/fuzz/fuzz_targets/ty_check_invalid_syntax.rs index ed6737b539c0c..9a6c6eb1f6b5c 100644 --- a/fuzz/fuzz_targets/ty_check_invalid_syntax.rs +++ b/fuzz/fuzz_targets/ty_check_invalid_syntax.rs @@ -18,7 +18,7 @@ use ruff_python_parser::{Mode, ParseOptions, parse_unchecked}; use ty_python_semantic::lint::LintRegistry; use ty_python_semantic::types::check_types; use ty_python_semantic::{ - CallStack, Db as SemanticDb, Program, ProgramSettings, PythonPlatform, PythonVersionWithSource, + Db as SemanticDb, Program, ProgramSettings, PythonPlatform, PythonVersionWithSource, SearchPathSettings, default_lint_registry, lint::RuleSelection, }; @@ -33,7 +33,6 @@ struct TestDb { system: TestSystem, vendored: VendoredFileSystem, rule_selection: Arc, - call_stack: CallStack, } impl TestDb { @@ -48,7 +47,6 @@ impl TestDb { vendored: ty_vendored::file_system().clone(), files: Files::default(), rule_selection: RuleSelection::from_registry(default_lint_registry()).into(), - call_stack: CallStack::default(), } } } @@ -95,10 +93,6 @@ impl SemanticDb for TestDb { fn lint_registry(&self) -> &LintRegistry { default_lint_registry() } - - fn call_stack(&self) -> &CallStack { - &self.call_stack - } } #[salsa::db] From 893dd2b404179e8e01bd5350795a5cd1d035be1b Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 29 Aug 2025 17:14:29 +0900 Subject: [PATCH 025/105] follow changes in the main branch --- .../resources/mdtest/annotations/invalid.md | 4 +++- .../diagnostics/semantic_syntax_errors.md | 5 ++++- .../resources/mdtest/function/return_type.md | 2 +- crates/ty_python_semantic/src/types/infer.rs | 20 ++++++++++++++++++- crates/ty_python_semantic/src/types/narrow.rs | 8 ++------ 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md index 91d55f7352684..c4b475da1b2ac 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md @@ -74,7 +74,9 @@ def _( def bar() -> None: return None -async def baz(): ... +async def baz() -> int: + return 42 + async def outer(): # avoid unrelated syntax errors on yield, yield from, and await def _( a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression" diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md index bb722884f32f2..48c31d78300a0 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md @@ -131,7 +131,8 @@ match obj: ```py class C: - def __await__(self): ... + def __await__(self): + yield # error: [invalid-syntax] "`return` statement outside of a function" return @@ -147,6 +148,8 @@ yield from [] await C() def f(): + # TODO: no error, C is awaitable + # error: [invalid-await] "`C` is not awaitable" # error: [invalid-syntax] "`await` outside of an asynchronous function" await C() ``` diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 1712d82b26ac8..fd1626e3a16f9 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -462,7 +462,7 @@ reveal_type(D().h(1)) # revealed: Literal[2] | Unknown reveal_type(C().h(True)) # revealed: Literal[True] reveal_type(D().h(True)) # revealed: Literal[2] | Unknown reveal_type(C().i(1)) # revealed: list[Literal[1]] -reveal_type(D().i(1)) # revealed: list[Unknown] +reveal_type(D().i(1)) # revealed: list[@Todo(list literal element type)] class F: def f(self) -> Literal[1, 2]: diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 2e93c2a31e10e..2c80ad223e3d4 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -639,7 +639,25 @@ impl<'db> ScopeInference<'db> { } } } - union.build() + + let module = parsed_module(db, self.scope.file(db)).load(db); + if self + .scope + .node(db) + .as_function(&module) + .is_some_and(|func| { + let index = semantic_index(db, self.scope.file(db)); + let is_generator = self.scope.file_scope_id(db).is_generator_function(index); + + func.is_async && !is_generator + }) + { + // TODO: yield/await type inference + KnownClass::CoroutineType + .to_specialized_instance(db, [Type::any(), Type::any(), union.build()]) + } else { + union.build() + } } } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 160e2ef19353f..98596b2d84721 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -762,12 +762,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { }; let rhs_ty = inference.expression_type(right); - let rhs_class = match rhs_ty { - Type::ClassLiteral(class) => class, - Type::GenericAlias(alias) => alias.origin(self.db), - _ => { - continue; - } + let Type::ClassLiteral(rhs_class) = rhs_ty else { + continue; }; // `else`-branch narrowing for `if type(x) is Y` can only be done From d1088545a08aeb57b67ec1e3a7f5141159efefa5 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 30 Aug 2025 12:24:50 +0900 Subject: [PATCH 026/105] Update infer.rs --- crates/ty_python_semantic/src/types/infer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index f247450d5d809..3ddf2d1e1309a 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -6426,7 +6426,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Special handling for `TypedDict` method calls if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() { - let value_type = self.expression_type(value); + let value_type = self.expression_type(value.as_ref()); if let Type::TypedDict(typed_dict_ty) = value_type { if matches!(attr.id.as_str(), "pop" | "setdefault") && !arguments.args.is_empty() { // Validate the key argument for `TypedDict` methods From 83cd6ae991d5f361ef2a9195a8bc56a7525ce6d8 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 30 Aug 2025 12:58:05 +0900 Subject: [PATCH 027/105] fix for fuzzer-reported panic --- crates/ty_python_semantic/src/types/class.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 47ea15ba83acb..b6d50dcb2d705 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -175,6 +175,19 @@ fn try_metaclass_cycle_initial<'db>( }) } +fn into_callable_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: ClassType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn into_callable_cycle_initial<'db>(_db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { + Type::Never +} + /// A category of classes with code generation capabilities (with synthesized methods). #[derive(Clone, Copy, Debug, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) enum CodeGeneratorKind { @@ -1062,7 +1075,7 @@ impl<'db> ClassType<'db> { /// Return a callable type (or union of callable types) that represents the callable /// constructor signature of this class. - #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked(cycle_fn=into_callable_cycle_recover, cycle_initial=into_callable_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(super) fn into_callable(self, db: &'db dyn Db) -> Type<'db> { let self_ty = Type::from(self); let metaclass_dunder_call_function_symbol = self_ty From c7eeffc19e10ec3f4d617a592c676369f79622b6 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 31 Aug 2025 19:08:51 +0900 Subject: [PATCH 028/105] use `CycleDetector` with `has_divergent_type()` --- crates/ty_python_semantic/src/types.rs | 134 +++++++++++------- crates/ty_python_semantic/src/types/class.rs | 20 ++- .../ty_python_semantic/src/types/generics.rs | 21 ++- .../ty_python_semantic/src/types/instance.rs | 14 +- .../src/types/signatures.rs | 31 ++-- .../src/types/subclass_of.rs | 18 ++- crates/ty_python_semantic/src/types/tuple.rs | 10 +- 7 files changed, 163 insertions(+), 85 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9e95e01160bcc..37e1167c39aea 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -118,21 +118,6 @@ fn return_type_cycle_initial<'db>(_db: &'db dyn Db, _self: BoundMethodType<'db>) Type::Dynamic(DynamicType::Divergent) } -#[allow(clippy::trivially_copy_pass_by_ref)] -fn has_divergent_type_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &bool, - _count: u32, - _self: Type<'db>, - _unit: (), -) -> salsa::CycleRecoveryAction { - salsa::CycleRecoveryAction::Iterate -} - -fn has_divergent_type_cycle_initial<'db>(_db: &'db dyn Db, _slf: Type<'db>, _unit: ()) -> bool { - false -} - pub fn check_types(db: &dyn Db, file: File) -> Vec { let _span = tracing::trace_span!("check_types", ?file).entered(); @@ -227,6 +212,9 @@ pub(crate) struct FindLegacyTypeVars; pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>; pub(crate) struct Normalized; +pub(crate) type HasDivergentTypeVisitor<'db> = CycleDetector, bool>; +pub(crate) struct HasDivergentType; + /// How a generic type has been specialized. /// /// This matters only if there is at least one invariant type parameter. @@ -564,6 +552,18 @@ impl<'db> PropertyInstanceType<'db> { .map(|ty| ty.materialize(db, materialization_kind)), ) } + + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.setter(db) + .is_some_and(|setter| setter.has_divergent_type_impl(db, visitor)) + || self + .getter(db) + .is_some_and(|getter| getter.has_divergent_type_impl(db, visitor)) + } } bitflags! { @@ -6572,47 +6572,51 @@ impl<'db> Type<'db> { } pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { - self._has_divergent_type(db, ()) + let visitor = HasDivergentTypeVisitor::new(false); + self.has_divergent_type_impl(db, &visitor) } - #[allow(clippy::used_underscore_binding)] - #[salsa::tracked(cycle_fn=has_divergent_type_cycle_recover, cycle_initial=has_divergent_type_cycle_initial, heap_size=get_size2::heap_size)] - fn _has_divergent_type(self, db: &'db dyn Db, _unit: ()) -> bool { + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { match self { Type::Dynamic(DynamicType::Divergent) => true, - Type::Union(union) => union - .elements(db) - .iter() - .any(|ty| ty.has_divergent_type(db)), + Type::Union(union) => { + visitor.visit(self, || union.has_divergent_type_impl(db, visitor)) + } Type::Intersection(intersection) => { - intersection - .positive(db) - .iter() - .any(|ty| ty.has_divergent_type(db)) - || intersection - .negative(db) - .iter() - .any(|ty| ty.has_divergent_type(db)) + visitor.visit(self, || intersection.has_divergent_type_impl(db, visitor)) + } + Type::GenericAlias(alias) => visitor.visit(self, || { + alias + .specialization(db) + .has_divergent_type_impl(db, visitor) + }), + Type::NominalInstance(instance) => visitor.visit(self, || { + instance.class(db).has_divergent_type_impl(db, visitor) + }), + Type::Callable(callable) => { + visitor.visit(self, || callable.has_divergent_type_impl(db, visitor)) + } + Type::ProtocolInstance(protocol) => { + visitor.visit(self, || protocol.has_divergent_type_impl(db, visitor)) } - Type::GenericAlias(alias) => alias.specialization(db).has_divergent_type(db), - Type::NominalInstance(instance) => instance.class(db).has_divergent_type(db), - Type::Callable(callable) => callable.has_divergent_type(db), - Type::ProtocolInstance(protocol) => protocol.has_divergent_type(db), Type::PropertyInstance(property) => { - property - .setter(db) - .is_some_and(|setter| setter.has_divergent_type(db)) - || property - .getter(db) - .is_some_and(|getter| getter.has_divergent_type(db)) - } - Type::TypeIs(type_is) => type_is.return_type(db).has_divergent_type(db), - Type::SubclassOf(subclass_of) => match subclass_of.subclass_of() { - SubclassOfInner::Dynamic(DynamicType::Divergent) => true, - SubclassOfInner::Dynamic(_) => false, - SubclassOfInner::Class(class) => class.has_divergent_type(db), - }, - Type::TypedDict(typed_dict) => typed_dict.defining_class().has_divergent_type(db), + visitor.visit(self, || property.has_divergent_type_impl(db, visitor)) + } + Type::TypeIs(type_is) => visitor.visit(self, || { + type_is.return_type(db).has_divergent_type_impl(db, visitor) + }), + Type::SubclassOf(subclass_of) => { + visitor.visit(self, || subclass_of.has_divergent_type_impl(db, visitor)) + } + Type::TypedDict(typed_dict) => visitor.visit(self, || { + typed_dict + .defining_class() + .has_divergent_type_impl(db, visitor) + }), Type::Never | Type::AlwaysTruthy | Type::AlwaysFalsy @@ -9245,8 +9249,12 @@ impl<'db> CallableType<'db> { }) } - fn has_divergent_type(self, db: &'db dyn Db) -> bool { - self.signatures(db).has_divergent_type(db) + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.signatures(db).has_divergent_type_impl(db, visitor) } } @@ -9919,6 +9927,16 @@ impl<'db> UnionType<'db> { C::from_bool(db, sorted_self == other.normalized(db)) } + + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.elements(db) + .iter() + .any(|ty| ty.has_divergent_type_impl(db, visitor)) + } } #[salsa::interned(debug, heap_size=IntersectionType::heap_size)] @@ -10134,6 +10152,20 @@ impl<'db> IntersectionType<'db> { ruff_memory_usage::order_set_heap_size(positive) + ruff_memory_usage::order_set_heap_size(negative) } + + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.positive(db) + .iter() + .any(|ty| ty.has_divergent_type_impl(db, visitor)) + || self + .negative(db) + .iter() + .any(|ty| ty.has_divergent_type_impl(db, visitor)) + } } /// # Ordering diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index b6d50dcb2d705..c9a94893538c4 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -30,11 +30,11 @@ use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, - DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, - NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping, - TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, - UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, + DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, + HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, + MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, + TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, + TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -1226,10 +1226,16 @@ impl<'db> ClassType<'db> { } } - pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + pub(super) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { match self { ClassType::NonGeneric(_) => false, - ClassType::Generic(generic) => generic.specialization(db).has_divergent_type(db), + ClassType::Generic(generic) => generic + .specialization(db) + .has_divergent_type_impl(db, visitor), } } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 26c778b9d127f..04b6f655e9288 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -15,10 +15,11 @@ use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::{ - ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, - Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, - UnionType, binding_type, declaration_type, + ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, + HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, + KnownInstanceType, MaterializationKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, + TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, UnionType, binding_type, + declaration_type, }; use crate::{Db, FxOrderSet}; @@ -873,11 +874,17 @@ impl<'db> Specialization<'db> { // look in `self.tuple`. } - pub(crate) fn has_divergent_type(self, db: &'db dyn Db) -> bool { - self.types(db).iter().any(|ty| ty.has_divergent_type(db)) + pub(crate) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.types(db) + .iter() + .any(|ty| ty.has_divergent_type_impl(db, visitor)) || self .tuple_inner(db) - .is_some_and(|tuple| tuple.has_divergent_type(db)) + .is_some_and(|tuple| tuple.has_divergent_type_impl(db, visitor)) } } diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 42ad9a7a7b664..c5e4067c1f5ac 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -12,9 +12,9 @@ use crate::types::enums::is_single_member_enum; use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ - ApplyTypeMappingVisitor, ClassBase, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsDisjointVisitor, IsEquivalentVisitor, MaterializationKind, NormalizedVisitor, TypeMapping, - TypeRelation, VarianceInferable, + ApplyTypeMappingVisitor, ClassBase, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, + HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, MaterializationKind, + NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -639,11 +639,15 @@ impl<'db> ProtocolInstanceType<'db> { self.inner.interface(db) } - pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + pub(super) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { self.inner .interface(db) .members(db) - .any(|member| member.ty().has_divergent_type(db)) + .any(|member| member.ty().has_divergent_type_impl(db, visitor)) } } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 2b41f0079cdda..758f4c4f8dea4 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -21,8 +21,9 @@ use crate::types::constraints::{ConstraintSet, Constraints, IteratorConstraintsE use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, - HasRelationToVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, NormalizedVisitor, - TypeMapping, TypeRelation, VarianceInferable, todo_type, + HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, + MaterializationKind, NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, + todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -235,10 +236,14 @@ impl<'db> CallableSignature<'db> { } } - pub(super) fn has_divergent_type(&self, db: &'db dyn Db) -> bool { + pub(super) fn has_divergent_type_impl( + &self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { self.overloads .iter() - .any(|signature| signature.has_divergent_type(db)) + .any(|signature| signature.has_divergent_type_impl(db, visitor)) } } @@ -1023,10 +1028,14 @@ impl<'db> Signature<'db> { Self { definition, ..self } } - fn has_divergent_type(&self, db: &'db dyn Db) -> bool { + fn has_divergent_type_impl( + &self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { self.return_ty - .is_some_and(|return_ty| return_ty.has_divergent_type(db)) - || self.parameters.has_divergent_type(db) + .is_some_and(|return_ty| return_ty.has_divergent_type_impl(db, visitor)) + || self.parameters.has_divergent_type_impl(db, visitor) } } @@ -1332,11 +1341,15 @@ impl<'db> Parameters<'db> { .rfind(|(_, parameter)| parameter.is_keyword_variadic()) } - fn has_divergent_type(&self, db: &'db dyn Db) -> bool { + fn has_divergent_type_impl( + &self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { self.iter().any(|parameter| { parameter .annotated_type() - .is_some_and(|ty| ty.has_divergent_type(db)) + .is_some_and(|ty| ty.has_divergent_type_impl(db, visitor)) }) } } diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 43f01dee2201e..33bafcca8dd4d 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -4,9 +4,9 @@ use crate::types::constraints::Constraints; use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, DynamicType, - FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, KnownClass, - MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, TypeMapping, - TypeRelation, + FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsDisjointVisitor, + KnownClass, MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, + TypeMapping, TypeRelation, }; use crate::{Db, FxOrderSet}; @@ -201,6 +201,18 @@ impl<'db> SubclassOfType<'db> { .into_class() .is_some_and(|class| class.class_literal(db).0.is_typed_dict(db)) } + + pub(super) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + match self.subclass_of { + SubclassOfInner::Dynamic(DynamicType::Divergent) => true, + SubclassOfInner::Dynamic(_) => false, + SubclassOfInner::Class(class) => class.has_divergent_type_impl(db, visitor), + } + } } impl<'db> VarianceInferable<'db> for SubclassOfType<'db> { diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index f3ecbbe8195b1..15fa14d0b48f1 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -22,7 +22,6 @@ use std::hash::Hash; use itertools::{Either, EitherOrBoth, Itertools}; use crate::semantic_index::definition::Definition; -use crate::types::Truthiness; use crate::types::class::{ClassType, KnownClass}; use crate::types::constraints::{Constraints, IteratorConstraintsExtension}; use crate::types::{ @@ -30,6 +29,7 @@ use crate::types::{ IsDisjointVisitor, IsEquivalentVisitor, MaterializationKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, UnionBuilder, UnionType, }; +use crate::types::{HasDivergentTypeVisitor, Truthiness}; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; use crate::{Db, FxOrderSet, Program}; @@ -286,10 +286,14 @@ impl<'db> TupleType<'db> { self.tuple(db).is_single_valued(db) } - pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + pub(super) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { self.tuple(db) .all_elements() - .any(|ty| ty.has_divergent_type(db)) + .any(|ty| ty.has_divergent_type_impl(db, visitor)) } } From f981c8487646dd53a89297621e3611433874b9e9 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 1 Sep 2025 17:24:56 +0900 Subject: [PATCH 029/105] the inferred return type should be monotonically widened in fixed-point iteration --- .../resources/mdtest/function/return_type.md | 37 ++++++-- .../src/semantic_index/scope.rs | 15 +++- crates/ty_python_semantic/src/types.rs | 2 +- .../ty_python_semantic/src/types/builder.rs | 2 +- .../ty_python_semantic/src/types/function.rs | 2 +- crates/ty_python_semantic/src/types/infer.rs | 90 +++++++++---------- 6 files changed, 86 insertions(+), 62 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 2f8efb02d7b28..b26c1eb0ef2b6 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -284,7 +284,7 @@ def h(x: int, y: str): elif x > 5: return y -reveal_type(h(1, "a")) # revealed: int | str | None +reveal_type(h(1, "a")) # revealed: int | None | str def generator(): yield 1 @@ -339,7 +339,7 @@ def divergent(value): return None # tuple[tuple[tuple[...] | None] | None] | None => tuple[Unknown] | None -reveal_type(divergent((1,))) # revealed: Divergent | None +reveal_type(divergent((1,))) # revealed: None | Divergent def call_divergent(x: int): return (divergent((1, 2, 3)), x) @@ -360,7 +360,26 @@ def eager_nested_scope(): return A.x -reveal_type(eager_nested_scope()) # revealed: Unknown +reveal_type(eager_nested_scope()) # revealed: Any + +class C: + def flip(self) -> "D": + return D() + +class D(C): + def flip(self) -> "C": + return C() + +def c_or_d(n: int): + if n == 0: + return D() + else: + return c_or_d(n - 1).flip() + +# In fixed-point iteration of the return type inference, the return type is monotonically widened. +# For example, once the return type of `c_or_d` is determined to be `C`, +# it will never be determined to be a subtype `D` in the subsequent iterations. +reveal_type(c_or_d(1)) # revealed: C ``` ### Class method @@ -380,8 +399,8 @@ class D(C): def f(self): return None -reveal_type(C().f()) # revealed: Literal[1] | Unknown -reveal_type(D().f()) # revealed: None | Literal[1] | Unknown +reveal_type(C().f()) # revealed: Literal[1] | Any +reveal_type(D().f()) # revealed: Literal[1] | None | Any ``` However, in the following cases, `Unknown` is not included in the inferred return type because there @@ -456,13 +475,13 @@ reveal_type(C().f()) # revealed: int reveal_type(D().f()) # revealed: int reveal_type(E().f()) # revealed: int reveal_type(C().g(1)) # revealed: Literal[1] -reveal_type(D().g(1)) # revealed: Literal[2] | Unknown +reveal_type(D().g(1)) # revealed: Literal[2] | Any reveal_type(C().h(1)) # revealed: Literal[1] -reveal_type(D().h(1)) # revealed: Literal[2] | Unknown +reveal_type(D().h(1)) # revealed: Literal[2] | Any reveal_type(C().h(True)) # revealed: Literal[True] -reveal_type(D().h(True)) # revealed: Literal[2] | Unknown +reveal_type(D().h(True)) # revealed: Literal[2] | Any reveal_type(C().i(1)) # revealed: list[Literal[1]] -reveal_type(D().i(1)) # revealed: list[@Todo(list literal element type)] +reveal_type(D().i(1)) # revealed: list[Any] class F: def f(self) -> Literal[1, 2]: diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index 5341c1f7e7348..5cf6c992d2c1b 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -1,6 +1,9 @@ use std::ops::Range; -use ruff_db::{files::File, parsed::ParsedModuleRef}; +use ruff_db::{ + files::File, + parsed::{ParsedModuleRef, parsed_module}, +}; use ruff_index::newtype_index; use ruff_python_ast as ast; @@ -70,6 +73,16 @@ impl<'db> ScopeId<'db> { NodeWithScopeKind::GeneratorExpression(_) => "", } } + + pub(crate) fn is_async_function(self, db: &'db dyn Db) -> bool { + let module = parsed_module(db, self.file(db)).load(db); + self.node(db).as_function(&module).is_some_and(|func| { + let index = semantic_index(db, self.file(db)); + let is_generator = self.file_scope_id(db).is_generator_function(index); + + func.is_async && !is_generator + }) + } } /// ID that uniquely identifies a scope inside of a module. diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 37e1167c39aea..b37ecace7f3c9 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -8984,7 +8984,7 @@ impl<'db> BoundMethodType<'db> { .last_definition(db) .body_scope(db); let inference = infer_scope_types(db, scope); - inference.infer_return_type(db, Some(self)) + inference.infer_return_type(db, Type::BoundMethod(self)) } #[salsa::tracked] diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 2ea753129610d..943055da2ad71 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -206,7 +206,7 @@ enum ReduceResult<'db> { // // For now (until we solve https://github.com/astral-sh/ty/issues/957), keep this number // below 200, which is the salsa fixpoint iteration limit. -const MAX_UNION_LITERALS: usize = 199; +const MAX_UNION_LITERALS: usize = 198; pub(crate) struct UnionBuilder<'db> { elements: Vec>, diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index e3235e3f8f9e3..8ecd89505bdff 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -960,7 +960,7 @@ impl<'db> FunctionType<'db> { pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.literal(db).last_definition(db).body_scope(db); let inference = infer_scope_types(db, scope); - inference.infer_return_type(db, None) + inference.infer_return_type(db, Type::FunctionLiteral(self)) } } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 3ddf2d1e1309a..c5a2479c7157d 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -127,10 +127,10 @@ use crate::types::typed_dict::{ }; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - BoundMethodType, CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, - DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, - LintDiagnosticGuard, MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, - ParameterForm, Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType, + CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType, + IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, LintDiagnosticGuard, + MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, + Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, @@ -580,48 +580,57 @@ impl<'db> ScopeInference<'db> { /// or `None` if the region is not a function body. /// In the case of methods, the return type of the superclass method is further unioned. /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. - pub(crate) fn infer_return_type( - &self, - db: &'db dyn Db, - method_ty: Option>, - ) -> Type<'db> { + pub(crate) fn infer_return_type(&self, db: &'db dyn Db, callee_ty: Type<'db>) -> Type<'db> { + // TODO: async function type inference + if self.scope.is_async_function(db) { + return Type::unknown(); + } + let mut union = UnionBuilder::new(db); - let Some(extra) = &self.extra else { - unreachable!( - "infer_return_type should only be called on a function body scope inference" - ); - }; let div = Type::Dynamic(DynamicType::Divergent); - for returnee in &extra.returnees { - let ty = returnee.map_or(Type::none(db), |expression| { - self.expression_type(expression) - }); - // `Divergent` appearing in a union does not mean true divergence, so it can be removed. + if self.is_cycle_callback() { + union = union.add(div); + } + let mut union_add = |ty: Type<'db>| { + let temp = std::mem::replace(&mut union, UnionBuilder::new(db)); if ty == div { - continue; + // `Divergent` appearing in a union does not mean true divergence, so it can be removed. } else if ty.has_divergent_type(db) { if let Type::Union(union_ty) = ty { let union_ty = union_ty.filter(db, |ty| **ty != div); if union_ty.has_divergent_type(db) { - union = union.add(div); + union = temp.add(div); } else { - union = union.add(union_ty); + union = temp.add(union_ty); } } else { - union = union.add(div); + union = temp.add(div); } } else { - union = union.add(ty); + union = temp.add(ty); } - } - if self.is_cycle_callback() { - union = union.add(div); + }; + let previous_type = callee_ty.infer_return_type(db).unwrap(); + // In fixed-point iteration of return type inference, the return type must be monotonically widened and not "oscillate". + // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. + union_add(previous_type); + + let Some(extra) = &self.extra else { + unreachable!( + "infer_return_type should only be called on a function body scope inference" + ); + }; + for returnee in &extra.returnees { + let ty = returnee.map_or(Type::none(db), |expression| { + self.expression_type(expression) + }); + union_add(ty); } let use_def = use_def_map(db, self.scope); if use_def.can_implicitly_return_none(db) { - union = union.add(Type::none(db)); + union_add(Type::none(db)); } - if let Some(method_ty) = method_ty { + if let Type::BoundMethod(method_ty) = callee_ty { // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. @@ -631,34 +640,17 @@ impl<'db> ScopeInference<'db> { .unwrap_or((Signature::unknown(), Type::unknown())); if let Some(generic_context) = signature.generic_context.as_ref() { // If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables. - union = union.add( + union_add( return_ty .apply_specialization(db, generic_context.unknown_specialization(db)), ); } else { - union = union.add(return_ty); + union_add(return_ty); } } } - let module = parsed_module(db, self.scope.file(db)).load(db); - if self - .scope - .node(db) - .as_function(&module) - .is_some_and(|func| { - let index = semantic_index(db, self.scope.file(db)); - let is_generator = self.scope.file_scope_id(db).is_generator_function(index); - - func.is_async && !is_generator - }) - { - // TODO: yield/await type inference - KnownClass::CoroutineType - .to_specialized_instance(db, [Type::any(), Type::any(), union.build()]) - } else { - union.build() - } + union.build().normalized(db) } } From fc3ae6f4bc0a424116dde337bcf4fd93212a71cc Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 1 Sep 2025 17:36:03 +0900 Subject: [PATCH 030/105] Update types.rs --- crates/ty_python_semantic/src/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 8644ca13e698b..1ba92962a6586 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -46,7 +46,7 @@ use crate::types::diagnostic::{INVALID_AWAIT, INVALID_TYPE_FORM, UNSUPPORTED_BOO pub use crate::types::display::DisplaySettings; use crate::types::enums::{enum_metadata, is_single_member_enum}; use crate::types::function::{ - DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction, + DataclassTransformerParams, FunctionDecorators, FunctionSpans, FunctionType, KnownFunction, }; use crate::types::generics::{ GenericContext, PartialSpecialization, Specialization, bind_typevar, walk_generic_context, From 3c0034f1a89dcd6acbdf52ac5d4165e8b4e23032 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 1 Sep 2025 19:52:47 +0900 Subject: [PATCH 031/105] fix `union_add` bug --- .../resources/mdtest/function/return_type.md | 9 +++++++++ crates/ty_python_semantic/src/types/infer.rs | 9 ++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index b26c1eb0ef2b6..6651319788969 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -347,6 +347,15 @@ def call_divergent(x: int): # TODO: it would be better to reveal `tuple[Divergent | None, int]` reveal_type(call_divergent(1)) # revealed: Divergent +def get_non_empty(node): + for child in node.children: + node = get_non_empty(child) + if node is not None: + return node + return None + +reveal_type(get_non_empty(None)) # revealed: None | Divergent + def nested_scope(): def inner(): return nested_scope() diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index c5a2479c7157d..38e45801b3dfb 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -592,22 +592,21 @@ impl<'db> ScopeInference<'db> { union = union.add(div); } let mut union_add = |ty: Type<'db>| { - let temp = std::mem::replace(&mut union, UnionBuilder::new(db)); if ty == div { // `Divergent` appearing in a union does not mean true divergence, so it can be removed. } else if ty.has_divergent_type(db) { if let Type::Union(union_ty) = ty { let union_ty = union_ty.filter(db, |ty| **ty != div); if union_ty.has_divergent_type(db) { - union = temp.add(div); + union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(div); } else { - union = temp.add(union_ty); + union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(union_ty); } } else { - union = temp.add(div); + union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(div); } } else { - union = temp.add(ty); + union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(ty); } }; let previous_type = callee_ty.infer_return_type(db).unwrap(); From 0bc75d47755d0e8f336ad91d5286b39f41533277 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 1 Sep 2025 20:08:46 +0900 Subject: [PATCH 032/105] increase max number of `pandas` diagnostics --- crates/ruff_benchmark/benches/ty_walltime.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index 491915f1c0da2..1aaaa992c7f05 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -164,7 +164,7 @@ static PANDAS: std::sync::LazyLock> = std::sync::LazyLock::ne max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 3000, + 3300, ) }); From 90ec05ec1afa999a6cfb47cac0d42ac79577cbea Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 1 Sep 2025 20:44:16 +0900 Subject: [PATCH 033/105] No type normalization in `infer_return_type` --- .../resources/mdtest/function/return_type.md | 16 ++++++++-------- crates/ty_python_semantic/src/types/infer.rs | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 6651319788969..674deff962300 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -284,7 +284,7 @@ def h(x: int, y: str): elif x > 5: return y -reveal_type(h(1, "a")) # revealed: int | None | str +reveal_type(h(1, "a")) # revealed: int | str | None def generator(): yield 1 @@ -369,7 +369,7 @@ def eager_nested_scope(): return A.x -reveal_type(eager_nested_scope()) # revealed: Any +reveal_type(eager_nested_scope()) # revealed: Unknown class C: def flip(self) -> "D": @@ -408,8 +408,8 @@ class D(C): def f(self): return None -reveal_type(C().f()) # revealed: Literal[1] | Any -reveal_type(D().f()) # revealed: Literal[1] | None | Any +reveal_type(C().f()) # revealed: Literal[1] | Unknown +reveal_type(D().f()) # revealed: None | Literal[1] | Unknown ``` However, in the following cases, `Unknown` is not included in the inferred return type because there @@ -484,13 +484,13 @@ reveal_type(C().f()) # revealed: int reveal_type(D().f()) # revealed: int reveal_type(E().f()) # revealed: int reveal_type(C().g(1)) # revealed: Literal[1] -reveal_type(D().g(1)) # revealed: Literal[2] | Any +reveal_type(D().g(1)) # revealed: Literal[2] | Unknown reveal_type(C().h(1)) # revealed: Literal[1] -reveal_type(D().h(1)) # revealed: Literal[2] | Any +reveal_type(D().h(1)) # revealed: Literal[2] | Unknown reveal_type(C().h(True)) # revealed: Literal[True] -reveal_type(D().h(True)) # revealed: Literal[2] | Any +reveal_type(D().h(True)) # revealed: Literal[2] | Unknown reveal_type(C().i(1)) # revealed: list[Literal[1]] -reveal_type(D().i(1)) # revealed: list[Any] +reveal_type(D().i(1)) # revealed: list[@Todo(list literal element type)] class F: def f(self) -> Literal[1, 2]: diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 38e45801b3dfb..3930724712143 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -649,7 +649,7 @@ impl<'db> ScopeInference<'db> { } } - union.build().normalized(db) + union.build() } } From 6581f3b59381e293a25d8a9aff5e1aa113fa3bd7 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 1 Sep 2025 22:56:24 +0900 Subject: [PATCH 034/105] move `sympy` to `good.txt` --- crates/ty_python_semantic/resources/primer/bad.txt | 1 - crates/ty_python_semantic/resources/primer/good.txt | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/primer/bad.txt b/crates/ty_python_semantic/resources/primer/bad.txt index c598405f9fcb6..d9da06edd3a50 100644 --- a/crates/ty_python_semantic/resources/primer/bad.txt +++ b/crates/ty_python_semantic/resources/primer/bad.txt @@ -18,5 +18,4 @@ setuptools # vendors packaging, see above spack # slow, success, but mypy-primer hangs processing the output spark # too many iterations steam.py # hangs (single threaded) -sympy # stack overflow (multi threaded) streamlit # too many iterations (uses packaging) diff --git a/crates/ty_python_semantic/resources/primer/good.txt b/crates/ty_python_semantic/resources/primer/good.txt index cc3b235caf928..2c1694edd992a 100644 --- a/crates/ty_python_semantic/resources/primer/good.txt +++ b/crates/ty_python_semantic/resources/primer/good.txt @@ -110,6 +110,7 @@ static-frame stone strawberry svcs +sympy tornado trio twine From f6418a60dceefa4e719ea17cb5d8177914d109e3 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 2 Sep 2025 00:30:47 +0900 Subject: [PATCH 035/105] The return type of a generator function is `Unknown` for now --- .../mdtest/diagnostics/semantic_syntax_errors.md | 2 -- .../resources/mdtest/function/return_type.md | 2 +- .../ty_python_semantic/src/semantic_index/scope.rs | 14 ++++++++------ crates/ty_python_semantic/src/types/infer.rs | 5 +++-- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md index 48c31d78300a0..a77efb3b96ac8 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md @@ -148,8 +148,6 @@ yield from [] await C() def f(): - # TODO: no error, C is awaitable - # error: [invalid-await] "`C` is not awaitable" # error: [invalid-syntax] "`await` outside of an asynchronous function" await C() ``` diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 674deff962300..636127f6c6a9b 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -292,7 +292,7 @@ def generator(): return None # TODO: Should be `Generator[Literal[1, 2], Any, None]` -reveal_type(generator()) # revealed: None +reveal_type(generator()) # revealed: Unknown ``` The return type of a recursive function is also inferred. When the return type inference would diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index 5cf6c992d2c1b..a323174e3ea61 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -74,14 +74,16 @@ impl<'db> ScopeId<'db> { } } - pub(crate) fn is_async_function(self, db: &'db dyn Db) -> bool { + pub(crate) fn is_coroutine_function(self, db: &'db dyn Db) -> bool { let module = parsed_module(db, self.file(db)).load(db); - self.node(db).as_function(&module).is_some_and(|func| { - let index = semantic_index(db, self.file(db)); - let is_generator = self.file_scope_id(db).is_generator_function(index); + self.node(db) + .as_function(&module) + .is_some_and(|func| func.is_async && !self.is_generator_function(db)) + } - func.is_async && !is_generator - }) + pub(crate) fn is_generator_function(self, db: &'db dyn Db) -> bool { + let index = semantic_index(db, self.file(db)); + self.file_scope_id(db).is_generator_function(index) } } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 3930724712143..24a51964b10c1 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -581,8 +581,9 @@ impl<'db> ScopeInference<'db> { /// In the case of methods, the return type of the superclass method is further unioned. /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. pub(crate) fn infer_return_type(&self, db: &'db dyn Db, callee_ty: Type<'db>) -> Type<'db> { - // TODO: async function type inference - if self.scope.is_async_function(db) { + // TODO: coroutine function type inference + // TODO: generator function type inference + if self.scope.is_coroutine_function(db) || self.scope.is_generator_function(db) { return Type::unknown(); } From acb911857cb4031a5483e7fcfd6a9f25ace8fe8e Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 2 Sep 2025 10:52:06 +0900 Subject: [PATCH 036/105] Update types.rs --- crates/ty_python_semantic/src/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 1ba92962a6586..b69a1dabbbaa3 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -8984,7 +8984,7 @@ impl<'db> BoundMethodType<'db> { inference.infer_return_type(db, Type::BoundMethod(self)) } - #[salsa::tracked] + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] fn class_definition(self, db: &'db dyn Db) -> Option> { let definition_scope = self.function(db).definition(db).scope(db); let index = semantic_index(db, definition_scope.file(db)); From 9b1c373f5dc421d4c9c5bf9e5aab2af35d2ce772 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 2 Sep 2025 13:20:20 +0900 Subject: [PATCH 037/105] add doc comments --- crates/ty_python_semantic/src/types/infer.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 24a51964b10c1..f3a8d91d21b6d 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -592,6 +592,22 @@ impl<'db> ScopeInference<'db> { if self.is_cycle_callback() { union = union.add(div); } + // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. + // For example, consider the following recursive function: + // ```py + // def div(n: int): + // if n == 0: + // return None + // else: + // return (div(n-1),) + // ``` + // If we try to infer the return type of this function naively, we will get `tuple[tuple[tuple[...] | None] | None] | None`, which never converges. + // So, when we detect a cycle, we set the cycle initial type to `Divergent`. Then the type obtained in the first cycle is `tuple[Divergent] | None`. + // Next, if there is a type containing `Divergent`, we replace it with the `Divergent` type itself. + // All types containing `Divergent` are flattened in the next cycle, resulting in a convergence of the return type in finite cycles. + // 0th: Divergent + // 1st: tuple[Divergent] | None => Divergent | None + // 2nd: tuple[Divergent | None] | None => Divergent | None let mut union_add = |ty: Type<'db>| { if ty == div { // `Divergent` appearing in a union does not mean true divergence, so it can be removed. From 2d1b8a311f2491cbe922e591be3e6d33a687bea1 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 2 Sep 2025 16:54:33 +0900 Subject: [PATCH 038/105] Add unit tests for the `Divergent` type --- crates/ty_python_semantic/src/types.rs | 68 +++++++++++++++++++ .../ty_python_semantic/src/types/builder.rs | 24 ++++++- 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index b69a1dabbbaa3..5a1735e088268 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -10799,4 +10799,72 @@ pub(crate) mod tests { .is_todo() ); } + + #[test] + fn divergent_type() { + let db = setup_db(); + + let div = Type::Dynamic(DynamicType::Divergent); + + let union = UnionType::from_elements(&db, [Type::unknown(), div]); + assert_eq!(union.display(&db).to_string(), "Unknown | Divergent"); + + let union = UnionType::from_elements(&db, [div, Type::unknown()]); + assert_eq!(union.display(&db).to_string(), "Divergent"); + + let union = UnionType::from_elements(&db, [div, KnownClass::Object.to_instance(&db)]); + assert_eq!(union.display(&db).to_string(), "object | Divergent"); + + let union = UnionType::from_elements(&db, [KnownClass::Object.to_instance(&db), div]); + assert_eq!(union.display(&db).to_string(), "object | Divergent"); + + let union = UnionType::from_elements( + &db, + [ + KnownClass::Object.to_instance(&db), + KnownClass::List.to_specialized_instance(&db, [div]), + ], + ); + assert_eq!(union.display(&db).to_string(), "object | list[Divergent]"); + + let union = UnionType::from_elements( + &db, + [ + KnownClass::Object.to_instance(&db), + KnownClass::List.to_specialized_instance(&db, [div]), + KnownClass::Int.to_instance(&db), + ], + ); + assert_eq!(union.display(&db).to_string(), "object | list[Divergent]"); + + let intersection = IntersectionBuilder::new(&db) + .add_positive(Type::Never) + .add_positive(div) + .build(); + assert_eq!(intersection.display(&db).to_string(), "Never & Divergent"); + + let intersection = IntersectionBuilder::new(&db) + .add_positive(div) + .add_positive(Type::Never) + .build(); + assert_eq!(intersection.display(&db).to_string(), "Divergent & Never"); + + let intersection = IntersectionBuilder::new(&db) + .add_positive(KnownClass::List.to_specialized_instance(&db, [div])) + .add_positive(Type::Never) + .build(); + assert_eq!( + intersection.display(&db).to_string(), + "list[Divergent] & Never" + ); + + let intersection = IntersectionBuilder::new(&db) + .add_positive(KnownClass::Int.to_instance(&db)) + .add_positive(KnownClass::List.to_specialized_instance(&db, [div])) + .build(); + assert_eq!( + intersection.display(&db).to_string(), + "int & list[Divergent]" + ); + } } diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 943055da2ad71..39f2d652730a6 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -241,9 +241,22 @@ impl<'db> UnionBuilder<'db> { /// Collapse the union to a single type: `object`. fn collapse_to_object(&mut self) { + let divergent = self.elements.iter().find_map(|elem| { + let UnionElement::Type(ty) = elem else { + return None; + }; + if ty.has_divergent_type(self.db) { + Some(*ty) + } else { + None + } + }); self.elements.clear(); self.elements .push(UnionElement::Type(Type::object(self.db))); + if let Some(divergent) = divergent { + self.elements.push(UnionElement::Type(divergent)); + } } /// Adds a type to this union. @@ -964,11 +977,18 @@ impl<'db> InnerIntersectionBuilder<'db> { return; } // same rule, reverse order - if new_positive.is_subtype_of(db, *existing_positive) { + if new_positive.is_subtype_of(db, *existing_positive) + && !existing_positive.has_divergent_type(db) + { to_remove.push(index); } // A & B = Never if A and B are disjoint - if new_positive.is_disjoint_from(db, *existing_positive) { + if new_positive.is_disjoint_from(db, *existing_positive) + && self + .positive + .iter() + .all(|existing_positive| !existing_positive.has_divergent_type(db)) + { *self = Self::default(); self.positive.insert(Type::Never); return; From a67b4349e7dddf83efee63229a8c32fbbca70ff5 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 2 Sep 2025 19:58:12 +0900 Subject: [PATCH 039/105] remove `divergence_safe_{todo, unknown}` --- crates/ty_python_semantic/src/types.rs | 14 +--- crates/ty_python_semantic/src/types/infer.rs | 85 +++++--------------- 2 files changed, 23 insertions(+), 76 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 5a1735e088268..abe9f62617b74 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -57,7 +57,7 @@ pub use crate::types::ide_support::{ definitions_for_attribute, definitions_for_imported_symbol, definitions_for_keyword_argument, definitions_for_name, find_active_signature_from_details, inlay_hint_function_argument_details, }; -use crate::types::infer::{divergence_safe_todo, infer_unpack_types}; +use crate::types::infer::infer_unpack_types; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature}; @@ -5037,11 +5037,7 @@ impl<'db> Type<'db> { let special_case = match self { Type::NominalInstance(nominal) => nominal.tuple_spec(db), Type::GenericAlias(alias) if alias.origin(db).is_tuple(db) => { - Some(Cow::Owned(TupleSpec::homogeneous(divergence_safe_todo( - db, - "*tuple[] annotations", - [self], - )))) + Some(Cow::Owned(TupleSpec::homogeneous(todo_type!("*tuple[] annotations")))) } Type::StringLiteral(string_literal_ty) => { let string_literal = string_literal_ty.value(db); @@ -5554,11 +5550,7 @@ impl<'db> Type<'db> { } Type::TypeVar(bound_typevar) => Some(Type::TypeVar(bound_typevar.to_instance(db)?)), Type::TypeAlias(alias) => alias.value_type(db).to_instance(db), - Type::Intersection(_) => Some(divergence_safe_todo( - db, - "Type::Intersection.to_instance", - [self], - )), + Type::Intersection(_) => Some(todo_type!("Type::Intersection.to_instance")), Type::BooleanLiteral(_) | Type::BytesLiteral(_) | Type::EnumLiteral(_) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index f3a8d91d21b6d..f6f85e13cc544 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -140,34 +140,6 @@ use crate::util::diagnostics::format_enumeration; use crate::util::subscript::{PyIndex, PySlice}; use crate::{Db, FxOrderSet, Program}; -pub(crate) fn divergence_safe_todo<'db>( - db: &'db dyn Db, - msg: &'static str, - types: impl IntoIterator>, -) -> Type<'db> { - let _ = msg; - let mut builder = IntersectionBuilder::new(db).add_positive(todo_type!(msg)); - for ty in types { - if ty.has_divergent_type(db) { - builder = builder.add_positive(ty); - } - } - builder.build() -} - -fn divergence_safe_unknown<'db>( - db: &'db dyn Db, - types: impl IntoIterator>, -) -> Type<'db> { - let mut builder = IntersectionBuilder::new(db).add_positive(Type::unknown()); - for ty in types { - if ty.has_divergent_type(db) { - builder = builder.add_positive(ty); - } - } - builder.build() -} - /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. @@ -3570,7 +3542,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(node) = node { report_invalid_exception_caught(&self.context, node, element); } - divergence_safe_unknown(self.db(), [element]) + Type::unknown() }, ); } @@ -3616,7 +3588,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(node) = node { report_invalid_exception_caught(&self.context, node, node_ty); } - divergence_safe_unknown(self.db(), [node_ty]) + Type::unknown() }; if is_star { @@ -4970,7 +4942,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.infer_binary_expression_type(assignment.into(), false, target_type, value_type, op) .unwrap_or_else(|| { report_unsupported_augmented_op(&mut self.context); - divergence_safe_unknown(self.db(), [target_type, value_type]) + Type::unknown() }) }; @@ -7653,7 +7625,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { add_inferred_python_version_hint_to_diagnostic(db, &mut diag, "resolving types"); } } - divergence_safe_unknown(db, [left_ty, right_ty]) + Type::unknown() }) } @@ -8182,7 +8154,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { | ast::CmpOp::Is | ast::CmpOp::IsNot => KnownClass::Bool.to_instance(builder.db()), // Other operators can return arbitrary types - _ => divergence_safe_unknown(builder.db(), [left_ty, right_ty]), + _ => Type::unknown(), } }); @@ -8831,10 +8803,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // TODO: Consider comparing the prefixes of the tuples, since that could give a comparison // result regardless of how long the variable-length tuple is. let (TupleSpec::Fixed(left), TupleSpec::Fixed(right)) = (left, right) else { - return Ok(divergence_safe_unknown( - self.db(), - left.all_elements().chain(right.all_elements()).copied(), - )); + return Ok(Type::unknown()); }; let left_iter = left.elements().copied(); @@ -9065,11 +9034,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // but we need to make sure we avoid emitting a diagnostic if one positive element has a `__getitem__` // method but another does not. This means `infer_subscript_expression_types` // needs to return a `Result` rather than eagerly emitting diagnostics. - (Type::Intersection(_), _) => Some(divergence_safe_todo( - db, - "Subscript expressions on intersections", - [value_ty, slice_ty], - )), + (Type::Intersection(_), _) => { + Some(todo_type!("Subscript expressions on intersections")) + } // Ex) Given `("a", "b", "c", "d")[1]`, return `"b"` (Type::NominalInstance(nominal), Type::IntLiteral(i64_int)) => nominal @@ -9085,7 +9052,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { tuple.len().display_minimum(), i64_int, ); - divergence_safe_unknown(db, [value_ty]) + Type::unknown() }) }), @@ -9103,7 +9070,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::heterogeneous_tuple(db, new_elements) } else { report_slice_step_size_zero(context, value_node.into()); - divergence_safe_unknown(self.db(), [value_ty, slice_ty]) + Type::unknown() } } TupleSpec::Variable(_) => { @@ -9144,7 +9111,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::string_literal(db, &literal) } else { report_slice_step_size_zero(context, value_node.into()); - divergence_safe_unknown(self.db(), [slice_ty]) + Type::unknown() } }), @@ -9180,7 +9147,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::bytes_literal(db, &new_bytes) } else { report_slice_step_size_zero(context, value_node.into()); - divergence_safe_unknown(self.db(), [slice_ty]) + Type::unknown() } }), @@ -9217,11 +9184,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { (Type::KnownInstance(KnownInstanceType::SubscriptedProtocol(_)), _) => { // TODO: emit a diagnostic - Some(divergence_safe_todo( - db, - "doubly-specialized typing.Protocol", - [value_ty, slice_ty], - )) + Some(todo_type!("doubly-specialized typing.Protocol")) } (Type::SpecialForm(SpecialFormType::Generic), typevars) => Some( @@ -9234,28 +9197,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { (Type::KnownInstance(KnownInstanceType::SubscriptedGeneric(_)), _) => { // TODO: emit a diagnostic - Some(divergence_safe_todo( - db, - "doubly-specialized typing.Generic", - [value_ty, slice_ty], - )) + Some(todo_type!("doubly-specialized typing.Generic")) } (Type::SpecialForm(SpecialFormType::Unpack), _) => { Some(Type::Dynamic(DynamicType::TodoUnpack)) } - (Type::SpecialForm(special_form), _) if special_form.class().is_special_form() => Some( - divergence_safe_todo(db, "Inference of subscript on special form", [slice_ty]), - ), + (Type::SpecialForm(special_form), _) if special_form.class().is_special_form() => { + Some(todo_type!("Inference of subscript on special form")) + } (Type::KnownInstance(known_instance), _) if known_instance.class().is_special_form() => { - Some(divergence_safe_todo( - db, - "Inference of subscript on special form", - [slice_ty], - )) + Some(todo_type!("Inference of subscript on special form")) } _ => None, @@ -9416,7 +9371,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - divergence_safe_unknown(self.db(), [value_ty, slice_ty]) + Type::unknown() } fn legacy_generic_class_context( From bf4c1f1fce80dcc67acc7ca98c033a9f6d3cea30 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 2 Sep 2025 20:18:36 +0900 Subject: [PATCH 040/105] remove `function_place` --- crates/ty_python_semantic/src/types.rs | 1 - crates/ty_python_semantic/src/types/infer.rs | 14 +------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index abe9f62617b74..9a1db3bf068df 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -5959,7 +5959,6 @@ impl<'db> Type<'db> { Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class_literal(db), Type::Dynamic(dynamic) => SubclassOfType::from(db, SubclassOfInner::Dynamic(dynamic)), // TODO intersections - // TODO divergence safety Type::Intersection(_) => SubclassOfType::from( db, SubclassOfInner::try_from_type(db, todo_type!("Intersection meta-type")) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index f6f85e13cc544..713d27564a24f 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -83,7 +83,7 @@ use crate::semantic_index::definition::{ }; use crate::semantic_index::expression::{Expression, ExpressionKind}; use crate::semantic_index::narrowing_constraints::ConstraintKey; -use crate::semantic_index::place::{PlaceExpr, PlaceExprRef, ScopedPlaceId}; +use crate::semantic_index::place::{PlaceExpr, PlaceExprRef}; use crate::semantic_index::scope::{ FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind, }; @@ -170,18 +170,6 @@ fn scope_cycle_initial<'db>(_db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInfer ScopeInference::cycle_fallback(scope) } -#[salsa::tracked] -fn function_place<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Option { - if let NodeWithScopeKind::Function(func) = scope.node(db) { - let file = scope.file(db); - let index = semantic_index(db, file); - let module = parsed_module(db, file).load(db); - Some(index.expect_single_definition(func.node(&module)).place(db)) - } else { - None - } -} - /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a place use or public type of a place. #[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=ruff_memory_usage::heap_size)] From bc6793cafa0b87b01120a012bdb3249047a2c9f8 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 01:09:57 +0900 Subject: [PATCH 041/105] make it clear that type inference for lambda expressions is not yet performed --- .../resources/mdtest/function/return_type.md | 5 +++++ crates/ty_python_semantic/src/semantic_index/scope.rs | 8 ++------ crates/ty_python_semantic/src/types/infer.rs | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 636127f6c6a9b..e22e9f407bb6f 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -286,6 +286,11 @@ def h(x: int, y: str): reveal_type(h(1, "a")) # revealed: int | str | None +lambda_func = lambda: 1 +# TODO: lambda function type inference +# Should be `Literal[1]` +reveal_type(lambda_func()) # revealed: Unknown + def generator(): yield 1 yield 2 diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index a323174e3ea61..8e4c7d2e6cafc 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -32,8 +32,8 @@ impl<'db> ScopeId<'db> { self.node(db).scope_kind().is_function_like() } - pub(crate) fn is_function_or_lambda(self, db: &'db dyn Db) -> bool { - self.node(db).scope_kind().is_function_or_lambda() + pub(crate) fn is_non_lambda_function(self, db: &'db dyn Db) -> bool { + self.node(db).scope_kind().is_non_lambda_function() } pub(crate) fn is_annotation(self, db: &'db dyn Db) -> bool { @@ -266,10 +266,6 @@ impl ScopeKind { ) } - pub(crate) const fn is_function_or_lambda(self) -> bool { - matches!(self, ScopeKind::Function | ScopeKind::Lambda) - } - pub(crate) const fn is_class(self) -> bool { matches!(self, ScopeKind::Class) } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 713d27564a24f..b748ab05fc03c 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -9636,7 +9636,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let diagnostics = context.finish(); - let extra = (!diagnostics.is_empty() || cycle_fallback || scope.is_function_or_lambda(db)) + let extra = (!diagnostics.is_empty() || cycle_fallback || scope.is_non_lambda_function(db)) .then(|| { let returnees = returnees .into_iter() From 480e64a82bcd415369de5de16305cbe45af44bf5 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 01:16:11 +0900 Subject: [PATCH 042/105] revert an unnecessary change --- crates/ty_python_semantic/src/types/builder.rs | 2 +- crates/ty_python_semantic/src/types/class.rs | 15 +-------------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 39f2d652730a6..8d2ddd60ada7e 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -206,7 +206,7 @@ enum ReduceResult<'db> { // // For now (until we solve https://github.com/astral-sh/ty/issues/957), keep this number // below 200, which is the salsa fixpoint iteration limit. -const MAX_UNION_LITERALS: usize = 198; +const MAX_UNION_LITERALS: usize = 199; pub(crate) struct UnionBuilder<'db> { elements: Vec>, diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index e50edbb6792b5..834690633203d 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -175,19 +175,6 @@ fn try_metaclass_cycle_initial<'db>( }) } -fn into_callable_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &Type<'db>, - _count: u32, - _self: ClassType<'db>, -) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate -} - -fn into_callable_cycle_initial<'db>(_db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { - Type::Never -} - /// A category of classes with code generation capabilities (with synthesized methods). #[derive(Clone, Copy, Debug, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) enum CodeGeneratorKind { @@ -1075,7 +1062,7 @@ impl<'db> ClassType<'db> { /// Return a callable type (or union of callable types) that represents the callable /// constructor signature of this class. - #[salsa::tracked(cycle_fn=into_callable_cycle_recover, cycle_initial=into_callable_cycle_initial, heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] pub(super) fn into_callable(self, db: &'db dyn Db) -> Type<'db> { let self_ty = Type::from(self); let metaclass_dunder_call_function_symbol = self_ty From 5f38dcb3b07cbf51025851aceb05dc12bb0a1df6 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 01:32:56 +0900 Subject: [PATCH 043/105] simply call the tracked function instead of using `LazyCell` --- crates/ty_python_semantic/src/types/narrow.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 98596b2d84721..f7b06a7a829fb 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -22,7 +22,6 @@ use itertools::Itertools; use ruff_python_ast as ast; use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; -use std::cell::LazyCell; use std::collections::hash_map::Entry; use super::UnionType; @@ -709,8 +708,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // and that requires cross-symbol constraints, which we don't support yet. return None; } - // Performance optimization: deferring type inference for an expression until it is actually needed. - let inference = LazyCell::new(|| infer_expression_types(self.db, expression)); let comparator_tuples = std::iter::once(&**left) .chain(comparators) @@ -726,6 +723,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if let Some(left_place) = place_expr(left) { let op = if is_positive { *op } else { op.negate() }; + let inference = infer_expression_types(self.db, expression); let lhs_ty = inference.expression_type(left); let rhs_ty = inference.expression_type(right); @@ -761,6 +759,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { op == &ast::CmpOp::IsNot }; + let inference = infer_expression_types(self.db, expression); let rhs_ty = inference.expression_type(right); let Type::ClassLiteral(rhs_class) = rhs_ty else { continue; From 7326e47d7bed3d28b6f809e46a1ff26952230071 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 01:48:46 +0900 Subject: [PATCH 044/105] Revert "revert an unnecessary change" This reverts commit 480e64a82bcd415369de5de16305cbe45af44bf5. --- .../resources/corpus/cycle_into_callable.py | 17 +++++++++++++++++ crates/ty_python_semantic/src/types/class.rs | 15 ++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 crates/ty_python_semantic/resources/corpus/cycle_into_callable.py diff --git a/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py b/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py new file mode 100644 index 0000000000000..ce4cd6a795d02 --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py @@ -0,0 +1,17 @@ +# Regression test for https://github.com/astral-sh/ruff/issues/17371 +# panicked in commit d1088545a08aeb57b67ec1e3a7f5141159efefa5 +# error message: +# dependency graph cycle when querying ClassType < 'db >::into_callable_(Id(1c00)) + +try: + class foo[T: bar](object): + pass + bar = foo +except Exception: + bar = lambda: 0 +def bar(): + pass + +@bar() +class bar: + pass diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 834690633203d..be32fe8453937 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -175,6 +175,19 @@ fn try_metaclass_cycle_initial<'db>( }) } +fn into_callable_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: ClassType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn into_callable_cycle_initial<'db>(_db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { + Type::Never +} + /// A category of classes with code generation capabilities (with synthesized methods). #[derive(Clone, Copy, Debug, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) enum CodeGeneratorKind { @@ -1062,7 +1075,7 @@ impl<'db> ClassType<'db> { /// Return a callable type (or union of callable types) that represents the callable /// constructor signature of this class. - #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked(cycle_fn=into_callable_cycle_recover, cycle_initial=into_callable_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(super) fn into_callable(self, db: &'db dyn Db) -> Type<'db> { let self_ty = Type::from(self); let metaclass_dunder_call_function_symbol = self_ty From 9739daabf828a0f16523647132dc87776272f3f9 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 11:39:17 +0900 Subject: [PATCH 045/105] remove `ScopeInference::scope` --- crates/ty_python_semantic/src/types/infer.rs | 26 ++++++++++++-------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index b748ab05fc03c..8f598f219bd86 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -470,8 +470,6 @@ struct Returnee { /// The inferred types for a scope region. #[derive(Debug, Eq, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) struct ScopeInference<'db> { - scope: ScopeId<'db>, - /// The types of every expression in this region. expressions: FxHashMap>, @@ -497,8 +495,9 @@ struct ScopeInferenceExtra { impl<'db> ScopeInference<'db> { fn cycle_fallback(scope: ScopeId<'db>) -> Self { + let _ = scope; + Self { - scope, extra: Some(Box::new(ScopeInferenceExtra { cycle_fallback: true, ..ScopeInferenceExtra::default() @@ -541,9 +540,20 @@ impl<'db> ScopeInference<'db> { /// In the case of methods, the return type of the superclass method is further unioned. /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. pub(crate) fn infer_return_type(&self, db: &'db dyn Db, callee_ty: Type<'db>) -> Type<'db> { + let scope = match callee_ty { + Type::FunctionLiteral(function) => { + function.literal(db).last_definition(db).body_scope(db) + } + Type::BoundMethod(method) => method + .function(db) + .literal(db) + .last_definition(db) + .body_scope(db), + _ => return Type::none(db), + }; // TODO: coroutine function type inference // TODO: generator function type inference - if self.scope.is_coroutine_function(db) || self.scope.is_generator_function(db) { + if scope.is_coroutine_function(db) || scope.is_generator_function(db) { return Type::unknown(); } @@ -602,7 +612,7 @@ impl<'db> ScopeInference<'db> { }); union_add(ty); } - let use_def = use_def_map(db, self.scope); + let use_def = use_def_map(db, scope); if use_def.can_implicitly_return_none(db) { union_add(Type::none(db)); } @@ -9652,11 +9662,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { expressions.shrink_to_fit(); - ScopeInference { - scope, - expressions, - extra, - } + ScopeInference { expressions, extra } } } From 866e56787d88563d48d2b6ea48c5efb6b4e5b7fd Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 12:30:40 +0900 Subject: [PATCH 046/105] revert changes in `narrow.rs` --- crates/ty_python_semantic/src/types/narrow.rs | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index f7b06a7a829fb..46a0b5a8f5b1e 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -709,26 +709,30 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } + let inference = infer_expression_types(self.db, expression); + let comparator_tuples = std::iter::once(&**left) .chain(comparators) .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); let mut constraints = NarrowingConstraints::default(); + let mut last_rhs_ty: Option = None; + for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) { + let lhs_ty = last_rhs_ty.unwrap_or_else(|| inference.expression_type(left)); + let rhs_ty = inference.expression_type(right); + last_rhs_ty = Some(rhs_ty); + match left { ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) | ast::Expr::Named(_) => { - if let Some(left_place) = place_expr(left) { + if let Some(left) = place_expr(left) { let op = if is_positive { *op } else { op.negate() }; - let inference = infer_expression_types(self.db, expression); - let lhs_ty = inference.expression_type(left); - let rhs_ty = inference.expression_type(right); - if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { - let place = self.expect_place(&left_place); + let place = self.expect_place(&left); constraints.insert(place, ty); } } @@ -745,6 +749,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { node_index: _, }, }) if keywords.is_empty() => { + let Type::ClassLiteral(rhs_class) = rhs_ty else { + continue; + }; + let target = match &**args { [first] => match place_expr(first) { Some(target) => target, @@ -759,12 +767,6 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { op == &ast::CmpOp::IsNot }; - let inference = infer_expression_types(self.db, expression); - let rhs_ty = inference.expression_type(right); - let Type::ClassLiteral(rhs_class) = rhs_ty else { - continue; - }; - // `else`-branch narrowing for `if type(x) is Y` can only be done // if `Y` is a final class if !rhs_class.is_final(self.db) && !is_positive { From a1bd75815d5750755e87bc2c690d0a3f62855068 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 16:35:25 +0900 Subject: [PATCH 047/105] `ScopeInference::cycle_fallback() == Divergent` --- crates/ty_python_semantic/src/types/infer.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 8f598f219bd86..a1d4268239589 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -532,7 +532,8 @@ impl<'db> ScopeInference<'db> { } fn fallback_type(&self) -> Option> { - self.is_cycle_callback().then_some(Type::Never) + self.is_cycle_callback() + .then_some(Type::Dynamic(DynamicType::Divergent)) } /// Returns the inferred return type of this function body (union of all possible return types), @@ -559,8 +560,8 @@ impl<'db> ScopeInference<'db> { let mut union = UnionBuilder::new(db); let div = Type::Dynamic(DynamicType::Divergent); - if self.is_cycle_callback() { - union = union.add(div); + if let Some(fallback_type) = self.fallback_type() { + union = union.add(fallback_type); } // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. // For example, consider the following recursive function: From 3c9703d1267e81f30e1bb2d9cf4d137e6d33c728 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 19:16:19 +0900 Subject: [PATCH 048/105] remove special casing of `Divergent` in union/intersection reduction --- crates/ty_python_semantic/src/types.rs | 68 ------------------- .../ty_python_semantic/src/types/builder.rs | 55 ++------------- 2 files changed, 4 insertions(+), 119 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9a1db3bf068df..a3104019514d6 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -10790,72 +10790,4 @@ pub(crate) mod tests { .is_todo() ); } - - #[test] - fn divergent_type() { - let db = setup_db(); - - let div = Type::Dynamic(DynamicType::Divergent); - - let union = UnionType::from_elements(&db, [Type::unknown(), div]); - assert_eq!(union.display(&db).to_string(), "Unknown | Divergent"); - - let union = UnionType::from_elements(&db, [div, Type::unknown()]); - assert_eq!(union.display(&db).to_string(), "Divergent"); - - let union = UnionType::from_elements(&db, [div, KnownClass::Object.to_instance(&db)]); - assert_eq!(union.display(&db).to_string(), "object | Divergent"); - - let union = UnionType::from_elements(&db, [KnownClass::Object.to_instance(&db), div]); - assert_eq!(union.display(&db).to_string(), "object | Divergent"); - - let union = UnionType::from_elements( - &db, - [ - KnownClass::Object.to_instance(&db), - KnownClass::List.to_specialized_instance(&db, [div]), - ], - ); - assert_eq!(union.display(&db).to_string(), "object | list[Divergent]"); - - let union = UnionType::from_elements( - &db, - [ - KnownClass::Object.to_instance(&db), - KnownClass::List.to_specialized_instance(&db, [div]), - KnownClass::Int.to_instance(&db), - ], - ); - assert_eq!(union.display(&db).to_string(), "object | list[Divergent]"); - - let intersection = IntersectionBuilder::new(&db) - .add_positive(Type::Never) - .add_positive(div) - .build(); - assert_eq!(intersection.display(&db).to_string(), "Never & Divergent"); - - let intersection = IntersectionBuilder::new(&db) - .add_positive(div) - .add_positive(Type::Never) - .build(); - assert_eq!(intersection.display(&db).to_string(), "Divergent & Never"); - - let intersection = IntersectionBuilder::new(&db) - .add_positive(KnownClass::List.to_specialized_instance(&db, [div])) - .add_positive(Type::Never) - .build(); - assert_eq!( - intersection.display(&db).to_string(), - "list[Divergent] & Never" - ); - - let intersection = IntersectionBuilder::new(&db) - .add_positive(KnownClass::Int.to_instance(&db)) - .add_positive(KnownClass::List.to_specialized_instance(&db, [div])) - .build(); - assert_eq!( - intersection.display(&db).to_string(), - "int & list[Divergent]" - ); - } } diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 8d2ddd60ada7e..1bdb85fa5021d 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -40,7 +40,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::type_ordering::union_or_intersection_elements_ordering; use crate::types::{ - BytesLiteralType, DynamicType, IntersectionType, KnownClass, StringLiteralType, Type, + BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType, }; use crate::{Db, FxOrderSet}; @@ -241,22 +241,9 @@ impl<'db> UnionBuilder<'db> { /// Collapse the union to a single type: `object`. fn collapse_to_object(&mut self) { - let divergent = self.elements.iter().find_map(|elem| { - let UnionElement::Type(ty) = elem else { - return None; - }; - if ty.has_divergent_type(self.db) { - Some(*ty) - } else { - None - } - }); self.elements.clear(); self.elements .push(UnionElement::Type(Type::object(self.db))); - if let Some(divergent) = divergent { - self.elements.push(UnionElement::Type(divergent)); - } } /// Adds a type to this union. @@ -465,15 +452,6 @@ impl<'db> UnionBuilder<'db> { ty if ty.is_object(self.db) => { self.collapse_to_object(); } - Type::Dynamic(DynamicType::Divergent) => { - if !self - .elements - .iter() - .any(|elem| matches!(elem, UnionElement::Type(elem) if ty == *elem)) - { - self.elements.push(UnionElement::Type(ty)); - } - } _ => { let bool_pair = if let Type::BooleanLiteral(b) = ty { Some(Type::BooleanLiteral(!b)) @@ -493,16 +471,6 @@ impl<'db> UnionBuilder<'db> { Type::Never // won't be used }; - if ty.has_divergent_type(self.db) - && !self - .elements - .iter() - .any(|elem| matches!(elem, UnionElement::Type(elem) if ty == *elem)) - { - self.elements.push(UnionElement::Type(ty)); - return; - } - for (index, element) in self.elements.iter_mut().enumerate() { let element_type = match element.try_reduce(self.db, ty) { ReduceResult::KeepIf(keep) => { @@ -896,9 +864,7 @@ impl<'db> InnerIntersectionBuilder<'db> { self.add_positive(db, Type::LiteralString); self.add_negative(db, Type::string_literal(db, "")); } - Type::Dynamic(DynamicType::Divergent) => { - self.positive.insert(new_positive); - } + _ => { let known_instance = new_positive .into_nominal_instance() @@ -967,9 +933,6 @@ impl<'db> InnerIntersectionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_positive) in self.positive.iter().enumerate() { - if new_positive.has_divergent_type(db) { - break; - } // S & T = S if S <: T if existing_positive.is_subtype_of(db, new_positive) || existing_positive.is_equivalent_to(db, new_positive) @@ -977,18 +940,11 @@ impl<'db> InnerIntersectionBuilder<'db> { return; } // same rule, reverse order - if new_positive.is_subtype_of(db, *existing_positive) - && !existing_positive.has_divergent_type(db) - { + if new_positive.is_subtype_of(db, *existing_positive) { to_remove.push(index); } // A & B = Never if A and B are disjoint - if new_positive.is_disjoint_from(db, *existing_positive) - && self - .positive - .iter() - .all(|existing_positive| !existing_positive.has_divergent_type(db)) - { + if new_positive.is_disjoint_from(db, *existing_positive) { *self = Self::default(); self.positive.insert(Type::Never); return; @@ -1000,9 +956,6 @@ impl<'db> InnerIntersectionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_negative) in self.negative.iter().enumerate() { - if new_positive.has_divergent_type(db) { - break; - } // S & ~T = Never if S <: T if new_positive.is_subtype_of(db, *existing_negative) { *self = Self::default(); From 7e418a75ee71c3627178756bd7df82f1e674f5b8 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 3 Sep 2025 20:17:55 +0900 Subject: [PATCH 049/105] Revert "remove special casing of `Divergent` in union/intersection reduction" This reverts commit 3c9703d1267e81f30e1bb2d9cf4d137e6d33c728. --- .../resources/corpus/divergent.py | 24 ++++++- crates/ty_python_semantic/src/types.rs | 68 +++++++++++++++++++ .../ty_python_semantic/src/types/builder.rs | 55 +++++++++++++-- 3 files changed, 142 insertions(+), 5 deletions(-) diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py index e120ae391d5f8..228060149a686 100644 --- a/crates/ty_python_semantic/resources/corpus/divergent.py +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -33,4 +33,26 @@ def f(cond: bool): return result -reveal_type(f(True)) \ No newline at end of file +reveal_type(f(True)) + +class Foo: + def value(self): + return 1 + +def unwrap(value): + if isinstance(value, Foo): + foo = value + return foo.value() + elif type(value) is tuple: + length = len(value) + if length == 0: + return () + elif length == 1: + return (unwrap(value[0]),) + else: + result = [] + for item in value: + result.append(unwrap(item)) + return tuple(result) + else: + raise TypeError() diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index a3104019514d6..9a1db3bf068df 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -10790,4 +10790,72 @@ pub(crate) mod tests { .is_todo() ); } + + #[test] + fn divergent_type() { + let db = setup_db(); + + let div = Type::Dynamic(DynamicType::Divergent); + + let union = UnionType::from_elements(&db, [Type::unknown(), div]); + assert_eq!(union.display(&db).to_string(), "Unknown | Divergent"); + + let union = UnionType::from_elements(&db, [div, Type::unknown()]); + assert_eq!(union.display(&db).to_string(), "Divergent"); + + let union = UnionType::from_elements(&db, [div, KnownClass::Object.to_instance(&db)]); + assert_eq!(union.display(&db).to_string(), "object | Divergent"); + + let union = UnionType::from_elements(&db, [KnownClass::Object.to_instance(&db), div]); + assert_eq!(union.display(&db).to_string(), "object | Divergent"); + + let union = UnionType::from_elements( + &db, + [ + KnownClass::Object.to_instance(&db), + KnownClass::List.to_specialized_instance(&db, [div]), + ], + ); + assert_eq!(union.display(&db).to_string(), "object | list[Divergent]"); + + let union = UnionType::from_elements( + &db, + [ + KnownClass::Object.to_instance(&db), + KnownClass::List.to_specialized_instance(&db, [div]), + KnownClass::Int.to_instance(&db), + ], + ); + assert_eq!(union.display(&db).to_string(), "object | list[Divergent]"); + + let intersection = IntersectionBuilder::new(&db) + .add_positive(Type::Never) + .add_positive(div) + .build(); + assert_eq!(intersection.display(&db).to_string(), "Never & Divergent"); + + let intersection = IntersectionBuilder::new(&db) + .add_positive(div) + .add_positive(Type::Never) + .build(); + assert_eq!(intersection.display(&db).to_string(), "Divergent & Never"); + + let intersection = IntersectionBuilder::new(&db) + .add_positive(KnownClass::List.to_specialized_instance(&db, [div])) + .add_positive(Type::Never) + .build(); + assert_eq!( + intersection.display(&db).to_string(), + "list[Divergent] & Never" + ); + + let intersection = IntersectionBuilder::new(&db) + .add_positive(KnownClass::Int.to_instance(&db)) + .add_positive(KnownClass::List.to_specialized_instance(&db, [div])) + .build(); + assert_eq!( + intersection.display(&db).to_string(), + "int & list[Divergent]" + ); + } } diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 1bdb85fa5021d..8d2ddd60ada7e 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -40,7 +40,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::type_ordering::union_or_intersection_elements_ordering; use crate::types::{ - BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type, + BytesLiteralType, DynamicType, IntersectionType, KnownClass, StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType, }; use crate::{Db, FxOrderSet}; @@ -241,9 +241,22 @@ impl<'db> UnionBuilder<'db> { /// Collapse the union to a single type: `object`. fn collapse_to_object(&mut self) { + let divergent = self.elements.iter().find_map(|elem| { + let UnionElement::Type(ty) = elem else { + return None; + }; + if ty.has_divergent_type(self.db) { + Some(*ty) + } else { + None + } + }); self.elements.clear(); self.elements .push(UnionElement::Type(Type::object(self.db))); + if let Some(divergent) = divergent { + self.elements.push(UnionElement::Type(divergent)); + } } /// Adds a type to this union. @@ -452,6 +465,15 @@ impl<'db> UnionBuilder<'db> { ty if ty.is_object(self.db) => { self.collapse_to_object(); } + Type::Dynamic(DynamicType::Divergent) => { + if !self + .elements + .iter() + .any(|elem| matches!(elem, UnionElement::Type(elem) if ty == *elem)) + { + self.elements.push(UnionElement::Type(ty)); + } + } _ => { let bool_pair = if let Type::BooleanLiteral(b) = ty { Some(Type::BooleanLiteral(!b)) @@ -471,6 +493,16 @@ impl<'db> UnionBuilder<'db> { Type::Never // won't be used }; + if ty.has_divergent_type(self.db) + && !self + .elements + .iter() + .any(|elem| matches!(elem, UnionElement::Type(elem) if ty == *elem)) + { + self.elements.push(UnionElement::Type(ty)); + return; + } + for (index, element) in self.elements.iter_mut().enumerate() { let element_type = match element.try_reduce(self.db, ty) { ReduceResult::KeepIf(keep) => { @@ -864,7 +896,9 @@ impl<'db> InnerIntersectionBuilder<'db> { self.add_positive(db, Type::LiteralString); self.add_negative(db, Type::string_literal(db, "")); } - + Type::Dynamic(DynamicType::Divergent) => { + self.positive.insert(new_positive); + } _ => { let known_instance = new_positive .into_nominal_instance() @@ -933,6 +967,9 @@ impl<'db> InnerIntersectionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_positive) in self.positive.iter().enumerate() { + if new_positive.has_divergent_type(db) { + break; + } // S & T = S if S <: T if existing_positive.is_subtype_of(db, new_positive) || existing_positive.is_equivalent_to(db, new_positive) @@ -940,11 +977,18 @@ impl<'db> InnerIntersectionBuilder<'db> { return; } // same rule, reverse order - if new_positive.is_subtype_of(db, *existing_positive) { + if new_positive.is_subtype_of(db, *existing_positive) + && !existing_positive.has_divergent_type(db) + { to_remove.push(index); } // A & B = Never if A and B are disjoint - if new_positive.is_disjoint_from(db, *existing_positive) { + if new_positive.is_disjoint_from(db, *existing_positive) + && self + .positive + .iter() + .all(|existing_positive| !existing_positive.has_divergent_type(db)) + { *self = Self::default(); self.positive.insert(Type::Never); return; @@ -956,6 +1000,9 @@ impl<'db> InnerIntersectionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_negative) in self.negative.iter().enumerate() { + if new_positive.has_divergent_type(db) { + break; + } // S & ~T = Never if S <: T if new_positive.is_subtype_of(db, *existing_negative) { *self = Self::default(); From aaa66f3e4db8cd45c49d2e4d6981d81ff5d4418c Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 4 Sep 2025 01:44:10 +0900 Subject: [PATCH 050/105] `Divergent` is not equivalent to other dynamic types --- .../resources/mdtest/function/return_type.md | 9 +++ crates/ty_python_semantic/src/types.rs | 64 +++++++------------ .../ty_python_semantic/src/types/builder.rs | 55 ++-------------- 3 files changed, 35 insertions(+), 93 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index e22e9f407bb6f..50c15f1ed056b 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -352,6 +352,15 @@ def call_divergent(x: int): # TODO: it would be better to reveal `tuple[Divergent | None, int]` reveal_type(call_divergent(1)) # revealed: Divergent +def tuple_obj(cond: bool): + if cond: + x = object() + else: + x = tuple_obj(cond) + return (x,) + +reveal_type(tuple_obj(True)) # revealed: tuple[object] + def get_non_empty(node): for child in node.children: node = get_non_empty(child) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9a1db3bf068df..460fb65f6ee72 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1905,6 +1905,10 @@ impl<'db> Type<'db> { } match (self, other) { + // The `Divergent` type is a special type that is not equivalent to other kinds of dynamic types, + // which prevents `Divergent` from being eliminated during union reduction. + (Type::Dynamic(_), Type::Dynamic(DynamicType::Divergent)) + | (Type::Dynamic(DynamicType::Divergent), Type::Dynamic(_)) => C::unsatisfiable(db), (Type::Dynamic(_), Type::Dynamic(_)) => C::always_satisfiable(db), (Type::SubclassOf(first), Type::SubclassOf(second)) => { @@ -10797,65 +10801,41 @@ pub(crate) mod tests { let div = Type::Dynamic(DynamicType::Divergent); + // The `Divergent` type must not be eliminated in union with other dynamic types, + // as this would prevent detection of divergent type inference using `Divergent`. let union = UnionType::from_elements(&db, [Type::unknown(), div]); assert_eq!(union.display(&db).to_string(), "Unknown | Divergent"); let union = UnionType::from_elements(&db, [div, Type::unknown()]); - assert_eq!(union.display(&db).to_string(), "Divergent"); + assert_eq!(union.display(&db).to_string(), "Divergent | Unknown"); + let union = UnionType::from_elements(&db, [div, Type::unknown(), todo_type!("1")]); + assert_eq!(union.display(&db).to_string(), "Divergent | Unknown"); + + assert!(div.is_equivalent_to(&db, div)); + assert!(!div.is_equivalent_to(&db, Type::unknown())); + assert!(!Type::unknown().is_equivalent_to(&db, div)); + + // The `object` type has a good convergence property, that is, its union with all other types is `object`. + // (e.g. `object | tuple[Divergent] == object`, `object | tuple[object] == object`) + // So we can safely eliminate `Divergent`. let union = UnionType::from_elements(&db, [div, KnownClass::Object.to_instance(&db)]); - assert_eq!(union.display(&db).to_string(), "object | Divergent"); + assert_eq!(union.display(&db).to_string(), "object"); let union = UnionType::from_elements(&db, [KnownClass::Object.to_instance(&db), div]); - assert_eq!(union.display(&db).to_string(), "object | Divergent"); - - let union = UnionType::from_elements( - &db, - [ - KnownClass::Object.to_instance(&db), - KnownClass::List.to_specialized_instance(&db, [div]), - ], - ); - assert_eq!(union.display(&db).to_string(), "object | list[Divergent]"); - - let union = UnionType::from_elements( - &db, - [ - KnownClass::Object.to_instance(&db), - KnownClass::List.to_specialized_instance(&db, [div]), - KnownClass::Int.to_instance(&db), - ], - ); - assert_eq!(union.display(&db).to_string(), "object | list[Divergent]"); + assert_eq!(union.display(&db).to_string(), "object"); + // The same can be said about intersections for the `Never` type. let intersection = IntersectionBuilder::new(&db) .add_positive(Type::Never) .add_positive(div) .build(); - assert_eq!(intersection.display(&db).to_string(), "Never & Divergent"); + assert_eq!(intersection.display(&db).to_string(), "Never"); let intersection = IntersectionBuilder::new(&db) .add_positive(div) .add_positive(Type::Never) .build(); - assert_eq!(intersection.display(&db).to_string(), "Divergent & Never"); - - let intersection = IntersectionBuilder::new(&db) - .add_positive(KnownClass::List.to_specialized_instance(&db, [div])) - .add_positive(Type::Never) - .build(); - assert_eq!( - intersection.display(&db).to_string(), - "list[Divergent] & Never" - ); - - let intersection = IntersectionBuilder::new(&db) - .add_positive(KnownClass::Int.to_instance(&db)) - .add_positive(KnownClass::List.to_specialized_instance(&db, [div])) - .build(); - assert_eq!( - intersection.display(&db).to_string(), - "int & list[Divergent]" - ); + assert_eq!(intersection.display(&db).to_string(), "Never"); } } diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 8d2ddd60ada7e..1bdb85fa5021d 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -40,7 +40,7 @@ use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::type_ordering::union_or_intersection_elements_ordering; use crate::types::{ - BytesLiteralType, DynamicType, IntersectionType, KnownClass, StringLiteralType, Type, + BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type, TypeVarBoundOrConstraints, UnionType, }; use crate::{Db, FxOrderSet}; @@ -241,22 +241,9 @@ impl<'db> UnionBuilder<'db> { /// Collapse the union to a single type: `object`. fn collapse_to_object(&mut self) { - let divergent = self.elements.iter().find_map(|elem| { - let UnionElement::Type(ty) = elem else { - return None; - }; - if ty.has_divergent_type(self.db) { - Some(*ty) - } else { - None - } - }); self.elements.clear(); self.elements .push(UnionElement::Type(Type::object(self.db))); - if let Some(divergent) = divergent { - self.elements.push(UnionElement::Type(divergent)); - } } /// Adds a type to this union. @@ -465,15 +452,6 @@ impl<'db> UnionBuilder<'db> { ty if ty.is_object(self.db) => { self.collapse_to_object(); } - Type::Dynamic(DynamicType::Divergent) => { - if !self - .elements - .iter() - .any(|elem| matches!(elem, UnionElement::Type(elem) if ty == *elem)) - { - self.elements.push(UnionElement::Type(ty)); - } - } _ => { let bool_pair = if let Type::BooleanLiteral(b) = ty { Some(Type::BooleanLiteral(!b)) @@ -493,16 +471,6 @@ impl<'db> UnionBuilder<'db> { Type::Never // won't be used }; - if ty.has_divergent_type(self.db) - && !self - .elements - .iter() - .any(|elem| matches!(elem, UnionElement::Type(elem) if ty == *elem)) - { - self.elements.push(UnionElement::Type(ty)); - return; - } - for (index, element) in self.elements.iter_mut().enumerate() { let element_type = match element.try_reduce(self.db, ty) { ReduceResult::KeepIf(keep) => { @@ -896,9 +864,7 @@ impl<'db> InnerIntersectionBuilder<'db> { self.add_positive(db, Type::LiteralString); self.add_negative(db, Type::string_literal(db, "")); } - Type::Dynamic(DynamicType::Divergent) => { - self.positive.insert(new_positive); - } + _ => { let known_instance = new_positive .into_nominal_instance() @@ -967,9 +933,6 @@ impl<'db> InnerIntersectionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_positive) in self.positive.iter().enumerate() { - if new_positive.has_divergent_type(db) { - break; - } // S & T = S if S <: T if existing_positive.is_subtype_of(db, new_positive) || existing_positive.is_equivalent_to(db, new_positive) @@ -977,18 +940,11 @@ impl<'db> InnerIntersectionBuilder<'db> { return; } // same rule, reverse order - if new_positive.is_subtype_of(db, *existing_positive) - && !existing_positive.has_divergent_type(db) - { + if new_positive.is_subtype_of(db, *existing_positive) { to_remove.push(index); } // A & B = Never if A and B are disjoint - if new_positive.is_disjoint_from(db, *existing_positive) - && self - .positive - .iter() - .all(|existing_positive| !existing_positive.has_divergent_type(db)) - { + if new_positive.is_disjoint_from(db, *existing_positive) { *self = Self::default(); self.positive.insert(Type::Never); return; @@ -1000,9 +956,6 @@ impl<'db> InnerIntersectionBuilder<'db> { let mut to_remove = SmallVec::<[usize; 1]>::new(); for (index, existing_negative) in self.negative.iter().enumerate() { - if new_positive.has_divergent_type(db) { - break; - } // S & ~T = Never if S <: T if new_positive.is_subtype_of(db, *existing_negative) { *self = Self::default(); From e3b896df046afbea1674c302617f44ac2ea9cda8 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 8 Sep 2025 11:06:55 +0900 Subject: [PATCH 051/105] add `DivergentType` to provide info about where the divergence occurs --- .../resources/mdtest/function/return_type.md | 3 +- .../ty_python_semantic/src/semantic_model.rs | 2 +- crates/ty_python_semantic/src/types.rs | 104 +++++++++++------- crates/ty_python_semantic/src/types/class.rs | 3 +- .../src/types/class_base.rs | 2 +- .../ty_python_semantic/src/types/function.rs | 19 +++- .../ty_python_semantic/src/types/generics.rs | 5 +- crates/ty_python_semantic/src/types/infer.rs | 75 +++++++------ .../ty_python_semantic/src/types/instance.rs | 3 +- .../src/types/signatures.rs | 11 +- .../src/types/subclass_of.rs | 5 +- crates/ty_python_semantic/src/types/tuple.rs | 3 +- .../src/types/type_ordering.rs | 8 +- 13 files changed, 148 insertions(+), 95 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 50c15f1ed056b..d7e4e6db27d35 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -349,8 +349,7 @@ reveal_type(divergent((1,))) # revealed: None | Divergent def call_divergent(x: int): return (divergent((1, 2, 3)), x) -# TODO: it would be better to reveal `tuple[Divergent | None, int]` -reveal_type(call_divergent(1)) # revealed: Divergent +reveal_type(call_divergent(1)) # revealed: tuple[None | Divergent, int] def tuple_obj(cond: bool): if cond: diff --git a/crates/ty_python_semantic/src/semantic_model.rs b/crates/ty_python_semantic/src/semantic_model.rs index fc6bf60c4bbba..467dcbf62c2bc 100644 --- a/crates/ty_python_semantic/src/semantic_model.rs +++ b/crates/ty_python_semantic/src/semantic_model.rs @@ -423,7 +423,7 @@ impl HasType for ast::ExprRef<'_> { let file_scope = index.expression_scope_id(self); let scope = file_scope.to_scope_id(model.db, model.file); - infer_scope_types(model.db, scope).expression_type(*self) + infer_scope_types(model.db, scope).expression_type(model.db, *self) } } diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 0ea04b513784f..b2d10fdeac98c 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -34,7 +34,7 @@ use crate::place::{Boundness, Place, PlaceAndQualifiers, imported_symbol}; use crate::semantic_index::definition::Definition; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::ScopeId; -use crate::semantic_index::{imported_modules, place_table, semantic_index}; +use crate::semantic_index::{FileScopeId, imported_modules, place_table, semantic_index}; use crate::suppression::check_suppressions; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; @@ -114,8 +114,16 @@ fn return_type_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -fn return_type_cycle_initial<'db>(_db: &'db dyn Db, _self: BoundMethodType<'db>) -> Type<'db> { - Type::Dynamic(DynamicType::Divergent) +fn return_type_cycle_initial<'db>(db: &'db dyn Db, method: BoundMethodType<'db>) -> Type<'db> { + Type::Dynamic(DynamicType::Divergent(DivergentType { + file: method.function(db).file(db), + file_scope: method + .function(db) + .literal(db) + .last_definition(db) + .body_scope(db) + .file_scope_id(db), + })) } pub fn check_types(db: &dyn Db, file: File) -> Vec { @@ -186,7 +194,7 @@ fn definition_expression_type<'db>( } } else { // expression is in a type-params sub-scope - infer_scope_types(db, scope).expression_type(expression) + infer_scope_types(db, scope).expression_type(db, expression) } } @@ -556,13 +564,14 @@ impl<'db> PropertyInstanceType<'db> { fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.setter(db) - .is_some_and(|setter| setter.has_divergent_type_impl(db, visitor)) + .is_some_and(|setter| setter.has_divergent_type_impl(db, div, visitor)) || self .getter(db) - .is_some_and(|getter| getter.has_divergent_type_impl(db, visitor)) + .is_some_and(|getter| getter.has_divergent_type_impl(db, div, visitor)) } } @@ -1917,8 +1926,8 @@ impl<'db> Type<'db> { match (self, other) { // The `Divergent` type is a special type that is not equivalent to other kinds of dynamic types, // which prevents `Divergent` from being eliminated during union reduction. - (Type::Dynamic(_), Type::Dynamic(DynamicType::Divergent)) - | (Type::Dynamic(DynamicType::Divergent), Type::Dynamic(_)) => C::unsatisfiable(db), + (Type::Dynamic(_), Type::Dynamic(DynamicType::Divergent(_))) + | (Type::Dynamic(DynamicType::Divergent(_)), Type::Dynamic(_)) => C::unsatisfiable(db), (Type::Dynamic(_), Type::Dynamic(_)) => C::always_satisfiable(db), (Type::SubclassOf(first), Type::SubclassOf(second)) => { @@ -6576,51 +6585,54 @@ impl<'db> Type<'db> { } } - pub(super) fn has_divergent_type(self, db: &'db dyn Db) -> bool { + pub(super) fn has_divergent_type(self, db: &'db dyn Db, div: Type<'db>) -> bool { let visitor = HasDivergentTypeVisitor::new(false); - self.has_divergent_type_impl(db, &visitor) + self.has_divergent_type_impl(db, div, &visitor) } fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { match self { - Type::Dynamic(DynamicType::Divergent) => true, + Type::Dynamic(DynamicType::Divergent(_)) => self == div, Type::Union(union) => { - visitor.visit(self, || union.has_divergent_type_impl(db, visitor)) - } - Type::Intersection(intersection) => { - visitor.visit(self, || intersection.has_divergent_type_impl(db, visitor)) + visitor.visit(self, || union.has_divergent_type_impl(db, div, visitor)) } + Type::Intersection(intersection) => visitor.visit(self, || { + intersection.has_divergent_type_impl(db, div, visitor) + }), Type::GenericAlias(alias) => visitor.visit(self, || { alias .specialization(db) - .has_divergent_type_impl(db, visitor) + .has_divergent_type_impl(db, div, visitor) }), Type::NominalInstance(instance) => visitor.visit(self, || { - instance.class(db).has_divergent_type_impl(db, visitor) + instance.class(db).has_divergent_type_impl(db, div, visitor) }), Type::Callable(callable) => { - visitor.visit(self, || callable.has_divergent_type_impl(db, visitor)) + visitor.visit(self, || callable.has_divergent_type_impl(db, div, visitor)) } Type::ProtocolInstance(protocol) => { - visitor.visit(self, || protocol.has_divergent_type_impl(db, visitor)) + visitor.visit(self, || protocol.has_divergent_type_impl(db, div, visitor)) } Type::PropertyInstance(property) => { - visitor.visit(self, || property.has_divergent_type_impl(db, visitor)) + visitor.visit(self, || property.has_divergent_type_impl(db, div, visitor)) } Type::TypeIs(type_is) => visitor.visit(self, || { - type_is.return_type(db).has_divergent_type_impl(db, visitor) + type_is + .return_type(db) + .has_divergent_type_impl(db, div, visitor) + }), + Type::SubclassOf(subclass_of) => visitor.visit(self, || { + subclass_of.has_divergent_type_impl(db, div, visitor) }), - Type::SubclassOf(subclass_of) => { - visitor.visit(self, || subclass_of.has_divergent_type_impl(db, visitor)) - } Type::TypedDict(typed_dict) => visitor.visit(self, || { typed_dict .defining_class() - .has_divergent_type_impl(db, visitor) + .has_divergent_type_impl(db, div, visitor) }), Type::Never | Type::AlwaysTruthy @@ -7012,6 +7024,19 @@ impl<'db> KnownInstanceType<'db> { } } +/// A type that is determined to be divergent during type inference for a recursive function. +/// This type must never be eliminated by dynamic type reduction +/// (e.g. `Divergent` is assignable to `@Todo`, but `@Todo | Divergent` must not be reducted to `@Todo`). +/// Otherwise, type inference cannot converge properly. +/// For detailed properties of this type, see the unit test at the end of the file. +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)] +pub struct DivergentType { + /// The source file where this divergence was detected. + file: File, + /// The file scope where this divergence was detected. + file_scope: FileScopeId, +} + #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)] pub enum DynamicType { /// An explicitly annotated `typing.Any` @@ -7037,15 +7062,12 @@ pub enum DynamicType { /// A special Todo-variant for `Unpack[Ts]`, so that we can treat it specially in `Generic[Unpack[Ts]]` TodoUnpack, /// A type that is determined to be divergent during type inference for a recursive function. - /// This type must never be eliminated by reduction - /// (e.g. `Divergent` is assignable to `@Todo`, but `@Todo | Divergent` must not be reducted to `@Todo`). - /// Otherwise, type inference cannot converge properly. - Divergent, + Divergent(DivergentType), } impl DynamicType { fn normalized(self) -> Self { - if matches!(self, Self::Divergent) { + if matches!(self, Self::Divergent(_)) { self } else { Self::Any @@ -7082,7 +7104,7 @@ impl std::fmt::Display for DynamicType { f.write_str("@Todo") } } - DynamicType::Divergent => f.write_str("Divergent"), + DynamicType::Divergent(_) => f.write_str("Divergent"), } } } @@ -9254,9 +9276,11 @@ impl<'db> CallableType<'db> { fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { - self.signatures(db).has_divergent_type_impl(db, visitor) + self.signatures(db) + .has_divergent_type_impl(db, div, visitor) } } @@ -9933,11 +9957,12 @@ impl<'db> UnionType<'db> { fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.elements(db) .iter() - .any(|ty| ty.has_divergent_type_impl(db, visitor)) + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) } } @@ -10158,15 +10183,16 @@ impl<'db> IntersectionType<'db> { fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.positive(db) .iter() - .any(|ty| ty.has_divergent_type_impl(db, visitor)) + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) || self .negative(db) .iter() - .any(|ty| ty.has_divergent_type_impl(db, visitor)) + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) } } @@ -10799,9 +10825,13 @@ pub(crate) mod tests { #[test] fn divergent_type() { - let db = setup_db(); + let mut db = setup_db(); + + db.write_dedented("src/foo.py", "").unwrap(); + let file = system_path_to_file(&db, "src/foo.py").unwrap(); + let file_scope = FileScopeId::global(); - let div = Type::Dynamic(DynamicType::Divergent); + let div = Type::Dynamic(DynamicType::Divergent(DivergentType { file, file_scope })); // The `Divergent` type must not be eliminated in union with other dynamic types, // as this would prevent detection of divergent type inference using `Divergent`. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index be32fe8453937..91520c4c76f62 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1229,13 +1229,14 @@ impl<'db> ClassType<'db> { pub(super) fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { match self { ClassType::NonGeneric(_) => false, ClassType::Generic(generic) => generic .specialization(db) - .has_divergent_type_impl(db, visitor), + .has_divergent_type_impl(db, div, visitor), } } diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 8591740cdd964..2ebfd1f358d19 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -54,7 +54,7 @@ impl<'db> ClassBase<'db> { | DynamicType::TodoTypeAlias | DynamicType::TodoUnpack, ) => "@Todo", - ClassBase::Dynamic(DynamicType::Divergent) => "Divergent", + ClassBase::Dynamic(DynamicType::Divergent(_)) => "Divergent", ClassBase::Protocol => "Protocol", ClassBase::Generic => "Generic", ClassBase::TypedDict => "TypedDict", diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 6562a681597d6..1d1ee68c6d77d 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -78,10 +78,10 @@ use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, - DeprecatedInstance, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsEquivalentVisitor, KnownClass, NormalizedVisitor, SpecialFormType, Truthiness, Type, - TypeMapping, TypeRelation, UnionBuilder, all_members, infer_scope_types, walk_generic_context, - walk_type_mapping, + DeprecatedInstance, DivergentType, DynamicType, FindLegacyTypeVarsVisitor, + HasRelationToVisitor, IsEquivalentVisitor, KnownClass, NormalizedVisitor, SpecialFormType, + Truthiness, Type, TypeMapping, TypeRelation, UnionBuilder, all_members, infer_scope_types, + walk_generic_context, walk_type_mapping, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; @@ -94,8 +94,15 @@ fn return_type_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -fn return_type_cycle_initial<'db>(_db: &'db dyn Db, _self: FunctionType<'db>) -> Type<'db> { - Type::Dynamic(DynamicType::Divergent) +fn return_type_cycle_initial<'db>(db: &'db dyn Db, function: FunctionType<'db>) -> Type<'db> { + Type::Dynamic(DynamicType::Divergent(DivergentType { + file: function.file(db), + file_scope: function + .literal(db) + .last_definition(db) + .body_scope(db) + .file_scope_id(db), + })) } /// A collection of useful spans for annotating functions. diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 04b6f655e9288..fb62807f432e7 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -877,14 +877,15 @@ impl<'db> Specialization<'db> { pub(crate) fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.types(db) .iter() - .any(|ty| ty.has_divergent_type_impl(db, visitor)) + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) || self .tuple_inner(db) - .is_some_and(|tuple| tuple.has_divergent_type_impl(db, visitor)) + .is_some_and(|tuple| tuple.has_divergent_type_impl(db, div, visitor)) } } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index d13f34a2dabd8..ed8d9f7086f9a 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -127,10 +127,10 @@ use crate::types::typed_dict::{ }; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType, - IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, LintDiagnosticGuard, - MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, - Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType, + CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DivergentType, + DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, + LintDiagnosticGuard, MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, + ParameterForm, Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, @@ -475,6 +475,8 @@ pub(crate) struct ScopeInference<'db> { /// The extra data that is only present for few inference regions. extra: Option>, + + scope: ScopeId<'db>, } #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] @@ -495,14 +497,13 @@ struct ScopeInferenceExtra { impl<'db> ScopeInference<'db> { fn cycle_fallback(scope: ScopeId<'db>) -> Self { - let _ = scope; - Self { extra: Some(Box::new(ScopeInferenceExtra { cycle_fallback: true, ..ScopeInferenceExtra::default() })), expressions: FxHashMap::default(), + scope, } } @@ -510,19 +511,24 @@ impl<'db> ScopeInference<'db> { self.extra.as_deref().map(|extra| &extra.diagnostics) } - pub(crate) fn expression_type(&self, expression: impl Into) -> Type<'db> { - self.try_expression_type(expression) + pub(crate) fn expression_type( + &self, + db: &'db dyn Db, + expression: impl Into, + ) -> Type<'db> { + self.try_expression_type(db, expression) .unwrap_or_else(Type::unknown) } pub(crate) fn try_expression_type( &self, + db: &'db dyn Db, expression: impl Into, ) -> Option> { self.expressions .get(&expression.into()) .copied() - .or_else(|| self.fallback_type()) + .or_else(|| self.fallback_type(db)) } fn is_cycle_callback(&self) -> bool { @@ -531,9 +537,12 @@ impl<'db> ScopeInference<'db> { .is_some_and(|extra| extra.cycle_fallback) } - fn fallback_type(&self) -> Option> { + fn fallback_type(&self, db: &'db dyn Db) -> Option> { self.is_cycle_callback() - .then_some(Type::Dynamic(DynamicType::Divergent)) + .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { + file: self.scope.file(db), + file_scope: self.scope.file_scope_id(db), + }))) } /// Returns the inferred return type of this function body (union of all possible return types), @@ -541,26 +550,18 @@ impl<'db> ScopeInference<'db> { /// In the case of methods, the return type of the superclass method is further unioned. /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. pub(crate) fn infer_return_type(&self, db: &'db dyn Db, callee_ty: Type<'db>) -> Type<'db> { - let scope = match callee_ty { - Type::FunctionLiteral(function) => { - function.literal(db).last_definition(db).body_scope(db) - } - Type::BoundMethod(method) => method - .function(db) - .literal(db) - .last_definition(db) - .body_scope(db), - _ => return Type::none(db), - }; // TODO: coroutine function type inference // TODO: generator function type inference - if scope.is_coroutine_function(db) || scope.is_generator_function(db) { + if self.scope.is_coroutine_function(db) || self.scope.is_generator_function(db) { return Type::unknown(); } let mut union = UnionBuilder::new(db); - let div = Type::Dynamic(DynamicType::Divergent); - if let Some(fallback_type) = self.fallback_type() { + let div = Type::Dynamic(DynamicType::Divergent(DivergentType { + file: self.scope.file(db), + file_scope: self.scope.file_scope_id(db), + })); + if let Some(fallback_type) = self.fallback_type(db) { union = union.add(fallback_type); } // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. @@ -582,10 +583,10 @@ impl<'db> ScopeInference<'db> { let mut union_add = |ty: Type<'db>| { if ty == div { // `Divergent` appearing in a union does not mean true divergence, so it can be removed. - } else if ty.has_divergent_type(db) { + } else if ty.has_divergent_type(db, div) { if let Type::Union(union_ty) = ty { let union_ty = union_ty.filter(db, |ty| **ty != div); - if union_ty.has_divergent_type(db) { + if union_ty.has_divergent_type(db, div) { union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(div); } else { union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(union_ty); @@ -609,11 +610,11 @@ impl<'db> ScopeInference<'db> { }; for returnee in &extra.returnees { let ty = returnee.map_or(Type::none(db), |expression| { - self.expression_type(expression) + self.expression_type(db, expression) }); union_add(ty); } - let use_def = use_def_map(db, scope); + let use_def = use_def_map(db, self.scope); if use_def.can_implicitly_return_none(db) { union_add(Type::none(db)); } @@ -1176,7 +1177,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { InferenceRegion::Scope(scope) if scope == expr_scope => { self.expression_type(expression) } - _ => infer_scope_types(self.db(), expr_scope).expression_type(expression), + _ => infer_scope_types(self.db(), expr_scope).expression_type(self.db(), expression), } } @@ -7688,8 +7689,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Non-todo Anys take precedence over Todos (as if we fix this `Todo` in the future, // the result would then become Any or Unknown, respectively). - (div @ Type::Dynamic(DynamicType::Divergent), _, _) - | (_, div @ Type::Dynamic(DynamicType::Divergent), _) => Some(div), + (div @ Type::Dynamic(DynamicType::Divergent(_)), _, _) + | (_, div @ Type::Dynamic(DynamicType::Divergent(_)), _) => Some(div), (any @ Type::Dynamic(DynamicType::Any), _, _) | (_, any @ Type::Dynamic(DynamicType::Any), _) => Some(any), @@ -8626,7 +8627,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); match eq_result { - todo @ Type::Dynamic(DynamicType::Todo(_) | DynamicType::Divergent) => return Ok(todo), + todo @ Type::Dynamic(DynamicType::Todo(_) | DynamicType::Divergent(_)) => return Ok(todo), // It's okay to ignore errors here because Python doesn't call `__bool__` // for different union variants. Instead, this is just for us to // evaluate a possibly truthy value to `false` or `true`. @@ -8654,7 +8655,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); Ok(match eq_result { - todo @ Type::Dynamic(DynamicType::Todo(_) | DynamicType::Divergent) => todo, + todo @ Type::Dynamic(DynamicType::Todo(_) | DynamicType::Divergent(_)) => todo, // It's okay to ignore errors here because Python doesn't call `__bool__` // for `is` and `is not` comparisons. This is an implementation detail // for how we determine the truthiness of a type. @@ -9663,7 +9664,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { expressions.shrink_to_fit(); - ScopeInference { expressions, extra } + ScopeInference { + expressions, + extra, + scope, + } } } diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index c5e4067c1f5ac..49c042286def1 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -642,12 +642,13 @@ impl<'db> ProtocolInstanceType<'db> { pub(super) fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.inner .interface(db) .members(db) - .any(|member| member.ty().has_divergent_type_impl(db, visitor)) + .any(|member| member.ty().has_divergent_type_impl(db, div, visitor)) } } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 758f4c4f8dea4..edc3bece02534 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -239,11 +239,12 @@ impl<'db> CallableSignature<'db> { pub(super) fn has_divergent_type_impl( &self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.overloads .iter() - .any(|signature| signature.has_divergent_type_impl(db, visitor)) + .any(|signature| signature.has_divergent_type_impl(db, div, visitor)) } } @@ -1031,11 +1032,12 @@ impl<'db> Signature<'db> { fn has_divergent_type_impl( &self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.return_ty - .is_some_and(|return_ty| return_ty.has_divergent_type_impl(db, visitor)) - || self.parameters.has_divergent_type_impl(db, visitor) + .is_some_and(|return_ty| return_ty.has_divergent_type_impl(db, div, visitor)) + || self.parameters.has_divergent_type_impl(db, div, visitor) } } @@ -1344,12 +1346,13 @@ impl<'db> Parameters<'db> { fn has_divergent_type_impl( &self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.iter().any(|parameter| { parameter .annotated_type() - .is_some_and(|ty| ty.has_divergent_type_impl(db, visitor)) + .is_some_and(|ty| ty.has_divergent_type_impl(db, div, visitor)) }) } } diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 33bafcca8dd4d..38b97bf056406 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -205,12 +205,13 @@ impl<'db> SubclassOfType<'db> { pub(super) fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { match self.subclass_of { - SubclassOfInner::Dynamic(DynamicType::Divergent) => true, + SubclassOfInner::Dynamic(d @ DynamicType::Divergent(_)) => Type::Dynamic(d) == div, SubclassOfInner::Dynamic(_) => false, - SubclassOfInner::Class(class) => class.has_divergent_type_impl(db, visitor), + SubclassOfInner::Class(class) => class.has_divergent_type_impl(db, div, visitor), } } } diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 15fa14d0b48f1..59a1a7720d0cf 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -289,11 +289,12 @@ impl<'db> TupleType<'db> { pub(super) fn has_divergent_type_impl( self, db: &'db dyn Db, + div: Type<'db>, visitor: &HasDivergentTypeVisitor<'db>, ) -> bool { self.tuple(db) .all_elements() - .any(|ty| ty.has_divergent_type_impl(db, visitor)) + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) } } diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 0b137cfc53629..7be4fb51650cc 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -276,8 +276,12 @@ fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering (DynamicType::TodoTypeAlias, _) => Ordering::Less, (_, DynamicType::TodoTypeAlias) => Ordering::Greater, - (DynamicType::Divergent, _) => Ordering::Less, - (_, DynamicType::Divergent) => Ordering::Greater, + (DynamicType::Divergent(left), DynamicType::Divergent(right)) => left + .file + .cmp(&right.file) + .then_with(|| left.file_scope.index().cmp(&right.file_scope.index())), + (DynamicType::Divergent(_), _) => Ordering::Less, + (_, DynamicType::Divergent(_)) => Ordering::Greater, } } From e1a46546a34599f02482976d5cc252a366d30097 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 8 Sep 2025 12:20:53 +0900 Subject: [PATCH 052/105] `DivergentType` has a `ScopeId` --- .../src/semantic_index/scope.rs | 1 + .../ty_python_semantic/src/semantic_model.rs | 2 +- crates/ty_python_semantic/src/types.rs | 44 +++++++++---------- .../src/types/class_base.rs | 2 +- .../ty_python_semantic/src/types/function.rs | 7 +-- crates/ty_python_semantic/src/types/infer.rs | 27 ++++-------- .../src/types/subclass_of.rs | 8 ++-- .../src/types/type_ordering.rs | 7 ++- 8 files changed, 41 insertions(+), 57 deletions(-) diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index 09804eb72ab4d..423aae345560c 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -18,6 +18,7 @@ use crate::{ /// A cross-module identifier of a scope that can be used as a salsa query parameter. #[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] +#[derive(Ord, PartialOrd)] pub struct ScopeId<'db> { pub file: File, diff --git a/crates/ty_python_semantic/src/semantic_model.rs b/crates/ty_python_semantic/src/semantic_model.rs index 467dcbf62c2bc..fc6bf60c4bbba 100644 --- a/crates/ty_python_semantic/src/semantic_model.rs +++ b/crates/ty_python_semantic/src/semantic_model.rs @@ -423,7 +423,7 @@ impl HasType for ast::ExprRef<'_> { let file_scope = index.expression_scope_id(self); let scope = file_scope.to_scope_id(model.db, model.file); - infer_scope_types(model.db, scope).expression_type(model.db, *self) + infer_scope_types(model.db, scope).expression_type(*self) } } diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 541d368d36431..770584cbe79a4 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -34,7 +34,7 @@ use crate::place::{Boundness, Place, PlaceAndQualifiers, imported_symbol}; use crate::semantic_index::definition::Definition; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::ScopeId; -use crate::semantic_index::{FileScopeId, imported_modules, place_table, semantic_index}; +use crate::semantic_index::{imported_modules, place_table, semantic_index}; use crate::suppression::check_suppressions; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; @@ -116,13 +116,11 @@ fn return_type_cycle_recover<'db>( fn return_type_cycle_initial<'db>(db: &'db dyn Db, method: BoundMethodType<'db>) -> Type<'db> { Type::Dynamic(DynamicType::Divergent(DivergentType { - file: method.function(db).file(db), - file_scope: method + scope: method .function(db) .literal(db) .last_definition(db) - .body_scope(db) - .file_scope_id(db), + .body_scope(db), })) } @@ -194,7 +192,7 @@ fn definition_expression_type<'db>( } } else { // expression is in a type-params sub-scope - infer_scope_types(db, scope).expression_type(db, expression) + infer_scope_types(db, scope).expression_type(expression) } } @@ -633,7 +631,7 @@ impl From for DataclassParams { #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] pub enum Type<'db> { /// The dynamic type: a statically unknown set of values - Dynamic(DynamicType), + Dynamic(DynamicType<'db>), /// The empty set of values Never, /// A specific function object @@ -881,14 +879,14 @@ impl<'db> Type<'db> { } } - pub(crate) const fn into_dynamic(self) -> Option { + pub(crate) const fn into_dynamic(self) -> Option> { match self { Type::Dynamic(dynamic_type) => Some(dynamic_type), _ => None, } } - pub(crate) const fn expect_dynamic(self) -> DynamicType { + pub(crate) const fn expect_dynamic(self) -> DynamicType<'db> { self.into_dynamic() .expect("Expected a Type::Dynamic variant") } @@ -7029,16 +7027,14 @@ impl<'db> KnownInstanceType<'db> { /// (e.g. `Divergent` is assignable to `@Todo`, but `@Todo | Divergent` must not be reducted to `@Todo`). /// Otherwise, type inference cannot converge properly. /// For detailed properties of this type, see the unit test at the end of the file. -#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)] -pub struct DivergentType { - /// The source file where this divergence was detected. - file: File, - /// The file scope where this divergence was detected. - file_scope: FileScopeId, +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] +pub struct DivergentType<'db> { + /// The scope where this divergence was detected. + scope: ScopeId<'db>, } -#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, get_size2::GetSize)] -pub enum DynamicType { +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] +pub enum DynamicType<'db> { /// An explicitly annotated `typing.Any` Any, /// An unannotated value, or a dynamic type resulting from an error @@ -7062,10 +7058,10 @@ pub enum DynamicType { /// A special Todo-variant for `Unpack[Ts]`, so that we can treat it specially in `Generic[Unpack[Ts]]` TodoUnpack, /// A type that is determined to be divergent during type inference for a recursive function. - Divergent(DivergentType), + Divergent(DivergentType<'db>), } -impl DynamicType { +impl DynamicType<'_> { fn normalized(self) -> Self { if matches!(self, Self::Divergent(_)) { self @@ -7075,7 +7071,7 @@ impl DynamicType { } } -impl std::fmt::Display for DynamicType { +impl std::fmt::Display for DynamicType<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { DynamicType::Any => f.write_str("Any"), @@ -10322,7 +10318,7 @@ impl BoundSuperError<'_> { #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, get_size2::GetSize)] pub enum SuperOwnerKind<'db> { - Dynamic(DynamicType), + Dynamic(DynamicType<'db>), Class(ClassType<'db>), Instance(NominalInstanceType<'db>), } @@ -10688,6 +10684,7 @@ pub(crate) mod tests { use super::*; use crate::db::tests::{TestDbBuilder, setup_db}; use crate::place::{global_symbol, typing_extensions_symbol, typing_symbol}; + use crate::semantic_index::FileScopeId; use ruff_db::files::system_path_to_file; use ruff_db::parsed::parsed_module; use ruff_db::system::DbWithWritableSystem as _; @@ -10840,9 +10837,10 @@ pub(crate) mod tests { db.write_dedented("src/foo.py", "").unwrap(); let file = system_path_to_file(&db, "src/foo.py").unwrap(); - let file_scope = FileScopeId::global(); + let file_scope_id = FileScopeId::global(); + let scope = file_scope_id.to_scope_id(&db, file); - let div = Type::Dynamic(DynamicType::Divergent(DivergentType { file, file_scope })); + let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope })); // The `Divergent` type must not be eliminated in union with other dynamic types, // as this would prevent detection of divergent type inference using `Divergent`. diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 23863a4c4eb47..c93aafb472c61 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -18,7 +18,7 @@ use crate::types::{ /// automatically construct the default specialization for that class. #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub enum ClassBase<'db> { - Dynamic(DynamicType), + Dynamic(DynamicType<'db>), Class(ClassType<'db>), /// Although `Protocol` is not a class in typeshed's stubs, it is at runtime, /// and can appear in the MRO of a class. diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index f1ff96a048017..d82256778fd02 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -96,12 +96,7 @@ fn return_type_cycle_recover<'db>( fn return_type_cycle_initial<'db>(db: &'db dyn Db, function: FunctionType<'db>) -> Type<'db> { Type::Dynamic(DynamicType::Divergent(DivergentType { - file: function.file(db), - file_scope: function - .literal(db) - .last_definition(db) - .body_scope(db) - .file_scope_id(db), + scope: function.literal(db).last_definition(db).body_scope(db), })) } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index e196d1af4f290..d800d715590dd 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -511,24 +511,19 @@ impl<'db> ScopeInference<'db> { self.extra.as_deref().map(|extra| &extra.diagnostics) } - pub(crate) fn expression_type( - &self, - db: &'db dyn Db, - expression: impl Into, - ) -> Type<'db> { - self.try_expression_type(db, expression) + pub(crate) fn expression_type(&self, expression: impl Into) -> Type<'db> { + self.try_expression_type(expression) .unwrap_or_else(Type::unknown) } pub(crate) fn try_expression_type( &self, - db: &'db dyn Db, expression: impl Into, ) -> Option> { self.expressions .get(&expression.into()) .copied() - .or_else(|| self.fallback_type(db)) + .or_else(|| self.fallback_type()) } fn is_cycle_callback(&self) -> bool { @@ -537,11 +532,10 @@ impl<'db> ScopeInference<'db> { .is_some_and(|extra| extra.cycle_fallback) } - fn fallback_type(&self, db: &'db dyn Db) -> Option> { + fn fallback_type(&self) -> Option> { self.is_cycle_callback() .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { - file: self.scope.file(db), - file_scope: self.scope.file_scope_id(db), + scope: self.scope, }))) } @@ -557,11 +551,8 @@ impl<'db> ScopeInference<'db> { } let mut union = UnionBuilder::new(db); - let div = Type::Dynamic(DynamicType::Divergent(DivergentType { - file: self.scope.file(db), - file_scope: self.scope.file_scope_id(db), - })); - if let Some(fallback_type) = self.fallback_type(db) { + let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope: self.scope })); + if let Some(fallback_type) = self.fallback_type() { union = union.add(fallback_type); } // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. @@ -610,7 +601,7 @@ impl<'db> ScopeInference<'db> { }; for returnee in &extra.returnees { let ty = returnee.map_or(Type::none(db), |expression| { - self.expression_type(db, expression) + self.expression_type(expression) }); union_add(ty); } @@ -1177,7 +1168,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { InferenceRegion::Scope(scope) if scope == expr_scope => { self.expression_type(expression) } - _ => infer_scope_types(self.db(), expr_scope).expression_type(self.db(), expression), + _ => infer_scope_types(self.db(), expr_scope).expression_type(expression), } } diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index fa2d64fcab4ac..1328ea0b856a8 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -234,7 +234,7 @@ impl<'db> VarianceInferable<'db> for SubclassOfType<'db> { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] pub(crate) enum SubclassOfInner<'db> { Class(ClassType<'db>), - Dynamic(DynamicType), + Dynamic(DynamicType<'db>), } impl<'db> SubclassOfInner<'db> { @@ -253,7 +253,7 @@ impl<'db> SubclassOfInner<'db> { } } - pub(crate) const fn into_dynamic(self) -> Option { + pub(crate) const fn into_dynamic(self) -> Option> { match self { Self::Class(_) => None, Self::Dynamic(dynamic) => Some(dynamic), @@ -284,8 +284,8 @@ impl<'db> From> for SubclassOfInner<'db> { } } -impl From for SubclassOfInner<'_> { - fn from(value: DynamicType) -> Self { +impl<'db> From> for SubclassOfInner<'db> { + fn from(value: DynamicType<'db>) -> Self { SubclassOfInner::Dynamic(value) } } diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 7be4fb51650cc..e18cfb452e398 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -276,10 +276,9 @@ fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering (DynamicType::TodoTypeAlias, _) => Ordering::Less, (_, DynamicType::TodoTypeAlias) => Ordering::Greater, - (DynamicType::Divergent(left), DynamicType::Divergent(right)) => left - .file - .cmp(&right.file) - .then_with(|| left.file_scope.index().cmp(&right.file_scope.index())), + (DynamicType::Divergent(left), DynamicType::Divergent(right)) => { + left.scope.cmp(&right.scope) + } (DynamicType::Divergent(_), _) => Ordering::Less, (_, DynamicType::Divergent(_)) => Ordering::Greater, } From 2197085f476cf71530ac5267d82d16f524c22951 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 8 Sep 2025 16:24:32 +0900 Subject: [PATCH 053/105] refactor --- crates/ty_python_semantic/src/types.rs | 27 +++++++++++--------- crates/ty_python_semantic/src/types/infer.rs | 13 +--------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 770584cbe79a4..16cf8a0e05841 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -9054,10 +9054,7 @@ impl<'db> BoundMethodType<'db> { .any(|deco| deco == KnownFunction::Final) } - pub(crate) fn base_signature_and_return_type( - self, - db: &'db dyn Db, - ) -> Option<(Signature<'db>, Type<'db>)> { + pub(super) fn base_return_type(self, db: &'db dyn Db) -> Option> { let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?; let name = self.function(db).name(db); @@ -9068,14 +9065,20 @@ impl<'db> BoundMethodType<'db> { let base_member = base.class_member(db, name, MemberLookupPolicy::default()); if let Place::Type(Type::FunctionLiteral(base_func), _) = base_member.place { if let [signature] = base_func.signature(db).overloads.as_slice() { - Some(( - signature.clone(), - signature.return_ty.unwrap_or_else(|| { - let base_method_ty = - base_func.into_bound_method_type(db, Type::instance(db, class)); - base_method_ty.infer_return_type(db) - }), - )) + let unspecialized_return_ty = signature.return_ty.unwrap_or_else(|| { + let base_method_ty = + base_func.into_bound_method_type(db, Type::instance(db, class)); + base_method_ty.infer_return_type(db) + }); + if let Some(generic_context) = signature.generic_context.as_ref() { + // If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables. + Some( + unspecialized_return_ty + .apply_specialization(db, generic_context.unknown_specialization(db)), + ) + } else { + Some(unspecialized_return_ty) + } } else { // TODO: Handle overloaded base methods. None diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index d800d715590dd..84b051d1856f9 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -614,18 +614,7 @@ impl<'db> ScopeInference<'db> { // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. if !method_ty.is_final(db) { - let (signature, return_ty) = method_ty - .base_signature_and_return_type(db) - .unwrap_or((Signature::unknown(), Type::unknown())); - if let Some(generic_context) = signature.generic_context.as_ref() { - // If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables. - union_add( - return_ty - .apply_specialization(db, generic_context.unknown_specialization(db)), - ); - } else { - union_add(return_ty); - } + union_add(method_ty.base_return_type(db).unwrap_or(Type::unknown())); } } From 8da9e9b5c9681be15a5ba926b7eb22ad66e4447f Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 8 Sep 2025 16:50:28 +0900 Subject: [PATCH 054/105] clarify that the inferred return type is not exposed in the signature --- .../ty_python_semantic/resources/mdtest/function/return_type.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index d7e4e6db27d35..6c3bcaf99c5c6 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -268,6 +268,8 @@ def f(): return 1 reveal_type(f()) # revealed: Literal[1] +# The inferred return type is not exposed in the signature. +reveal_type(f) # revealed: def f() -> Unknown def g(cond: bool): if cond: From f6cd44b8b9e984a56b1f31588671b6492b159929 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 9 Sep 2025 12:44:47 +0900 Subject: [PATCH 055/105] use `any_over_type` in `has_divergent_type` --- crates/ty_python_semantic/src/types.rs | 130 +----------------- crates/ty_python_semantic/src/types/class.rs | 24 +--- .../ty_python_semantic/src/types/generics.rs | 23 +--- .../ty_python_semantic/src/types/instance.rs | 18 +-- .../src/types/signatures.rs | 40 +----- .../src/types/subclass_of.rs | 19 +-- crates/ty_python_semantic/src/types/tuple.rs | 13 +- 7 files changed, 23 insertions(+), 244 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 16cf8a0e05841..0475508dc7572 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -64,6 +64,7 @@ use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signat use crate::types::tuple::TupleSpec; pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type}; use crate::types::variance::{TypeVarVariance, VarianceInferable}; +use crate::types::visitor::any_over_type; use crate::unpack::EvaluationMode; pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic; use crate::{Db, FxOrderSet, Module, Program}; @@ -219,9 +220,6 @@ pub(crate) struct FindLegacyTypeVars; pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>; pub(crate) struct Normalized; -pub(crate) type HasDivergentTypeVisitor<'db> = CycleDetector, bool>; -pub(crate) struct HasDivergentType; - /// How a generic type has been specialized. /// /// This matters only if there is at least one invariant type parameter. @@ -549,19 +547,6 @@ impl<'db> PropertyInstanceType<'db> { ty.find_legacy_typevars_impl(db, binding_context, typevars, visitor); } } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.setter(db) - .is_some_and(|setter| setter.has_divergent_type_impl(db, div, visitor)) - || self - .getter(db) - .is_some_and(|getter| getter.has_divergent_type_impl(db, div, visitor)) - } } bitflags! { @@ -6566,79 +6551,10 @@ impl<'db> Type<'db> { } pub(super) fn has_divergent_type(self, db: &'db dyn Db, div: Type<'db>) -> bool { - let visitor = HasDivergentTypeVisitor::new(false); - self.has_divergent_type_impl(db, div, &visitor) - } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - match self { - Type::Dynamic(DynamicType::Divergent(_)) => self == div, - Type::Union(union) => { - visitor.visit(self, || union.has_divergent_type_impl(db, div, visitor)) - } - Type::Intersection(intersection) => visitor.visit(self, || { - intersection.has_divergent_type_impl(db, div, visitor) - }), - Type::GenericAlias(alias) => visitor.visit(self, || { - alias - .specialization(db) - .has_divergent_type_impl(db, div, visitor) - }), - Type::NominalInstance(instance) => visitor.visit(self, || { - instance.class(db).has_divergent_type_impl(db, div, visitor) - }), - Type::Callable(callable) => { - visitor.visit(self, || callable.has_divergent_type_impl(db, div, visitor)) - } - Type::ProtocolInstance(protocol) => { - visitor.visit(self, || protocol.has_divergent_type_impl(db, div, visitor)) - } - Type::PropertyInstance(property) => { - visitor.visit(self, || property.has_divergent_type_impl(db, div, visitor)) - } - Type::TypeIs(type_is) => visitor.visit(self, || { - type_is - .return_type(db) - .has_divergent_type_impl(db, div, visitor) - }), - Type::SubclassOf(subclass_of) => visitor.visit(self, || { - subclass_of.has_divergent_type_impl(db, div, visitor) - }), - Type::TypedDict(typed_dict) => visitor.visit(self, || { - typed_dict - .defining_class() - .has_divergent_type_impl(db, div, visitor) - }), - Type::Never - | Type::AlwaysTruthy - | Type::AlwaysFalsy - | Type::WrapperDescriptor(_) - | Type::MethodWrapper(_) - | Type::DataclassDecorator(_) - | Type::DataclassTransformer(_) - | Type::ModuleLiteral(_) - | Type::ClassLiteral(_) - | Type::IntLiteral(_) - | Type::BooleanLiteral(_) - | Type::LiteralString - | Type::StringLiteral(_) - | Type::BytesLiteral(_) - | Type::EnumLiteral(_) - | Type::BoundSuper(_) - | Type::SpecialForm(_) - | Type::KnownInstance(_) - | Type::NonInferableTypeVar(_) - | Type::TypeVar(_) - | Type::FunctionLiteral(_) - | Type::BoundMethod(_) - | Type::Dynamic(_) - | Type::TypeAlias(_) => false, - } + any_over_type(db, self, &|ty| match ty { + Type::Dynamic(DynamicType::Divergent(_)) => ty == div, + _ => false, + }) } } @@ -9284,16 +9200,6 @@ impl<'db> CallableType<'db> { .is_equivalent_to_impl(db, other.signatures(db), visitor) }) } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.signatures(db) - .has_divergent_type_impl(db, div, visitor) - } } /// Represents a specific instance of `types.MethodWrapperType` @@ -9963,17 +9869,6 @@ impl<'db> UnionType<'db> { C::from_bool(db, sorted_self == other.normalized(db)) } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.elements(db) - .iter() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - } } #[salsa::interned(debug, heap_size=IntersectionType::heap_size)] @@ -10189,21 +10084,6 @@ impl<'db> IntersectionType<'db> { ruff_memory_usage::order_set_heap_size(positive) + ruff_memory_usage::order_set_heap_size(negative) } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.positive(db) - .iter() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - || self - .negative(db) - .iter() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - } } /// # Ordering diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 108771d6d9480..6f790ea157aaf 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -26,11 +26,11 @@ use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, - DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, - HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, - MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, - TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, - TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, + DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, + IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, + NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping, + TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, + UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -1208,20 +1208,6 @@ impl<'db> ClassType<'db> { } } - pub(super) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - match self { - ClassType::NonGeneric(_) => false, - ClassType::Generic(generic) => generic - .specialization(db) - .has_divergent_type_impl(db, div, visitor), - } - } - pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool { self.class_literal(db).0.is_protocol(db) } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 2d613042115cf..bfb248202bb02 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -15,11 +15,10 @@ use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::{ - ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, - HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, - KnownInstanceType, MaterializationKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, - TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, UnionType, binding_type, - declaration_type, + ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, + IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, + Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, + UnionType, binding_type, declaration_type, }; use crate::{Db, FxOrderSet}; @@ -907,20 +906,6 @@ impl<'db> Specialization<'db> { // A tuple's specialization will include all of its element types, so we don't need to also // look in `self.tuple`. } - - pub(crate) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.types(db) - .iter() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - || self - .tuple_inner(db) - .is_some_and(|tuple| tuple.has_divergent_type_impl(db, div, visitor)) - } } /// A mapping between type variables and types. diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 3ce081bf5adc3..8714c824aaae5 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -12,9 +12,9 @@ use crate::types::enums::is_single_member_enum; use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ - ApplyTypeMappingVisitor, ClassBase, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, - HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, TypeMapping, - TypeRelation, VarianceInferable, + ApplyTypeMappingVisitor, ClassBase, FindLegacyTypeVarsVisitor, HasRelationToVisitor, + IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, TypeMapping, TypeRelation, + VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -609,18 +609,6 @@ impl<'db> ProtocolInstanceType<'db> { pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { self.inner.interface(db) } - - pub(super) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.inner - .interface(db) - .members(db) - .any(|member| member.ty().has_divergent_type_impl(db, div, visitor)) - } } impl<'db> VarianceInferable<'db> for ProtocolInstanceType<'db> { diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 2f3fbd8aabbdf..ffd11bf77fe0f 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -21,9 +21,8 @@ use crate::types::constraints::{ConstraintSet, Constraints, IteratorConstraintsE use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, - HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, - MaterializationKind, NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, - todo_type, + HasRelationToVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, NormalizedVisitor, + TypeMapping, TypeRelation, VarianceInferable, todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -223,17 +222,6 @@ impl<'db> CallableSignature<'db> { } } } - - pub(super) fn has_divergent_type_impl( - &self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.overloads - .iter() - .any(|signature| signature.has_divergent_type_impl(db, div, visitor)) - } } impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> { @@ -1012,17 +1000,6 @@ impl<'db> Signature<'db> { pub(crate) fn with_definition(self, definition: Option>) -> Self { Self { definition, ..self } } - - fn has_divergent_type_impl( - &self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.return_ty - .is_some_and(|return_ty| return_ty.has_divergent_type_impl(db, div, visitor)) - || self.parameters.has_divergent_type_impl(db, div, visitor) - } } impl<'db> VarianceInferable<'db> for &Signature<'db> { @@ -1358,19 +1335,6 @@ impl<'db> Parameters<'db> { .enumerate() .rfind(|(_, parameter)| parameter.is_keyword_variadic()) } - - fn has_divergent_type_impl( - &self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.iter().any(|parameter| { - parameter - .annotated_type() - .is_some_and(|ty| ty.has_divergent_type_impl(db, div, visitor)) - }) - } } impl<'db, 'a> IntoIterator for &'a Parameters<'db> { diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 1328ea0b856a8..35ecbac7ee5aa 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -4,9 +4,9 @@ use crate::types::constraints::Constraints; use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, DynamicType, - FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsDisjointVisitor, - KnownClass, MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, - TypeMapping, TypeRelation, + FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, KnownClass, + MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, TypeMapping, + TypeRelation, }; use crate::{Db, FxOrderSet}; @@ -193,19 +193,6 @@ impl<'db> SubclassOfType<'db> { .into_class() .is_some_and(|class| class.class_literal(db).0.is_typed_dict(db)) } - - pub(super) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - match self.subclass_of { - SubclassOfInner::Dynamic(d @ DynamicType::Divergent(_)) => Type::Dynamic(d) == div, - SubclassOfInner::Dynamic(_) => false, - SubclassOfInner::Class(class) => class.has_divergent_type_impl(db, div, visitor), - } - } } impl<'db> VarianceInferable<'db> for SubclassOfType<'db> { diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index bfe240664b597..cb947d2695a03 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -22,6 +22,7 @@ use std::hash::Hash; use itertools::{Either, EitherOrBoth, Itertools}; use crate::semantic_index::definition::Definition; +use crate::types::Truthiness; use crate::types::class::{ClassType, KnownClass}; use crate::types::constraints::{Constraints, IteratorConstraintsExtension}; use crate::types::{ @@ -29,7 +30,6 @@ use crate::types::{ IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, TypeRelation, UnionBuilder, UnionType, }; -use crate::types::{HasDivergentTypeVisitor, Truthiness}; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; use crate::{Db, FxOrderSet, Program}; @@ -277,17 +277,6 @@ impl<'db> TupleType<'db> { pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool { self.tuple(db).is_single_valued(db) } - - pub(super) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.tuple(db) - .all_elements() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - } } fn to_class_type_cycle_recover<'db>( From 164af9e78ee7e07b0b5bd5f048fe8f3d068b5c11 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 9 Sep 2025 13:11:32 +0900 Subject: [PATCH 056/105] remove unnecessary changes --- crates/ty_python_semantic/src/types/infer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 90d5ed704ffc5..d632afa825f2f 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -8579,7 +8579,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); match eq_result { - todo @ Type::Dynamic(DynamicType::Todo(_) | DynamicType::Divergent(_)) => return Ok(todo), + todo @ Type::Dynamic(DynamicType::Todo(_)) => return Ok(todo), // It's okay to ignore errors here because Python doesn't call `__bool__` // for different union variants. Instead, this is just for us to // evaluate a possibly truthy value to `false` or `true`. @@ -8607,7 +8607,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); Ok(match eq_result { - todo @ Type::Dynamic(DynamicType::Todo(_) | DynamicType::Divergent(_)) => todo, + todo @ Type::Dynamic(DynamicType::Todo(_)) => todo, // It's okay to ignore errors here because Python doesn't call `__bool__` // for `is` and `is not` comparisons. This is an implementation detail // for how we determine the truthiness of a type. From 4f68d19e4e213720891a3089b44f64e958060da3 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 9 Sep 2025 20:45:05 +0900 Subject: [PATCH 057/105] Update return_type.md --- .../ty_python_semantic/resources/mdtest/function/return_type.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 6c3bcaf99c5c6..0d9007907914f 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -268,7 +268,7 @@ def f(): return 1 reveal_type(f()) # revealed: Literal[1] -# The inferred return type is not exposed in the signature. +# TODO: should be `def f() -> Literal[1]` reveal_type(f) # revealed: def f() -> Unknown def g(cond: bool): From fca09372204e7a21d670989163e0ea4a7e238b39 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 9 Sep 2025 21:13:26 +0900 Subject: [PATCH 058/105] remove unnecessary type annotations --- .../diagnostics/semantic_syntax_errors.md | 15 +---- ...246_-_Python_3.10_(96aa8ec77d46553d).snap" | 63 +++++++++---------- 2 files changed, 33 insertions(+), 45 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md index a77efb3b96ac8..efed211d92187 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md @@ -15,10 +15,7 @@ python-version = "3.10" ``` ```py -from ty_extensions import Unknown - -# TODO: async generator type inference support -async def elements(n) -> Unknown: +async def elements(n): yield n async def f(): @@ -57,10 +54,7 @@ python-version = "3.11" ``` ```py -from ty_extensions import Unknown - -# TODO: async generator type inference support -async def elements(n) -> Unknown: +async def elements(n): yield n async def f(): @@ -335,10 +329,7 @@ def _(): This error includes `await`, `async for`, `async with`, and `async` comprehensions. ```py -from ty_extensions import Unknown - -# TODO: async generator type inference support -async def elements(n) -> Unknown: +async def elements(n): yield n def _(): diff --git "a/crates/ty_python_semantic/resources/mdtest/snapshots/semantic_syntax_erro\342\200\246_-_Semantic_syntax_erro\342\200\246_-_`async`_comprehensio\342\200\246_-_Python_3.10_(96aa8ec77d46553d).snap" "b/crates/ty_python_semantic/resources/mdtest/snapshots/semantic_syntax_erro\342\200\246_-_Semantic_syntax_erro\342\200\246_-_`async`_comprehensio\342\200\246_-_Python_3.10_(96aa8ec77d46553d).snap" index b590dd0f90da6..f25e7b1bac843 100644 --- "a/crates/ty_python_semantic/resources/mdtest/snapshots/semantic_syntax_erro\342\200\246_-_Semantic_syntax_erro\342\200\246_-_`async`_comprehensio\342\200\246_-_Python_3.10_(96aa8ec77d46553d).snap" +++ "b/crates/ty_python_semantic/resources/mdtest/snapshots/semantic_syntax_erro\342\200\246_-_Semantic_syntax_erro\342\200\246_-_`async`_comprehensio\342\200\246_-_Python_3.10_(96aa8ec77d46553d).snap" @@ -12,53 +12,50 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syn ## mdtest_snippet.py ``` - 1 | from ty_extensions import Unknown - 2 | - 3 | # TODO: async generator type inference support - 4 | async def elements(n) -> Unknown: - 5 | yield n - 6 | - 7 | async def f(): - 8 | # error: 19 [invalid-syntax] "cannot use an asynchronous comprehension inside of a synchronous comprehension on Python 3.10 (syntax was added in 3.11)" - 9 | return {n: [x async for x in elements(n)] for n in range(3)} -10 | async def test(): -11 | # error: [not-iterable] "Object of type `range` is not async-iterable" -12 | return [[x async for x in elements(n)] async for n in range(3)] + 1 | async def elements(n): + 2 | yield n + 3 | + 4 | async def f(): + 5 | # error: 19 [invalid-syntax] "cannot use an asynchronous comprehension inside of a synchronous comprehension on Python 3.10 (syntax was added in 3.11)" + 6 | return {n: [x async for x in elements(n)] for n in range(3)} + 7 | async def test(): + 8 | # error: [not-iterable] "Object of type `range` is not async-iterable" + 9 | return [[x async for x in elements(n)] async for n in range(3)] +10 | async def f(): +11 | [x for x in [1]] and [x async for x in elements(1)] +12 | 13 | async def f(): -14 | [x for x in [1]] and [x async for x in elements(1)] -15 | -16 | async def f(): -17 | def g(): -18 | pass -19 | [x async for x in elements(1)] +14 | def g(): +15 | pass +16 | [x async for x in elements(1)] ``` # Diagnostics ``` error[invalid-syntax] - --> src/mdtest_snippet.py:9:19 - | - 7 | async def f(): - 8 | # error: 19 [invalid-syntax] "cannot use an asynchronous comprehension inside of a synchronous comprehension on Python 3.10 (synta… - 9 | return {n: [x async for x in elements(n)] for n in range(3)} - | ^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot use an asynchronous comprehension inside of a synchronous comprehension on Python 3.10 (syntax was added in 3.11) -10 | async def test(): -11 | # error: [not-iterable] "Object of type `range` is not async-iterable" - | + --> src/mdtest_snippet.py:6:19 + | +4 | async def f(): +5 | # error: 19 [invalid-syntax] "cannot use an asynchronous comprehension inside of a synchronous comprehension on Python 3.10 (syntax… +6 | return {n: [x async for x in elements(n)] for n in range(3)} + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot use an asynchronous comprehension inside of a synchronous comprehension on Python 3.10 (syntax was added in 3.11) +7 | async def test(): +8 | # error: [not-iterable] "Object of type `range` is not async-iterable" + | ``` ``` error[not-iterable]: Object of type `range` is not async-iterable - --> src/mdtest_snippet.py:12:59 + --> src/mdtest_snippet.py:9:59 | -10 | async def test(): -11 | # error: [not-iterable] "Object of type `range` is not async-iterable" -12 | return [[x async for x in elements(n)] async for n in range(3)] + 7 | async def test(): + 8 | # error: [not-iterable] "Object of type `range` is not async-iterable" + 9 | return [[x async for x in elements(n)] async for n in range(3)] | ^^^^^^^^ -13 | async def f(): -14 | [x for x in [1]] and [x async for x in elements(1)] +10 | async def f(): +11 | [x for x in [1]] and [x async for x in elements(1)] | info: It has no `__aiter__` method info: rule `not-iterable` is enabled by default From fe252a045b43ba48b6b09cbdf1bdead7cd0024e4 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Tue, 9 Sep 2025 21:15:45 +0900 Subject: [PATCH 059/105] Update crates/ty_python_semantic/resources/mdtest/function/return_type.md Co-authored-by: Carl Meyer --- .../ty_python_semantic/resources/mdtest/function/return_type.md | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 0d9007907914f..41e658705bfeb 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -478,7 +478,6 @@ class D(C): def f(self): return 2 # TODO: This should be an invalid-override error. - # If the override is invalid, the type of the method should be that of the base class method. def g(self, x: int): return 2 # A strict application of the Liskov Substitution Principle would consider From c76c0f88b46e9e7abe1f19151ebe51b69aa32229 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 9 Sep 2025 21:53:13 +0900 Subject: [PATCH 060/105] Update return_type.md --- .../resources/mdtest/function/return_type.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 41e658705bfeb..66bd11588a8d3 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -300,6 +300,18 @@ def generator(): # TODO: Should be `Generator[Literal[1, 2], Any, None]` reveal_type(generator()) # revealed: Unknown + +async def async_generator(): + yield + +# TODO: Should be `AsyncGenerator[None, Any]` +reveal_type(async_generator()) # revealed: Unknown + +async def coroutine(): + return + +# TODO: Should be `CoroutineType[Any, Any, None]` +reveal_type(coroutine()) # revealed: Unknown ``` The return type of a recursive function is also inferred. When the return type inference would From 312d97af871d9e30adab9f98237cde06717eb88c Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 9 Sep 2025 22:06:37 +0900 Subject: [PATCH 061/105] Update return_type.md --- .../ty_python_semantic/resources/mdtest/function/return_type.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 66bd11588a8d3..31844db069421 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -315,7 +315,7 @@ reveal_type(coroutine()) # revealed: Unknown ``` The return type of a recursive function is also inferred. When the return type inference would -diverge, it is truncated and replaced with the type `Unknown`. +diverge, it is truncated and replaced with the special dynamic type `Divergent`. ```py def fibonacci(n: int): From 69953ff67d57a15500511ac9c53a6a53271e9d21 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 12 Sep 2025 17:38:09 +0900 Subject: [PATCH 062/105] normalize recursive types --- .../resources/mdtest/function/return_type.md | 26 ++++- crates/ty_python_semantic/src/types.rs | 97 +++++++++++++++++-- .../src/types/class_base.rs | 2 +- crates/ty_python_semantic/src/types/cyclic.rs | 25 ++++- crates/ty_python_semantic/src/types/infer.rs | 44 ++++----- .../src/types/subclass_of.rs | 2 +- 6 files changed, 156 insertions(+), 40 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 31844db069421..59b20d1dafa6e 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -317,6 +317,11 @@ reveal_type(coroutine()) # revealed: Unknown The return type of a recursive function is also inferred. When the return type inference would diverge, it is truncated and replaced with the special dynamic type `Divergent`. +```toml +[environment] +python-version = "3.12" +``` + ```py def fibonacci(n: int): if n == 0: @@ -357,13 +362,26 @@ def divergent(value): else: return None -# tuple[tuple[tuple[...] | None] | None] | None => tuple[Unknown] | None -reveal_type(divergent((1,))) # revealed: None | Divergent +# tuple[tuple[tuple[...] | None] | None] | None => tuple[Divergent] | None +reveal_type(divergent((1,))) # revealed: tuple[Divergent] | None def call_divergent(x: int): return (divergent((1, 2, 3)), x) -reveal_type(call_divergent(1)) # revealed: tuple[None | Divergent, int] +reveal_type(call_divergent(1)) # revealed: tuple[tuple[Divergent] | None, int] + +def list1[T](x: T) -> list[T]: + return [x] + +def divergent2(value): + if type(value) is tuple: + return (divergent2(value[0]),) + elif type(value) is list: + return list1(divergent2(value[0])) + else: + return None + +reveal_type(divergent2((1,))) # revealed: tuple[Divergent] | list[Divergent] | None def tuple_obj(cond: bool): if cond: @@ -381,7 +399,7 @@ def get_non_empty(node): return node return None -reveal_type(get_non_empty(None)) # revealed: None | Divergent +reveal_type(get_non_empty(None)) # revealed: (Divergent & ~None) | None def nested_scope(): def inner(): diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 23832b3300969..d02b960b21e8c 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -222,10 +222,56 @@ pub(crate) type TryBoolVisitor<'db> = CycleDetector, Result>>; pub(crate) struct TryBool; +#[derive(Default)] +pub(crate) enum NormalizationKind<'db> { + #[default] + Normal, + Recursive(Type<'db>), +} + +impl NormalizationKind<'_> { + pub(crate) fn is_recursive(&self) -> bool { + matches!(self, Self::Recursive(_)) + } +} + /// A [`TypeTransformer`] that is used in `normalized` methods. -pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>; +#[derive(Default)] +pub(crate) struct NormalizedVisitor<'db> { + transformer: TypeTransformer<'db, Normalized>, + /// If this is [`NormalizationKind::Recursive`], calling [`Type::normalized_impl`] will normalize the recursive type. + /// A recursive type here means a type that contains a `Divergent` type. + /// Normalizing recursive types allows recursive type inference for divergent functions to converge. + kind: NormalizationKind<'db>, +} pub(crate) struct Normalized; +impl<'db> NormalizedVisitor<'db> { + fn recursive(self, div: Type<'db>) -> Self { + debug_assert!(matches!(div, Type::Dynamic(DynamicType::Divergent(_)))); + Self { + transformer: self.transformer, + kind: NormalizationKind::Recursive(div), + } + } + + fn visit(&self, item: Type<'db>, func: impl FnOnce() -> Type<'db>) -> Type<'db> { + self.transformer.visit(item, func) + } + + fn visit_no_shift(&self, item: Type<'db>, func: impl FnOnce() -> Type<'db>) -> Type<'db> { + self.transformer.visit_no_shift(item, func) + } + + fn level(&self) -> usize { + self.transformer.level() + } + + fn is_recursive(&self) -> bool { + self.kind.is_recursive() + } +} + /// How a generic type has been specialized. /// /// This matters only if there is at least one invariant type parameter. @@ -1166,8 +1212,21 @@ impl<'db> Type<'db> { #[must_use] pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + if let NormalizationKind::Recursive(div) = visitor.kind { + if visitor.level() == 0 && self == div { + // int | Divergent = int | (int | (int | ...)) = int + return Type::Never; + } else if visitor.level() >= 1 && self.has_divergent_type(db, div) { + // G[G[Divergent]] = G[Divergent] + return div; + } + } match self { - Type::Union(union) => visitor.visit(self, || union.normalized_impl(db, visitor)), + Type::Union(union) => { + // As explained above, `Divergent` in a union type does not mean true divergence, + // so we normalize the type while keeping the nesting level the same. + visitor.visit_no_shift(self, || union.normalized_impl(db, visitor)) + } Type::Intersection(intersection) => visitor.visit(self, || { Type::Intersection(intersection.normalized_impl(db, visitor)) }), @@ -1213,7 +1272,7 @@ impl<'db> Type<'db> { Type::TypeIs(type_is) => visitor.visit(self, || { type_is.with_type(db, type_is.return_type(db).normalized_impl(db, visitor)) }), - Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized()), + Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized(visitor.is_recursive())), Type::EnumLiteral(enum_literal) if is_single_member_enum(db, enum_literal.enum_class(db)) => { @@ -7022,7 +7081,10 @@ pub enum DynamicType<'db> { } impl DynamicType<'_> { - fn normalized(self) -> Self { + fn normalized(self, is_recursive: bool) -> Self { + if is_recursive { + return self; + } if matches!(self, Self::Divergent(_)) { self } else { @@ -10208,7 +10270,7 @@ impl<'db> UnionType<'db> { .map(|ty| ty.normalized_impl(db, visitor)) .fold( UnionBuilder::new(db) - .order_elements(true) + .order_elements(!visitor.is_recursive()) .unpack_aliases(true), UnionBuilder::add, ) @@ -10580,7 +10642,9 @@ pub enum SuperOwnerKind<'db> { impl<'db> SuperOwnerKind<'db> { fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { - SuperOwnerKind::Dynamic(dynamic) => SuperOwnerKind::Dynamic(dynamic.normalized()), + SuperOwnerKind::Dynamic(dynamic) => { + SuperOwnerKind::Dynamic(dynamic.normalized(visitor.is_recursive())) + } SuperOwnerKind::Class(class) => { SuperOwnerKind::Class(class.normalized_impl(db, visitor)) } @@ -11120,6 +11184,27 @@ pub(crate) mod tests { let union = UnionType::from_elements(&db, [KnownClass::Object.to_instance(&db), div]); assert_eq!(union.display(&db).to_string(), "object"); + let visitor = NormalizedVisitor::default().recursive(div); + let recursice = UnionType::from_elements( + &db, + [ + KnownClass::List.to_specialized_instance(&db, [div]), + Type::none(&db), + ], + ); + let nested_rec = KnownClass::List.to_specialized_instance(&db, [recursice]); + assert_eq!( + nested_rec.display(&db).to_string(), + "list[list[Divergent] | None]" + ); + let normalized = nested_rec.normalized_impl(&db, &visitor); + assert_eq!(normalized.display(&db).to_string(), "list[Divergent]"); + + let union = UnionType::from_elements(&db, [div, KnownClass::Int.to_instance(&db)]); + assert_eq!(union.display(&db).to_string(), "Divergent | int"); + let normalized = union.normalized_impl(&db, &visitor); + assert_eq!(normalized.display(&db).to_string(), "int"); + // The same can be said about intersections for the `Never` type. let intersection = IntersectionBuilder::new(&db) .add_positive(Type::Never) diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 1306651c5c9fc..92acf23b9e1df 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -37,7 +37,7 @@ impl<'db> ClassBase<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { - Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized()), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized(visitor.is_recursive())), Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), Self::Protocol | Self::Generic | Self::TypedDict => self, } diff --git a/crates/ty_python_semantic/src/types/cyclic.rs b/crates/ty_python_semantic/src/types/cyclic.rs index 2d006a977df88..60798be0606a7 100644 --- a/crates/ty_python_semantic/src/types/cyclic.rs +++ b/crates/ty_python_semantic/src/types/cyclic.rs @@ -58,6 +58,10 @@ pub(crate) struct CycleDetector { /// sort-of defeat the point of a cache if we did!) cache: RefCell>, + /// The nesting level of the `visit` method. + /// This is necessary for normalizing recursive types. + level: RefCell, + fallback: R, _tag: PhantomData, @@ -68,12 +72,13 @@ impl CycleDetector { CycleDetector { seen: RefCell::new(FxIndexSet::default()), cache: RefCell::new(FxHashMap::default()), + level: RefCell::new(0), fallback, _tag: PhantomData, } } - pub(crate) fn visit(&self, item: T, func: impl FnOnce() -> R) -> R { + fn visit_impl(&self, shift_level: bool, item: T, func: impl FnOnce() -> R) -> R { if let Some(val) = self.cache.borrow().get(&item) { return val.clone(); } @@ -83,12 +88,30 @@ impl CycleDetector { return self.fallback.clone(); } + if shift_level { + *self.level.borrow_mut() += 1; + } let ret = func(); + if shift_level { + *self.level.borrow_mut() -= 1; + } self.seen.borrow_mut().pop(); self.cache.borrow_mut().insert(item, ret.clone()); ret } + + pub(crate) fn visit(&self, item: T, func: impl FnOnce() -> R) -> R { + self.visit_impl(true, item, func) + } + + pub(crate) fn visit_no_shift(&self, item: T, func: impl FnOnce() -> R) -> R { + self.visit_impl(false, item, func) + } + + pub(crate) fn level(&self) -> usize { + *self.level.borrow() + } } impl Default for CycleDetector { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index f3e310bd338c5..abaede6dea898 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -51,7 +51,8 @@ use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map}; use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - ClassLiteral, DivergentType, DynamicType, Truthiness, Type, TypeAndQualifiers, UnionBuilder, + ClassLiteral, DivergentType, DynamicType, NormalizedVisitor, Truthiness, Type, + TypeAndQualifiers, UnionBuilder, }; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -468,6 +469,7 @@ impl<'db> ScopeInference<'db> { if let Some(fallback_type) = self.fallback_type() { union = union.add(fallback_type); } + let visitor = NormalizedVisitor::default().recursive(div); // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. // For example, consider the following recursive function: // ```py @@ -479,33 +481,16 @@ impl<'db> ScopeInference<'db> { // ``` // If we try to infer the return type of this function naively, we will get `tuple[tuple[tuple[...] | None] | None] | None`, which never converges. // So, when we detect a cycle, we set the cycle initial type to `Divergent`. Then the type obtained in the first cycle is `tuple[Divergent] | None`. - // Next, if there is a type containing `Divergent`, we replace it with the `Divergent` type itself. - // All types containing `Divergent` are flattened in the next cycle, resulting in a convergence of the return type in finite cycles. + // Let's call such a type containing `Divergent` a "recursive type". + // Next, if there is a type containing a recursive type (let's call this a nested recursive type), we replace the inner recursive type with the `Divergent` type. + // All recursive types are flattened in the next cycle, resulting in a convergence of the return type in finite cycles. // 0th: Divergent - // 1st: tuple[Divergent] | None => Divergent | None - // 2nd: tuple[Divergent | None] | None => Divergent | None - let mut union_add = |ty: Type<'db>| { - if ty == div { - // `Divergent` appearing in a union does not mean true divergence, so it can be removed. - } else if ty.has_divergent_type(db, div) { - if let Type::Union(union_ty) = ty { - let union_ty = union_ty.filter(db, |ty| **ty != div); - if union_ty.has_divergent_type(db, div) { - union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(div); - } else { - union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(union_ty); - } - } else { - union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(div); - } - } else { - union = std::mem::replace(&mut union, UnionBuilder::new(db)).add(ty); - } - }; + // 1st: tuple[Divergent] | None + // 2nd: tuple[tuple[Divergent] | None] | None => tuple[Divergent] | None let previous_type = callee_ty.infer_return_type(db).unwrap(); // In fixed-point iteration of return type inference, the return type must be monotonically widened and not "oscillate". // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. - union_add(previous_type); + union = union.add(previous_type.normalized_impl(db, &visitor)); let Some(extra) = &self.extra else { unreachable!( @@ -516,18 +501,23 @@ impl<'db> ScopeInference<'db> { let ty = returnee.map_or(Type::none(db), |expression| { self.expression_type(expression) }); - union_add(ty); + union = union.add(ty.normalized_impl(db, &visitor)); } let use_def = use_def_map(db, self.scope); if use_def.can_implicitly_return_none(db) { - union_add(Type::none(db)); + union = union.add(Type::none(db)); } if let Type::BoundMethod(method_ty) = callee_ty { // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. if !method_ty.is_final(db) { - union_add(method_ty.base_return_type(db).unwrap_or(Type::unknown())); + union = union.add( + method_ty + .base_return_type(db) + .unwrap_or(Type::unknown()) + .normalized_impl(db, &visitor), + ); } } diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 35ecbac7ee5aa..f7d8dd4891517 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -250,7 +250,7 @@ impl<'db> SubclassOfInner<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), - Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized()), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized(visitor.is_recursive())), } } From b750358bc7394f48c6f7c76c327df6d202b0b585 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 12 Sep 2025 20:48:22 +0900 Subject: [PATCH 063/105] remove `CycleRecovery` and set the cycle initial value to `Divergent` --- .../resources/mdtest/attributes.md | 3 +- .../resources/mdtest/cycle.md | 6 +- crates/ty_python_semantic/src/types/infer.rs | 102 +++++------------- .../src/types/infer/builder.rs | 83 +++++++------- 4 files changed, 70 insertions(+), 124 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/attributes.md b/crates/ty_python_semantic/resources/mdtest/attributes.md index 1ef154ee3600d..43b1b568c4dbb 100644 --- a/crates/ty_python_semantic/resources/mdtest/attributes.md +++ b/crates/ty_python_semantic/resources/mdtest/attributes.md @@ -2223,7 +2223,8 @@ class C: def copy(self, other: "C"): self.x = other.x -reveal_type(C().x) # revealed: Unknown | Literal[1] +# TODO: Should be `Unknown | Literal[1]` +reveal_type(C().x) # revealed: Unknown | Literal[1] | Divergent ``` If the only assignment to a name is cyclic, we just infer `Unknown` for that attribute: diff --git a/crates/ty_python_semantic/resources/mdtest/cycle.md b/crates/ty_python_semantic/resources/mdtest/cycle.md index 0bd3b5b2a6df4..da12590ff07cf 100644 --- a/crates/ty_python_semantic/resources/mdtest/cycle.md +++ b/crates/ty_python_semantic/resources/mdtest/cycle.md @@ -28,6 +28,8 @@ class Point: self.x, self.y = other.x, other.y p = Point() -reveal_type(p.x) # revealed: Unknown | int -reveal_type(p.y) # revealed: Unknown | int +# TODO: should be `Unknown | int` +reveal_type(p.x) # revealed: Unknown | int | Divergent +# TODO: should be `Unknown | int` +reveal_type(p.y) # revealed: Unknown | int | Divergent ``` diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index bd9d9463e99ac..de3552fb453dd 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -210,22 +210,13 @@ fn infer_expression_types_impl<'db>( .finish_expression() } -/// How many fixpoint iterations to allow before falling back to Divergent type. -const ITERATIONS_BEFORE_FALLBACK: u32 = 10; - fn expression_cycle_recover<'db>( - db: &'db dyn Db, + _db: &'db dyn Db, _value: &ExpressionInference<'db>, - count: u32, - input: InferExpression<'db>, + _count: u32, + _input: InferExpression<'db>, ) -> salsa::CycleRecoveryAction> { - if count == ITERATIONS_BEFORE_FALLBACK { - salsa::CycleRecoveryAction::Fallback(ExpressionInference::cycle_fallback( - input.expression(db).scope(db), - )) - } else { - salsa::CycleRecoveryAction::Iterate - } + salsa::CycleRecoveryAction::Iterate } fn expression_cycle_initial<'db>( @@ -480,35 +471,6 @@ impl<'db> InferenceRegion<'db> { } } -#[derive(Debug, Clone, Copy, Eq, PartialEq, get_size2::GetSize, salsa::Update)] -enum CycleRecovery<'db> { - /// An initial-value for fixpoint iteration; all types are `Type::Never`. - Initial, - /// A divergence-fallback value for fixpoint iteration; all types are `Divergent`. - Divergent(ScopeId<'db>), -} - -impl<'db> CycleRecovery<'db> { - fn merge(self, other: Option>) -> Self { - if let Some(other) = other { - match (self, other) { - // It's important here that we keep the scope of `self` if merging two `Divergent`. - (Self::Divergent(scope), _) | (_, Self::Divergent(scope)) => Self::Divergent(scope), - _ => Self::Initial, - } - } else { - self - } - } - - fn fallback_type(self) -> Type<'db> { - match self { - Self::Initial => Type::Never, - Self::Divergent(scope) => Type::divergent(scope), - } - } -} - /// The inferred types for a scope region. #[derive(Debug, Eq, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) struct ScopeInference<'db> { @@ -516,15 +478,15 @@ pub(crate) struct ScopeInference<'db> { expressions: FxHashMap>, /// The extra data that is only present for few inference regions. - extra: Option>>, + extra: Option>, scope: ScopeId<'db>, } #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] -struct ScopeInferenceExtra<'db> { - /// Is this a cycle-recovery inference result, and if so, what kind? - cycle_recovery: Option>, +struct ScopeInferenceExtra { + /// Whether to use the fallback type for missing expressions/bindings/declarations or recursive type inference. + cycle_fallback: bool, /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, @@ -539,7 +501,7 @@ impl<'db> ScopeInference<'db> { fn cycle_initial(scope: ScopeId<'db>) -> Self { Self { extra: Some(Box::new(ScopeInferenceExtra { - cycle_recovery: Some(CycleRecovery::Initial), + cycle_fallback: true, ..ScopeInferenceExtra::default() })), expressions: FxHashMap::default(), @@ -566,15 +528,10 @@ impl<'db> ScopeInference<'db> { .or_else(|| self.fallback_type()) } - fn is_cycle_callback(&self) -> bool { + fn fallback_type(&self) -> Option> { self.extra .as_ref() - .and_then(|extra| extra.cycle_recovery) - .is_some() - } - - fn fallback_type(&self) -> Option> { - self.is_cycle_callback() + .is_some_and(|extra| extra.cycle_fallback) .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { scope: self.scope, }))) @@ -681,8 +638,8 @@ pub(crate) struct DefinitionInference<'db> { #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] struct DefinitionInferenceExtra<'db> { - /// Is this a cycle-recovery inference result, and if so, what kind? - cycle_recovery: Option>, + /// Whether to use the fallback type for missing expressions/bindings/declarations or recursive type inference. + cycle_fallback: bool, /// The definitions that are deferred. deferred: Box<[Definition<'db>]>, @@ -705,7 +662,7 @@ impl<'db> DefinitionInference<'db> { #[cfg(debug_assertions)] scope, extra: Some(Box::new(DefinitionInferenceExtra { - cycle_recovery: Some(CycleRecovery::Initial), + cycle_fallback: true, ..DefinitionInferenceExtra::default() })), } @@ -777,7 +734,10 @@ impl<'db> DefinitionInference<'db> { fn fallback_type(&self) -> Option> { self.extra .as_ref() - .and_then(|extra| extra.cycle_recovery.map(CycleRecovery::fallback_type)) + .is_some_and(|extra| extra.cycle_fallback) + .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { + scope: self.scope, + }))) } pub(crate) fn undecorated_type(&self) -> Option> { @@ -809,8 +769,8 @@ struct ExpressionInferenceExtra<'db> { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, - /// Is this a cycle recovery inference result, and if so, what kind? - cycle_recovery: Option>, + /// Whether to use the fallback type for missing expressions/bindings/declarations or recursive type inference. + cycle_fallback: bool, /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, @@ -821,22 +781,7 @@ impl<'db> ExpressionInference<'db> { let _ = scope; Self { extra: Some(Box::new(ExpressionInferenceExtra { - cycle_recovery: Some(CycleRecovery::Initial), - all_definitely_bound: true, - ..ExpressionInferenceExtra::default() - })), - expressions: FxHashMap::default(), - - #[cfg(debug_assertions)] - scope, - } - } - - fn cycle_fallback(scope: ScopeId<'db>) -> Self { - let _ = scope; - Self { - extra: Some(Box::new(ExpressionInferenceExtra { - cycle_recovery: Some(CycleRecovery::Divergent(scope)), + cycle_fallback: true, all_definitely_bound: true, ..ExpressionInferenceExtra::default() })), @@ -865,7 +810,10 @@ impl<'db> ExpressionInference<'db> { fn fallback_type(&self) -> Option> { self.extra .as_ref() - .and_then(|extra| extra.cycle_recovery.map(CycleRecovery::fallback_type)) + .is_some_and(|extra| extra.cycle_fallback) + .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { + scope: self.scope, + }))) } /// Returns true if all places in this expression are definitely bound. diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index f1f2e9c2841cc..71e987dce0b63 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -9,10 +9,10 @@ use ruff_text_size::{Ranged, TextRange}; use rustc_hash::{FxHashMap, FxHashSet}; use super::{ - CycleRecovery, DefinitionInference, DefinitionInferenceExtra, ExpressionInference, - ExpressionInferenceExtra, InferenceRegion, ScopeInference, ScopeInferenceExtra, - infer_deferred_types, infer_definition_types, infer_expression_types, - infer_same_file_expression_type, infer_scope_types, infer_unpack_types, + DefinitionInference, DefinitionInferenceExtra, ExpressionInference, ExpressionInferenceExtra, + InferenceRegion, ScopeInference, ScopeInferenceExtra, infer_deferred_types, + infer_definition_types, infer_expression_types, infer_same_file_expression_type, + infer_scope_types, infer_unpack_types, }; use crate::module_name::{ModuleName, ModuleNameResolutionError}; use crate::module_resolver::{ @@ -85,12 +85,13 @@ use crate::types::typed_dict::{ }; use crate::types::visitor::any_over_type; use crate::types::{ - CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType, - IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, - MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, - SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, - TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, - TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, + CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DivergentType, + DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, + MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, + Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, + TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, + TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, + UnionBuilder, UnionType, binding_type, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -252,8 +253,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// For function definitions, the undecorated type of the function. undecorated_type: Option>, - /// Did we merge in a sub-region with a cycle-recovery fallback, and if so, what kind? - cycle_recovery: Option>, + /// Whether to use the fallback type for missing expressions/bindings/declarations or recursive type inference. + cycle_fallback: bool, /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, @@ -289,20 +290,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { typevar_binding_context: None, deferred: VecSet::default(), undecorated_type: None, - cycle_recovery: None, + cycle_fallback: false, all_definitely_bound: true, } } - fn extend_cycle_recovery(&mut self, other_recovery: Option>) { - match &mut self.cycle_recovery { - Some(recovery) => *recovery = recovery.merge(other_recovery), - recovery @ None => *recovery = other_recovery, - } - } - fn fallback_type(&self) -> Option> { - self.cycle_recovery.map(CycleRecovery::fallback_type) + self.cycle_fallback + .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { + scope: self.scope, + }))) } fn extend_definition(&mut self, inference: &DefinitionInference<'db>) { @@ -317,7 +314,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } if let Some(extra) = &inference.extra { - self.extend_cycle_recovery(extra.cycle_recovery); + self.cycle_fallback |= extra.cycle_fallback; self.context.extend(&extra.diagnostics); self.deferred.extend(extra.deferred.iter().copied()); } @@ -335,7 +332,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(extra) = &inference.extra { self.context.extend(&extra.diagnostics); - self.extend_cycle_recovery(extra.cycle_recovery); + self.cycle_fallback |= extra.cycle_fallback; if !matches!(self.region, InferenceRegion::Scope(..)) { self.bindings.extend(extra.bindings.iter().copied()); @@ -8881,7 +8878,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { bindings, declarations, deferred, - cycle_recovery, + cycle_fallback, all_definitely_bound, // Ignored; only relevant to definition regions @@ -8909,7 +8906,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); let extra = - (cycle_recovery.is_some() || !bindings.is_empty() || !diagnostics.is_empty() || !all_definitely_bound).then(|| { + (cycle_fallback || !bindings.is_empty() || !diagnostics.is_empty() || !all_definitely_bound).then(|| { if bindings.len() > 20 { tracing::debug!( "Inferred expression region `{:?}` contains {} bindings. Lookups by linear scan might be slow.", @@ -8921,7 +8918,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Box::new(ExpressionInferenceExtra { bindings: bindings.into_boxed_slice(), diagnostics, - cycle_recovery, + cycle_fallback, all_definitely_bound, }) }); @@ -8946,7 +8943,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { bindings, declarations, deferred, - cycle_recovery, + cycle_fallback, undecorated_type, all_definitely_bound: _, // builder only state @@ -8962,12 +8959,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let diagnostics = context.finish(); let extra = (!diagnostics.is_empty() - || cycle_recovery.is_some() + || cycle_fallback || undecorated_type.is_some() || !deferred.is_empty()) .then(|| { Box::new(DefinitionInferenceExtra { - cycle_recovery, + cycle_fallback, deferred: deferred.into_boxed_slice(), diagnostics, undecorated_type, @@ -9010,7 +9007,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { context, mut expressions, scope, - cycle_recovery, + cycle_fallback, // Ignored, because scope types are never extended into other scopes. deferred: _, @@ -9032,20 +9029,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let diagnostics = context.finish(); - let extra = (!diagnostics.is_empty() - || cycle_recovery.is_some() - || scope.is_non_lambda_function(db)) - .then(|| { - let returnees = returnees - .into_iter() - .map(|returnee| returnee.expression) - .collect(); - Box::new(ScopeInferenceExtra { - cycle_recovery, - diagnostics, - returnees, - }) - }); + let extra = (!diagnostics.is_empty() || cycle_fallback || scope.is_non_lambda_function(db)) + .then(|| { + let returnees = returnees + .into_iter() + .map(|returnee| returnee.expression) + .collect(); + Box::new(ScopeInferenceExtra { + cycle_fallback, + diagnostics, + returnees, + }) + }); expressions.shrink_to_fit(); From 8a19b05d74fb094075d65bb65d784bfc3f54a5cb Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 12 Sep 2025 20:55:51 +0900 Subject: [PATCH 064/105] don't sort types during recursive type normalization --- crates/ty_python_semantic/src/types.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 3fb1f94cccefd..6722f081c3d19 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -10365,7 +10365,9 @@ impl<'db> IntersectionType<'db> { .map(|ty| ty.normalized_impl(db, visitor)) .collect(); - elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); + if !visitor.is_recursive() { + elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); + } elements } From 35ffbfc649b976880e847bd59deb13d8cd8f1896 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 12 Sep 2025 21:08:01 +0900 Subject: [PATCH 065/105] remove the `debug_assertions` attributes --- crates/ty_python_semantic/src/types/infer.rs | 5 ----- crates/ty_python_semantic/src/types/infer/builder.rs | 4 ---- 2 files changed, 9 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index de3552fb453dd..51cbb18bbac09 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -616,7 +616,6 @@ pub(crate) struct DefinitionInference<'db> { expressions: FxHashMap>, /// The scope this region is part of. - #[cfg(debug_assertions)] scope: ScopeId<'db>, /// The types of every binding in this region. @@ -659,7 +658,6 @@ impl<'db> DefinitionInference<'db> { expressions: FxHashMap::default(), bindings: Box::default(), declarations: Box::default(), - #[cfg(debug_assertions)] scope, extra: Some(Box::new(DefinitionInferenceExtra { cycle_fallback: true, @@ -754,7 +752,6 @@ pub(crate) struct ExpressionInference<'db> { extra: Option>>, /// The scope this region is part of. - #[cfg(debug_assertions)] scope: ScopeId<'db>, } @@ -778,7 +775,6 @@ struct ExpressionInferenceExtra<'db> { impl<'db> ExpressionInference<'db> { fn cycle_initial(scope: ScopeId<'db>) -> Self { - let _ = scope; Self { extra: Some(Box::new(ExpressionInferenceExtra { cycle_fallback: true, @@ -787,7 +783,6 @@ impl<'db> ExpressionInference<'db> { })), expressions: FxHashMap::default(), - #[cfg(debug_assertions)] scope, } } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 71e987dce0b63..afa26699b6299 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -303,7 +303,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn extend_definition(&mut self, inference: &DefinitionInference<'db>) { - #[cfg(debug_assertions)] assert_eq!(self.scope, inference.scope); self.expressions.extend(inference.expressions.iter()); @@ -321,7 +320,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn extend_expression(&mut self, inference: &ExpressionInference<'db>) { - #[cfg(debug_assertions)] assert_eq!(self.scope, inference.scope); self.extend_expression_unchecked(inference); @@ -8928,7 +8926,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ExpressionInference { expressions, extra, - #[cfg(debug_assertions)] scope, } } @@ -8991,7 +8988,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { DefinitionInference { expressions, - #[cfg(debug_assertions)] scope, bindings: bindings.into_boxed_slice(), declarations: declarations.into_boxed_slice(), From 23e0cfef59e255f1507a8735c1f1e90bc891c01a Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 13 Sep 2025 04:31:17 +0900 Subject: [PATCH 066/105] do nothing except normalization of recursive types when `visitor.is_recursive()` --- crates/ty_python_semantic/src/types.rs | 8 +++++++- crates/ty_python_semantic/src/types/instance.rs | 3 +++ crates/ty_python_semantic/src/types/signatures.rs | 10 +++++++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 6722f081c3d19..3e677f70f4fb6 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1286,7 +1286,10 @@ impl<'db> Type<'db> { // TODO: Normalize TypedDicts self } - Type::TypeAlias(alias) => alias.value_type(db).normalized_impl(db, visitor), + Type::TypeAlias(alias) if !visitor.is_recursive() => { + alias.value_type(db).normalized_impl(db, visitor) + } + Type::TypeAlias(_) => self, Type::LiteralString | Type::AlwaysFalsy | Type::AlwaysTruthy @@ -7566,6 +7569,9 @@ impl<'db> TypeVarInstance<'db> { } pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + if visitor.is_recursive() { + return self; + } Self::new( db, self.name(db), diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 003e4ff459906..12d49e28697da 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -624,6 +624,9 @@ impl<'db> ProtocolInstanceType<'db> { db: &'db dyn Db, visitor: &NormalizedVisitor<'db>, ) -> Type<'db> { + if visitor.is_recursive() { + return Type::ProtocolInstance(self); + } if self.is_equivalent_to_object(db) { return Type::object(); } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 115b6b46debde..31f7e1e64c080 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -433,7 +433,7 @@ impl<'db> Signature<'db> { .map(|ctx| ctx.normalized_impl(db, visitor)), // Discard the definition when normalizing, so that two equivalent signatures // with different `Definition`s share the same Salsa ID when normalized - definition: None, + definition: visitor.is_recursive().then_some(self.definition).flatten(), parameters: self .parameters .iter() @@ -1489,6 +1489,14 @@ impl<'db> Parameter<'db> { form, } = self; + if visitor.is_recursive() { + return Self { + annotated_type: annotated_type.map(|ty| ty.normalized_impl(db, visitor)), + kind: kind.clone(), + form: *form, + }; + } + // Ensure unions and intersections are ordered in the annotated type (if there is one). // Ensure that a parameter without an annotation is treated equivalently to a parameter // with a dynamic type as its annotation. (We must use `Any` here as all dynamic types From 3d74f1996d0165c5ec162fff4b1b25139fe30067 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 13 Sep 2025 11:09:14 +0900 Subject: [PATCH 067/105] set the cycle initial value of `BoundMethod::into_callable_type, ClassType::into_callable` to `CallableType::bottom()` --- crates/ty_python_semantic/src/types.rs | 20 ++++++++++++++++++-- crates/ty_python_semantic/src/types/class.rs | 4 ++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 3e677f70f4fb6..6473d1bc43334 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -126,6 +126,23 @@ fn return_type_cycle_initial<'db>(db: &'db dyn Db, method: BoundMethodType<'db>) })) } +#[allow(clippy::trivially_copy_pass_by_ref)] +fn into_callable_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &CallableType<'db>, + _count: u32, + _self: BoundMethodType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn into_callable_type_cycle_initial<'db>( + db: &'db dyn Db, + _method: BoundMethodType<'db>, +) -> CallableType<'db> { + CallableType::new(db, CallableSignature::single(Signature::bottom()), false) +} + pub fn check_types(db: &dyn Db, file: File) -> Vec { let _span = tracing::trace_span!("check_types", ?file).entered(); @@ -9048,7 +9065,7 @@ impl<'db> BoundMethodType<'db> { self_instance } - #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked(cycle_fn=into_callable_cycle_recover, cycle_initial=into_callable_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> { let function = self.function(db); let self_instance = self.typing_self_type(db); @@ -9257,7 +9274,6 @@ impl<'db> CallableType<'db> { /// /// Specifically, this represents a callable type with a single signature: /// `(*args: object, **kwargs: object) -> Never`. - #[cfg(test)] pub(crate) fn bottom(db: &'db dyn Db) -> Type<'db> { Self::single(db, Signature::bottom()) } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 9ef002ef0f889..60077e8ca6ef4 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -180,8 +180,8 @@ fn into_callable_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -fn into_callable_cycle_initial<'db>(_db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { - Type::Never +fn into_callable_cycle_initial<'db>(db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { + CallableType::bottom(db) } /// A category of classes with code generation capabilities (with synthesized methods). From 3e2c2c4a07a079779041da9ad5f7ebcd864a5165 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 14 Sep 2025 03:14:45 +0900 Subject: [PATCH 068/105] =?UTF-8?q?`DivergentType`=20has=20the=20kind=20of?= =?UTF-8?q?=20divergence,=20and=20the=20cycle=20initial=20values=20?= =?UTF-8?q?=E2=80=8B=E2=80=8Bof=20various=20tracked=20functions=20are=20se?= =?UTF-8?q?t=20to=20`DivergentType`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../resources/mdtest/attributes.md | 3 +- .../resources/mdtest/cycle.md | 6 +- .../ty_python_semantic/src/semantic_index.rs | 27 +- .../src/semantic_index/symbol.rs | 6 +- crates/ty_python_semantic/src/types.rs | 120 +++++++-- crates/ty_python_semantic/src/types/class.rs | 99 +++++--- .../ty_python_semantic/src/types/function.rs | 14 +- crates/ty_python_semantic/src/types/infer.rs | 238 +++++++++++++----- .../src/types/infer/builder.rs | 82 +++--- .../src/types/infer/tests.rs | 27 +- .../src/types/type_ordering.rs | 16 +- .../ty_python_semantic/src/types/unpacker.rs | 21 +- 12 files changed, 481 insertions(+), 178 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/attributes.md b/crates/ty_python_semantic/resources/mdtest/attributes.md index 43b1b568c4dbb..1ef154ee3600d 100644 --- a/crates/ty_python_semantic/resources/mdtest/attributes.md +++ b/crates/ty_python_semantic/resources/mdtest/attributes.md @@ -2223,8 +2223,7 @@ class C: def copy(self, other: "C"): self.x = other.x -# TODO: Should be `Unknown | Literal[1]` -reveal_type(C().x) # revealed: Unknown | Literal[1] | Divergent +reveal_type(C().x) # revealed: Unknown | Literal[1] ``` If the only assignment to a name is cyclic, we just infer `Unknown` for that attribute: diff --git a/crates/ty_python_semantic/resources/mdtest/cycle.md b/crates/ty_python_semantic/resources/mdtest/cycle.md index da12590ff07cf..0bd3b5b2a6df4 100644 --- a/crates/ty_python_semantic/resources/mdtest/cycle.md +++ b/crates/ty_python_semantic/resources/mdtest/cycle.md @@ -28,8 +28,6 @@ class Point: self.x, self.y = other.x, other.y p = Point() -# TODO: should be `Unknown | int` -reveal_type(p.x) # revealed: Unknown | int | Divergent -# TODO: should be `Unknown | int` -reveal_type(p.y) # revealed: Unknown | int | Divergent +reveal_type(p.x) # revealed: Unknown | int +reveal_type(p.y) # revealed: Unknown | int ``` diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index fc57d0b2f6b44..986e514b9277b 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -25,7 +25,9 @@ pub use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::scope::{ NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopeKind, ScopeLaziness, }; -use crate::semantic_index::symbol::ScopedSymbolId; +use crate::semantic_index::symbol::{ + ImplicitAttributeTable, ImplicitAttributeTableBuilder, ScopedSymbolId, Symbol, +}; use crate::semantic_index::use_def::{EnclosingSnapshotKey, ScopedEnclosingSnapshotId, UseDefMap}; use crate::semantic_model::HasTrackedScope; @@ -73,6 +75,29 @@ pub(crate) fn place_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc( + db: &'db dyn Db, + class_body_scope: ScopeId<'db>, +) -> ImplicitAttributeTable { + let mut table = ImplicitAttributeTableBuilder::default(); + + let file = class_body_scope.file(db); + let index = semantic_index(db, file); + + for function_scope_id in attribute_scopes(db, class_body_scope) { + let function_place_table = index.place_table(function_scope_id); + for member in function_place_table.members() { + if let Some(attr) = member.as_instance_attribute() { + let symbol = Symbol::new(attr.into()); + table.add(symbol); + } + } + } + + table.build() +} + /// Returns the set of modules that are imported anywhere in `file`. /// /// This set only considers `import` statements, not `from...import` statements, because: diff --git a/crates/ty_python_semantic/src/semantic_index/symbol.rs b/crates/ty_python_semantic/src/semantic_index/symbol.rs index 400165485ed46..c0120eb10a5c2 100644 --- a/crates/ty_python_semantic/src/semantic_index/symbol.rs +++ b/crates/ty_python_semantic/src/semantic_index/symbol.rs @@ -149,7 +149,7 @@ impl Symbol { /// /// Allows lookup by name and a symbol's ID. #[derive(Default, get_size2::GetSize)] -pub(super) struct SymbolTable { +pub(crate) struct SymbolTable { symbols: IndexVec, /// Map from symbol name to its ID. @@ -268,3 +268,7 @@ impl DerefMut for SymbolTableBuilder { &mut self.table } } + +pub(crate) type ScopedImplicitAttributeId = ScopedSymbolId; +pub(crate) type ImplicitAttributeTable = SymbolTable; +pub(super) type ImplicitAttributeTableBuilder = SymbolTableBuilder; diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9fe3db05edbb4..ac3a0e227782a 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -34,7 +34,10 @@ use crate::place::{Boundness, Place, PlaceAndQualifiers, imported_symbol}; use crate::semantic_index::definition::Definition; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::ScopeId; -use crate::semantic_index::{imported_modules, place_table, semantic_index}; +use crate::semantic_index::symbol::ScopedImplicitAttributeId; +use crate::semantic_index::{ + implicit_attribute_table, imported_modules, place_table, semantic_index, +}; use crate::suppression::check_suppressions; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; pub(crate) use crate::types::class_base::ClassBase; @@ -58,7 +61,7 @@ pub use crate::types::ide_support::{ definitions_for_attribute, definitions_for_imported_symbol, definitions_for_keyword_argument, definitions_for_name, find_active_signature_from_details, inlay_hint_function_argument_details, }; -use crate::types::infer::infer_unpack_types; +use crate::types::infer::infer_function_scope_types; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature}; @@ -117,13 +120,14 @@ fn return_type_cycle_recover<'db>( } fn return_type_cycle_initial<'db>(db: &'db dyn Db, method: BoundMethodType<'db>) -> Type<'db> { - Type::Dynamic(DynamicType::Divergent(DivergentType { - scope: method + Type::divergent(DivergentType::function_return_type( + db, + method .function(db) .literal(db) .last_definition(db) .body_scope(db), - })) + )) } pub fn check_types(db: &dyn Db, file: File) -> Vec { @@ -420,13 +424,32 @@ fn member_lookup_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } +#[allow(clippy::needless_pass_by_value)] fn member_lookup_cycle_initial<'db>( - _db: &'db dyn Db, - _self: Type<'db>, - _name: Name, + db: &'db dyn Db, + self_ty: Type<'db>, + name: Name, _policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { - Place::bound(Type::Never).into() + let class_body_scope = match self_ty { + Type::NominalInstance(instance) => instance.class_literal(db).body_scope(db), + Type::GenericAlias(alias) => alias.origin(db).body_scope(db), + Type::ClassLiteral(class) => class.body_scope(db), + _ => { + return Place::bound(Type::Never).into(); + } + }; + let implicit_attribute_table = implicit_attribute_table(db, class_body_scope); + let initial = if let Some(implicit_attribute) = implicit_attribute_table.symbol_id(&name) { + Type::divergent(DivergentType::implicit_attribute( + db, + class_body_scope, + implicit_attribute, + )) + } else { + Type::Never + }; + Place::bound(initial).into() } fn class_lookup_cycle_recover<'db>( @@ -440,13 +463,32 @@ fn class_lookup_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } +#[allow(clippy::needless_pass_by_value)] fn class_lookup_cycle_initial<'db>( - _db: &'db dyn Db, - _self: Type<'db>, - _name: Name, + db: &'db dyn Db, + self_ty: Type<'db>, + name: Name, _policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { - Place::bound(Type::Never).into() + let class_body_scope = match self_ty { + Type::NominalInstance(instance) => instance.class_literal(db).body_scope(db), + Type::GenericAlias(alias) => alias.origin(db).body_scope(db), + Type::ClassLiteral(class) => class.body_scope(db), + _ => { + return Place::bound(Type::Never).into(); + } + }; + let implicit_attribute_table = implicit_attribute_table(db, class_body_scope); + let initial = if let Some(implicit_attribute) = implicit_attribute_table.symbol_id(&name) { + Type::divergent(DivergentType::implicit_attribute( + db, + class_body_scope, + implicit_attribute, + )) + } else { + Type::Never + }; + Place::bound(initial).into() } #[allow(clippy::trivially_copy_pass_by_ref)] @@ -838,8 +880,8 @@ impl<'db> Type<'db> { Self::Dynamic(DynamicType::Unknown) } - pub(crate) fn divergent(scope: ScopeId<'db>) -> Self { - Self::Dynamic(DynamicType::Divergent(DivergentType { scope })) + pub(crate) fn divergent(divergent: DivergentType<'db>) -> Self { + Self::Dynamic(DynamicType::Divergent(divergent)) } pub const fn is_unknown(&self) -> bool { @@ -7036,15 +7078,56 @@ impl<'db> KnownInstanceType<'db> { } } +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] +pub enum DivergenceKind { + /// Divergence in function return type inference. + FunctionReturnType, + /// Divergence in implicit attribute type inference. + ImplicitAttribute(ScopedImplicitAttributeId), + /// Unknown divergence that we have not yet handled. + Todo, +} + +impl DivergenceKind { + pub fn is_todo(self) -> bool { + matches!(self, Self::Todo) + } +} + /// A type that is determined to be divergent during recursive type inference. /// This type must never be eliminated by dynamic type reduction /// (e.g. `Divergent` is assignable to `@Todo`, but `@Todo | Divergent` must not be reducted to `@Todo`). /// Otherwise, type inference cannot converge properly. /// For detailed properties of this type, see the unit test at the end of the file. -#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] pub struct DivergentType<'db> { /// The scope where this divergence was detected. + /// * The function scope in case of function return type inference + /// * The class body scope in case of implicit attribute type inference scope: ScopeId<'db>, + /// The kind of divergence. + kind: DivergenceKind, +} + +// The Salsa heap is tracked separately. +impl get_size2::GetSize for DivergentType<'_> {} + +impl<'db> DivergentType<'db> { + pub(crate) fn function_return_type(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { + Self::new(db, scope, DivergenceKind::FunctionReturnType) + } + + pub(crate) fn implicit_attribute( + db: &'db dyn Db, + scope: ScopeId<'db>, + attribute: ScopedImplicitAttributeId, + ) -> Self { + Self::new(db, scope, DivergenceKind::ImplicitAttribute(attribute)) + } + + pub(crate) fn todo(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { + Self::new(db, scope, DivergenceKind::Todo) + } } #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] @@ -9084,7 +9167,8 @@ impl<'db> BoundMethodType<'db> { .literal(db) .last_definition(db) .body_scope(db); - let inference = infer_scope_types(db, scope); + let inference = + infer_function_scope_types(db, scope, DivergentType::function_return_type(db, scope)); inference.infer_return_type(db, Type::BoundMethod(self)) } @@ -11113,7 +11197,7 @@ pub(crate) mod tests { let file_scope_id = FileScopeId::global(); let scope = file_scope_id.to_scope_id(&db, file); - let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope })); + let div = Type::divergent(DivergentType::function_return_type(&db, scope)); // The `Divergent` type must not be eliminated in union with other dynamic types, // as this would prevent detection of divergent type inference using `Divergent`. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 4b9753313a16f..6fca8e8ee2aea 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -4,7 +4,7 @@ use super::TypeVarVariance; use super::{ BoundTypeVarInstance, IntersectionBuilder, MemberLookupPolicy, Mro, MroError, MroIterator, SpecialFormType, SubclassOfType, Truthiness, Type, TypeQualifiers, class_base::ClassBase, - function::FunctionType, infer_expression_type, infer_unpack_types, + function::FunctionType, }; use crate::FxOrderMap; use crate::module_resolver::KnownModule; @@ -13,6 +13,7 @@ use crate::semantic_index::scope::NodeWithScopeKind; use crate::semantic_index::symbol::Symbol; use crate::semantic_index::{ DeclarationWithConstraint, SemanticIndex, attribute_declarations, attribute_scopes, + implicit_attribute_table, }; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::context::InferContext; @@ -20,17 +21,21 @@ use crate::types::diagnostic::{INVALID_LEGACY_TYPE_VARIABLE, INVALID_TYPE_ALIAS_ use crate::types::enums::enum_metadata; use crate::types::function::{DataclassTransformerParams, KnownFunction}; use crate::types::generics::{GenericContext, Specialization, walk_specialization}; -use crate::types::infer::nearest_enclosing_class; +use crate::types::infer::{ + infer_implicit_attribute_expression_type, infer_unpack_implicit_attribute_types, + nearest_enclosing_class, +}; use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, - DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, - NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeContext, - TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, - TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, + DataclassParams, DeprecatedInstance, DivergentType, FindLegacyTypeVarsVisitor, + HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, + MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, + TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, + TypeVarKind, TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, + infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -105,13 +110,23 @@ fn implicit_attribute_recover<'db>( salsa::CycleRecoveryAction::Iterate } +#[allow(clippy::needless_pass_by_value)] fn implicit_attribute_initial<'db>( - _db: &'db dyn Db, - _class_body_scope: ScopeId<'db>, - _name: String, + db: &'db dyn Db, + class_body_scope: ScopeId<'db>, + name: String, _target_method_decorator: MethodDecorator, ) -> PlaceAndQualifiers<'db> { - Place::Unbound.into() + let implicit_attribute_table = implicit_attribute_table(db, class_body_scope); + let implicit_attr_id = implicit_attribute_table + .symbol_id(&name) + .expect("Implicit attribute must exist in the table"); + Place::bound(Type::divergent(DivergentType::implicit_attribute( + db, + class_body_scope, + implicit_attr_id, + ))) + .into() } fn try_mro_cycle_recover<'db>( @@ -2890,6 +2905,14 @@ impl<'db> ClassLiteral<'db> { let index = semantic_index(db, file); let class_map = use_def_map(db, class_body_scope); let class_table = place_table(db, class_body_scope); + let implicit_attr_table = implicit_attribute_table(db, class_body_scope); + let div = if let Some(implicit_attr) = implicit_attr_table.symbol_id(&name) { + DivergentType::implicit_attribute(db, class_body_scope, implicit_attr) + } else { + // Implicit attributes should either not exist or have a type declaration, and should not diverge. + DivergentType::todo(db, class_body_scope) + }; + let visitor = NormalizedVisitor::default().recursive(Type::divergent(div)); let is_valid_scope = |method_scope: ScopeId<'db>| { if let Some(method_def) = method_scope.node(db).as_function() { @@ -2942,12 +2965,14 @@ impl<'db> ClassLiteral<'db> { // `self.SOME_CONSTANT: Final = 1`, infer the type from the value // on the right-hand side. - let inferred_ty = infer_expression_type( + let inferred_ty = infer_implicit_attribute_expression_type( db, index.expression(value), TypeContext::default(), + div, ); - return Place::bound(inferred_ty).with_qualifiers(all_qualifiers); + return Place::bound(inferred_ty.normalized_impl(db, &visitor)) + .with_qualifiers(all_qualifiers); } // If there is no right-hand side, just record that we saw a `Final` qualifier @@ -3020,24 +3045,28 @@ impl<'db> ClassLiteral<'db> { // (.., self.name, ..) = // [.., self.name, ..] = - let unpacked = infer_unpack_types(db, unpack); + let unpacked = + infer_unpack_implicit_attribute_types(db, unpack, div); let inferred_ty = unpacked.expression_type(assign.target(&module)); - union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + union_of_inferred_types = union_of_inferred_types + .add(inferred_ty.normalized_impl(db, &visitor)); } TargetKind::Single => { // We found an un-annotated attribute assignment of the form: // // self.name = - let inferred_ty = infer_expression_type( + let inferred_ty = infer_implicit_attribute_expression_type( db, index.expression(assign.value(&module)), TypeContext::default(), + div, ); - union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + union_of_inferred_types = union_of_inferred_types + .add(inferred_ty.normalized_impl(db, &visitor)); } } } @@ -3048,27 +3077,31 @@ impl<'db> ClassLiteral<'db> { // // for .., self.name, .. in : - let unpacked = infer_unpack_types(db, unpack); + let unpacked = + infer_unpack_implicit_attribute_types(db, unpack, div); let inferred_ty = unpacked.expression_type(for_stmt.target(&module)); - union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + union_of_inferred_types = union_of_inferred_types + .add(inferred_ty.normalized_impl(db, &visitor)); } TargetKind::Single => { // We found an attribute assignment like: // // for self.name in : - let iterable_ty = infer_expression_type( + let iterable_ty = infer_implicit_attribute_expression_type( db, index.expression(for_stmt.iterable(&module)), TypeContext::default(), + div, ); // TODO: Potential diagnostics resulting from the iterable are currently not reported. let inferred_ty = iterable_ty.iterate(db).homogeneous_element_type(db); - union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + union_of_inferred_types = union_of_inferred_types + .add(inferred_ty.normalized_impl(db, &visitor)); } } } @@ -3079,21 +3112,24 @@ impl<'db> ClassLiteral<'db> { // // with as .., self.name, ..: - let unpacked = infer_unpack_types(db, unpack); + let unpacked = + infer_unpack_implicit_attribute_types(db, unpack, div); let inferred_ty = unpacked.expression_type(with_item.target(&module)); - union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + union_of_inferred_types = union_of_inferred_types + .add(inferred_ty.normalized_impl(db, &visitor)); } TargetKind::Single => { // We found an attribute assignment like: // // with as self.name: - let context_ty = infer_expression_type( + let context_ty = infer_implicit_attribute_expression_type( db, index.expression(with_item.context_expr(&module)), TypeContext::default(), + div, ); let inferred_ty = if with_item.is_async() { context_ty.aenter(db) @@ -3101,7 +3137,8 @@ impl<'db> ClassLiteral<'db> { context_ty.enter(db) }; - union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + union_of_inferred_types = union_of_inferred_types + .add(inferred_ty.normalized_impl(db, &visitor)); } } } @@ -3112,28 +3149,32 @@ impl<'db> ClassLiteral<'db> { // // [... for .., self.name, .. in ] - let unpacked = infer_unpack_types(db, unpack); + let unpacked = + infer_unpack_implicit_attribute_types(db, unpack, div); let inferred_ty = unpacked.expression_type(comprehension.target(&module)); - union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + union_of_inferred_types = union_of_inferred_types + .add(inferred_ty.normalized_impl(db, &visitor)); } TargetKind::Single => { // We found an attribute assignment like: // // [... for self.name in ] - let iterable_ty = infer_expression_type( + let iterable_ty = infer_implicit_attribute_expression_type( db, index.expression(comprehension.iterable(&module)), TypeContext::default(), + div, ); // TODO: Potential diagnostics resulting from the iterable are currently not reported. let inferred_ty = iterable_ty.iterate(db).homogeneous_element_type(db); - union_of_inferred_types = union_of_inferred_types.add(inferred_ty); + union_of_inferred_types = union_of_inferred_types + .add(inferred_ty.normalized_impl(db, &visitor)); } } } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 6dc067efe6860..084b84a572c4b 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -73,6 +73,7 @@ use crate::types::diagnostic::{ report_runtime_check_against_non_runtime_checkable_protocol, }; use crate::types::generics::GenericContext; +use crate::types::infer::infer_function_scope_types; use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; @@ -81,8 +82,7 @@ use crate::types::{ DeprecatedInstance, DivergentType, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, SpecialFormType, TrackedConstraintSet, Truthiness, Type, TypeMapping, TypeRelation, - UnionBuilder, all_members, binding_type, infer_scope_types, todo_type, walk_generic_context, - walk_type_mapping, + UnionBuilder, all_members, binding_type, todo_type, walk_generic_context, walk_type_mapping, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; @@ -96,9 +96,10 @@ fn return_type_cycle_recover<'db>( } fn return_type_cycle_initial<'db>(db: &'db dyn Db, function: FunctionType<'db>) -> Type<'db> { - Type::Dynamic(DynamicType::Divergent(DivergentType { - scope: function.literal(db).last_definition(db).body_scope(db), - })) + Type::Dynamic(DynamicType::Divergent(DivergentType::function_return_type( + db, + function.literal(db).last_definition(db).body_scope(db), + ))) } /// A collection of useful spans for annotating functions. @@ -1045,7 +1046,8 @@ impl<'db> FunctionType<'db> { #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.literal(db).last_definition(db).body_scope(db); - let inference = infer_scope_types(db, scope); + let inference = + infer_function_scope_types(db, scope, DivergentType::function_return_type(db, scope)); inference.infer_return_type(db, Type::FunctionLiteral(self)) } } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 51cbb18bbac09..b6a77da822714 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -61,11 +61,30 @@ mod builder; #[cfg(test)] mod tests; +pub(crate) fn infer_scope_types<'db>( + db: &'db dyn Db, + scope: ScopeId<'db>, +) -> &'db ScopeInference<'db> { + infer_scope_types_impl(db, scope, DivergentType::todo(db, scope)) +} + +pub(crate) fn infer_function_scope_types<'db>( + db: &'db dyn Db, + scope: ScopeId<'db>, + divergent: DivergentType<'db>, +) -> &'db ScopeInference<'db> { + infer_scope_types_impl(db, scope, divergent) +} + /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. #[salsa::tracked(returns(ref), cycle_fn=scope_cycle_recover, cycle_initial=scope_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { +fn infer_scope_types_impl<'db>( + db: &'db dyn Db, + scope: ScopeId<'db>, + _divergent: DivergentType<'db>, +) -> ScopeInference<'db> { let file = scope.file(db); let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered(); @@ -83,20 +102,37 @@ fn scope_cycle_recover<'db>( _value: &ScopeInference<'db>, _count: u32, _scope: ScopeId<'db>, + _divergent: DivergentType<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } -fn scope_cycle_initial<'db>(_db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { - ScopeInference::cycle_initial(scope) +fn scope_cycle_initial<'db>( + _db: &'db dyn Db, + scope: ScopeId<'db>, + divergent: DivergentType<'db>, +) -> ScopeInference<'db> { + ScopeInference::cycle_initial(divergent, scope) +} + +pub(crate) fn infer_definition_types<'db>( + db: &'db dyn Db, + definition: Definition<'db>, +) -> &'db DefinitionInference<'db> { + infer_definition_types_impl( + db, + definition, + DivergentType::todo(db, definition.scope(db)), + ) } /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a place use or public type of a place. #[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -pub(crate) fn infer_definition_types<'db>( +fn infer_definition_types_impl<'db>( db: &'db dyn Db, definition: Definition<'db>, + _divergent: DivergentType<'db>, ) -> DefinitionInference<'db> { let file = definition.file(db); let module = parsed_module(db, file).load(db); @@ -118,6 +154,7 @@ fn definition_cycle_recover<'db>( _value: &DefinitionInference<'db>, _count: u32, _definition: Definition<'db>, + _divergent: DivergentType<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } @@ -125,8 +162,9 @@ fn definition_cycle_recover<'db>( fn definition_cycle_initial<'db>( db: &'db dyn Db, definition: Definition<'db>, + divergent: DivergentType<'db>, ) -> DefinitionInference<'db> { - DefinitionInference::cycle_initial(definition.scope(db)) + DefinitionInference::cycle_initial(definition.scope(db), divergent) } /// Infer types for all deferred type expressions in a [`Definition`]. @@ -167,7 +205,10 @@ fn deferred_cycle_initial<'db>( db: &'db dyn Db, definition: Definition<'db>, ) -> DefinitionInference<'db> { - DefinitionInference::cycle_initial(definition.scope(db)) + DefinitionInference::cycle_initial( + definition.scope(db), + DivergentType::todo(db, definition.scope(db)), + ) } /// Infer all types for an [`Expression`] (including sub-expressions). @@ -179,11 +220,19 @@ pub(crate) fn infer_expression_types<'db>( expression: Expression<'db>, tcx: TypeContext<'db>, ) -> &'db ExpressionInference<'db> { - infer_expression_types_impl(db, InferExpression::new(db, expression, tcx)) + infer_expression_types_impl( + db, + InferExpression::new( + db, + expression, + tcx, + DivergentType::todo(db, expression.scope(db)), + ), + ) } #[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -fn infer_expression_types_impl<'db>( +pub(super) fn infer_expression_types_impl<'db>( db: &'db dyn Db, input: InferExpression<'db>, ) -> ExpressionInference<'db> { @@ -223,7 +272,7 @@ fn expression_cycle_initial<'db>( db: &'db dyn Db, input: InferExpression<'db>, ) -> ExpressionInference<'db> { - ExpressionInference::cycle_initial(input.expression(db).scope(db)) + ExpressionInference::cycle_initial(input.expression(db).scope(db), input.divergent(db)) } /// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query. @@ -253,7 +302,24 @@ pub(crate) fn infer_expression_type<'db>( expression: Expression<'db>, tcx: TypeContext<'db>, ) -> Type<'db> { - infer_expression_type_impl(db, InferExpression::new(db, expression, tcx)) + infer_expression_type_impl( + db, + InferExpression::new( + db, + expression, + tcx, + DivergentType::todo(db, expression.scope(db)), + ), + ) +} + +pub(crate) fn infer_implicit_attribute_expression_type<'db>( + db: &'db dyn Db, + expression: Expression<'db>, + tcx: TypeContext<'db>, + divergent: DivergentType<'db>, +) -> Type<'db> { + infer_expression_type_impl(db, InferExpression::new(db, expression, tcx, divergent)) } #[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] @@ -275,11 +341,8 @@ fn single_expression_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -fn single_expression_cycle_initial<'db>( - _db: &'db dyn Db, - _input: InferExpression<'db>, -) -> Type<'db> { - Type::Never +fn single_expression_cycle_initial<'db>(db: &'db dyn Db, input: InferExpression<'db>) -> Type<'db> { + Type::Dynamic(DynamicType::Divergent(input.divergent(db))) } /// An `Expression` with an optional `TypeContext`. @@ -287,29 +350,41 @@ fn single_expression_cycle_initial<'db>( /// This is a Salsa supertype used as the input to `infer_expression_types` to avoid /// interning an `ExpressionWithContext` unnecessarily when no type context is provided. #[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update)] -enum InferExpression<'db> { - Bare(Expression<'db>), - WithContext(ExpressionWithContext<'db>), +pub(super) enum InferExpression<'db> { + WithDivergent(ExpressionWithDivergent<'db>), + WithContextAndDivergent(ExpressionWithContextAndDivergent<'db>), } impl<'db> InferExpression<'db> { - fn new( + pub(super) fn new( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, + divergent: DivergentType<'db>, ) -> InferExpression<'db> { if tcx.annotation.is_some() { - InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx)) + InferExpression::WithContextAndDivergent(ExpressionWithContextAndDivergent::new( + db, expression, tcx, divergent, + )) } else { // Drop the empty `TypeContext` to avoid the interning cost. - InferExpression::Bare(expression) + InferExpression::WithDivergent(ExpressionWithDivergent::new(db, expression, divergent)) } } + #[cfg(test)] + pub(super) fn with_divergent( + db: &'db dyn Db, + expression: Expression<'db>, + divergent: DivergentType<'db>, + ) -> InferExpression<'db> { + InferExpression::WithDivergent(ExpressionWithDivergent::new(db, expression, divergent)) + } + fn expression(self, db: &'db dyn Db) -> Expression<'db> { match self { - InferExpression::Bare(expression) => expression, - InferExpression::WithContext(expression_with_context) => { + InferExpression::WithDivergent(bare) => bare.expression(db), + InferExpression::WithContextAndDivergent(expression_with_context) => { expression_with_context.expression(db) } } @@ -317,19 +392,36 @@ impl<'db> InferExpression<'db> { fn tcx(self, db: &'db dyn Db) -> TypeContext<'db> { match self { - InferExpression::Bare(_) => TypeContext::default(), - InferExpression::WithContext(expression_with_context) => { + InferExpression::WithDivergent(_) => TypeContext::default(), + InferExpression::WithContextAndDivergent(expression_with_context) => { expression_with_context.tcx(db) } } } + + fn divergent(self, db: &'db dyn Db) -> DivergentType<'db> { + match self { + InferExpression::WithDivergent(bare) => bare.divergent(db), + InferExpression::WithContextAndDivergent(expression_with_context) => { + expression_with_context.divergent(db) + } + } + } +} + +/// An `Expression` with a `DivergentType`. +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +pub(super) struct ExpressionWithDivergent<'db> { + expression: Expression<'db>, + divergent: DivergentType<'db>, } -/// An `Expression` with a `TypeContext`. +/// An `Expression` with a `TypeContext` and a `DivergentType`. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -struct ExpressionWithContext<'db> { +pub(super) struct ExpressionWithContextAndDivergent<'db> { expression: Expression<'db>, tcx: TypeContext<'db>, + divergent: DivergentType<'db>, } /// The type context for a given expression, namely the type annotation @@ -359,7 +451,14 @@ pub(crate) fn static_expression_truthiness<'db>( db: &'db dyn Db, expression: Expression<'db>, ) -> Truthiness { - let inference = infer_expression_types_impl(db, InferExpression::Bare(expression)); + let inference = infer_expression_types_impl( + db, + InferExpression::WithDivergent(ExpressionWithDivergent::new( + db, + expression, + DivergentType::todo(db, expression.scope(db)), + )), + ); if !inference.all_places_definitely_bound() { return Truthiness::Ambiguous; @@ -389,6 +488,21 @@ fn static_expression_truthiness_cycle_initial<'db>( Truthiness::Ambiguous } +pub(super) fn infer_unpack_implicit_attribute_types<'db>( + db: &'db dyn Db, + unpack: Unpack<'db>, + divergent: DivergentType<'db>, +) -> &'db UnpackResult<'db> { + infer_unpack_types_impl(db, unpack, divergent) +} + +pub(super) fn infer_unpack_types<'db>( + db: &'db dyn Db, + unpack: Unpack<'db>, +) -> &'db UnpackResult<'db> { + infer_unpack_types_impl(db, unpack, DivergentType::todo(db, unpack.target_scope(db))) +} + /// Infer the types for an [`Unpack`] operation. /// /// This infers the expression type and performs structural match against the target expression @@ -396,14 +510,18 @@ fn static_expression_truthiness_cycle_initial<'db>( /// type of the variables involved in this unpacking along with any violations that are detected /// during this unpacking. #[salsa::tracked(returns(ref), cycle_fn=unpack_cycle_recover, cycle_initial=unpack_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> { +fn infer_unpack_types_impl<'db>( + db: &'db dyn Db, + unpack: Unpack<'db>, + divergent: DivergentType<'db>, +) -> UnpackResult<'db> { let file = unpack.file(db); let module = parsed_module(db, file).load(db); let _span = tracing::trace_span!("infer_unpack_types", range=?unpack.range(db, &module), ?file) .entered(); let mut unpacker = Unpacker::new(db, unpack.target_scope(db), &module); - unpacker.unpack(unpack.target(db, &module), unpack.value(db)); + unpacker.unpack(unpack.target(db, &module), unpack.value(db), divergent); unpacker.finish() } @@ -412,12 +530,17 @@ fn unpack_cycle_recover<'db>( _value: &UnpackResult<'db>, _count: u32, _unpack: Unpack<'db>, + _divergent: DivergentType<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } -fn unpack_cycle_initial<'db>(_db: &'db dyn Db, _unpack: Unpack<'db>) -> UnpackResult<'db> { - UnpackResult::cycle_initial(Type::Never) +fn unpack_cycle_initial<'db>( + _db: &'db dyn Db, + _unpack: Unpack<'db>, + divergent: DivergentType<'db>, +) -> UnpackResult<'db> { + UnpackResult::cycle_initial(Type::divergent(divergent)) } /// Returns the type of the nearest enclosing class for the given scope. @@ -478,15 +601,15 @@ pub(crate) struct ScopeInference<'db> { expressions: FxHashMap>, /// The extra data that is only present for few inference regions. - extra: Option>, + extra: Option>>, scope: ScopeId<'db>, } #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] -struct ScopeInferenceExtra { - /// Whether to use the fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_fallback: bool, +struct ScopeInferenceExtra<'db> { + /// The fallback type for missing expressions/bindings/declarations or recursive type inference. + cycle_recovery: Option>, /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, @@ -498,10 +621,10 @@ struct ScopeInferenceExtra { } impl<'db> ScopeInference<'db> { - fn cycle_initial(scope: ScopeId<'db>) -> Self { + fn cycle_initial(divergent: DivergentType<'db>, scope: ScopeId<'db>) -> Self { Self { extra: Some(Box::new(ScopeInferenceExtra { - cycle_fallback: true, + cycle_recovery: Some(divergent), ..ScopeInferenceExtra::default() })), expressions: FxHashMap::default(), @@ -531,10 +654,7 @@ impl<'db> ScopeInference<'db> { fn fallback_type(&self) -> Option> { self.extra .as_ref() - .is_some_and(|extra| extra.cycle_fallback) - .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { - scope: self.scope, - }))) + .and_then(|extra| extra.cycle_recovery.map(Type::divergent)) } /// Returns the inferred return type of this function body (union of all possible return types), @@ -549,7 +669,7 @@ impl<'db> ScopeInference<'db> { } let mut union = UnionBuilder::new(db); - let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope: self.scope })); + let div = Type::divergent(DivergentType::function_return_type(db, self.scope)); if let Some(fallback_type) = self.fallback_type() { union = union.add(fallback_type); } @@ -616,6 +736,7 @@ pub(crate) struct DefinitionInference<'db> { expressions: FxHashMap>, /// The scope this region is part of. + #[cfg(debug_assertions)] scope: ScopeId<'db>, /// The types of every binding in this region. @@ -637,8 +758,8 @@ pub(crate) struct DefinitionInference<'db> { #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] struct DefinitionInferenceExtra<'db> { - /// Whether to use the fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_fallback: bool, + /// The fallback type for missing expressions/bindings/declarations or recursive type inference. + cycle_recovery: Option>, /// The definitions that are deferred. deferred: Box<[Definition<'db>]>, @@ -651,16 +772,17 @@ struct DefinitionInferenceExtra<'db> { } impl<'db> DefinitionInference<'db> { - fn cycle_initial(scope: ScopeId<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>, divergent: DivergentType<'db>) -> Self { let _ = scope; Self { expressions: FxHashMap::default(), bindings: Box::default(), declarations: Box::default(), + #[cfg(debug_assertions)] scope, extra: Some(Box::new(DefinitionInferenceExtra { - cycle_fallback: true, + cycle_recovery: Some(divergent), ..DefinitionInferenceExtra::default() })), } @@ -732,10 +854,7 @@ impl<'db> DefinitionInference<'db> { fn fallback_type(&self) -> Option> { self.extra .as_ref() - .is_some_and(|extra| extra.cycle_fallback) - .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { - scope: self.scope, - }))) + .and_then(|extra| extra.cycle_recovery.map(Type::divergent)) } pub(crate) fn undecorated_type(&self) -> Option> { @@ -752,6 +871,7 @@ pub(crate) struct ExpressionInference<'db> { extra: Option>>, /// The scope this region is part of. + #[cfg(debug_assertions)] scope: ScopeId<'db>, } @@ -766,23 +886,24 @@ struct ExpressionInferenceExtra<'db> { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, - /// Whether to use the fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_fallback: bool, + /// The fallback type for missing expressions/bindings/declarations or recursive type inference. + cycle_recovery: Option>, /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, } impl<'db> ExpressionInference<'db> { - fn cycle_initial(scope: ScopeId<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>, divergent: DivergentType<'db>) -> Self { + let _ = scope; Self { extra: Some(Box::new(ExpressionInferenceExtra { - cycle_fallback: true, + cycle_recovery: Some(divergent), all_definitely_bound: true, ..ExpressionInferenceExtra::default() })), expressions: FxHashMap::default(), - + #[cfg(debug_assertions)] scope, } } @@ -805,10 +926,7 @@ impl<'db> ExpressionInference<'db> { fn fallback_type(&self) -> Option> { self.extra .as_ref() - .is_some_and(|extra| extra.cycle_fallback) - .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { - scope: self.scope, - }))) + .and_then(|extra| extra.cycle_recovery.map(Type::divergent)) } /// Returns true if all places in this expression are definitely bound. diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index afa26699b6299..bff834fb29099 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -253,8 +253,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// For function definitions, the undecorated type of the function. undecorated_type: Option>, - /// Whether to use the fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_fallback: bool, + /// The fallback type for missing expressions/bindings/declarations or recursive type inference. + cycle_recovery: Option>, /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, @@ -290,16 +290,30 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { typevar_binding_context: None, deferred: VecSet::default(), undecorated_type: None, - cycle_fallback: false, + cycle_recovery: None, all_definitely_bound: true, } } fn fallback_type(&self) -> Option> { - self.cycle_fallback - .then_some(Type::Dynamic(DynamicType::Divergent(DivergentType { - scope: self.scope, - }))) + self.cycle_recovery.map(Type::divergent) + } + + fn merge_cycle_recovery(&mut self, other: Option>) { + match (self.cycle_recovery, other) { + (None, _) | (Some(_), None) => { + self.cycle_recovery = self.cycle_recovery.or(other); + } + (Some(self_), Some(other)) => { + if self_ == other { + // OK, do nothing + } else if self_.kind(self.db()).is_todo() { + self.cycle_recovery = Some(other); + } else { + panic!("Cannot merge divergent types"); + } + } + } } fn extend_definition(&mut self, inference: &DefinitionInference<'db>) { @@ -313,7 +327,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } if let Some(extra) = &inference.extra { - self.cycle_fallback |= extra.cycle_fallback; + self.merge_cycle_recovery(extra.cycle_recovery); self.context.extend(&extra.diagnostics); self.deferred.extend(extra.deferred.iter().copied()); } @@ -330,7 +344,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(extra) = &inference.extra { self.context.extend(&extra.diagnostics); - self.cycle_fallback |= extra.cycle_fallback; + self.merge_cycle_recovery(extra.cycle_recovery); if !matches!(self.region, InferenceRegion::Scope(..)) { self.bindings.extend(extra.bindings.iter().copied()); @@ -5239,15 +5253,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = tuple; let db = self.db(); - let divergent = Type::divergent(self.scope()); let element_types = elts.iter().map(|element| { // TODO: Use the type context for more precise inference. - let element_type = self.infer_expression(element, TypeContext::default()); - if element_type.has_divergent_type(self.db(), divergent) { - divergent - } else { - element_type - } + self.infer_expression(element, TypeContext::default()) }); Type::heterogeneous_tuple(db, element_types) @@ -8876,7 +8884,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { bindings, declarations, deferred, - cycle_fallback, + cycle_recovery, all_definitely_bound, // Ignored; only relevant to definition regions @@ -8904,7 +8912,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); let extra = - (cycle_fallback || !bindings.is_empty() || !diagnostics.is_empty() || !all_definitely_bound).then(|| { + (cycle_recovery.is_some() || !bindings.is_empty() || !diagnostics.is_empty() || !all_definitely_bound).then(|| { if bindings.len() > 20 { tracing::debug!( "Inferred expression region `{:?}` contains {} bindings. Lookups by linear scan might be slow.", @@ -8916,7 +8924,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Box::new(ExpressionInferenceExtra { bindings: bindings.into_boxed_slice(), diagnostics, - cycle_fallback, + cycle_recovery, all_definitely_bound, }) }); @@ -8926,6 +8934,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ExpressionInference { expressions, extra, + #[cfg(debug_assertions)] scope, } } @@ -8940,7 +8949,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { bindings, declarations, deferred, - cycle_fallback, + cycle_recovery, undecorated_type, all_definitely_bound: _, // builder only state @@ -8956,12 +8965,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let diagnostics = context.finish(); let extra = (!diagnostics.is_empty() - || cycle_fallback + || cycle_recovery.is_some() || undecorated_type.is_some() || !deferred.is_empty()) .then(|| { Box::new(DefinitionInferenceExtra { - cycle_fallback, + cycle_recovery, deferred: deferred.into_boxed_slice(), diagnostics, undecorated_type, @@ -8988,6 +8997,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { DefinitionInference { expressions, + #[cfg(debug_assertions)] scope, bindings: bindings.into_boxed_slice(), declarations: declarations.into_boxed_slice(), @@ -9003,7 +9013,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { context, mut expressions, scope, - cycle_fallback, + cycle_recovery, // Ignored, because scope types are never extended into other scopes. deferred: _, @@ -9025,18 +9035,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let diagnostics = context.finish(); - let extra = (!diagnostics.is_empty() || cycle_fallback || scope.is_non_lambda_function(db)) - .then(|| { - let returnees = returnees - .into_iter() - .map(|returnee| returnee.expression) - .collect(); - Box::new(ScopeInferenceExtra { - cycle_fallback, - diagnostics, - returnees, - }) - }); + let extra = (!diagnostics.is_empty() + || cycle_recovery.is_some() + || scope.is_non_lambda_function(db)) + .then(|| { + let returnees = returnees + .into_iter() + .map(|returnee| returnee.expression) + .collect(); + Box::new(ScopeInferenceExtra { + cycle_recovery, + diagnostics, + returnees, + }) + }); expressions.shrink_to_fit(); diff --git a/crates/ty_python_semantic/src/types/infer/tests.rs b/crates/ty_python_semantic/src/types/infer/tests.rs index ec4a0d46c8e0f..8aa437ea1caad 100644 --- a/crates/ty_python_semantic/src/types/infer/tests.rs +++ b/crates/ty_python_semantic/src/types/infer/tests.rs @@ -442,10 +442,11 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; + let div = DivergentType::todo(&db, global_scope(&db, file_main)); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::Bare(x_rhs_expression(&db)), + InferExpression::with_divergent(&db, x_rhs_expression(&db), div), &events, ); @@ -466,11 +467,11 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - + let div = DivergentType::todo(&db, global_scope(&db, file_main)); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::Bare(x_rhs_expression(&db)), + InferExpression::with_divergent(&db, x_rhs_expression(&db), div), &events, ); @@ -534,10 +535,11 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; + let div = DivergentType::todo(&db, global_scope(&db, file_main)); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::Bare(x_rhs_expression(&db)), + InferExpression::with_divergent(&db, x_rhs_expression(&db), div), &events, ); @@ -560,11 +562,11 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - + let div = DivergentType::todo(&db, global_scope(&db, file_main)); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::Bare(x_rhs_expression(&db)), + InferExpression::with_divergent(&db, x_rhs_expression(&db), div), &events, ); @@ -631,10 +633,11 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); db.take_salsa_events() }; + let div = DivergentType::todo(&db, global_scope(&db, file_main)); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::Bare(x_rhs_expression(&db)), + InferExpression::with_divergent(&db, x_rhs_expression(&db), div), &events, ); @@ -659,11 +662,11 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); db.take_salsa_events() }; - + let div = DivergentType::todo(&db, global_scope(&db, file_main)); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::Bare(x_rhs_expression(&db)), + InferExpression::with_divergent(&db, x_rhs_expression(&db), div), &events, ); @@ -705,10 +708,11 @@ fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; let foo_call = semantic_index(&db, bar).expression(call); + let div = DivergentType::todo(&db, global_scope(&db, bar)); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::Bare(foo_call), + InferExpression::with_divergent(&db, foo_call, div), &events, ); @@ -736,10 +740,11 @@ fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; let foo_call = semantic_index(&db, bar).expression(call); + let div = DivergentType::todo(&db, global_scope(&db, bar)); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::Bare(foo_call), + InferExpression::with_divergent(&db, foo_call, div), &events, ); diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 3a0f1cd251113..790700748e2b7 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -118,7 +118,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (SubclassOfInner::Class(_), _) => Ordering::Less, (_, SubclassOfInner::Class(_)) => Ordering::Greater, (SubclassOfInner::Dynamic(left), SubclassOfInner::Dynamic(right)) => { - dynamic_elements_ordering(left, right) + dynamic_elements_ordering(db, left, right) } } } @@ -172,7 +172,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (_, ClassBase::TypedDict) => Ordering::Greater, (ClassBase::Dynamic(left), ClassBase::Dynamic(right)) => { - dynamic_elements_ordering(left, right) + dynamic_elements_ordering(db, left, right) } }) .then_with(|| match (left.owner(db), right.owner(db)) { @@ -185,7 +185,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (SuperOwnerKind::Instance(_), _) => Ordering::Less, (_, SuperOwnerKind::Instance(_)) => Ordering::Greater, (SuperOwnerKind::Dynamic(left), SuperOwnerKind::Dynamic(right)) => { - dynamic_elements_ordering(left, right) + dynamic_elements_ordering(db, left, right) } }) } @@ -204,7 +204,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::PropertyInstance(_), _) => Ordering::Less, (_, Type::PropertyInstance(_)) => Ordering::Greater, - (Type::Dynamic(left), Type::Dynamic(right)) => dynamic_elements_ordering(*left, *right), + (Type::Dynamic(left), Type::Dynamic(right)) => dynamic_elements_ordering(db, *left, *right), (Type::Dynamic(_), _) => Ordering::Less, (_, Type::Dynamic(_)) => Ordering::Greater, @@ -253,7 +253,11 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( } /// Determine a canonical order for two instances of [`DynamicType`]. -fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering { +fn dynamic_elements_ordering<'db>( + db: &'db dyn Db, + left: DynamicType<'db>, + right: DynamicType<'db>, +) -> Ordering { match (left, right) { (DynamicType::Any, _) => Ordering::Less, (_, DynamicType::Any) => Ordering::Greater, @@ -277,7 +281,7 @@ fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering (_, DynamicType::TodoTypeAlias) => Ordering::Greater, (DynamicType::Divergent(left), DynamicType::Divergent(right)) => { - left.scope.cmp(&right.scope) + left.scope(db).cmp(&right.scope(db)) } (DynamicType::Divergent(_), _) => Ordering::Less, (_, DynamicType::Divergent(_)) => Ordering::Greater, diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index b448053295f82..7d7efe7112e1f 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -8,8 +8,9 @@ use ruff_python_ast::{self as ast, AnyNodeRef}; use crate::Db; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::scope::ScopeId; +use crate::types::infer::{InferExpression, infer_expression_types_impl}; use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker}; -use crate::types::{Type, TypeCheckDiagnostics, TypeContext, infer_expression_types}; +use crate::types::{DivergentType, Type, TypeCheckDiagnostics, TypeContext}; use crate::unpack::{UnpackKind, UnpackValue}; use super::context::InferContext; @@ -42,15 +43,25 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { } /// Unpack the value to the target expression. - pub(crate) fn unpack(&mut self, target: &ast::Expr, value: UnpackValue<'db>) { + pub(crate) fn unpack( + &mut self, + target: &ast::Expr, + value: UnpackValue<'db>, + divergent: DivergentType<'db>, + ) { debug_assert!( matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)), "Unpacking target must be a list or tuple expression" ); - let value_type = - infer_expression_types(self.db(), value.expression(), TypeContext::default()) - .expression_type(value.expression().node_ref(self.db(), self.module())); + let input = InferExpression::new( + self.db(), + value.expression(), + TypeContext::default(), + divergent, + ); + let value_type = infer_expression_types_impl(self.db(), input) + .expression_type(value.expression().node_ref(self.db(), self.module())); let value_type = match value.kind() { UnpackKind::Assign => { From 74af681bb5fb5b501f5742d4034169ce5f86fecc Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 14 Sep 2025 13:53:54 +0900 Subject: [PATCH 069/105] union the previous cycle type in `implicit_attribute` --- .../resources/mdtest/attributes.md | 17 +++++ crates/ty_python_semantic/src/types.rs | 70 +++++++++++++++---- crates/ty_python_semantic/src/types/class.rs | 21 ++++-- .../src/types/infer/builder.rs | 2 + .../src/types/infer/tests.rs | 27 ++++++- 5 files changed, 119 insertions(+), 18 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/attributes.md b/crates/ty_python_semantic/resources/mdtest/attributes.md index 1ef154ee3600d..0c25ee00ad85c 100644 --- a/crates/ty_python_semantic/resources/mdtest/attributes.md +++ b/crates/ty_python_semantic/resources/mdtest/attributes.md @@ -2281,6 +2281,23 @@ class B: reveal_type(B().x) # revealed: Unknown | Literal[1] reveal_type(A().x) # revealed: Unknown | Literal[1] + +class Base: + def flip(self) -> "Sub": + return Sub() + +class Sub(Base): + def flip(self) -> "Base": + return Base() + +class C2: + def __init__(self, x: Sub): + self.x = x + + def replace_with(self, other: "C2"): + self.x = other.x.flip() + +reveal_type(C2(Sub()).x) # revealed: Unknown | Base ``` This case additionally tests our union/intersection simplification logic: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index ac3a0e227782a..7cf97487d654b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -3025,7 +3025,7 @@ impl<'db> Type<'db> { policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { tracing::trace!("class_member: {}.{}", self.display(db), name); - match self { + let result = match self { Type::Union(union) => union.map_with_boundness_and_qualifiers(db, |elem| { elem.class_member_with_policy(db, name.clone(), policy) }), @@ -3043,7 +3043,28 @@ impl<'db> Type<'db> { .expect( "`Type::find_name_in_mro()` should return `Some()` when called on a meta-type", ), - } + }; + result.map_type(|ty| { + // In fixed-point iteration of type inference, the member type must be monotonically widened and not "oscillate". + // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. + let previous_cycle_value = self.class_member_with_policy(db, name.clone(), policy); + + let ty = if let Some(previous_ty) = previous_cycle_value.place.ignore_possibly_unbound() + { + UnionType::from_elements(db, [ty, previous_ty]) + } else { + ty + }; + + if let Place::Type(div @ Type::Dynamic(DynamicType::Divergent(_)), _) = + class_lookup_cycle_initial(db, self, name, policy).place + { + let visitor = NormalizedVisitor::default().recursive(div); + ty.normalized_impl(db, &visitor) + } else { + ty + } + }) } /// This function roughly corresponds to looking up an attribute in the `__dict__` of an object. @@ -3497,7 +3518,7 @@ impl<'db> Type<'db> { let name_str = name.as_str(); - match self { + let result = match self { Type::Union(union) => union.map_with_boundness_and_qualifiers(db, |elem| { elem.member_lookup_with_policy(db, name_str.into(), policy) }), @@ -3565,20 +3586,20 @@ impl<'db> Type<'db> { // If an attribute is not available on the bound method object, // it will be looked up on the underlying function object: Type::FunctionLiteral(bound_method.function(db)) - .member_lookup_with_policy(db, name, policy) + .member_lookup_with_policy(db, name.clone(), policy) }) } }, Type::KnownBoundMethod(method) => method .class() .to_instance(db) - .member_lookup_with_policy(db, name, policy), + .member_lookup_with_policy(db, name.clone(), policy), Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType .to_instance(db) - .member_lookup_with_policy(db, name, policy), + .member_lookup_with_policy(db, name.clone(), policy), Type::DataclassDecorator(_) => KnownClass::FunctionType .to_instance(db) - .member_lookup_with_policy(db, name, policy), + .member_lookup_with_policy(db, name.clone(), policy), Type::Callable(_) | Type::DataclassTransformer(_) if name_str == "__call__" => { Place::bound(self).into() @@ -3586,10 +3607,10 @@ impl<'db> Type<'db> { Type::Callable(callable) if callable.is_function_like(db) => KnownClass::FunctionType .to_instance(db) - .member_lookup_with_policy(db, name, policy), + .member_lookup_with_policy(db, name.clone(), policy), Type::Callable(_) | Type::DataclassTransformer(_) => { - Type::object().member_lookup_with_policy(db, name, policy) + Type::object().member_lookup_with_policy(db, name.clone(), policy) } Type::NominalInstance(instance) @@ -3647,9 +3668,11 @@ impl<'db> Type<'db> { policy, ), - Type::TypeAlias(alias) => alias - .value_type(db) - .member_lookup_with_policy(db, name, policy), + Type::TypeAlias(alias) => { + alias + .value_type(db) + .member_lookup_with_policy(db, name.clone(), policy) + } Type::NominalInstance(..) | Type::ProtocolInstance(..) @@ -3811,7 +3834,28 @@ impl<'db> Type<'db> { .try_call_dunder_get_on_attribute(db, owner_attr.clone()) .unwrap_or(owner_attr) } - } + }; + result.map_type(|ty| { + // In fixed-point iteration of type inference, the member type must be monotonically widened and not "oscillate". + // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. + let previous_cycle_value = self.member_lookup_with_policy(db, name.clone(), policy); + + let ty = if let Some(previous_ty) = previous_cycle_value.place.ignore_possibly_unbound() + { + UnionType::from_elements(db, [ty, previous_ty]) + } else { + ty + }; + + if let Place::Type(div @ Type::Dynamic(DynamicType::Divergent(_)), _) = + member_lookup_cycle_initial(db, self, name, policy).place + { + let visotor = NormalizedVisitor::default().recursive(div); + ty.normalized_impl(db, &visotor) + } else { + ty + } + }) } /// Resolves the boolean value of the type and falls back to [`Truthiness::Ambiguous`] if the type doesn't implement `__bool__` correctly. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 6fca8e8ee2aea..94c8cc7eb6b5a 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -118,9 +118,9 @@ fn implicit_attribute_initial<'db>( _target_method_decorator: MethodDecorator, ) -> PlaceAndQualifiers<'db> { let implicit_attribute_table = implicit_attribute_table(db, class_body_scope); - let implicit_attr_id = implicit_attribute_table - .symbol_id(&name) - .expect("Implicit attribute must exist in the table"); + let Some(implicit_attr_id) = implicit_attribute_table.symbol_id(&name) else { + return Place::bound(Type::divergent(DivergentType::todo(db, class_body_scope))).into(); + }; Place::bound(Type::divergent(DivergentType::implicit_attribute( db, class_body_scope, @@ -2885,7 +2885,7 @@ impl<'db> ClassLiteral<'db> { cycle_initial=implicit_attribute_initial, heap_size=ruff_memory_usage::heap_size, )] - fn implicit_attribute_inner( + pub(super) fn implicit_attribute_inner( db: &'db dyn Db, class_body_scope: ScopeId<'db>, name: String, @@ -2987,6 +2987,19 @@ impl<'db> ClassLiteral<'db> { if !qualifiers.contains(TypeQualifiers::FINAL) { union_of_inferred_types = union_of_inferred_types.add(Type::unknown()); } + if let Place::Type(previous_cycle_type, _) = Self::implicit_attribute_inner( + db, + class_body_scope, + name.clone(), + target_method_decorator, + ) + .place + { + // In fixed-point iteration of type inference, the attribute type must be monotonically widened and not "oscillate". + // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. + union_of_inferred_types = + union_of_inferred_types.add(previous_cycle_type.normalized_impl(db, &visitor)); + } for (attribute_assignments, method_scope_id) in attribute_assignments(db, class_body_scope, &name) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index bff834fb29099..6254a2b0f8b65 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -317,6 +317,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn extend_definition(&mut self, inference: &DefinitionInference<'db>) { + #[cfg(debug_assertions)] assert_eq!(self.scope, inference.scope); self.expressions.extend(inference.expressions.iter()); @@ -334,6 +335,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn extend_expression(&mut self, inference: &ExpressionInference<'db>) { + #[cfg(debug_assertions)] assert_eq!(self.scope, inference.scope); self.extend_expression_unchecked(inference); diff --git a/crates/ty_python_semantic/src/types/infer/tests.rs b/crates/ty_python_semantic/src/types/infer/tests.rs index 8aa437ea1caad..8d235264d9b15 100644 --- a/crates/ty_python_semantic/src/types/infer/tests.rs +++ b/crates/ty_python_semantic/src/types/infer/tests.rs @@ -5,14 +5,33 @@ use crate::place::{ConsideredDefinitions, Place, global_symbol}; use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::{global_scope, place_table, semantic_index, use_def_map}; -use crate::types::{KnownClass, KnownInstanceType, UnionType, check_types}; +use crate::types::function::FunctionType; +use crate::types::{BoundMethodType, KnownClass, KnownInstanceType, UnionType, check_types}; use ruff_db::diagnostic::Diagnostic; use ruff_db::files::{File, system_path_to_file}; use ruff_db::system::DbWithWritableSystem as _; use ruff_db::testing::{assert_function_query_was_not_run, assert_function_query_was_run}; +use salsa::Database; use super::*; +fn __() { + let _ = &Type::member_lookup_with_policy; + let _ = &Type::class_member_with_policy; + let _ = &FunctionType::infer_return_type; + let _ = &BoundMethodType::infer_return_type; + let _ = &ClassLiteral::implicit_attribute_inner; +} +/// These queries refer to a value ​​from the previous cycle to ensure convergence. +/// Therefore, even when convergence is apparent, they will cycle at least once. +const QUERIES_USE_PREVIOUS_CYCLE_VALUE: [&str; 5] = [ + "Type < 'db >::member_lookup_with_policy_", + "Type < 'db >::class_member_with_policy_", + "FunctionType < 'db >::infer_return_type_", + "BoundMethodType < 'db >::infer_return_type_", + "ClassLiteral < 'db >::implicit_attribute_inner_", +]; + #[track_caller] fn get_symbol<'db>( db: &'db TestDb, @@ -273,6 +292,12 @@ fn unbound_symbol_no_reachability_constraint_check() { .iter() .filter_map(|event| { if let salsa::EventKind::WillIterateCycle { database_key, .. } = event.kind { + if QUERIES_USE_PREVIOUS_CYCLE_VALUE.contains( + &db.ingredient_debug_name(database_key.ingredient_index()) + .as_ref(), + ) { + return None; + } Some(format!("{database_key:?}")) } else { None From 2a730605727e5c8b69875b115682318c6d93c25e Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 14 Sep 2025 16:40:29 +0900 Subject: [PATCH 070/105] Add `PossiblyRecursive{Scope, Definition, Expression}` --- crates/ty_python_semantic/src/types/infer.rs | 117 +++++++++++-------- 1 file changed, 70 insertions(+), 47 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index b6a77da822714..020d40cd1de52 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -31,7 +31,7 @@ //! //! Many of our type inference Salsa queries implement cycle recovery via fixed-point iteration. In //! general, they initiate fixed-point iteration by returning an `Inference` type that returns -//! `Type::Never` for all expressions, bindings, and declarations, and then they continue iterating +//! the `Divergent` type for all expressions, bindings, and declarations, and then they continue iterating //! the query cycle until a fixed-point is reached. Salsa has a built-in fixed limit on the number //! of iterations, so if we fail to converge, Salsa will eventually panic. (This should of course //! be considered a bug.) @@ -61,11 +61,24 @@ mod builder; #[cfg(test)] mod tests; +/// A scope that may be recursive. +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +pub(super) struct PossiblyRecursiveScope<'db> { + scope: ScopeId<'db>, + divergent: DivergentType<'db>, +} + +/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. +/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the +/// scope. pub(crate) fn infer_scope_types<'db>( db: &'db dyn Db, scope: ScopeId<'db>, ) -> &'db ScopeInference<'db> { - infer_scope_types_impl(db, scope, DivergentType::todo(db, scope)) + infer_scope_types_impl( + db, + PossiblyRecursiveScope::new(db, scope, DivergentType::todo(db, scope)), + ) } pub(crate) fn infer_function_scope_types<'db>( @@ -73,19 +86,15 @@ pub(crate) fn infer_function_scope_types<'db>( scope: ScopeId<'db>, divergent: DivergentType<'db>, ) -> &'db ScopeInference<'db> { - infer_scope_types_impl(db, scope, divergent) + infer_scope_types_impl(db, PossiblyRecursiveScope::new(db, scope, divergent)) } -/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. -/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the -/// scope. #[salsa::tracked(returns(ref), cycle_fn=scope_cycle_recover, cycle_initial=scope_cycle_initial, heap_size=ruff_memory_usage::heap_size)] fn infer_scope_types_impl<'db>( db: &'db dyn Db, - scope: ScopeId<'db>, - _divergent: DivergentType<'db>, + scope: PossiblyRecursiveScope<'db>, ) -> ScopeInference<'db> { - let file = scope.file(db); + let file = scope.scope(db).file(db); let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered(); let module = parsed_module(db, file).load(db); @@ -94,77 +103,91 @@ fn infer_scope_types_impl<'db>( // The isolation of the query is by the return inferred types. let index = semantic_index(db, file); - TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish_scope() + TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope.scope(db)), index, &module) + .finish_scope() } fn scope_cycle_recover<'db>( _db: &'db dyn Db, _value: &ScopeInference<'db>, _count: u32, - _scope: ScopeId<'db>, - _divergent: DivergentType<'db>, + _scope: PossiblyRecursiveScope<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } fn scope_cycle_initial<'db>( - _db: &'db dyn Db, - scope: ScopeId<'db>, - divergent: DivergentType<'db>, + db: &'db dyn Db, + scope: PossiblyRecursiveScope<'db>, ) -> ScopeInference<'db> { - ScopeInference::cycle_initial(divergent, scope) + ScopeInference::cycle_initial(scope.divergent(db), scope.scope(db)) +} + +/// A [`Definition`] that may be recursive. +#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +pub(super) struct PossiblyRecursiveDefinition<'db> { + definition: Definition<'db>, + divergent: DivergentType<'db>, } +/// Infer all types for a [`Definition`] (including sub-expressions). +/// Use when resolving a place use or public type of a place. pub(crate) fn infer_definition_types<'db>( db: &'db dyn Db, definition: Definition<'db>, ) -> &'db DefinitionInference<'db> { infer_definition_types_impl( db, - definition, - DivergentType::todo(db, definition.scope(db)), + PossiblyRecursiveDefinition::new( + db, + definition, + DivergentType::todo(db, definition.scope(db)), + ), ) } -/// Infer all types for a [`Definition`] (including sub-expressions). -/// Use when resolving a place use or public type of a place. #[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=ruff_memory_usage::heap_size)] fn infer_definition_types_impl<'db>( db: &'db dyn Db, - definition: Definition<'db>, - _divergent: DivergentType<'db>, + definition: PossiblyRecursiveDefinition<'db>, ) -> DefinitionInference<'db> { - let file = definition.file(db); + let file = definition.definition(db).file(db); let module = parsed_module(db, file).load(db); let _span = tracing::trace_span!( "infer_definition_types", - range = ?definition.kind(db).target_range(&module), + range = ?definition.definition(db).kind(db).target_range(&module), ?file ) .entered(); let index = semantic_index(db, file); - TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index, &module) - .finish_definition() + TypeInferenceBuilder::new( + db, + InferenceRegion::Definition(definition.definition(db)), + index, + &module, + ) + .finish_definition() } fn definition_cycle_recover<'db>( _db: &'db dyn Db, _value: &DefinitionInference<'db>, _count: u32, - _definition: Definition<'db>, - _divergent: DivergentType<'db>, + _definition: PossiblyRecursiveDefinition<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } fn definition_cycle_initial<'db>( db: &'db dyn Db, - definition: Definition<'db>, - divergent: DivergentType<'db>, + definition: PossiblyRecursiveDefinition<'db>, ) -> DefinitionInference<'db> { - DefinitionInference::cycle_initial(definition.scope(db), divergent) + DefinitionInference::cycle_initial( + definition.definition(db).scope(db), + definition.divergent(db), + ) } /// Infer types for all deferred type expressions in a [`Definition`]. @@ -351,8 +374,8 @@ fn single_expression_cycle_initial<'db>(db: &'db dyn Db, input: InferExpression< /// interning an `ExpressionWithContext` unnecessarily when no type context is provided. #[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update)] pub(super) enum InferExpression<'db> { - WithDivergent(ExpressionWithDivergent<'db>), - WithContextAndDivergent(ExpressionWithContextAndDivergent<'db>), + Bare(PossiblyRecursiveExpression<'db>), + WithContext(PossiblyRecursiveExpressionWithContext<'db>), } impl<'db> InferExpression<'db> { @@ -363,12 +386,12 @@ impl<'db> InferExpression<'db> { divergent: DivergentType<'db>, ) -> InferExpression<'db> { if tcx.annotation.is_some() { - InferExpression::WithContextAndDivergent(ExpressionWithContextAndDivergent::new( + InferExpression::WithContext(PossiblyRecursiveExpressionWithContext::new( db, expression, tcx, divergent, )) } else { // Drop the empty `TypeContext` to avoid the interning cost. - InferExpression::WithDivergent(ExpressionWithDivergent::new(db, expression, divergent)) + InferExpression::Bare(PossiblyRecursiveExpression::new(db, expression, divergent)) } } @@ -378,13 +401,13 @@ impl<'db> InferExpression<'db> { expression: Expression<'db>, divergent: DivergentType<'db>, ) -> InferExpression<'db> { - InferExpression::WithDivergent(ExpressionWithDivergent::new(db, expression, divergent)) + InferExpression::Bare(PossiblyRecursiveExpression::new(db, expression, divergent)) } fn expression(self, db: &'db dyn Db) -> Expression<'db> { match self { - InferExpression::WithDivergent(bare) => bare.expression(db), - InferExpression::WithContextAndDivergent(expression_with_context) => { + InferExpression::Bare(bare) => bare.expression(db), + InferExpression::WithContext(expression_with_context) => { expression_with_context.expression(db) } } @@ -392,8 +415,8 @@ impl<'db> InferExpression<'db> { fn tcx(self, db: &'db dyn Db) -> TypeContext<'db> { match self { - InferExpression::WithDivergent(_) => TypeContext::default(), - InferExpression::WithContextAndDivergent(expression_with_context) => { + InferExpression::Bare(_) => TypeContext::default(), + InferExpression::WithContext(expression_with_context) => { expression_with_context.tcx(db) } } @@ -401,24 +424,24 @@ impl<'db> InferExpression<'db> { fn divergent(self, db: &'db dyn Db) -> DivergentType<'db> { match self { - InferExpression::WithDivergent(bare) => bare.divergent(db), - InferExpression::WithContextAndDivergent(expression_with_context) => { + InferExpression::Bare(bare) => bare.divergent(db), + InferExpression::WithContext(expression_with_context) => { expression_with_context.divergent(db) } } } } -/// An `Expression` with a `DivergentType`. +/// An [`Expression`] that may be recursive. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -pub(super) struct ExpressionWithDivergent<'db> { +pub(super) struct PossiblyRecursiveExpression<'db> { expression: Expression<'db>, divergent: DivergentType<'db>, } -/// An `Expression` with a `TypeContext` and a `DivergentType`. +/// An [`Expression`] with a [`TypeContext`], that may be recursive. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -pub(super) struct ExpressionWithContextAndDivergent<'db> { +pub(super) struct PossiblyRecursiveExpressionWithContext<'db> { expression: Expression<'db>, tcx: TypeContext<'db>, divergent: DivergentType<'db>, @@ -453,7 +476,7 @@ pub(crate) fn static_expression_truthiness<'db>( ) -> Truthiness { let inference = infer_expression_types_impl( db, - InferExpression::WithDivergent(ExpressionWithDivergent::new( + InferExpression::Bare(PossiblyRecursiveExpression::new( db, expression, DivergentType::todo(db, expression.scope(db)), From 360229625e2c4a77f331788c3b6ebc74c94db42d Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 14 Sep 2025 17:39:40 +0900 Subject: [PATCH 071/105] experiment: `DivergentType::scope: ScopeId -> file: File, file_scope: FileScopeId` --- crates/ty_python_semantic/src/types.rs | 29 +++++++++++++++---- .../src/types/type_ordering.rs | 6 +++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 7cf97487d654b..af9598c8c219e 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -36,7 +36,7 @@ use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::ScopeId; use crate::semantic_index::symbol::ScopedImplicitAttributeId; use crate::semantic_index::{ - implicit_attribute_table, imported_modules, place_table, semantic_index, + FileScopeId, implicit_attribute_table, imported_modules, place_table, semantic_index, }; use crate::suppression::check_suppressions; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; @@ -7145,10 +7145,14 @@ impl DivergenceKind { /// For detailed properties of this type, see the unit test at the end of the file. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] pub struct DivergentType<'db> { + /// The file in which the divergence occurs. + file: File, + /// The scope where this divergence was detected. /// * The function scope in case of function return type inference /// * The class body scope in case of implicit attribute type inference - scope: ScopeId<'db>, + file_scope: FileScopeId, + /// The kind of divergence. kind: DivergenceKind, } @@ -7158,7 +7162,12 @@ impl get_size2::GetSize for DivergentType<'_> {} impl<'db> DivergentType<'db> { pub(crate) fn function_return_type(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { - Self::new(db, scope, DivergenceKind::FunctionReturnType) + Self::new( + db, + scope.file(db), + scope.file_scope_id(db), + DivergenceKind::FunctionReturnType, + ) } pub(crate) fn implicit_attribute( @@ -7166,11 +7175,21 @@ impl<'db> DivergentType<'db> { scope: ScopeId<'db>, attribute: ScopedImplicitAttributeId, ) -> Self { - Self::new(db, scope, DivergenceKind::ImplicitAttribute(attribute)) + Self::new( + db, + scope.file(db), + scope.file_scope_id(db), + DivergenceKind::ImplicitAttribute(attribute), + ) } pub(crate) fn todo(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { - Self::new(db, scope, DivergenceKind::Todo) + Self::new( + db, + scope.file(db), + scope.file_scope_id(db), + DivergenceKind::Todo, + ) } } diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 790700748e2b7..981307eeb7b2f 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -281,7 +281,11 @@ fn dynamic_elements_ordering<'db>( (_, DynamicType::TodoTypeAlias) => Ordering::Greater, (DynamicType::Divergent(left), DynamicType::Divergent(right)) => { - left.scope(db).cmp(&right.scope(db)) + left.file(db).cmp(&right.file(db)).then_with(|| { + left.file_scope(db) + .index() + .cmp(&right.file_scope(db).index()) + }) } (DynamicType::Divergent(_), _) => Ordering::Less, (_, DynamicType::Divergent(_)) => Ordering::Greater, From 48aea8a92af6d058ffef63cae412df7061961a28 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sun, 14 Sep 2025 17:54:15 +0900 Subject: [PATCH 072/105] Revert "experiment: `DivergentType::scope: ScopeId -> file: File, file_scope: FileScopeId`" This reverts commit 360229625e2c4a77f331788c3b6ebc74c94db42d. --- crates/ty_python_semantic/src/types.rs | 29 ++++--------------- .../src/types/type_ordering.rs | 6 +--- 2 files changed, 6 insertions(+), 29 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index af9598c8c219e..7cf97487d654b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -36,7 +36,7 @@ use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::ScopeId; use crate::semantic_index::symbol::ScopedImplicitAttributeId; use crate::semantic_index::{ - FileScopeId, implicit_attribute_table, imported_modules, place_table, semantic_index, + implicit_attribute_table, imported_modules, place_table, semantic_index, }; use crate::suppression::check_suppressions; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; @@ -7145,14 +7145,10 @@ impl DivergenceKind { /// For detailed properties of this type, see the unit test at the end of the file. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] pub struct DivergentType<'db> { - /// The file in which the divergence occurs. - file: File, - /// The scope where this divergence was detected. /// * The function scope in case of function return type inference /// * The class body scope in case of implicit attribute type inference - file_scope: FileScopeId, - + scope: ScopeId<'db>, /// The kind of divergence. kind: DivergenceKind, } @@ -7162,12 +7158,7 @@ impl get_size2::GetSize for DivergentType<'_> {} impl<'db> DivergentType<'db> { pub(crate) fn function_return_type(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { - Self::new( - db, - scope.file(db), - scope.file_scope_id(db), - DivergenceKind::FunctionReturnType, - ) + Self::new(db, scope, DivergenceKind::FunctionReturnType) } pub(crate) fn implicit_attribute( @@ -7175,21 +7166,11 @@ impl<'db> DivergentType<'db> { scope: ScopeId<'db>, attribute: ScopedImplicitAttributeId, ) -> Self { - Self::new( - db, - scope.file(db), - scope.file_scope_id(db), - DivergenceKind::ImplicitAttribute(attribute), - ) + Self::new(db, scope, DivergenceKind::ImplicitAttribute(attribute)) } pub(crate) fn todo(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { - Self::new( - db, - scope.file(db), - scope.file_scope_id(db), - DivergenceKind::Todo, - ) + Self::new(db, scope, DivergenceKind::Todo) } } diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 981307eeb7b2f..790700748e2b7 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -281,11 +281,7 @@ fn dynamic_elements_ordering<'db>( (_, DynamicType::TodoTypeAlias) => Ordering::Greater, (DynamicType::Divergent(left), DynamicType::Divergent(right)) => { - left.file(db).cmp(&right.file(db)).then_with(|| { - left.file_scope(db) - .index() - .cmp(&right.file_scope(db).index()) - }) + left.scope(db).cmp(&right.scope(db)) } (DynamicType::Divergent(_), _) => Ordering::Less, (_, DynamicType::Divergent(_)) => Ordering::Greater, From b7fa61bd4775c598092fe8399ff84b3f1c7662a2 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 15 Sep 2025 15:09:24 +0900 Subject: [PATCH 073/105] proper cycle recovery handling for `infer_expression_types` --- .../resources/corpus/divergent.py | 11 ++ crates/ty_python_semantic/src/dunder_all.rs | 14 +- crates/ty_python_semantic/src/types.rs | 13 +- crates/ty_python_semantic/src/types/class.rs | 8 +- .../ty_python_semantic/src/types/context.rs | 4 + crates/ty_python_semantic/src/types/infer.rs | 109 ++++++++------ .../src/types/infer/builder.rs | 134 +++++++++++++++++- .../src/types/infer/tests.rs | 69 ++++++--- crates/ty_python_semantic/src/types/narrow.rs | 40 ++++-- .../ty_python_semantic/src/types/unpacker.rs | 14 +- 10 files changed, 319 insertions(+), 97 deletions(-) diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py index 228060149a686..bc3c0e1fff66d 100644 --- a/crates/ty_python_semantic/resources/corpus/divergent.py +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -56,3 +56,14 @@ def unwrap(value): return tuple(result) else: raise TypeError() + +def descent(x: int, y: int): + if x > y: + y, x = descent(y, x) + return x, y + if x == 1: + return (1, 0) + if y == 1: + return (0, 1) + else: + return descent(x-1, y-1) diff --git a/crates/ty_python_semantic/src/dunder_all.rs b/crates/ty_python_semantic/src/dunder_all.rs index 10eab9321a919..20e73260d49fd 100644 --- a/crates/ty_python_semantic/src/dunder_all.rs +++ b/crates/ty_python_semantic/src/dunder_all.rs @@ -7,7 +7,7 @@ use ruff_python_ast::statement_visitor::{StatementVisitor, walk_stmt}; use ruff_python_ast::{self as ast}; use crate::semantic_index::{SemanticIndex, semantic_index}; -use crate::types::{Truthiness, Type, TypeContext, infer_expression_types}; +use crate::types::{DivergentType, Truthiness, Type, TypeContext, infer_expression_types}; use crate::{Db, ModuleName, resolve_module}; #[allow(clippy::ref_option)] @@ -182,8 +182,16 @@ impl<'db> DunderAllNamesCollector<'db> { /// /// This function panics if `expr` was not marked as a standalone expression during semantic indexing. fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> { - infer_expression_types(self.db, self.index.expression(expr), TypeContext::default()) - .expression_type(expr) + // QUESTION: Could there be a case for divergence here? + let cycle_recovery = + DivergentType::should_not_diverge(self.db, self.index.expression(expr).scope(self.db)); + infer_expression_types( + self.db, + self.index.expression(expr), + TypeContext::default(), + cycle_recovery, + ) + .expression_type(expr) } /// Evaluate the given expression and return its truthiness. diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 7cf97487d654b..21cffd6a942f5 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -7128,13 +7128,14 @@ pub enum DivergenceKind { FunctionReturnType, /// Divergence in implicit attribute type inference. ImplicitAttribute(ScopedImplicitAttributeId), - /// Unknown divergence that we have not yet handled. - Todo, + /// Type inference should not diverge, + /// or it should be of a kind that we have not yet handled. + ShouldNotDiverge, } impl DivergenceKind { - pub fn is_todo(self) -> bool { - matches!(self, Self::Todo) + pub fn should_not_diverge(self) -> bool { + matches!(self, Self::ShouldNotDiverge) } } @@ -7169,8 +7170,8 @@ impl<'db> DivergentType<'db> { Self::new(db, scope, DivergenceKind::ImplicitAttribute(attribute)) } - pub(crate) fn todo(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { - Self::new(db, scope, DivergenceKind::Todo) + pub(crate) fn should_not_diverge(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { + Self::new(db, scope, DivergenceKind::ShouldNotDiverge) } } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 94c8cc7eb6b5a..25193e77a31b3 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -119,7 +119,11 @@ fn implicit_attribute_initial<'db>( ) -> PlaceAndQualifiers<'db> { let implicit_attribute_table = implicit_attribute_table(db, class_body_scope); let Some(implicit_attr_id) = implicit_attribute_table.symbol_id(&name) else { - return Place::bound(Type::divergent(DivergentType::todo(db, class_body_scope))).into(); + return Place::bound(Type::divergent(DivergentType::should_not_diverge( + db, + class_body_scope, + ))) + .into(); }; Place::bound(Type::divergent(DivergentType::implicit_attribute( db, @@ -2910,7 +2914,7 @@ impl<'db> ClassLiteral<'db> { DivergentType::implicit_attribute(db, class_body_scope, implicit_attr) } else { // Implicit attributes should either not exist or have a type declaration, and should not diverge. - DivergentType::todo(db, class_body_scope) + DivergentType::should_not_diverge(db, class_body_scope) }; let visitor = NormalizedVisitor::default().recursive(Type::divergent(div)); diff --git a/crates/ty_python_semantic/src/types/context.rs b/crates/ty_python_semantic/src/types/context.rs index d13d99a75ed76..41644b3860f44 100644 --- a/crates/ty_python_semantic/src/types/context.rs +++ b/crates/ty_python_semantic/src/types/context.rs @@ -98,6 +98,10 @@ impl<'db, 'ast> InferContext<'db, 'ast> { self.diagnostics.get_mut().extend(other); } + pub(crate) fn take_diagnostics(&mut self) -> TypeCheckDiagnostics { + self.diagnostics.take() + } + /// Optionally return a builder for a lint diagnostic guard. /// /// If the current context believes a diagnostic should be reported for diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 020d40cd1de52..0e513c2069b2a 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -65,7 +65,7 @@ mod tests; #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] pub(super) struct PossiblyRecursiveScope<'db> { scope: ScopeId<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, } /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. @@ -77,16 +77,16 @@ pub(crate) fn infer_scope_types<'db>( ) -> &'db ScopeInference<'db> { infer_scope_types_impl( db, - PossiblyRecursiveScope::new(db, scope, DivergentType::todo(db, scope)), + PossiblyRecursiveScope::new(db, scope, DivergentType::should_not_diverge(db, scope)), ) } pub(crate) fn infer_function_scope_types<'db>( db: &'db dyn Db, scope: ScopeId<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, ) -> &'db ScopeInference<'db> { - infer_scope_types_impl(db, PossiblyRecursiveScope::new(db, scope, divergent)) + infer_scope_types_impl(db, PossiblyRecursiveScope::new(db, scope, cycle_recovery)) } #[salsa::tracked(returns(ref), cycle_fn=scope_cycle_recover, cycle_initial=scope_cycle_initial, heap_size=ruff_memory_usage::heap_size)] @@ -120,14 +120,14 @@ fn scope_cycle_initial<'db>( db: &'db dyn Db, scope: PossiblyRecursiveScope<'db>, ) -> ScopeInference<'db> { - ScopeInference::cycle_initial(scope.divergent(db), scope.scope(db)) + ScopeInference::cycle_initial(scope.cycle_recovery(db), scope.scope(db)) } /// A [`Definition`] that may be recursive. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] pub(super) struct PossiblyRecursiveDefinition<'db> { definition: Definition<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, } /// Infer all types for a [`Definition`] (including sub-expressions). @@ -141,7 +141,7 @@ pub(crate) fn infer_definition_types<'db>( PossiblyRecursiveDefinition::new( db, definition, - DivergentType::todo(db, definition.scope(db)), + DivergentType::should_not_diverge(db, definition.scope(db)), ), ) } @@ -186,7 +186,7 @@ fn definition_cycle_initial<'db>( ) -> DefinitionInference<'db> { DefinitionInference::cycle_initial( definition.definition(db).scope(db), - definition.divergent(db), + definition.cycle_recovery(db), ) } @@ -230,7 +230,7 @@ fn deferred_cycle_initial<'db>( ) -> DefinitionInference<'db> { DefinitionInference::cycle_initial( definition.scope(db), - DivergentType::todo(db, definition.scope(db)), + DivergentType::should_not_diverge(db, definition.scope(db)), ) } @@ -242,15 +242,11 @@ pub(crate) fn infer_expression_types<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, + cycle_recovery: DivergentType<'db>, ) -> &'db ExpressionInference<'db> { infer_expression_types_impl( db, - InferExpression::new( - db, - expression, - tcx, - DivergentType::todo(db, expression.scope(db)), - ), + InferExpression::new(db, expression, tcx, cycle_recovery), ) } @@ -295,7 +291,7 @@ fn expression_cycle_initial<'db>( db: &'db dyn Db, input: InferExpression<'db>, ) -> ExpressionInference<'db> { - ExpressionInference::cycle_initial(input.expression(db).scope(db), input.divergent(db)) + ExpressionInference::cycle_initial(input.expression(db).scope(db), input.cycle_recovery(db)) } /// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query. @@ -307,9 +303,10 @@ pub(super) fn infer_same_file_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, + cycle_recovery: DivergentType<'db>, parsed: &ParsedModuleRef, ) -> Type<'db> { - let inference = infer_expression_types(db, expression, tcx); + let inference = infer_expression_types(db, expression, tcx, cycle_recovery); inference.expression_type(expression.node_ref(db, parsed)) } @@ -331,7 +328,7 @@ pub(crate) fn infer_expression_type<'db>( db, expression, tcx, - DivergentType::todo(db, expression.scope(db)), + DivergentType::should_not_diverge(db, expression.scope(db)), ), ) } @@ -340,9 +337,12 @@ pub(crate) fn infer_implicit_attribute_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, ) -> Type<'db> { - infer_expression_type_impl(db, InferExpression::new(db, expression, tcx, divergent)) + infer_expression_type_impl( + db, + InferExpression::new(db, expression, tcx, cycle_recovery), + ) } #[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] @@ -365,7 +365,7 @@ fn single_expression_cycle_recover<'db>( } fn single_expression_cycle_initial<'db>(db: &'db dyn Db, input: InferExpression<'db>) -> Type<'db> { - Type::Dynamic(DynamicType::Divergent(input.divergent(db))) + Type::Dynamic(DynamicType::Divergent(input.cycle_recovery(db))) } /// An `Expression` with an optional `TypeContext`. @@ -383,25 +383,36 @@ impl<'db> InferExpression<'db> { db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, ) -> InferExpression<'db> { if tcx.annotation.is_some() { InferExpression::WithContext(PossiblyRecursiveExpressionWithContext::new( - db, expression, tcx, divergent, + db, + expression, + tcx, + cycle_recovery, )) } else { // Drop the empty `TypeContext` to avoid the interning cost. - InferExpression::Bare(PossiblyRecursiveExpression::new(db, expression, divergent)) + InferExpression::Bare(PossiblyRecursiveExpression::new( + db, + expression, + cycle_recovery, + )) } } #[cfg(test)] - pub(super) fn with_divergent( + pub(super) fn bare( db: &'db dyn Db, expression: Expression<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, ) -> InferExpression<'db> { - InferExpression::Bare(PossiblyRecursiveExpression::new(db, expression, divergent)) + InferExpression::Bare(PossiblyRecursiveExpression::new( + db, + expression, + cycle_recovery, + )) } fn expression(self, db: &'db dyn Db) -> Expression<'db> { @@ -422,11 +433,11 @@ impl<'db> InferExpression<'db> { } } - fn divergent(self, db: &'db dyn Db) -> DivergentType<'db> { + fn cycle_recovery(self, db: &'db dyn Db) -> DivergentType<'db> { match self { - InferExpression::Bare(bare) => bare.divergent(db), + InferExpression::Bare(bare) => bare.cycle_recovery(db), InferExpression::WithContext(expression_with_context) => { - expression_with_context.divergent(db) + expression_with_context.cycle_recovery(db) } } } @@ -436,7 +447,7 @@ impl<'db> InferExpression<'db> { #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] pub(super) struct PossiblyRecursiveExpression<'db> { expression: Expression<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, } /// An [`Expression`] with a [`TypeContext`], that may be recursive. @@ -444,7 +455,7 @@ pub(super) struct PossiblyRecursiveExpression<'db> { pub(super) struct PossiblyRecursiveExpressionWithContext<'db> { expression: Expression<'db>, tcx: TypeContext<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, } /// The type context for a given expression, namely the type annotation @@ -479,7 +490,7 @@ pub(crate) fn static_expression_truthiness<'db>( InferExpression::Bare(PossiblyRecursiveExpression::new( db, expression, - DivergentType::todo(db, expression.scope(db)), + DivergentType::should_not_diverge(db, expression.scope(db)), )), ); @@ -514,16 +525,20 @@ fn static_expression_truthiness_cycle_initial<'db>( pub(super) fn infer_unpack_implicit_attribute_types<'db>( db: &'db dyn Db, unpack: Unpack<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, ) -> &'db UnpackResult<'db> { - infer_unpack_types_impl(db, unpack, divergent) + infer_unpack_types_impl(db, unpack, cycle_recovery) } pub(super) fn infer_unpack_types<'db>( db: &'db dyn Db, unpack: Unpack<'db>, ) -> &'db UnpackResult<'db> { - infer_unpack_types_impl(db, unpack, DivergentType::todo(db, unpack.target_scope(db))) + infer_unpack_types_impl( + db, + unpack, + DivergentType::should_not_diverge(db, unpack.target_scope(db)), + ) } /// Infer the types for an [`Unpack`] operation. @@ -536,7 +551,7 @@ pub(super) fn infer_unpack_types<'db>( fn infer_unpack_types_impl<'db>( db: &'db dyn Db, unpack: Unpack<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, ) -> UnpackResult<'db> { let file = unpack.file(db); let module = parsed_module(db, file).load(db); @@ -544,7 +559,7 @@ fn infer_unpack_types_impl<'db>( .entered(); let mut unpacker = Unpacker::new(db, unpack.target_scope(db), &module); - unpacker.unpack(unpack.target(db, &module), unpack.value(db), divergent); + unpacker.unpack(unpack.target(db, &module), unpack.value(db), cycle_recovery); unpacker.finish() } @@ -553,7 +568,7 @@ fn unpack_cycle_recover<'db>( _value: &UnpackResult<'db>, _count: u32, _unpack: Unpack<'db>, - _divergent: DivergentType<'db>, + _cycle_recovery: DivergentType<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } @@ -561,9 +576,9 @@ fn unpack_cycle_recover<'db>( fn unpack_cycle_initial<'db>( _db: &'db dyn Db, _unpack: Unpack<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, ) -> UnpackResult<'db> { - UnpackResult::cycle_initial(Type::divergent(divergent)) + UnpackResult::cycle_initial(Type::divergent(cycle_recovery)) } /// Returns the type of the nearest enclosing class for the given scope. @@ -644,10 +659,10 @@ struct ScopeInferenceExtra<'db> { } impl<'db> ScopeInference<'db> { - fn cycle_initial(divergent: DivergentType<'db>, scope: ScopeId<'db>) -> Self { + fn cycle_initial(cycle_recovery: DivergentType<'db>, scope: ScopeId<'db>) -> Self { Self { extra: Some(Box::new(ScopeInferenceExtra { - cycle_recovery: Some(divergent), + cycle_recovery: Some(cycle_recovery), ..ScopeInferenceExtra::default() })), expressions: FxHashMap::default(), @@ -795,7 +810,7 @@ struct DefinitionInferenceExtra<'db> { } impl<'db> DefinitionInference<'db> { - fn cycle_initial(scope: ScopeId<'db>, divergent: DivergentType<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>, cycle_recovery: DivergentType<'db>) -> Self { let _ = scope; Self { @@ -805,7 +820,7 @@ impl<'db> DefinitionInference<'db> { #[cfg(debug_assertions)] scope, extra: Some(Box::new(DefinitionInferenceExtra { - cycle_recovery: Some(divergent), + cycle_recovery: Some(cycle_recovery), ..DefinitionInferenceExtra::default() })), } @@ -917,11 +932,11 @@ struct ExpressionInferenceExtra<'db> { } impl<'db> ExpressionInference<'db> { - fn cycle_initial(scope: ScopeId<'db>, divergent: DivergentType<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>, cycle_recovery: DivergentType<'db>) -> Self { let _ = scope; Self { extra: Some(Box::new(ExpressionInferenceExtra { - cycle_recovery: Some(divergent), + cycle_recovery: Some(cycle_recovery), all_definitely_bound: true, ..ExpressionInferenceExtra::default() })), diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 6254a2b0f8b65..62235db235792 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -40,7 +40,8 @@ use crate::semantic_index::scope::{ }; use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::{ - ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table, + ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, implicit_attribute_table, + place_table, }; use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator}; @@ -89,7 +90,7 @@ use crate::types::{ DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, - TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, + TypeAliasType, TypeAndQualifiers, TypeCheckDiagnostics, TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, }; @@ -260,6 +261,18 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { all_definitely_bound: bool, } +/// Normally, double checking is not allowed in a [`TypeInferenceBuilder`], +/// but it is sometimes necessary for expression inference to detect divergence. +/// Explicit double checking can be achieved by taking a snapshot of the state before the check +/// and then reverting the state using the snapshot after the check. +struct TypeInferenceSnapshot { + diagnostics: TypeCheckDiagnostics, + expression_keys: FxHashSet, + length_of_bindings: usize, + length_of_declarations: usize, + length_of_deferred: usize, +} + impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// How big a string do we build before bailing? /// @@ -299,6 +312,84 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.cycle_recovery.map(Type::divergent) } + fn snapshot(&mut self) -> TypeInferenceSnapshot { + TypeInferenceSnapshot { + diagnostics: self.context.take_diagnostics(), + expression_keys: self.expressions.keys().copied().collect(), + length_of_bindings: self.bindings.len(), + length_of_declarations: self.declarations.len(), + length_of_deferred: self.deferred.len(), + } + } + + fn restore(&mut self, snapshot: &TypeInferenceSnapshot) { + self.context.take_diagnostics(); + self.context.extend(&snapshot.diagnostics); + self.expressions + .retain(|k, _| snapshot.expression_keys.contains(k)); + self.bindings.truncate(snapshot.length_of_bindings); + self.declarations.truncate(snapshot.length_of_declarations); + self.deferred.truncate(snapshot.length_of_deferred); + } + + /// If the inference in this expression diverges, what kind of divergence is possible? + fn expression_cycle_recovery(&mut self, expression: &ast::Expr) -> DivergentType<'db> { + let db = self.db(); + match expression { + ast::Expr::Call(call) => { + let snapshot = self.snapshot(); + let callable_ty = + self.try_expression_type(call.func.as_ref()) + .unwrap_or_else(|| { + self.infer_maybe_standalone_expression( + &call.func, + TypeContext::default(), + ) + }); + self.restore(&snapshot); + match callable_ty { + Type::FunctionLiteral(func) => DivergentType::function_return_type( + db, + func.literal(db).last_definition(db).body_scope(db), + ), + Type::BoundMethod(method) => DivergentType::function_return_type( + db, + method + .function(db) + .literal(db) + .last_definition(db) + .body_scope(db), + ), + _ => DivergentType::should_not_diverge(db, self.scope()), + } + } + ast::Expr::Attribute(attr) => { + let snapshot = self.snapshot(); + let value_ty = self + .try_expression_type(attr.value.as_ref()) + .unwrap_or_else(|| { + self.infer_maybe_standalone_expression(&attr.value, TypeContext::default()) + }); + self.restore(&snapshot); + let body_scope = match value_ty { + Type::NominalInstance(instance) => instance.class_literal(db).body_scope(db), + Type::ClassLiteral(class) => class.body_scope(db), + Type::GenericAlias(generic) => generic.origin(db).body_scope(db), + _ => { + return DivergentType::should_not_diverge(db, self.scope()); + } + }; + let implicit_attribute_table = implicit_attribute_table(db, body_scope); + if let Some(attribute) = implicit_attribute_table.symbol_id(&attr.attr) { + DivergentType::implicit_attribute(db, body_scope, attribute) + } else { + DivergentType::should_not_diverge(db, self.scope()) + } + } + _ => DivergentType::should_not_diverge(db, self.scope()), + } + } + fn merge_cycle_recovery(&mut self, other: Option>) { match (self.cycle_recovery, other) { (None, _) | (Some(_), None) => { @@ -307,7 +398,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { (Some(self_), Some(other)) => { if self_ == other { // OK, do nothing - } else if self_.kind(self.db()).is_todo() { + } else if self_.kind(self.db()).should_not_diverge() { self.cycle_recovery = Some(other); } else { panic!("Cannot merge divergent types"); @@ -5028,7 +5119,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { standalone_expression: Expression<'db>, tcx: TypeContext<'db>, ) -> Type<'db> { - let types = infer_expression_types(self.db(), standalone_expression, tcx); + let cycle_recovery = self + .cycle_recovery + .unwrap_or_else(|| self.expression_cycle_recovery(expression)); + let types = infer_expression_types(self.db(), standalone_expression, tcx, cycle_recovery); self.extend_expression(types); // Instead of calling `self.expression_type(expr)` after extending here, we get @@ -5471,10 +5565,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `[... for a.x in not_iterable] if is_first { + // QUESTION: Could there be a case for divergence here? + let cycle_recovery = + builder + .cycle_recovery + .unwrap_or(DivergentType::should_not_diverge( + builder.db(), + builder.scope(), + )); infer_same_file_expression_type( builder.db(), builder.index.expression(iter_expr), TypeContext::default(), + cycle_recovery, builder.module(), ) } else { @@ -5498,7 +5601,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut infer_iterable_type = || { let expression = self.index.expression(iterable); - let result = infer_expression_types(self.db(), expression, TypeContext::default()); + // QUESTION: Could there be a case for divergence here? + let cycle_recovery = self + .cycle_recovery + .unwrap_or(DivergentType::should_not_diverge(self.db(), self.scope())); + let result = infer_expression_types( + self.db(), + expression, + TypeContext::default(), + cycle_recovery, + ); // Two things are different if it's the first comprehension: // (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope, @@ -9330,6 +9442,10 @@ where fn into_boxed_slice(self) -> Box<[(K, V)]> { self.0.into_boxed_slice() } + + fn truncate(&mut self, len: usize) { + self.0.truncate(len); + } } impl Extend<(K, V)> for VecMap @@ -9373,6 +9489,14 @@ impl VecSet { fn into_boxed_slice(self) -> Box<[V]> { self.0.into_boxed_slice() } + + fn truncate(&mut self, len: usize) { + self.0.truncate(len); + } + + fn len(&self) -> usize { + self.0.len() + } } impl VecSet diff --git a/crates/ty_python_semantic/src/types/infer/tests.rs b/crates/ty_python_semantic/src/types/infer/tests.rs index 8d235264d9b15..e2a359b67f850 100644 --- a/crates/ty_python_semantic/src/types/infer/tests.rs +++ b/crates/ty_python_semantic/src/types/infer/tests.rs @@ -3,8 +3,10 @@ use crate::db::tests::{TestDb, setup_db}; use crate::place::symbol; use crate::place::{ConsideredDefinitions, Place, global_symbol}; use crate::semantic_index::definition::Definition; -use crate::semantic_index::scope::FileScopeId; -use crate::semantic_index::{global_scope, place_table, semantic_index, use_def_map}; +use crate::semantic_index::scope::{FileScopeId, ScopeKind}; +use crate::semantic_index::{ + global_scope, implicit_attribute_table, place_table, semantic_index, use_def_map, +}; use crate::types::function::FunctionType; use crate::types::{BoundMethodType, KnownClass, KnownInstanceType, UnionType, check_types}; use ruff_db::diagnostic::Diagnostic; @@ -57,6 +59,24 @@ fn get_symbol<'db>( symbol(db, scope, symbol_name, ConsideredDefinitions::EndOfScope).place } +#[track_caller] +fn get_scope<'db>( + db: &'db TestDb, + file: File, + name: &str, + kind: ScopeKind, +) -> Option> { + let module = parsed_module(db, file).load(db); + let index = semantic_index(db, file); + for (child_scope, _) in index.child_scopes(FileScopeId::global()) { + let scope = child_scope.to_scope_id(db, file); + if scope.name(db, &module) == name && scope.scope(db).kind() == kind { + return Some(scope); + } + } + None +} + #[track_caller] fn assert_diagnostic_messages(diagnostics: &[Diagnostic], expected: &[&str]) { let messages: Vec<&str> = diagnostics @@ -467,11 +487,15 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - let div = DivergentType::todo(&db, global_scope(&db, file_main)); + let file_mod = system_path_to_file(&db, "/src/mod.py").unwrap(); + let class_body_scope = get_scope(&db, file_mod, "C", ScopeKind::Class).unwrap(); + let implicit_attribute_table = implicit_attribute_table(&db, class_body_scope); + let attribute = implicit_attribute_table.symbol_id("attr").unwrap(); + let cycle_recovery = DivergentType::implicit_attribute(&db, class_body_scope, attribute); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::with_divergent(&db, x_rhs_expression(&db), div), + InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), &events, ); @@ -492,11 +516,12 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - let div = DivergentType::todo(&db, global_scope(&db, file_main)); + let class_body_scope = get_scope(&db, file_mod, "C", ScopeKind::Class).unwrap(); + let cycle_recovery = DivergentType::implicit_attribute(&db, class_body_scope, attribute); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::with_divergent(&db, x_rhs_expression(&db), div), + InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), &events, ); @@ -560,11 +585,11 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - let div = DivergentType::todo(&db, global_scope(&db, file_main)); + let cycle_recovery = DivergentType::should_not_diverge(&db, global_scope(&db, file_main)); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::with_divergent(&db, x_rhs_expression(&db), div), + InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), &events, ); @@ -587,11 +612,11 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - let div = DivergentType::todo(&db, global_scope(&db, file_main)); + let cycle_recovery = DivergentType::should_not_diverge(&db, global_scope(&db, file_main)); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::with_divergent(&db, x_rhs_expression(&db), div), + InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), &events, ); @@ -658,11 +683,15 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); db.take_salsa_events() }; - let div = DivergentType::todo(&db, global_scope(&db, file_main)); + let file_mod = system_path_to_file(&db, "/src/mod.py").unwrap(); + let class_body_scope = get_scope(&db, file_mod, "C", ScopeKind::Class).unwrap(); + let implicit_attribute_table = implicit_attribute_table(&db, class_body_scope); + let attribute = implicit_attribute_table.symbol_id("class_attr").unwrap(); + let cycle_recovery = DivergentType::implicit_attribute(&db, class_body_scope, attribute); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::with_divergent(&db, x_rhs_expression(&db), div), + InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), &events, ); @@ -687,11 +716,12 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); db.take_salsa_events() }; - let div = DivergentType::todo(&db, global_scope(&db, file_main)); + let class_body_scope = get_scope(&db, file_mod, "C", ScopeKind::Class).unwrap(); + let cycle_recovery = DivergentType::implicit_attribute(&db, class_body_scope, attribute); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::with_divergent(&db, x_rhs_expression(&db), div), + InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), &events, ); @@ -733,11 +763,13 @@ fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; let foo_call = semantic_index(&db, bar).expression(call); - let div = DivergentType::todo(&db, global_scope(&db, bar)); + let foo = system_path_to_file(&db, "src/foo.py")?; + let function_scope = get_scope(&db, foo, "foo", ScopeKind::Function).unwrap(); + let cycle_recovery = DivergentType::function_return_type(&db, function_scope); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::with_divergent(&db, foo_call, div), + InferExpression::bare(&db, foo_call, cycle_recovery), &events, ); @@ -765,11 +797,12 @@ fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; let foo_call = semantic_index(&db, bar).expression(call); - let div = DivergentType::todo(&db, global_scope(&db, bar)); + let function_scope = get_scope(&db, foo, "foo", ScopeKind::Function).unwrap(); + let cycle_recovery = DivergentType::function_return_type(&db, function_scope); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::with_divergent(&db, foo_call, div), + InferExpression::bare(&db, foo_call, cycle_recovery), &events, ); diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 2fb6157acb904..388d1f82302ee 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -11,8 +11,9 @@ use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::function::KnownFunction; use crate::types::infer::infer_same_file_expression_type; use crate::types::{ - ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType, - Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types, + ClassLiteral, ClassType, DivergentType, IntersectionBuilder, KnownClass, SubclassOfInner, + SubclassOfType, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, + infer_expression_types, }; use ruff_db::parsed::{ParsedModuleRef, parsed_module}; @@ -773,7 +774,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } - let inference = infer_expression_types(self.db, expression, TypeContext::default()); + let cycle_recovery = DivergentType::should_not_diverge(self.db, expression.scope(self.db)); + let inference = + infer_expression_types(self.db, expression, TypeContext::default(), cycle_recovery); let comparator_tuples = std::iter::once(&**left) .chain(comparators) @@ -863,7 +866,10 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - let inference = infer_expression_types(self.db, expression, TypeContext::default()); + // QUESTION: Could there be a case for divergence here? + let cycle_recovery = DivergentType::should_not_diverge(self.db, expression.scope(self.db)); + let inference = + infer_expression_types(self.db, expression, TypeContext::default(), cycle_recovery); let callable_ty = inference.expression_type(&*expr_call.func); @@ -983,8 +989,15 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); - let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) - .to_instance(self.db)?; + let cycle_recovery = DivergentType::should_not_diverge(self.db, cls.scope(self.db)); + let ty = infer_same_file_expression_type( + self.db, + cls, + TypeContext::default(), + cycle_recovery, + self.module, + ) + .to_instance(self.db)?; Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -997,8 +1010,15 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); - let ty = - infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); + // QUESTION: Could there be a case for divergence here? + let cycle_recovery = DivergentType::should_not_diverge(self.db, value.scope(self.db)); + let ty = infer_same_file_expression_type( + self.db, + value, + TypeContext::default(), + cycle_recovery, + self.module, + ); Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -1027,7 +1047,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - let inference = infer_expression_types(self.db, expression, TypeContext::default()); + let cycle_recovery = DivergentType::should_not_diverge(self.db, expression.scope(self.db)); + let inference = + infer_expression_types(self.db, expression, TypeContext::default(), cycle_recovery); let mut sub_constraints = expr_bool_op .values .iter() diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index 7d7efe7112e1f..e0c39c96557b7 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -47,7 +47,7 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { &mut self, target: &ast::Expr, value: UnpackValue<'db>, - divergent: DivergentType<'db>, + cycle_recovery: DivergentType<'db>, ) { debug_assert!( matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)), @@ -58,7 +58,7 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { self.db(), value.expression(), TypeContext::default(), - divergent, + cycle_recovery, ); let value_type = infer_expression_types_impl(self.db(), input) .expression_type(value.expression().node_ref(self.db(), self.module())); @@ -189,7 +189,7 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { UnpackResult { diagnostics: self.context.finish(), targets: self.targets, - cycle_fallback_type: None, + cycle_recovery: None, } } } @@ -202,7 +202,7 @@ pub(crate) struct UnpackResult<'db> { /// The fallback type for missing expressions. /// /// This is used only when constructing a cycle-recovery `UnpackResult`. - cycle_fallback_type: Option>, + cycle_recovery: Option>, } impl<'db> UnpackResult<'db> { @@ -228,7 +228,7 @@ impl<'db> UnpackResult<'db> { self.targets .get(&expr.into()) .copied() - .or(self.cycle_fallback_type) + .or(self.cycle_recovery) } /// Returns the diagnostics in this unpacking assignment. @@ -236,11 +236,11 @@ impl<'db> UnpackResult<'db> { &self.diagnostics } - pub(crate) fn cycle_initial(cycle_fallback_type: Type<'db>) -> Self { + pub(crate) fn cycle_initial(cycle_recovery: Type<'db>) -> Self { Self { targets: FxHashMap::default(), diagnostics: TypeCheckDiagnostics::default(), - cycle_fallback_type: Some(cycle_fallback_type), + cycle_recovery: Some(cycle_recovery), } } } From bea05a54a960156f5c29cf60481a78ef7b6fe758 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 16 Sep 2025 20:52:29 +0900 Subject: [PATCH 074/105] use `DivergentType` to protect tracked functions that may fall into divergent type inference Many problematic examples from astral-sh/ty#256 are now resolved. --- .../resources/mdtest/cycle.md | 18 +- crates/ty_python_semantic/src/dunder_all.rs | 14 +- .../ty_python_semantic/src/semantic_index.rs | 27 +- .../src/semantic_index/definition.rs | 1 + .../src/semantic_index/expression.rs | 1 + .../src/semantic_index/symbol.rs | 4 - crates/ty_python_semantic/src/types.rs | 597 +++++++++++++----- crates/ty_python_semantic/src/types/class.rs | 117 ++-- .../src/types/class_base.rs | 16 +- .../ty_python_semantic/src/types/context.rs | 4 - .../ty_python_semantic/src/types/function.rs | 35 +- .../ty_python_semantic/src/types/generics.rs | 56 ++ crates/ty_python_semantic/src/types/infer.rs | 326 +++------- .../src/types/infer/builder.rs | 264 ++++---- .../src/types/infer/tests.rs | 74 +-- .../ty_python_semantic/src/types/instance.rs | 27 +- crates/ty_python_semantic/src/types/narrow.rs | 40 +- .../src/types/signatures.rs | 87 ++- .../src/types/subclass_of.rs | 25 +- crates/ty_python_semantic/src/types/tuple.rs | 46 ++ .../src/types/type_ordering.rs | 99 ++- .../ty_python_semantic/src/types/unpacker.rs | 38 +- crates/ty_python_semantic/src/unpack.rs | 1 + crates/ty_python_semantic/tests/corpus.rs | 3 - 24 files changed, 1171 insertions(+), 749 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/cycle.md b/crates/ty_python_semantic/resources/mdtest/cycle.md index 0bd3b5b2a6df4..ffd107aa0ce4f 100644 --- a/crates/ty_python_semantic/resources/mdtest/cycle.md +++ b/crates/ty_python_semantic/resources/mdtest/cycle.md @@ -28,6 +28,20 @@ class Point: self.x, self.y = other.x, other.y p = Point() -reveal_type(p.x) # revealed: Unknown | int -reveal_type(p.y) # revealed: Unknown | int +# TODO: should be `Unknown | int` +reveal_type(p.x) # revealed: Unknown | int | Divergent +# TODO: should be `Unknown | int` +reveal_type(p.y) # revealed: Unknown | int | Divergent +``` + +## Self-referential bare type alias + +```py +A = list["A" | None] + +def f(x: A): + # TODO: should be `list[A | None]`? + reveal_type(x) # revealed: list[Divergent] + # TODO: should be `A | None`? + reveal_type(x[0]) # revealed: Divergent ``` diff --git a/crates/ty_python_semantic/src/dunder_all.rs b/crates/ty_python_semantic/src/dunder_all.rs index 20e73260d49fd..10eab9321a919 100644 --- a/crates/ty_python_semantic/src/dunder_all.rs +++ b/crates/ty_python_semantic/src/dunder_all.rs @@ -7,7 +7,7 @@ use ruff_python_ast::statement_visitor::{StatementVisitor, walk_stmt}; use ruff_python_ast::{self as ast}; use crate::semantic_index::{SemanticIndex, semantic_index}; -use crate::types::{DivergentType, Truthiness, Type, TypeContext, infer_expression_types}; +use crate::types::{Truthiness, Type, TypeContext, infer_expression_types}; use crate::{Db, ModuleName, resolve_module}; #[allow(clippy::ref_option)] @@ -182,16 +182,8 @@ impl<'db> DunderAllNamesCollector<'db> { /// /// This function panics if `expr` was not marked as a standalone expression during semantic indexing. fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> { - // QUESTION: Could there be a case for divergence here? - let cycle_recovery = - DivergentType::should_not_diverge(self.db, self.index.expression(expr).scope(self.db)); - infer_expression_types( - self.db, - self.index.expression(expr), - TypeContext::default(), - cycle_recovery, - ) - .expression_type(expr) + infer_expression_types(self.db, self.index.expression(expr), TypeContext::default()) + .expression_type(expr) } /// Evaluate the given expression and return its truthiness. diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index 986e514b9277b..fc57d0b2f6b44 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -25,9 +25,7 @@ pub use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::scope::{ NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopeKind, ScopeLaziness, }; -use crate::semantic_index::symbol::{ - ImplicitAttributeTable, ImplicitAttributeTableBuilder, ScopedSymbolId, Symbol, -}; +use crate::semantic_index::symbol::ScopedSymbolId; use crate::semantic_index::use_def::{EnclosingSnapshotKey, ScopedEnclosingSnapshotId, UseDefMap}; use crate::semantic_model::HasTrackedScope; @@ -75,29 +73,6 @@ pub(crate) fn place_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc( - db: &'db dyn Db, - class_body_scope: ScopeId<'db>, -) -> ImplicitAttributeTable { - let mut table = ImplicitAttributeTableBuilder::default(); - - let file = class_body_scope.file(db); - let index = semantic_index(db, file); - - for function_scope_id in attribute_scopes(db, class_body_scope) { - let function_place_table = index.place_table(function_scope_id); - for member in function_place_table.members() { - if let Some(attr) = member.as_instance_attribute() { - let symbol = Symbol::new(attr.into()); - table.add(symbol); - } - } - } - - table.build() -} - /// Returns the set of modules that are imported anywhere in `file`. /// /// This set only considers `import` statements, not `from...import` statements, because: diff --git a/crates/ty_python_semantic/src/semantic_index/definition.rs b/crates/ty_python_semantic/src/semantic_index/definition.rs index b06390fa8479e..46065649aae4e 100644 --- a/crates/ty_python_semantic/src/semantic_index/definition.rs +++ b/crates/ty_python_semantic/src/semantic_index/definition.rs @@ -23,6 +23,7 @@ use crate::unpack::{Unpack, UnpackPosition}; /// before this `Definition`. However, the ID can be considered stable and it is okay to use /// `Definition` in cross-module` salsa queries or as a field on other salsa tracked structs. #[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] +#[derive(PartialOrd, Ord)] pub struct Definition<'db> { /// The file in which the definition occurs. pub file: File, diff --git a/crates/ty_python_semantic/src/semantic_index/expression.rs b/crates/ty_python_semantic/src/semantic_index/expression.rs index 3f6f159d179f9..4c598153ddf48 100644 --- a/crates/ty_python_semantic/src/semantic_index/expression.rs +++ b/crates/ty_python_semantic/src/semantic_index/expression.rs @@ -32,6 +32,7 @@ pub(crate) enum ExpressionKind { /// * a field of a type that is a return type of a cross-module query /// * an argument of a cross-module query #[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] +#[derive(PartialOrd, Ord)] pub(crate) struct Expression<'db> { /// The file in which the expression occurs. pub(crate) file: File, diff --git a/crates/ty_python_semantic/src/semantic_index/symbol.rs b/crates/ty_python_semantic/src/semantic_index/symbol.rs index c0120eb10a5c2..a0f3c33b090dd 100644 --- a/crates/ty_python_semantic/src/semantic_index/symbol.rs +++ b/crates/ty_python_semantic/src/semantic_index/symbol.rs @@ -268,7 +268,3 @@ impl DerefMut for SymbolTableBuilder { &mut self.table } } - -pub(crate) type ScopedImplicitAttributeId = ScopedSymbolId; -pub(crate) type ImplicitAttributeTable = SymbolTable; -pub(super) type ImplicitAttributeTableBuilder = SymbolTableBuilder; diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 21cffd6a942f5..fe33502a1f297 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -34,12 +34,10 @@ use crate::place::{Boundness, Place, PlaceAndQualifiers, imported_symbol}; use crate::semantic_index::definition::Definition; use crate::semantic_index::place::ScopedPlaceId; use crate::semantic_index::scope::ScopeId; -use crate::semantic_index::symbol::ScopedImplicitAttributeId; -use crate::semantic_index::{ - implicit_attribute_table, imported_modules, place_table, semantic_index, -}; +use crate::semantic_index::{imported_modules, place_table, semantic_index}; use crate::suppression::check_suppressions; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; +use crate::types::class::MethodDecorator; pub(crate) use crate::types::class_base::ClassBase; use crate::types::constraints::{ ConstraintSet, IteratorConstraintsExtension, OptionConstraintsExtension, @@ -61,7 +59,7 @@ pub use crate::types::ide_support::{ definitions_for_attribute, definitions_for_imported_symbol, definitions_for_keyword_argument, definitions_for_name, find_active_signature_from_details, inlay_hint_function_argument_details, }; -use crate::types::infer::infer_function_scope_types; +use crate::types::infer::InferExpression; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature}; @@ -69,7 +67,7 @@ use crate::types::tuple::TupleSpec; pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type}; use crate::types::variance::{TypeVarVariance, VarianceInferable}; use crate::types::visitor::any_over_type; -use crate::unpack::EvaluationMode; +use crate::unpack::{EvaluationMode, Unpack}; pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic; use crate::{Db, FxOrderSet, Module, Program}; pub(crate) use class::{ClassLiteral, ClassType, GenericAlias, KnownClass}; @@ -120,13 +118,15 @@ fn return_type_cycle_recover<'db>( } fn return_type_cycle_initial<'db>(db: &'db dyn Db, method: BoundMethodType<'db>) -> Type<'db> { - Type::divergent(DivergentType::function_return_type( + Type::divergent(DivergentType::new( db, - method - .function(db) - .literal(db) - .last_definition(db) - .body_scope(db), + DivergenceKind::InferReturnType( + method + .function(db) + .literal(db) + .last_definition(db) + .body_scope(db), + ), )) } @@ -241,19 +241,13 @@ pub(crate) type TryBoolVisitor<'db> = CycleDetector, Result>>; pub(crate) struct TryBool; -#[derive(Default)] +#[derive(Default, Copy, Clone, Debug)] pub(crate) enum NormalizationKind<'db> { #[default] Normal, Recursive(Type<'db>), } -impl NormalizationKind<'_> { - pub(crate) fn is_recursive(&self) -> bool { - matches!(self, Self::Recursive(_)) - } -} - /// A [`TypeTransformer`] that is used in `normalized` methods. #[derive(Default)] pub(crate) struct NormalizedVisitor<'db> { @@ -285,10 +279,6 @@ impl<'db> NormalizedVisitor<'db> { fn level(&self) -> usize { self.transformer.level() } - - fn is_recursive(&self) -> bool { - self.kind.is_recursive() - } } /// How a generic type has been specialized. @@ -350,7 +340,7 @@ enum InstanceFallbackShadowsNonDataDescriptor { } bitflags! { - #[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)] + #[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub(crate) struct MemberLookupPolicy: u8 { /// Dunder methods are looked up on the meta-type of a type without potentially falling /// back on attributes on the type itself. For example, when implicitly invoked on an @@ -413,6 +403,8 @@ impl Default for MemberLookupPolicy { } } +impl get_size2::GetSize for MemberLookupPolicy {} + fn member_lookup_cycle_recover<'db>( _db: &'db dyn Db, _value: &PlaceAndQualifiers<'db>, @@ -427,29 +419,19 @@ fn member_lookup_cycle_recover<'db>( #[allow(clippy::needless_pass_by_value)] fn member_lookup_cycle_initial<'db>( db: &'db dyn Db, - self_ty: Type<'db>, + self_type: Type<'db>, name: Name, - _policy: MemberLookupPolicy, + policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { - let class_body_scope = match self_ty { - Type::NominalInstance(instance) => instance.class_literal(db).body_scope(db), - Type::GenericAlias(alias) => alias.origin(db).body_scope(db), - Type::ClassLiteral(class) => class.body_scope(db), - _ => { - return Place::bound(Type::Never).into(); - } - }; - let implicit_attribute_table = implicit_attribute_table(db, class_body_scope); - let initial = if let Some(implicit_attribute) = implicit_attribute_table.symbol_id(&name) { - Type::divergent(DivergentType::implicit_attribute( - db, - class_body_scope, - implicit_attribute, - )) - } else { - Type::Never - }; - Place::bound(initial).into() + Place::bound(Type::divergent(DivergentType::new( + db, + DivergenceKind::MemberLookupWithPolicy { + self_type, + name, + policy, + }, + ))) + .into() } fn class_lookup_cycle_recover<'db>( @@ -466,29 +448,19 @@ fn class_lookup_cycle_recover<'db>( #[allow(clippy::needless_pass_by_value)] fn class_lookup_cycle_initial<'db>( db: &'db dyn Db, - self_ty: Type<'db>, + self_type: Type<'db>, name: Name, - _policy: MemberLookupPolicy, + policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { - let class_body_scope = match self_ty { - Type::NominalInstance(instance) => instance.class_literal(db).body_scope(db), - Type::GenericAlias(alias) => alias.origin(db).body_scope(db), - Type::ClassLiteral(class) => class.body_scope(db), - _ => { - return Place::bound(Type::Never).into(); - } - }; - let implicit_attribute_table = implicit_attribute_table(db, class_body_scope); - let initial = if let Some(implicit_attribute) = implicit_attribute_table.symbol_id(&name) { - Type::divergent(DivergentType::implicit_attribute( - db, - class_body_scope, - implicit_attribute, - )) - } else { - Type::Never - }; - Place::bound(initial).into() + Place::bound(Type::divergent(DivergentType::new( + db, + DivergenceKind::ClassLookupWithPolicy { + self_type, + name, + policy, + }, + ))) + .into() } #[allow(clippy::trivially_copy_pass_by_ref)] @@ -642,6 +614,16 @@ impl<'db> PropertyInstanceType<'db> { ) } + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + Self::new( + db, + self.getter(db) + .map(|ty| ty.recursive_type_normalized(db, visitor)), + self.setter(db) + .map(|ty| ty.recursive_type_normalized(db, visitor)), + ) + } + fn find_legacy_typevars_impl( self, db: &'db dyn Db, @@ -1257,21 +1239,8 @@ impl<'db> Type<'db> { #[must_use] pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { - if let NormalizationKind::Recursive(div) = visitor.kind { - if visitor.level() == 0 && self == div { - // int | Divergent = int | (int | (int | ...)) = int - return Type::Never; - } else if visitor.level() >= 1 && self.has_divergent_type(db, div) { - // G[G[Divergent]] = G[Divergent] - return div; - } - } match self { - Type::Union(union) => { - // As explained above, `Divergent` in a union type does not mean true divergence, - // so we normalize the type while keeping the nesting level the same. - visitor.visit_no_shift(self, || union.normalized_impl(db, visitor)) - } + Type::Union(union) => visitor.visit(self, || union.normalized_impl(db, visitor)), Type::Intersection(intersection) => visitor.visit(self, || { Type::Intersection(intersection.normalized_impl(db, visitor)) }), @@ -1317,7 +1286,7 @@ impl<'db> Type<'db> { Type::TypeIs(type_is) => visitor.visit(self, || { type_is.with_type(db, type_is.return_type(db).normalized_impl(db, visitor)) }), - Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized(visitor.is_recursive())), + Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized_impl(visitor.kind)), Type::EnumLiteral(enum_literal) if is_single_member_enum(db, enum_literal.enum_class(db)) => { @@ -1328,8 +1297,102 @@ impl<'db> Type<'db> { // TODO: Normalize TypedDicts self } - Type::TypeAlias(alias) if !visitor.is_recursive() => { - alias.value_type(db).normalized_impl(db, visitor) + Type::TypeAlias(alias) => alias.value_type(db).normalized_impl(db, visitor), + Type::LiteralString + | Type::AlwaysFalsy + | Type::AlwaysTruthy + | Type::BooleanLiteral(_) + | Type::BytesLiteral(_) + | Type::EnumLiteral(_) + | Type::StringLiteral(_) + | Type::Never + | Type::WrapperDescriptor(_) + | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) + | Type::ModuleLiteral(_) + | Type::ClassLiteral(_) + | Type::SpecialForm(_) + | Type::IntLiteral(_) => self, + } + } + + #[must_use] + pub(crate) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + if let NormalizationKind::Recursive(div) = visitor.kind { + if visitor.level() == 0 && self == div { + // int | Divergent = int | (int | (int | ...)) = int + return Type::Never; + } else if visitor.level() >= 1 && self.has_divergent_type(db, div) { + // G[G[Divergent]] = G[Divergent] + return div; + } + } + match self { + Type::Union(union) => { + // As explained above, `Divergent` in a union type does not mean true divergence, + // so we normalize the type while keeping the nesting level the same. + visitor.visit_no_shift(self, || union.recursive_type_normalized(db, visitor)) + } + Type::Intersection(intersection) => visitor.visit(self, || { + Type::Intersection(intersection.recursive_type_normalized(db, visitor)) + }), + Type::Callable(callable) => visitor.visit(self, || { + Type::Callable(callable.recursive_type_normalized(db, visitor)) + }), + Type::ProtocolInstance(protocol) => visitor.visit(self, || { + Type::ProtocolInstance(protocol.recursive_type_normalized(db, visitor)) + }), + Type::NominalInstance(instance) => visitor.visit(self, || { + Type::NominalInstance(instance.recursive_type_normalized(db, visitor)) + }), + Type::FunctionLiteral(function) => visitor.visit(self, || { + Type::FunctionLiteral(function.recursive_type_normalized(db, visitor)) + }), + Type::PropertyInstance(property) => visitor.visit(self, || { + Type::PropertyInstance(property.recursive_type_normalized(db, visitor)) + }), + Type::KnownBoundMethod(method_kind) => visitor.visit(self, || { + Type::KnownBoundMethod(method_kind.recursive_type_normalized(db, visitor)) + }), + Type::BoundMethod(method) => visitor.visit(self, || { + Type::BoundMethod(method.recursive_type_normalized(db, visitor)) + }), + Type::BoundSuper(bound_super) => visitor.visit(self, || { + Type::BoundSuper(bound_super.recursive_type_normalized(db, visitor)) + }), + Type::GenericAlias(generic) => visitor.visit(self, || { + Type::GenericAlias(generic.recursive_type_normalized(db, visitor)) + }), + Type::SubclassOf(subclass_of) => visitor.visit(self, || { + Type::SubclassOf(subclass_of.recursive_type_normalized(db, visitor)) + }), + Type::TypeVar(bound_typevar) => visitor.visit(self, || { + Type::TypeVar(bound_typevar.recursive_type_normalized(db, visitor)) + }), + Type::NonInferableTypeVar(bound_typevar) => visitor.visit(self, || { + Type::NonInferableTypeVar(bound_typevar.recursive_type_normalized(db, visitor)) + }), + Type::KnownInstance(known_instance) => visitor.visit(self, || { + Type::KnownInstance(known_instance.recursive_type_normalized(db, visitor)) + }), + Type::TypeIs(type_is) => visitor.visit(self, || { + type_is.with_type( + db, + type_is + .return_type(db) + .recursive_type_normalized(db, visitor), + ) + }), + Type::Dynamic(dynamic) => { + Type::Dynamic(dynamic.recursive_type_normalized(visitor.kind)) + } + Type::TypedDict(_) => { + // TODO: Normalize TypedDicts + self } Type::TypeAlias(_) => self, Type::LiteralString @@ -3060,7 +3123,7 @@ impl<'db> Type<'db> { class_lookup_cycle_initial(db, self, name, policy).place { let visitor = NormalizedVisitor::default().recursive(div); - ty.normalized_impl(db, &visitor) + ty.recursive_type_normalized(db, &visitor) } else { ty } @@ -3851,7 +3914,7 @@ impl<'db> Type<'db> { member_lookup_cycle_initial(db, self, name, policy).place { let visotor = NormalizedVisitor::default().recursive(div); - ty.normalized_impl(db, &visotor) + ty.recursive_type_normalized(db, &visotor) } else { ty } @@ -6913,6 +6976,30 @@ impl<'db> TypeMapping<'_, 'db> { } } } + + fn recursive_type_normalized(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + match self { + TypeMapping::Specialization(specialization) => { + TypeMapping::Specialization(specialization.recursive_type_normalized(db, visitor)) + } + TypeMapping::PartialSpecialization(partial) => { + TypeMapping::PartialSpecialization(partial.recursive_type_normalized(db, visitor)) + } + TypeMapping::PromoteLiterals => TypeMapping::PromoteLiterals, + TypeMapping::BindLegacyTypevars(binding_context) => { + TypeMapping::BindLegacyTypevars(*binding_context) + } + TypeMapping::BindSelf(self_type) => { + TypeMapping::BindSelf(self_type.recursive_type_normalized(db, visitor)) + } + TypeMapping::MarkTypeVarsInferable(binding_context) => { + TypeMapping::MarkTypeVarsInferable(*binding_context) + } + TypeMapping::Materialize(materialization_kind) => { + TypeMapping::Materialize(*materialization_kind) + } + } + } } /// A Salsa-tracked constraint set. This is only needed to have something appropriately small to @@ -7027,6 +7114,30 @@ impl<'db> KnownInstanceType<'db> { } } + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + match self { + Self::SubscriptedProtocol(context) => { + Self::SubscriptedProtocol(context.recursive_type_normalized(db, visitor)) + } + Self::SubscriptedGeneric(context) => { + Self::SubscriptedGeneric(context.recursive_type_normalized(db, visitor)) + } + Self::TypeVar(typevar) => Self::TypeVar(typevar.recursive_type_normalized(db, visitor)), + Self::TypeAliasType(type_alias) => { + Self::TypeAliasType(type_alias.recursive_type_normalized(db, visitor)) + } + Self::Deprecated(deprecated) => { + // Nothing to normalize + Self::Deprecated(deprecated) + } + Self::Field(field) => Self::Field(field.recursive_type_normalized(db, visitor)), + Self::ConstraintSet(set) => { + // Nothing to normalize + Self::ConstraintSet(set) + } + } + } + fn class(self, db: &'db dyn Db) -> KnownClass { match self { Self::SubscriptedProtocol(_) | Self::SubscriptedGeneric(_) => KnownClass::SpecialForm, @@ -7122,22 +7233,42 @@ impl<'db> KnownInstanceType<'db> { } } -#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] -pub enum DivergenceKind { - /// Divergence in function return type inference. - FunctionReturnType, - /// Divergence in implicit attribute type inference. - ImplicitAttribute(ScopedImplicitAttributeId), - /// Type inference should not diverge, - /// or it should be of a kind that we have not yet handled. - ShouldNotDiverge, +#[allow(private_interfaces)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] +pub enum DivergenceKind<'db> { + /// Divergence from `{FunctionLiteral, BoundMethodType}::infer_return_type`. + InferReturnType(ScopeId<'db>), + /// Divergence from `ClassLiteral::implicit_attribute_inner`. + ImplicitAttribute { + class_body_scope: ScopeId<'db>, + name: String, + target_method_decorator: MethodDecorator, + }, + /// Divergence from `Type::member_lookup_with_policy`. + MemberLookupWithPolicy { + self_type: Type<'db>, + name: Name, + policy: MemberLookupPolicy, + }, + /// Divergence from `Type::class_lookup_with_policy`. + ClassLookupWithPolicy { + self_type: Type<'db>, + name: Name, + policy: MemberLookupPolicy, + }, + /// Divergence from `infer_expression_type_impl`. + InferExpression(InferExpression<'db>), + /// Divergence from `infer_expression_types_impl`. + InferExpressionTypes(InferExpression<'db>), + /// Divergence from `infer_definition_types`. + InferDefinitionTypes(Definition<'db>), + /// Divergence from `infer_scope_types`. + InferScopeTypes(ScopeId<'db>), + /// Divergence from `infer_unpack_types`. + InferUnpackTypes(Unpack<'db>), } -impl DivergenceKind { - pub fn should_not_diverge(self) -> bool { - matches!(self, Self::ShouldNotDiverge) - } -} +pub(crate) type CycleRecoveryType<'db> = Type<'db>; /// A type that is determined to be divergent during recursive type inference. /// This type must never be eliminated by dynamic type reduction @@ -7146,35 +7277,14 @@ impl DivergenceKind { /// For detailed properties of this type, see the unit test at the end of the file. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] pub struct DivergentType<'db> { - /// The scope where this divergence was detected. - /// * The function scope in case of function return type inference - /// * The class body scope in case of implicit attribute type inference - scope: ScopeId<'db>, /// The kind of divergence. - kind: DivergenceKind, + #[returns(ref)] + kind: DivergenceKind<'db>, } // The Salsa heap is tracked separately. impl get_size2::GetSize for DivergentType<'_> {} -impl<'db> DivergentType<'db> { - pub(crate) fn function_return_type(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { - Self::new(db, scope, DivergenceKind::FunctionReturnType) - } - - pub(crate) fn implicit_attribute( - db: &'db dyn Db, - scope: ScopeId<'db>, - attribute: ScopedImplicitAttributeId, - ) -> Self { - Self::new(db, scope, DivergenceKind::ImplicitAttribute(attribute)) - } - - pub(crate) fn should_not_diverge(db: &'db dyn Db, scope: ScopeId<'db>) -> Self { - Self::new(db, scope, DivergenceKind::ShouldNotDiverge) - } -} - #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] pub enum DynamicType<'db> { /// An explicitly annotated `typing.Any` @@ -7203,17 +7313,23 @@ pub enum DynamicType<'db> { Divergent(DivergentType<'db>), } -impl DynamicType<'_> { - fn normalized(self, is_recursive: bool) -> Self { - if is_recursive { - return self; - } - if matches!(self, Self::Divergent(_)) { - self - } else { - Self::Any +impl<'db> DynamicType<'db> { + fn normalized_impl(self, kind: NormalizationKind<'db>) -> Self { + match kind { + NormalizationKind::Recursive(_) => self, + NormalizationKind::Normal => { + if matches!(self, Self::Divergent(_)) { + self + } else { + Self::Any + } + } } } + + fn recursive_type_normalized(self, _kind: NormalizationKind<'db>) -> Self { + self + } } impl std::fmt::Display for DynamicType<'_> { @@ -7548,6 +7664,15 @@ impl<'db> FieldInstance<'db> { self.kw_only(db), ) } + + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + FieldInstance::new( + db, + self.default_type(db).recursive_type_normalized(db, visitor), + self.init(db), + self.kw_only(db), + ) + } } /// Whether this typevar was created via the legacy `TypeVar` constructor, using PEP 695 syntax, @@ -7690,9 +7815,6 @@ impl<'db> TypeVarInstance<'db> { } pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { - if visitor.is_recursive() { - return self; - } Self::new( db, self.name(db), @@ -7720,6 +7842,14 @@ impl<'db> TypeVarInstance<'db> { ) } + fn recursive_type_normalized( + self, + _db: &'db dyn Db, + _visitor: &NormalizedVisitor<'db>, + ) -> Self { + self + } + fn materialize_impl( self, db: &'db dyn Db, @@ -7794,7 +7924,7 @@ impl<'db> TypeVarInstance<'db> { Some(TypeVarBoundOrConstraints::UpperBound(ty)) } - #[salsa::tracked] + #[salsa::tracked(cycle_fn=lazy_constraint_cycle_recover, cycle_initial=lazy_constraint_cycle_initial)] fn lazy_constraints(self, db: &'db dyn Db) -> Option> { let definition = self.definition(db)?; let module = parsed_module(db, definition.file(db)).load(db); @@ -7804,7 +7934,7 @@ impl<'db> TypeVarInstance<'db> { Some(TypeVarBoundOrConstraints::Constraints(ty)) } - #[salsa::tracked] + #[salsa::tracked(cycle_fn=lazy_default_cycle_recover, cycle_initial=lazy_default_cycle_initial)] fn lazy_default(self, db: &'db dyn Db) -> Option> { let definition = self.definition(db)?; let module = parsed_module(db, definition.file(db)).load(db); @@ -7834,6 +7964,40 @@ fn lazy_bound_cycle_initial<'db>( None } +#[allow(clippy::ref_option)] +fn lazy_constraint_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Option>, + _count: u32, + _self: TypeVarInstance<'db>, +) -> salsa::CycleRecoveryAction>> { + salsa::CycleRecoveryAction::Iterate +} + +fn lazy_constraint_cycle_initial<'db>( + _db: &'db dyn Db, + _self: TypeVarInstance<'db>, +) -> Option> { + None +} + +#[allow(clippy::ref_option)] +fn lazy_default_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Option>, + _count: u32, + _self: TypeVarInstance<'db>, +) -> salsa::CycleRecoveryAction>> { + salsa::CycleRecoveryAction::Iterate +} + +fn lazy_default_cycle_initial<'db>( + _db: &'db dyn Db, + _self: TypeVarInstance<'db>, +) -> Option> { + None +} + /// Where a type variable is bound and usable. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub enum BindingContext<'db> { @@ -7960,6 +8124,14 @@ impl<'db> BoundTypeVarInstance<'db> { ) } + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + Self::new( + db, + self.typevar(db).recursive_type_normalized(db, visitor), + self.binding_context(db), + ) + } + fn materialize_impl( self, db: &'db dyn Db, @@ -9212,8 +9384,7 @@ impl<'db> BoundMethodType<'db> { .literal(db) .last_definition(db) .body_scope(db); - let inference = - infer_function_scope_types(db, scope, DivergentType::function_return_type(db, scope)); + let inference = infer_scope_types(db, scope); inference.infer_return_type(db, Type::BoundMethod(self)) } @@ -9284,6 +9455,15 @@ impl<'db> BoundMethodType<'db> { ) } + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + Self::new( + db, + self.function(db).recursive_type_normalized(db, visitor), + self.self_instance(db) + .recursive_type_normalized(db, visitor), + ) + } + fn has_relation_to_impl( self, db: &'db dyn Db, @@ -9411,6 +9591,14 @@ impl<'db> CallableType<'db> { ) } + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + CallableType::new( + db, + self.signatures(db).recursive_type_normalized(db, visitor), + self.is_function_like(db), + ) + } + fn apply_type_mapping_impl<'a>( self, db: &'db dyn Db, @@ -9631,6 +9819,32 @@ impl<'db> KnownBoundMethodType<'db> { } } + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + match self { + KnownBoundMethodType::FunctionTypeDunderGet(function) => { + KnownBoundMethodType::FunctionTypeDunderGet( + function.recursive_type_normalized(db, visitor), + ) + } + KnownBoundMethodType::FunctionTypeDunderCall(function) => { + KnownBoundMethodType::FunctionTypeDunderCall( + function.recursive_type_normalized(db, visitor), + ) + } + KnownBoundMethodType::PropertyDunderGet(property) => { + KnownBoundMethodType::PropertyDunderGet( + property.recursive_type_normalized(db, visitor), + ) + } + KnownBoundMethodType::PropertyDunderSet(property) => { + KnownBoundMethodType::PropertyDunderSet( + property.recursive_type_normalized(db, visitor), + ) + } + KnownBoundMethodType::StrStartswith(_) => self, + } + } + /// Return the [`KnownClass`] that inhabitants of this type are instances of at runtime fn class(self) -> KnownClass { match self { @@ -10034,6 +10248,14 @@ impl<'db> PEP695TypeAliasType<'db> { fn normalized_impl(self, _db: &'db dyn Db, _visitor: &NormalizedVisitor<'db>) -> Self { self } + + fn recursive_type_normalized( + self, + _db: &'db dyn Db, + _visitor: &NormalizedVisitor<'db>, + ) -> Self { + self + } } #[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)] @@ -10100,6 +10322,15 @@ impl<'db> ManualPEP695TypeAliasType<'db> { self.value(db).normalized_impl(db, visitor), ) } + + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + Self::new( + db, + self.name(db), + self.definition(db), + self.value(db).recursive_type_normalized(db, visitor), + ) + } } #[derive( @@ -10139,6 +10370,17 @@ impl<'db> TypeAliasType<'db> { } } + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + match self { + TypeAliasType::PEP695(type_alias) => { + TypeAliasType::PEP695(type_alias.recursive_type_normalized(db, visitor)) + } + TypeAliasType::ManualPEP695(type_alias) => { + TypeAliasType::ManualPEP695(type_alias.recursive_type_normalized(db, visitor)) + } + } + } + pub(crate) fn name(self, db: &'db dyn Db) -> &'db str { match self { TypeAliasType::PEP695(type_alias) => type_alias.name(db), @@ -10413,13 +10655,30 @@ impl<'db> UnionType<'db> { .map(|ty| ty.normalized_impl(db, visitor)) .fold( UnionBuilder::new(db) - .order_elements(!visitor.is_recursive()) + .order_elements(true) .unpack_aliases(true), UnionBuilder::add, ) .build() } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Type<'db> { + self.elements(db) + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .fold( + UnionBuilder::new(db) + .order_elements(false) + .unpack_aliases(false), + UnionBuilder::add, + ) + .build() + } + pub(crate) fn is_equivalent_to_impl( self, db: &'db dyn Db, @@ -10509,9 +10768,7 @@ impl<'db> IntersectionType<'db> { .map(|ty| ty.normalized_impl(db, visitor)) .collect(); - if !visitor.is_recursive() { - elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); - } + elements.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(db, l, r)); elements } @@ -10522,6 +10779,29 @@ impl<'db> IntersectionType<'db> { ) } + pub(crate) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + fn normalized_set<'db>( + db: &'db dyn Db, + elements: &FxOrderSet>, + visitor: &NormalizedVisitor<'db>, + ) -> FxOrderSet> { + elements + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect() + } + + IntersectionType::new( + db, + normalized_set(db, self.positive(db), visitor), + normalized_set(db, self.negative(db), visitor), + ) + } + /// Return `true` if `self` represents exactly the same set of possible runtime objects as `other` pub(crate) fn is_equivalent_to_impl( self, @@ -10788,7 +11068,7 @@ impl<'db> SuperOwnerKind<'db> { fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { SuperOwnerKind::Dynamic(dynamic) => { - SuperOwnerKind::Dynamic(dynamic.normalized(visitor.is_recursive())) + SuperOwnerKind::Dynamic(dynamic.normalized_impl(visitor.kind)) } SuperOwnerKind::Class(class) => { SuperOwnerKind::Class(class.normalized_impl(db, visitor)) @@ -10801,6 +11081,20 @@ impl<'db> SuperOwnerKind<'db> { } } + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + match self { + SuperOwnerKind::Dynamic(dynamic) => { + SuperOwnerKind::Dynamic(dynamic.recursive_type_normalized(visitor.kind)) + } + SuperOwnerKind::Class(class) => { + SuperOwnerKind::Class(class.recursive_type_normalized(db, visitor)) + } + SuperOwnerKind::Instance(instance) => { + SuperOwnerKind::Instance(instance.recursive_type_normalized(db, visitor)) + } + } + } + fn iter_mro(self, db: &'db dyn Db) -> impl Iterator> { match self { SuperOwnerKind::Dynamic(dynamic) => { @@ -11075,6 +11369,14 @@ impl<'db> BoundSuperType<'db> { self.owner(db).normalized_impl(db, visitor), ) } + + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + Self::new( + db, + self.pivot_class(db).recursive_type_normalized(db, visitor), + self.owner(db).recursive_type_normalized(db, visitor), + ) + } } #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] @@ -11242,7 +11544,10 @@ pub(crate) mod tests { let file_scope_id = FileScopeId::global(); let scope = file_scope_id.to_scope_id(&db, file); - let div = Type::divergent(DivergentType::function_return_type(&db, scope)); + let div = Type::divergent(DivergentType::new( + &db, + DivergenceKind::InferReturnType(scope), + )); // The `Divergent` type must not be eliminated in union with other dynamic types, // as this would prevent detection of divergent type inference using `Divergent`. @@ -11281,12 +11586,12 @@ pub(crate) mod tests { nested_rec.display(&db).to_string(), "list[list[Divergent] | None]" ); - let normalized = nested_rec.normalized_impl(&db, &visitor); + let normalized = nested_rec.recursive_type_normalized(&db, &visitor); assert_eq!(normalized.display(&db).to_string(), "list[Divergent]"); let union = UnionType::from_elements(&db, [div, KnownClass::Int.to_instance(&db)]); assert_eq!(union.display(&db).to_string(), "Divergent | int"); - let normalized = union.normalized_impl(&db, &visitor); + let normalized = union.recursive_type_normalized(&db, &visitor); assert_eq!(normalized.display(&db).to_string(), "int"); // The same can be said about intersections for the `Never` type. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 25193e77a31b3..ffe2dcdd887df 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -13,7 +13,6 @@ use crate::semantic_index::scope::NodeWithScopeKind; use crate::semantic_index::symbol::Symbol; use crate::semantic_index::{ DeclarationWithConstraint, SemanticIndex, attribute_declarations, attribute_scopes, - implicit_attribute_table, }; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::context::InferContext; @@ -21,16 +20,13 @@ use crate::types::diagnostic::{INVALID_LEGACY_TYPE_VARIABLE, INVALID_TYPE_ALIAS_ use crate::types::enums::enum_metadata; use crate::types::function::{DataclassTransformerParams, KnownFunction}; use crate::types::generics::{GenericContext, Specialization, walk_specialization}; -use crate::types::infer::{ - infer_implicit_attribute_expression_type, infer_unpack_implicit_attribute_types, - nearest_enclosing_class, -}; +use crate::types::infer::{infer_expression_type, infer_unpack_types, nearest_enclosing_class}; use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, - DataclassParams, DeprecatedInstance, DivergentType, FindLegacyTypeVarsVisitor, + DataclassParams, DeprecatedInstance, DivergenceKind, DivergentType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, @@ -115,20 +111,15 @@ fn implicit_attribute_initial<'db>( db: &'db dyn Db, class_body_scope: ScopeId<'db>, name: String, - _target_method_decorator: MethodDecorator, + target_method_decorator: MethodDecorator, ) -> PlaceAndQualifiers<'db> { - let implicit_attribute_table = implicit_attribute_table(db, class_body_scope); - let Some(implicit_attr_id) = implicit_attribute_table.symbol_id(&name) else { - return Place::bound(Type::divergent(DivergentType::should_not_diverge( - db, - class_body_scope, - ))) - .into(); - }; - Place::bound(Type::divergent(DivergentType::implicit_attribute( + Place::bound(Type::divergent(DivergentType::new( db, - class_body_scope, - implicit_attr_id, + DivergenceKind::ImplicitAttribute { + class_body_scope, + name, + target_method_decorator, + }, ))) .into() } @@ -300,6 +291,19 @@ impl<'db> GenericAlias<'db> { ) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + Self::new( + db, + self.origin(db), + self.specialization(db) + .recursive_type_normalized(db, visitor), + ) + } + pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { self.origin(db).definition(db) } @@ -425,6 +429,17 @@ impl<'db> ClassType<'db> { } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + match self { + Self::NonGeneric(_) => self, + Self::Generic(generic) => Self::Generic(generic.recursive_type_normalized(db, visitor)), + } + } + pub(super) fn has_pep_695_type_params(self, db: &'db dyn Db) -> bool { match self { Self::NonGeneric(class) => class.has_pep_695_type_params(db), @@ -1265,7 +1280,7 @@ impl<'db> VarianceInferable<'db> for ClassType<'db> { /// A filter that describes which methods are considered when looking for implicit attribute assignments /// in [`ClassLiteral::implicit_attribute`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, get_size2::GetSize)] pub(super) enum MethodDecorator { None, ClassMethod, @@ -2909,13 +2924,14 @@ impl<'db> ClassLiteral<'db> { let index = semantic_index(db, file); let class_map = use_def_map(db, class_body_scope); let class_table = place_table(db, class_body_scope); - let implicit_attr_table = implicit_attribute_table(db, class_body_scope); - let div = if let Some(implicit_attr) = implicit_attr_table.symbol_id(&name) { - DivergentType::implicit_attribute(db, class_body_scope, implicit_attr) - } else { - // Implicit attributes should either not exist or have a type declaration, and should not diverge. - DivergentType::should_not_diverge(db, class_body_scope) - }; + let div = DivergentType::new( + db, + DivergenceKind::ImplicitAttribute { + class_body_scope, + name: name.clone(), + target_method_decorator, + }, + ); let visitor = NormalizedVisitor::default().recursive(Type::divergent(div)); let is_valid_scope = |method_scope: ScopeId<'db>| { @@ -2969,13 +2985,12 @@ impl<'db> ClassLiteral<'db> { // `self.SOME_CONSTANT: Final = 1`, infer the type from the value // on the right-hand side. - let inferred_ty = infer_implicit_attribute_expression_type( + let inferred_ty = infer_expression_type( db, index.expression(value), TypeContext::default(), - div, ); - return Place::bound(inferred_ty.normalized_impl(db, &visitor)) + return Place::bound(inferred_ty.recursive_type_normalized(db, &visitor)) .with_qualifiers(all_qualifiers); } @@ -3001,8 +3016,8 @@ impl<'db> ClassLiteral<'db> { { // In fixed-point iteration of type inference, the attribute type must be monotonically widened and not "oscillate". // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. - union_of_inferred_types = - union_of_inferred_types.add(previous_cycle_type.normalized_impl(db, &visitor)); + union_of_inferred_types = union_of_inferred_types + .add(previous_cycle_type.recursive_type_normalized(db, &visitor)); } for (attribute_assignments, method_scope_id) in @@ -3062,28 +3077,26 @@ impl<'db> ClassLiteral<'db> { // (.., self.name, ..) = // [.., self.name, ..] = - let unpacked = - infer_unpack_implicit_attribute_types(db, unpack, div); + let unpacked = infer_unpack_types(db, unpack); let inferred_ty = unpacked.expression_type(assign.target(&module)); union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.normalized_impl(db, &visitor)); + .add(inferred_ty.recursive_type_normalized(db, &visitor)); } TargetKind::Single => { // We found an un-annotated attribute assignment of the form: // // self.name = - let inferred_ty = infer_implicit_attribute_expression_type( + let inferred_ty = infer_expression_type( db, index.expression(assign.value(&module)), TypeContext::default(), - div, ); union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.normalized_impl(db, &visitor)); + .add(inferred_ty.recursive_type_normalized(db, &visitor)); } } } @@ -3094,31 +3107,29 @@ impl<'db> ClassLiteral<'db> { // // for .., self.name, .. in : - let unpacked = - infer_unpack_implicit_attribute_types(db, unpack, div); + let unpacked = infer_unpack_types(db, unpack); let inferred_ty = unpacked.expression_type(for_stmt.target(&module)); union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.normalized_impl(db, &visitor)); + .add(inferred_ty.recursive_type_normalized(db, &visitor)); } TargetKind::Single => { // We found an attribute assignment like: // // for self.name in : - let iterable_ty = infer_implicit_attribute_expression_type( + let iterable_ty = infer_expression_type( db, index.expression(for_stmt.iterable(&module)), TypeContext::default(), - div, ); // TODO: Potential diagnostics resulting from the iterable are currently not reported. let inferred_ty = iterable_ty.iterate(db).homogeneous_element_type(db); union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.normalized_impl(db, &visitor)); + .add(inferred_ty.recursive_type_normalized(db, &visitor)); } } } @@ -3129,24 +3140,22 @@ impl<'db> ClassLiteral<'db> { // // with as .., self.name, ..: - let unpacked = - infer_unpack_implicit_attribute_types(db, unpack, div); + let unpacked = infer_unpack_types(db, unpack); let inferred_ty = unpacked.expression_type(with_item.target(&module)); union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.normalized_impl(db, &visitor)); + .add(inferred_ty.recursive_type_normalized(db, &visitor)); } TargetKind::Single => { // We found an attribute assignment like: // // with as self.name: - let context_ty = infer_implicit_attribute_expression_type( + let context_ty = infer_expression_type( db, index.expression(with_item.context_expr(&module)), TypeContext::default(), - div, ); let inferred_ty = if with_item.is_async() { context_ty.aenter(db) @@ -3155,7 +3164,7 @@ impl<'db> ClassLiteral<'db> { }; union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.normalized_impl(db, &visitor)); + .add(inferred_ty.recursive_type_normalized(db, &visitor)); } } } @@ -3166,32 +3175,30 @@ impl<'db> ClassLiteral<'db> { // // [... for .., self.name, .. in ] - let unpacked = - infer_unpack_implicit_attribute_types(db, unpack, div); + let unpacked = infer_unpack_types(db, unpack); let inferred_ty = unpacked.expression_type(comprehension.target(&module)); union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.normalized_impl(db, &visitor)); + .add(inferred_ty.recursive_type_normalized(db, &visitor)); } TargetKind::Single => { // We found an attribute assignment like: // // [... for self.name in ] - let iterable_ty = infer_implicit_attribute_expression_type( + let iterable_ty = infer_expression_type( db, index.expression(comprehension.iterable(&module)), TypeContext::default(), - div, ); // TODO: Potential diagnostics resulting from the iterable are currently not reported. let inferred_ty = iterable_ty.iterate(db).homogeneous_element_type(db); union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.normalized_impl(db, &visitor)); + .add(inferred_ty.recursive_type_normalized(db, &visitor)); } } } diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 94ae00c1cdd49..7d0fbed325895 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -37,12 +37,26 @@ impl<'db> ClassBase<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { - Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized(visitor.is_recursive())), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized_impl(visitor.kind)), Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), Self::Protocol | Self::Generic | Self::TypedDict => self, } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + match self { + Self::Dynamic(dynamic) => { + Self::Dynamic(dynamic.recursive_type_normalized(visitor.kind)) + } + Self::Class(class) => Self::Class(class.recursive_type_normalized(db, visitor)), + Self::Protocol | Self::Generic | Self::TypedDict => self, + } + } + pub(crate) fn name(self, db: &'db dyn Db) -> &'db str { match self { ClassBase::Class(class) => class.name(db), diff --git a/crates/ty_python_semantic/src/types/context.rs b/crates/ty_python_semantic/src/types/context.rs index 41644b3860f44..d13d99a75ed76 100644 --- a/crates/ty_python_semantic/src/types/context.rs +++ b/crates/ty_python_semantic/src/types/context.rs @@ -98,10 +98,6 @@ impl<'db, 'ast> InferContext<'db, 'ast> { self.diagnostics.get_mut().extend(other); } - pub(crate) fn take_diagnostics(&mut self) -> TypeCheckDiagnostics { - self.diagnostics.take() - } - /// Optionally return a builder for a lint diagnostic guard. /// /// If the current context believes a diagnostic should be reported for diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 084b84a572c4b..b1d20d42aef7d 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -73,13 +73,13 @@ use crate::types::diagnostic::{ report_runtime_check_against_non_runtime_checkable_protocol, }; use crate::types::generics::GenericContext; -use crate::types::infer::infer_function_scope_types; +use crate::types::infer::infer_scope_types; use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, - DeprecatedInstance, DivergentType, DynamicType, FindLegacyTypeVarsVisitor, + DeprecatedInstance, DivergenceKind, DivergentType, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, SpecialFormType, TrackedConstraintSet, Truthiness, Type, TypeMapping, TypeRelation, UnionBuilder, all_members, binding_type, todo_type, walk_generic_context, walk_type_mapping, @@ -96,9 +96,9 @@ fn return_type_cycle_recover<'db>( } fn return_type_cycle_initial<'db>(db: &'db dyn Db, function: FunctionType<'db>) -> Type<'db> { - Type::Dynamic(DynamicType::Divergent(DivergentType::function_return_type( + Type::Dynamic(DynamicType::Divergent(DivergentType::new( db, - function.literal(db).last_definition(db).body_scope(db), + DivergenceKind::InferReturnType(function.literal(db).last_definition(db).body_scope(db)), ))) } @@ -699,6 +699,13 @@ impl<'db> FunctionLiteral<'db> { .map(|ctx| ctx.normalized_impl(db, visitor)); Self::new(db, self.last_definition(db), context) } + + fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + let context = self + .inherited_generic_context(db) + .map(|ctx| ctx.recursive_type_normalized(db, visitor)); + Self::new(db, self.last_definition(db), context) + } } /// Represents a function type, which might be a non-generic function, or a specialization of a @@ -1042,12 +1049,28 @@ impl<'db> FunctionType<'db> { Self::new(db, self.literal(db).normalized_impl(db, visitor), mappings) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + let mappings: Box<_> = self + .type_mappings(db) + .iter() + .map(|mapping| mapping.recursive_type_normalized(db, visitor)) + .collect(); + Self::new( + db, + self.literal(db).recursive_type_normalized(db, visitor), + mappings, + ) + } + /// Infers this function scope's types and returns the inferred return type. #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.literal(db).last_definition(db).body_scope(db); - let inference = - infer_function_scope_types(db, scope, DivergentType::function_return_type(db, scope)); + let inference = infer_scope_types(db, scope); inference.infer_return_type(db, Type::FunctionLiteral(self)) } } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 9b87da7fd951d..3207ad64afe38 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -399,6 +399,19 @@ impl<'db> GenericContext<'db> { Self::from_typevar_instances(db, variables) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + let variables = self + .variables(db) + .iter() + .map(|bound_typevar| bound_typevar.recursive_type_normalized(db, visitor)); + + Self::from_typevar_instances(db, variables) + } + fn heap_size((variables,): &(FxOrderSet>,)) -> usize { ruff_memory_usage::order_set_heap_size(variables) } @@ -729,6 +742,31 @@ impl<'db> Specialization<'db> { ) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + let types: Box<[_]> = self + .types(db) + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect(); + let tuple_inner = self + .tuple_inner(db) + .map(|tuple| tuple.recursive_type_normalized(db, visitor)); + let context = self + .generic_context(db) + .recursive_type_normalized(db, visitor); + Self::new( + db, + context, + types, + self.materialization_kind(db), + tuple_inner, + ) + } + pub(super) fn materialize_impl( self, db: &'db dyn Db, @@ -988,6 +1026,24 @@ impl<'db> PartialSpecialization<'_, 'db> { types, } } + + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> PartialSpecialization<'db, 'db> { + let generic_context = self.generic_context.recursive_type_normalized(db, visitor); + let types: Cow<_> = self + .types + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect(); + + PartialSpecialization { + generic_context, + types, + } + } } /// Performs type inference between parameter annotations and argument types, producing a diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 0e513c2069b2a..29ed1a02dcc09 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -51,8 +51,8 @@ use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map}; use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - ClassLiteral, DivergentType, DynamicType, NormalizedVisitor, Truthiness, Type, - TypeAndQualifiers, UnionBuilder, + ClassLiteral, CycleRecoveryType, DivergenceKind, DivergentType, NormalizedVisitor, Truthiness, + Type, TypeAndQualifiers, UnionBuilder, UnionType, }; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -61,40 +61,12 @@ mod builder; #[cfg(test)] mod tests; -/// A scope that may be recursive. -#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -pub(super) struct PossiblyRecursiveScope<'db> { - scope: ScopeId<'db>, - cycle_recovery: DivergentType<'db>, -} - /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. -pub(crate) fn infer_scope_types<'db>( - db: &'db dyn Db, - scope: ScopeId<'db>, -) -> &'db ScopeInference<'db> { - infer_scope_types_impl( - db, - PossiblyRecursiveScope::new(db, scope, DivergentType::should_not_diverge(db, scope)), - ) -} - -pub(crate) fn infer_function_scope_types<'db>( - db: &'db dyn Db, - scope: ScopeId<'db>, - cycle_recovery: DivergentType<'db>, -) -> &'db ScopeInference<'db> { - infer_scope_types_impl(db, PossiblyRecursiveScope::new(db, scope, cycle_recovery)) -} - #[salsa::tracked(returns(ref), cycle_fn=scope_cycle_recover, cycle_initial=scope_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -fn infer_scope_types_impl<'db>( - db: &'db dyn Db, - scope: PossiblyRecursiveScope<'db>, -) -> ScopeInference<'db> { - let file = scope.scope(db).file(db); +pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { + let file = scope.file(db); let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered(); let module = parsed_module(db, file).load(db); @@ -103,90 +75,69 @@ fn infer_scope_types_impl<'db>( // The isolation of the query is by the return inferred types. let index = semantic_index(db, file); - TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope.scope(db)), index, &module) - .finish_scope() + TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index, &module).finish_scope() } fn scope_cycle_recover<'db>( _db: &'db dyn Db, _value: &ScopeInference<'db>, _count: u32, - _scope: PossiblyRecursiveScope<'db>, + _scope: ScopeId<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } -fn scope_cycle_initial<'db>( - db: &'db dyn Db, - scope: PossiblyRecursiveScope<'db>, -) -> ScopeInference<'db> { - ScopeInference::cycle_initial(scope.cycle_recovery(db), scope.scope(db)) -} - -/// A [`Definition`] that may be recursive. -#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -pub(super) struct PossiblyRecursiveDefinition<'db> { - definition: Definition<'db>, - cycle_recovery: DivergentType<'db>, +fn scope_cycle_initial<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { + ScopeInference::cycle_initial( + Type::divergent(DivergentType::new( + db, + DivergenceKind::InferScopeTypes(scope), + )), + scope, + ) } /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a place use or public type of a place. +#[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(crate) fn infer_definition_types<'db>( db: &'db dyn Db, definition: Definition<'db>, -) -> &'db DefinitionInference<'db> { - infer_definition_types_impl( - db, - PossiblyRecursiveDefinition::new( - db, - definition, - DivergentType::should_not_diverge(db, definition.scope(db)), - ), - ) -} - -#[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -fn infer_definition_types_impl<'db>( - db: &'db dyn Db, - definition: PossiblyRecursiveDefinition<'db>, ) -> DefinitionInference<'db> { - let file = definition.definition(db).file(db); + let file = definition.file(db); let module = parsed_module(db, file).load(db); let _span = tracing::trace_span!( "infer_definition_types", - range = ?definition.definition(db).kind(db).target_range(&module), + range = ?definition.kind(db).target_range(&module), ?file ) .entered(); let index = semantic_index(db, file); - TypeInferenceBuilder::new( - db, - InferenceRegion::Definition(definition.definition(db)), - index, - &module, - ) - .finish_definition() + TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index, &module) + .finish_definition() } fn definition_cycle_recover<'db>( _db: &'db dyn Db, _value: &DefinitionInference<'db>, _count: u32, - _definition: PossiblyRecursiveDefinition<'db>, + _definition: Definition<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } fn definition_cycle_initial<'db>( db: &'db dyn Db, - definition: PossiblyRecursiveDefinition<'db>, + definition: Definition<'db>, ) -> DefinitionInference<'db> { DefinitionInference::cycle_initial( - definition.definition(db).scope(db), - definition.cycle_recovery(db), + definition.scope(db), + Type::divergent(DivergentType::new( + db, + DivergenceKind::InferDefinitionTypes(definition), + )), ) } @@ -228,10 +179,7 @@ fn deferred_cycle_initial<'db>( db: &'db dyn Db, definition: Definition<'db>, ) -> DefinitionInference<'db> { - DefinitionInference::cycle_initial( - definition.scope(db), - DivergentType::should_not_diverge(db, definition.scope(db)), - ) + DefinitionInference::cycle_initial(definition.scope(db), Type::Never) } /// Infer all types for an [`Expression`] (including sub-expressions). @@ -242,12 +190,8 @@ pub(crate) fn infer_expression_types<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, - cycle_recovery: DivergentType<'db>, ) -> &'db ExpressionInference<'db> { - infer_expression_types_impl( - db, - InferExpression::new(db, expression, tcx, cycle_recovery), - ) + infer_expression_types_impl(db, InferExpression::new(db, expression, tcx)) } #[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] @@ -275,7 +219,7 @@ pub(super) fn infer_expression_types_impl<'db>( index, &module, ) - .finish_expression() + .finish_expression(input) } fn expression_cycle_recover<'db>( @@ -291,7 +235,11 @@ fn expression_cycle_initial<'db>( db: &'db dyn Db, input: InferExpression<'db>, ) -> ExpressionInference<'db> { - ExpressionInference::cycle_initial(input.expression(db).scope(db), input.cycle_recovery(db)) + let cycle_recovery = Type::divergent(DivergentType::new( + db, + DivergenceKind::InferExpressionTypes(input), + )); + ExpressionInference::cycle_initial(input.expression(db).scope(db), cycle_recovery) } /// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query. @@ -303,10 +251,9 @@ pub(super) fn infer_same_file_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, - cycle_recovery: DivergentType<'db>, parsed: &ParsedModuleRef, ) -> Type<'db> { - let inference = infer_expression_types(db, expression, tcx, cycle_recovery); + let inference = infer_expression_types(db, expression, tcx); inference.expression_type(expression.node_ref(db, parsed)) } @@ -322,27 +269,7 @@ pub(crate) fn infer_expression_type<'db>( expression: Expression<'db>, tcx: TypeContext<'db>, ) -> Type<'db> { - infer_expression_type_impl( - db, - InferExpression::new( - db, - expression, - tcx, - DivergentType::should_not_diverge(db, expression.scope(db)), - ), - ) -} - -pub(crate) fn infer_implicit_attribute_expression_type<'db>( - db: &'db dyn Db, - expression: Expression<'db>, - tcx: TypeContext<'db>, - cycle_recovery: DivergentType<'db>, -) -> Type<'db> { - infer_expression_type_impl( - db, - InferExpression::new(db, expression, tcx, cycle_recovery), - ) + infer_expression_type_impl(db, InferExpression::new(db, expression, tcx)) } #[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] @@ -352,7 +279,16 @@ fn infer_expression_type_impl<'db>(db: &'db dyn Db, input: InferExpression<'db>) // It's okay to call the "same file" version here because we're inside a salsa query. let inference = infer_expression_types_impl(db, input); - inference.expression_type(input.expression(db).node_ref(db, &module)) + + let div = Type::divergent(DivergentType::new( + db, + DivergenceKind::InferExpression(input), + )); + let previous_cycle_value = infer_expression_type_impl(db, input); + let visitor = NormalizedVisitor::default().recursive(div); + let result_ty = inference.expression_type(input.expression(db).node_ref(db, &module)); + UnionType::from_elements(db, [result_ty, previous_cycle_value]) + .recursive_type_normalized(db, &visitor) } fn single_expression_cycle_recover<'db>( @@ -365,17 +301,32 @@ fn single_expression_cycle_recover<'db>( } fn single_expression_cycle_initial<'db>(db: &'db dyn Db, input: InferExpression<'db>) -> Type<'db> { - Type::Dynamic(DynamicType::Divergent(input.cycle_recovery(db))) + Type::divergent(DivergentType::new( + db, + DivergenceKind::InferExpression(input), + )) } /// An `Expression` with an optional `TypeContext`. /// /// This is a Salsa supertype used as the input to `infer_expression_types` to avoid /// interning an `ExpressionWithContext` unnecessarily when no type context is provided. -#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update)] +#[derive( + Debug, + Clone, + Copy, + Eq, + Hash, + PartialEq, + PartialOrd, + Ord, + salsa::Supertype, + salsa::Update, + get_size2::GetSize, +)] pub(super) enum InferExpression<'db> { - Bare(PossiblyRecursiveExpression<'db>), - WithContext(PossiblyRecursiveExpressionWithContext<'db>), + Bare(Expression<'db>), + WithContext(ExpressionWithContext<'db>), } impl<'db> InferExpression<'db> { @@ -383,41 +334,18 @@ impl<'db> InferExpression<'db> { db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, - cycle_recovery: DivergentType<'db>, ) -> InferExpression<'db> { if tcx.annotation.is_some() { - InferExpression::WithContext(PossiblyRecursiveExpressionWithContext::new( - db, - expression, - tcx, - cycle_recovery, - )) + InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx)) } else { // Drop the empty `TypeContext` to avoid the interning cost. - InferExpression::Bare(PossiblyRecursiveExpression::new( - db, - expression, - cycle_recovery, - )) + InferExpression::Bare(expression) } } - #[cfg(test)] - pub(super) fn bare( - db: &'db dyn Db, - expression: Expression<'db>, - cycle_recovery: DivergentType<'db>, - ) -> InferExpression<'db> { - InferExpression::Bare(PossiblyRecursiveExpression::new( - db, - expression, - cycle_recovery, - )) - } - fn expression(self, db: &'db dyn Db) -> Expression<'db> { match self { - InferExpression::Bare(bare) => bare.expression(db), + InferExpression::Bare(bare) => bare, InferExpression::WithContext(expression_with_context) => { expression_with_context.expression(db) } @@ -432,32 +360,19 @@ impl<'db> InferExpression<'db> { } } } - - fn cycle_recovery(self, db: &'db dyn Db) -> DivergentType<'db> { - match self { - InferExpression::Bare(bare) => bare.cycle_recovery(db), - InferExpression::WithContext(expression_with_context) => { - expression_with_context.cycle_recovery(db) - } - } - } -} - -/// An [`Expression`] that may be recursive. -#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -pub(super) struct PossiblyRecursiveExpression<'db> { - expression: Expression<'db>, - cycle_recovery: DivergentType<'db>, } -/// An [`Expression`] with a [`TypeContext`], that may be recursive. +/// An [`Expression`] with a [`TypeContext`]. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -pub(super) struct PossiblyRecursiveExpressionWithContext<'db> { +#[derive(PartialOrd, Ord)] +pub(super) struct ExpressionWithContext<'db> { expression: Expression<'db>, tcx: TypeContext<'db>, - cycle_recovery: DivergentType<'db>, } +/// The Salsa heap is tracked separately. +impl get_size2::GetSize for ExpressionWithContext<'_> {} + /// The type context for a given expression, namely the type annotation /// in an annotated assignment. /// @@ -485,14 +400,7 @@ pub(crate) fn static_expression_truthiness<'db>( db: &'db dyn Db, expression: Expression<'db>, ) -> Truthiness { - let inference = infer_expression_types_impl( - db, - InferExpression::Bare(PossiblyRecursiveExpression::new( - db, - expression, - DivergentType::should_not_diverge(db, expression.scope(db)), - )), - ); + let inference = infer_expression_types_impl(db, InferExpression::Bare(expression)); if !inference.all_places_definitely_bound() { return Truthiness::Ambiguous; @@ -501,7 +409,6 @@ pub(crate) fn static_expression_truthiness<'db>( let file = expression.file(db); let module = parsed_module(db, file).load(db); let node = expression.node_ref(db, &module); - inference.expression_type(node).bool(db) } @@ -522,25 +429,6 @@ fn static_expression_truthiness_cycle_initial<'db>( Truthiness::Ambiguous } -pub(super) fn infer_unpack_implicit_attribute_types<'db>( - db: &'db dyn Db, - unpack: Unpack<'db>, - cycle_recovery: DivergentType<'db>, -) -> &'db UnpackResult<'db> { - infer_unpack_types_impl(db, unpack, cycle_recovery) -} - -pub(super) fn infer_unpack_types<'db>( - db: &'db dyn Db, - unpack: Unpack<'db>, -) -> &'db UnpackResult<'db> { - infer_unpack_types_impl( - db, - unpack, - DivergentType::should_not_diverge(db, unpack.target_scope(db)), - ) -} - /// Infer the types for an [`Unpack`] operation. /// /// This infers the expression type and performs structural match against the target expression @@ -548,19 +436,15 @@ pub(super) fn infer_unpack_types<'db>( /// type of the variables involved in this unpacking along with any violations that are detected /// during this unpacking. #[salsa::tracked(returns(ref), cycle_fn=unpack_cycle_recover, cycle_initial=unpack_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -fn infer_unpack_types_impl<'db>( - db: &'db dyn Db, - unpack: Unpack<'db>, - cycle_recovery: DivergentType<'db>, -) -> UnpackResult<'db> { +pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> { let file = unpack.file(db); let module = parsed_module(db, file).load(db); let _span = tracing::trace_span!("infer_unpack_types", range=?unpack.range(db, &module), ?file) .entered(); let mut unpacker = Unpacker::new(db, unpack.target_scope(db), &module); - unpacker.unpack(unpack.target(db, &module), unpack.value(db), cycle_recovery); - unpacker.finish() + unpacker.unpack(unpack.target(db, &module), unpack.value(db)); + unpacker.finish(unpack) } fn unpack_cycle_recover<'db>( @@ -568,17 +452,16 @@ fn unpack_cycle_recover<'db>( _value: &UnpackResult<'db>, _count: u32, _unpack: Unpack<'db>, - _cycle_recovery: DivergentType<'db>, ) -> salsa::CycleRecoveryAction> { salsa::CycleRecoveryAction::Iterate } -fn unpack_cycle_initial<'db>( - _db: &'db dyn Db, - _unpack: Unpack<'db>, - cycle_recovery: DivergentType<'db>, -) -> UnpackResult<'db> { - UnpackResult::cycle_initial(Type::divergent(cycle_recovery)) +fn unpack_cycle_initial<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> { + let cycle_recovery = Type::divergent(DivergentType::new( + db, + DivergenceKind::InferUnpackTypes(unpack), + )); + UnpackResult::cycle_initial(cycle_recovery) } /// Returns the type of the nearest enclosing class for the given scope. @@ -647,7 +530,7 @@ pub(crate) struct ScopeInference<'db> { #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] struct ScopeInferenceExtra<'db> { /// The fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_recovery: Option>, + cycle_recovery: Option>, /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, @@ -659,7 +542,7 @@ struct ScopeInferenceExtra<'db> { } impl<'db> ScopeInference<'db> { - fn cycle_initial(cycle_recovery: DivergentType<'db>, scope: ScopeId<'db>) -> Self { + fn cycle_initial(cycle_recovery: CycleRecoveryType<'db>, scope: ScopeId<'db>) -> Self { Self { extra: Some(Box::new(ScopeInferenceExtra { cycle_recovery: Some(cycle_recovery), @@ -690,9 +573,7 @@ impl<'db> ScopeInference<'db> { } fn fallback_type(&self) -> Option> { - self.extra - .as_ref() - .and_then(|extra| extra.cycle_recovery.map(Type::divergent)) + self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } /// Returns the inferred return type of this function body (union of all possible return types), @@ -707,9 +588,12 @@ impl<'db> ScopeInference<'db> { } let mut union = UnionBuilder::new(db); - let div = Type::divergent(DivergentType::function_return_type(db, self.scope)); - if let Some(fallback_type) = self.fallback_type() { - union = union.add(fallback_type); + let div = Type::divergent(DivergentType::new( + db, + DivergenceKind::InferReturnType(self.scope), + )); + if let Some(fall_back) = self.fallback_type() { + union = union.add(fall_back); } let visitor = NormalizedVisitor::default().recursive(div); // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. @@ -729,10 +613,10 @@ impl<'db> ScopeInference<'db> { // 0th: Divergent // 1st: tuple[Divergent] | None // 2nd: tuple[tuple[Divergent] | None] | None => tuple[Divergent] | None - let previous_type = callee_ty.infer_return_type(db).unwrap(); + let previous_cycle_value = callee_ty.infer_return_type(db).unwrap(); // In fixed-point iteration of return type inference, the return type must be monotonically widened and not "oscillate". // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. - union = union.add(previous_type.normalized_impl(db, &visitor)); + union = union.add(previous_cycle_value.recursive_type_normalized(db, &visitor)); let Some(extra) = &self.extra else { unreachable!( @@ -743,7 +627,7 @@ impl<'db> ScopeInference<'db> { let ty = returnee.map_or(Type::none(db), |expression| { self.expression_type(expression) }); - union = union.add(ty.normalized_impl(db, &visitor)); + union = union.add(ty.recursive_type_normalized(db, &visitor)); } let use_def = use_def_map(db, self.scope); if use_def.can_implicitly_return_none(db) { @@ -758,7 +642,7 @@ impl<'db> ScopeInference<'db> { method_ty .base_return_type(db) .unwrap_or(Type::unknown()) - .normalized_impl(db, &visitor), + .recursive_type_normalized(db, &visitor), ); } } @@ -797,7 +681,7 @@ pub(crate) struct DefinitionInference<'db> { #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] struct DefinitionInferenceExtra<'db> { /// The fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_recovery: Option>, + cycle_recovery: Option>, /// The definitions that are deferred. deferred: Box<[Definition<'db>]>, @@ -810,7 +694,7 @@ struct DefinitionInferenceExtra<'db> { } impl<'db> DefinitionInference<'db> { - fn cycle_initial(scope: ScopeId<'db>, cycle_recovery: DivergentType<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>, cycle_recovery: CycleRecoveryType<'db>) -> Self { let _ = scope; Self { @@ -890,9 +774,7 @@ impl<'db> DefinitionInference<'db> { } fn fallback_type(&self) -> Option> { - self.extra - .as_ref() - .and_then(|extra| extra.cycle_recovery.map(Type::divergent)) + self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } pub(crate) fn undecorated_type(&self) -> Option> { @@ -925,14 +807,14 @@ struct ExpressionInferenceExtra<'db> { diagnostics: TypeCheckDiagnostics, /// The fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_recovery: Option>, + cycle_recovery: Option>, /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, } impl<'db> ExpressionInference<'db> { - fn cycle_initial(scope: ScopeId<'db>, cycle_recovery: DivergentType<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>, cycle_recovery: CycleRecoveryType<'db>) -> Self { let _ = scope; Self { extra: Some(Box::new(ExpressionInferenceExtra { @@ -962,9 +844,7 @@ impl<'db> ExpressionInference<'db> { } fn fallback_type(&self) -> Option> { - self.extra - .as_ref() - .and_then(|extra| extra.cycle_recovery.map(Type::divergent)) + self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } /// Returns true if all places in this expression are definitely bound. diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 62235db235792..48c68c4fc6d96 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -40,8 +40,7 @@ use crate::semantic_index::scope::{ }; use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::{ - ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, implicit_attribute_table, - place_table, + ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table, }; use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind}; use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator}; @@ -75,6 +74,7 @@ use crate::types::function::{ }; use crate::types::generics::LegacyGenericBase; use crate::types::generics::{GenericContext, bind_typevar}; +use crate::types::infer::infer_expression_types_impl; use crate::types::instance::SliceLiteral; use crate::types::mro::MroErrorKind; use crate::types::signatures::Signature; @@ -86,13 +86,13 @@ use crate::types::typed_dict::{ }; use crate::types::visitor::any_over_type; use crate::types::{ - CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DivergentType, - DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, - MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, - Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, - TypeAliasType, TypeAndQualifiers, TypeCheckDiagnostics, TypeContext, TypeQualifiers, - TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, - UnionBuilder, UnionType, binding_type, todo_type, + CallDunderError, CallableType, ClassLiteral, ClassType, CycleRecoveryType, DataclassParams, + DivergenceKind, DivergentType, DynamicType, InferExpression, IntersectionBuilder, + IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, + NormalizedVisitor, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, + SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, + TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, + TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -255,24 +255,12 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { undecorated_type: Option>, /// The fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_recovery: Option>, + cycle_recovery: Option>, /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, } -/// Normally, double checking is not allowed in a [`TypeInferenceBuilder`], -/// but it is sometimes necessary for expression inference to detect divergence. -/// Explicit double checking can be achieved by taking a snapshot of the state before the check -/// and then reverting the state using the snapshot after the check. -struct TypeInferenceSnapshot { - diagnostics: TypeCheckDiagnostics, - expression_keys: FxHashSet, - length_of_bindings: usize, - length_of_declarations: usize, - length_of_deferred: usize, -} - impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// How big a string do we build before bailing? /// @@ -309,99 +297,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn fallback_type(&self) -> Option> { - self.cycle_recovery.map(Type::divergent) - } - - fn snapshot(&mut self) -> TypeInferenceSnapshot { - TypeInferenceSnapshot { - diagnostics: self.context.take_diagnostics(), - expression_keys: self.expressions.keys().copied().collect(), - length_of_bindings: self.bindings.len(), - length_of_declarations: self.declarations.len(), - length_of_deferred: self.deferred.len(), - } + self.cycle_recovery } - fn restore(&mut self, snapshot: &TypeInferenceSnapshot) { - self.context.take_diagnostics(); - self.context.extend(&snapshot.diagnostics); - self.expressions - .retain(|k, _| snapshot.expression_keys.contains(k)); - self.bindings.truncate(snapshot.length_of_bindings); - self.declarations.truncate(snapshot.length_of_declarations); - self.deferred.truncate(snapshot.length_of_deferred); - } - - /// If the inference in this expression diverges, what kind of divergence is possible? - fn expression_cycle_recovery(&mut self, expression: &ast::Expr) -> DivergentType<'db> { - let db = self.db(); - match expression { - ast::Expr::Call(call) => { - let snapshot = self.snapshot(); - let callable_ty = - self.try_expression_type(call.func.as_ref()) - .unwrap_or_else(|| { - self.infer_maybe_standalone_expression( - &call.func, - TypeContext::default(), - ) - }); - self.restore(&snapshot); - match callable_ty { - Type::FunctionLiteral(func) => DivergentType::function_return_type( - db, - func.literal(db).last_definition(db).body_scope(db), - ), - Type::BoundMethod(method) => DivergentType::function_return_type( - db, - method - .function(db) - .literal(db) - .last_definition(db) - .body_scope(db), - ), - _ => DivergentType::should_not_diverge(db, self.scope()), + fn merge_cycle_recovery(&mut self, other: Option>) { + if let Some(other) = other { + match self.cycle_recovery { + Some(existing) => { + self.cycle_recovery = + Some(UnionType::from_elements(self.db(), [existing, other])); } - } - ast::Expr::Attribute(attr) => { - let snapshot = self.snapshot(); - let value_ty = self - .try_expression_type(attr.value.as_ref()) - .unwrap_or_else(|| { - self.infer_maybe_standalone_expression(&attr.value, TypeContext::default()) - }); - self.restore(&snapshot); - let body_scope = match value_ty { - Type::NominalInstance(instance) => instance.class_literal(db).body_scope(db), - Type::ClassLiteral(class) => class.body_scope(db), - Type::GenericAlias(generic) => generic.origin(db).body_scope(db), - _ => { - return DivergentType::should_not_diverge(db, self.scope()); - } - }; - let implicit_attribute_table = implicit_attribute_table(db, body_scope); - if let Some(attribute) = implicit_attribute_table.symbol_id(&attr.attr) { - DivergentType::implicit_attribute(db, body_scope, attribute) - } else { - DivergentType::should_not_diverge(db, self.scope()) - } - } - _ => DivergentType::should_not_diverge(db, self.scope()), - } - } - - fn merge_cycle_recovery(&mut self, other: Option>) { - match (self.cycle_recovery, other) { - (None, _) | (Some(_), None) => { - self.cycle_recovery = self.cycle_recovery.or(other); - } - (Some(self_), Some(other)) => { - if self_ == other { - // OK, do nothing - } else if self_.kind(self.db()).should_not_diverge() { + None => { self.cycle_recovery = Some(other); - } else { - panic!("Cannot merge divergent types"); } } } @@ -3996,7 +3903,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) => { self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown())); - let object_ty = self.infer_expression(object, TypeContext::default()); + let object_ty = + self.infer_maybe_standalone_expression(object, TypeContext::default()); if let Some(assigned_ty) = assigned_ty { self.validate_attribute_assignment( @@ -5119,10 +5027,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { standalone_expression: Expression<'db>, tcx: TypeContext<'db>, ) -> Type<'db> { - let cycle_recovery = self - .cycle_recovery - .unwrap_or_else(|| self.expression_cycle_recovery(expression)); - let types = infer_expression_types(self.db(), standalone_expression, tcx, cycle_recovery); + let types = infer_expression_types(self.db(), standalone_expression, tcx); self.extend_expression(types); // Instead of calling `self.expression_type(expr)` after extending here, we get @@ -5565,19 +5470,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `[... for a.x in not_iterable] if is_first { - // QUESTION: Could there be a case for divergence here? - let cycle_recovery = - builder - .cycle_recovery - .unwrap_or(DivergentType::should_not_diverge( - builder.db(), - builder.scope(), - )); infer_same_file_expression_type( builder.db(), builder.index.expression(iter_expr), TypeContext::default(), - cycle_recovery, builder.module(), ) } else { @@ -5601,16 +5497,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut infer_iterable_type = || { let expression = self.index.expression(iterable); - // QUESTION: Could there be a case for divergence here? - let cycle_recovery = self - .cycle_recovery - .unwrap_or(DivergentType::should_not_diverge(self.db(), self.scope())); - let result = infer_expression_types( - self.db(), - expression, - TypeContext::default(), - cycle_recovery, - ); + let result = infer_expression_types(self.db(), expression, TypeContext::default()); // Two things are different if it's the first comprehension: // (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope, @@ -6866,7 +6753,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match ctx { ExprContext::Load => self.infer_attribute_load(attribute), ExprContext::Store => { - self.infer_expression(value, TypeContext::default()); + self.infer_maybe_standalone_expression(value, TypeContext::default()); Type::Never } ExprContext::Del => { @@ -6874,7 +6761,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::Never } ExprContext::Invalid => { - self.infer_expression(value, TypeContext::default()); + self.infer_maybe_standalone_expression(value, TypeContext::default()); Type::unknown() } } @@ -8988,14 +8875,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - pub(super) fn finish_expression(mut self) -> ExpressionInference<'db> { + pub(super) fn finish_expression( + mut self, + input: InferExpression<'db>, + ) -> ExpressionInference<'db> { self.infer_region(); + let db = self.db(); let Self { context, mut expressions, scope, - bindings, + mut bindings, declarations, deferred, cycle_recovery, @@ -9025,6 +8916,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { "Expression region can't have deferred types" ); + let div = Type::divergent(DivergentType::new( + db, + DivergenceKind::InferExpressionTypes(input), + )); + let visitor = NormalizedVisitor::default().recursive(div); + let previous_cycle_value = infer_expression_types_impl(db, input); let extra = (cycle_recovery.is_some() || !bindings.is_empty() || !diagnostics.is_empty() || !all_definitely_bound).then(|| { if bindings.len() > 20 { @@ -9034,6 +8931,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { bindings.len() ); } + for (binding, binding_ty) in bindings.iter_mut() { + if let Some((_, previous_binding)) = previous_cycle_value.extra.as_deref() + .and_then(|extra| extra.bindings.iter().find(|(previous_binding, _)| previous_binding == binding)) { + *binding_ty = UnionType::from_elements( + db, + [*binding_ty, *previous_binding], + ).recursive_type_normalized(db, &visitor); + } else { + *binding_ty = binding_ty.recursive_type_normalized(db, &visitor); + } + } Box::new(ExpressionInferenceExtra { bindings: bindings.into_boxed_slice(), @@ -9044,6 +8952,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }); expressions.shrink_to_fit(); + for (expr, ty) in &mut expressions { + let previous_ty = previous_cycle_value.expression_type(*expr); + *ty = UnionType::from_elements(db, [*ty, previous_ty]) + .recursive_type_normalized(db, &visitor); + } ExpressionInference { expressions, @@ -9055,26 +8968,33 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { pub(super) fn finish_definition(mut self) -> DefinitionInference<'db> { self.infer_region(); + let db = self.db(); let Self { context, mut expressions, scope, - bindings, - declarations, + mut bindings, + mut declarations, deferred, cycle_recovery, undecorated_type, all_definitely_bound: _, // builder only state typevar_binding_context: _, - deferred_state: _, + deferred_state, called_functions: _, index: _, - region: _, + region, returnees: _, } = self; + let (InferenceRegion::Definition(definition) | InferenceRegion::Deferred(definition)) = + region + else { + panic!("expected definition/deferred region"); + }; + let _ = scope; let diagnostics = context.finish(); @@ -9108,6 +9028,32 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } expressions.shrink_to_fit(); + if !matches!(region, InferenceRegion::Deferred(_)) && !deferred_state.is_deferred() { + let div = Type::divergent(DivergentType::new( + db, + DivergenceKind::InferDefinitionTypes(definition), + )); + let visitor = NormalizedVisitor::default().recursive(div); + let previous_cycle_value = infer_definition_types(db, definition); + for (expr, ty) in &mut expressions { + let previous_ty = previous_cycle_value.expression_type(*expr); + *ty = UnionType::from_elements(db, [*ty, previous_ty]) + .recursive_type_normalized(db, &visitor); + } + + for (binding, binding_ty) in bindings.iter_mut() { + let previous_ty = previous_cycle_value.binding_type(*binding); + *binding_ty = UnionType::from_elements(db, [*binding_ty, previous_ty]) + .recursive_type_normalized(db, &visitor); + } + for (declaration, TypeAndQualifiers { inner, .. }) in declarations.iter_mut() { + let previous_ty = previous_cycle_value + .declaration_type(*declaration) + .inner_type(); + *inner = UnionType::from_elements(db, [*inner, previous_ty]) + .recursive_type_normalized(db, &visitor); + } + } DefinitionInference { expressions, @@ -9165,6 +9111,22 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }); expressions.shrink_to_fit(); + if let NodeWithScopeKind::TypeAlias(_) = scope.node(db) { + // Don't perform recursive type normalization on type aliases + } else { + let div = Type::divergent(DivergentType::new( + db, + DivergenceKind::InferScopeTypes(scope), + )); + let visitor = NormalizedVisitor::default().recursive(div); + let previous_cycle_value = infer_scope_types(db, scope); + + for (expr, ty) in &mut expressions { + let previous_ty = previous_cycle_value.expression_type(*expr); + *ty = UnionType::from_elements(db, [*ty, previous_ty]) + .recursive_type_normalized(db, &visitor); + } + } ScopeInference { expressions, @@ -9430,6 +9392,10 @@ where self.0.iter().map(|(k, v)| (k, v)) } + fn iter_mut(&mut self) -> impl ExactSizeIterator { + self.0.iter_mut().map(|(k, v)| (&*k, v)) + } + fn insert(&mut self, key: K, value: V) { debug_assert!( !self.0.iter().any(|(existing, _)| existing == &key), @@ -9442,10 +9408,6 @@ where fn into_boxed_slice(self) -> Box<[(K, V)]> { self.0.into_boxed_slice() } - - fn truncate(&mut self, len: usize) { - self.0.truncate(len); - } } impl Extend<(K, V)> for VecMap @@ -9489,14 +9451,6 @@ impl VecSet { fn into_boxed_slice(self) -> Box<[V]> { self.0.into_boxed_slice() } - - fn truncate(&mut self, len: usize) { - self.0.truncate(len); - } - - fn len(&self) -> usize { - self.0.len() - } } impl VecSet diff --git a/crates/ty_python_semantic/src/types/infer/tests.rs b/crates/ty_python_semantic/src/types/infer/tests.rs index e2a359b67f850..6eef61de10ef9 100644 --- a/crates/ty_python_semantic/src/types/infer/tests.rs +++ b/crates/ty_python_semantic/src/types/infer/tests.rs @@ -3,10 +3,8 @@ use crate::db::tests::{TestDb, setup_db}; use crate::place::symbol; use crate::place::{ConsideredDefinitions, Place, global_symbol}; use crate::semantic_index::definition::Definition; -use crate::semantic_index::scope::{FileScopeId, ScopeKind}; -use crate::semantic_index::{ - global_scope, implicit_attribute_table, place_table, semantic_index, use_def_map, -}; +use crate::semantic_index::scope::FileScopeId; +use crate::semantic_index::{global_scope, place_table, semantic_index, use_def_map}; use crate::types::function::FunctionType; use crate::types::{BoundMethodType, KnownClass, KnownInstanceType, UnionType, check_types}; use ruff_db::diagnostic::Diagnostic; @@ -23,15 +21,26 @@ fn __() { let _ = &FunctionType::infer_return_type; let _ = &BoundMethodType::infer_return_type; let _ = &ClassLiteral::implicit_attribute_inner; + let _ = &infer_expression_type_impl; + let _ = &infer_expression_types_impl; + let _ = &infer_definition_types; + let _ = &infer_scope_types; + let _ = &infer_unpack_types; } /// These queries refer to a value ​​from the previous cycle to ensure convergence. /// Therefore, even when convergence is apparent, they will cycle at least once. -const QUERIES_USE_PREVIOUS_CYCLE_VALUE: [&str; 5] = [ +/// TODO: Is it possible to use the salsa API to get the value from the previous cycle (without doing anything if called for the first time)? +const QUERIES_USE_PREVIOUS_CYCLE_VALUE: [&str; 10] = [ "Type < 'db >::member_lookup_with_policy_", "Type < 'db >::class_member_with_policy_", "FunctionType < 'db >::infer_return_type_", "BoundMethodType < 'db >::infer_return_type_", "ClassLiteral < 'db >::implicit_attribute_inner_", + "infer_expression_type_impl", + "infer_expression_types_impl", + "infer_definition_types", + "infer_scope_types", + "infer_unpack_types", ]; #[track_caller] @@ -59,24 +68,6 @@ fn get_symbol<'db>( symbol(db, scope, symbol_name, ConsideredDefinitions::EndOfScope).place } -#[track_caller] -fn get_scope<'db>( - db: &'db TestDb, - file: File, - name: &str, - kind: ScopeKind, -) -> Option> { - let module = parsed_module(db, file).load(db); - let index = semantic_index(db, file); - for (child_scope, _) in index.child_scopes(FileScopeId::global()) { - let scope = child_scope.to_scope_id(db, file); - if scope.name(db, &module) == name && scope.scope(db).kind() == kind { - return Some(scope); - } - } - None -} - #[track_caller] fn assert_diagnostic_messages(diagnostics: &[Diagnostic], expected: &[&str]) { let messages: Vec<&str> = diagnostics @@ -487,15 +478,10 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - let file_mod = system_path_to_file(&db, "/src/mod.py").unwrap(); - let class_body_scope = get_scope(&db, file_mod, "C", ScopeKind::Class).unwrap(); - let implicit_attribute_table = implicit_attribute_table(&db, class_body_scope); - let attribute = implicit_attribute_table.symbol_id("attr").unwrap(); - let cycle_recovery = DivergentType::implicit_attribute(&db, class_body_scope, attribute); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), + InferExpression::Bare(x_rhs_expression(&db)), &events, ); @@ -516,12 +502,10 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - let class_body_scope = get_scope(&db, file_mod, "C", ScopeKind::Class).unwrap(); - let cycle_recovery = DivergentType::implicit_attribute(&db, class_body_scope, attribute); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), + InferExpression::Bare(x_rhs_expression(&db)), &events, ); @@ -585,11 +569,10 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - let cycle_recovery = DivergentType::should_not_diverge(&db, global_scope(&db, file_main)); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), + InferExpression::Bare(x_rhs_expression(&db)), &events, ); @@ -612,11 +595,10 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; - let cycle_recovery = DivergentType::should_not_diverge(&db, global_scope(&db, file_main)); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), + InferExpression::Bare(x_rhs_expression(&db)), &events, ); @@ -683,15 +665,10 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); db.take_salsa_events() }; - let file_mod = system_path_to_file(&db, "/src/mod.py").unwrap(); - let class_body_scope = get_scope(&db, file_mod, "C", ScopeKind::Class).unwrap(); - let implicit_attribute_table = implicit_attribute_table(&db, class_body_scope); - let attribute = implicit_attribute_table.symbol_id("class_attr").unwrap(); - let cycle_recovery = DivergentType::implicit_attribute(&db, class_body_scope, attribute); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), + InferExpression::Bare(x_rhs_expression(&db)), &events, ); @@ -716,12 +693,10 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); db.take_salsa_events() }; - let class_body_scope = get_scope(&db, file_mod, "C", ScopeKind::Class).unwrap(); - let cycle_recovery = DivergentType::implicit_attribute(&db, class_body_scope, attribute); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::bare(&db, x_rhs_expression(&db), cycle_recovery), + InferExpression::Bare(x_rhs_expression(&db)), &events, ); @@ -763,13 +738,10 @@ fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; let foo_call = semantic_index(&db, bar).expression(call); - let foo = system_path_to_file(&db, "src/foo.py")?; - let function_scope = get_scope(&db, foo, "foo", ScopeKind::Function).unwrap(); - let cycle_recovery = DivergentType::function_return_type(&db, function_scope); assert_function_query_was_run( &db, infer_expression_types_impl, - InferExpression::bare(&db, foo_call, cycle_recovery), + InferExpression::Bare(foo_call), &events, ); @@ -797,12 +769,10 @@ fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> { let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value; let foo_call = semantic_index(&db, bar).expression(call); - let function_scope = get_scope(&db, foo, "foo", ScopeKind::Function).unwrap(); - let cycle_recovery = DivergentType::function_return_type(&db, function_scope); assert_function_query_was_not_run( &db, infer_expression_types_impl, - InferExpression::bare(&db, foo_call, cycle_recovery), + InferExpression::Bare(foo_call), &events, ); diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 12d49e28697da..135e55ef100b2 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -336,6 +336,22 @@ impl<'db> NominalInstanceType<'db> { } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + match self.0 { + NominalInstanceInner::ExactTuple(tuple) => Self(NominalInstanceInner::ExactTuple( + tuple.recursive_type_normalized(db, visitor), + )), + NominalInstanceInner::NonTuple(class) => Self(NominalInstanceInner::NonTuple( + class.recursive_type_normalized(db, visitor), + )), + NominalInstanceInner::Object => Self(NominalInstanceInner::Object), + } + } + pub(super) fn has_relation_to_impl( self, db: &'db dyn Db, @@ -624,9 +640,6 @@ impl<'db> ProtocolInstanceType<'db> { db: &'db dyn Db, visitor: &NormalizedVisitor<'db>, ) -> Type<'db> { - if visitor.is_recursive() { - return Type::ProtocolInstance(self); - } if self.is_equivalent_to_object(db) { return Type::object(); } @@ -638,6 +651,14 @@ impl<'db> ProtocolInstanceType<'db> { } } + pub(super) fn recursive_type_normalized( + self, + _db: &'db dyn Db, + _visitor: &NormalizedVisitor<'db>, + ) -> Self { + self + } + /// Return `true` if this protocol type is equivalent to the protocol `other`. /// /// TODO: consider the types of the members as well as their existence diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 388d1f82302ee..2fb6157acb904 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -11,9 +11,8 @@ use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::function::KnownFunction; use crate::types::infer::infer_same_file_expression_type; use crate::types::{ - ClassLiteral, ClassType, DivergentType, IntersectionBuilder, KnownClass, SubclassOfInner, - SubclassOfType, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, - infer_expression_types, + ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType, + Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types, }; use ruff_db::parsed::{ParsedModuleRef, parsed_module}; @@ -774,9 +773,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { return None; } - let cycle_recovery = DivergentType::should_not_diverge(self.db, expression.scope(self.db)); - let inference = - infer_expression_types(self.db, expression, TypeContext::default(), cycle_recovery); + let inference = infer_expression_types(self.db, expression, TypeContext::default()); let comparator_tuples = std::iter::once(&**left) .chain(comparators) @@ -866,10 +863,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - // QUESTION: Could there be a case for divergence here? - let cycle_recovery = DivergentType::should_not_diverge(self.db, expression.scope(self.db)); - let inference = - infer_expression_types(self.db, expression, TypeContext::default(), cycle_recovery); + let inference = infer_expression_types(self.db, expression, TypeContext::default()); let callable_ty = inference.expression_type(&*expr_call.func); @@ -989,15 +983,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); - let cycle_recovery = DivergentType::should_not_diverge(self.db, cls.scope(self.db)); - let ty = infer_same_file_expression_type( - self.db, - cls, - TypeContext::default(), - cycle_recovery, - self.module, - ) - .to_instance(self.db)?; + let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module) + .to_instance(self.db)?; Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -1010,15 +997,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let subject = place_expr(subject.node_ref(self.db, self.module))?; let place = self.expect_place(&subject); - // QUESTION: Could there be a case for divergence here? - let cycle_recovery = DivergentType::should_not_diverge(self.db, value.scope(self.db)); - let ty = infer_same_file_expression_type( - self.db, - value, - TypeContext::default(), - cycle_recovery, - self.module, - ); + let ty = + infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module); Some(NarrowingConstraints::from_iter([(place, ty)])) } @@ -1047,9 +1027,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - let cycle_recovery = DivergentType::should_not_diverge(self.db, expression.scope(self.db)); - let inference = - infer_expression_types(self.db, expression, TypeContext::default(), cycle_recovery); + let inference = infer_expression_types(self.db, expression, TypeContext::default()); let mut sub_constraints = expr_bool_op .values .iter() diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index c070294a6df74..f461b34e551a3 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -74,6 +74,18 @@ impl<'db> CallableSignature<'db> { ) } + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + Self::from_overloads( + self.overloads + .iter() + .map(|signature| signature.recursive_type_normalized(db, visitor)), + ) + } + pub(crate) fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db, @@ -437,7 +449,7 @@ impl<'db> Signature<'db> { .map(|ctx| ctx.normalized_impl(db, visitor)), // Discard the definition when normalizing, so that two equivalent signatures // with different `Definition`s share the same Salsa ID when normalized - definition: visitor.is_recursive().then_some(self.definition).flatten(), + definition: None, parameters: self .parameters .iter() @@ -449,6 +461,30 @@ impl<'db> Signature<'db> { } } + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + Self { + generic_context: self + .generic_context + .map(|ctx| ctx.recursive_type_normalized(db, visitor)), + inherited_generic_context: self + .inherited_generic_context + .map(|ctx| ctx.recursive_type_normalized(db, visitor)), + definition: self.definition, + parameters: self + .parameters + .iter() + .map(|param| param.recursive_type_normalized(db, visitor)) + .collect(), + return_ty: self + .return_ty + .map(|return_ty| return_ty.recursive_type_normalized(db, visitor)), + } + } + pub(crate) fn apply_type_mapping<'a>( &self, db: &'db dyn Db, @@ -1493,14 +1529,6 @@ impl<'db> Parameter<'db> { form, } = self; - if visitor.is_recursive() { - return Self { - annotated_type: annotated_type.map(|ty| ty.normalized_impl(db, visitor)), - kind: kind.clone(), - form: *form, - }; - } - // Ensure unions and intersections are ordered in the annotated type (if there is one). // Ensure that a parameter without an annotation is treated equivalently to a parameter // with a dynamic type as its annotation. (We must use `Any` here as all dynamic types @@ -1545,6 +1573,47 @@ impl<'db> Parameter<'db> { } } + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + let Parameter { + annotated_type, + kind, + form, + } = self; + + let annotated_type = annotated_type.map(|ty| ty.recursive_type_normalized(db, visitor)); + + let kind = match kind { + ParameterKind::PositionalOnly { name, default_type } => ParameterKind::PositionalOnly { + name: name.clone(), + default_type: default_type.map(|ty| ty.recursive_type_normalized(db, visitor)), + }, + ParameterKind::PositionalOrKeyword { name, default_type } => { + ParameterKind::PositionalOrKeyword { + name: name.clone(), + default_type: default_type.map(|ty| ty.recursive_type_normalized(db, visitor)), + } + } + ParameterKind::KeywordOnly { name, default_type } => ParameterKind::KeywordOnly { + name: name.clone(), + default_type: default_type.map(|ty| ty.recursive_type_normalized(db, visitor)), + }, + ParameterKind::Variadic { name } => ParameterKind::Variadic { name: name.clone() }, + ParameterKind::KeywordVariadic { name } => { + ParameterKind::KeywordVariadic { name: name.clone() } + } + }; + + Self { + annotated_type, + kind, + form: *form, + } + } + fn from_node_and_kind( db: &'db dyn Db, definition: Definition<'db>, diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index b7c619ae58127..ba075766627f3 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -181,6 +181,16 @@ impl<'db> SubclassOfType<'db> { } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + Self { + subclass_of: self.subclass_of.recursive_type_normalized(db, visitor), + } + } + pub(crate) fn to_instance(self, db: &'db dyn Db) -> Type<'db> { match self.subclass_of { SubclassOfInner::Class(class) => Type::instance(db, class), @@ -250,7 +260,20 @@ impl<'db> SubclassOfInner<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), - Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized(visitor.is_recursive())), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized_impl(visitor.kind)), + } + } + + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + match self { + Self::Class(class) => Self::Class(class.recursive_type_normalized(db, visitor)), + Self::Dynamic(dynamic) => { + Self::Dynamic(dynamic.recursive_type_normalized(visitor.kind)) + } } } diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 5b76defefd5aa..91a415552f245 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -228,6 +228,14 @@ impl<'db> TupleType<'db> { TupleType::new(db, &self.tuple(db).normalized_impl(db, visitor)) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + Self::new_internal(db, self.tuple(db).recursive_type_normalized(db, visitor)) + } + pub(crate) fn apply_type_mapping_impl<'a>( self, db: &'db dyn Db, @@ -386,6 +394,14 @@ impl<'db> FixedLengthTuple> { Self::from_elements(self.0.iter().map(|ty| ty.normalized_impl(db, visitor))) } + fn recursive_type_normalized(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + Self::from_elements( + self.0 + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)), + ) + } + fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db, @@ -699,6 +715,25 @@ impl<'db> VariableLengthTuple> { }) } + fn recursive_type_normalized(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + let prefix = self + .prefix + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect::>(); + let suffix = self + .suffix + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect::>(); + let variable = self.variable.recursive_type_normalized(db, visitor); + Self { + prefix, + variable, + suffix, + } + } + fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db, @@ -1040,6 +1075,17 @@ impl<'db> Tuple> { } } + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &NormalizedVisitor<'db>, + ) -> Self { + match self { + Tuple::Fixed(tuple) => Tuple::Fixed(tuple.recursive_type_normalized(db, visitor)), + Tuple::Variable(tuple) => Tuple::Variable(tuple.recursive_type_normalized(db, visitor)), + } + } + pub(crate) fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db, diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index 790700748e2b7..d505a4a5dd5b6 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -1,6 +1,7 @@ use std::cmp::Ordering; use crate::db::Db; +use crate::types::{DivergenceKind, DivergentType}; use super::{ DynamicType, SuperOwnerKind, TodoType, Type, TypeIsType, class_base::ClassBase, @@ -281,7 +282,7 @@ fn dynamic_elements_ordering<'db>( (_, DynamicType::TodoTypeAlias) => Ordering::Greater, (DynamicType::Divergent(left), DynamicType::Divergent(right)) => { - left.scope(db).cmp(&right.scope(db)) + divergent_ordering(db, left, right) } (DynamicType::Divergent(_), _) => Ordering::Less, (_, DynamicType::Divergent(_)) => Ordering::Greater, @@ -309,3 +310,99 @@ fn typeis_ordering(db: &dyn Db, left: TypeIsType, right: TypeIsType) -> Ordering }, } } + +fn divergent_ordering<'db>( + db: &'db dyn Db, + left: DivergentType<'db>, + right: DivergentType<'db>, +) -> Ordering { + match (left.kind(db), right.kind(db)) { + (DivergenceKind::InferReturnType(left), DivergenceKind::InferReturnType(right)) => { + left.cmp(right) + } + (DivergenceKind::InferReturnType(_), _) => Ordering::Less, + (_, DivergenceKind::InferReturnType(_)) => Ordering::Greater, + + ( + DivergenceKind::ImplicitAttribute { + class_body_scope: left_scope, + name: left_name, + target_method_decorator: left_deco, + }, + DivergenceKind::ImplicitAttribute { + class_body_scope: right_scope, + name: right_name, + target_method_decorator: right_deco, + }, + ) => left_scope + .cmp(right_scope) + .then_with(|| left_name.cmp(right_name)) + .then_with(|| left_deco.cmp(right_deco)), + (DivergenceKind::ImplicitAttribute { .. }, _) => Ordering::Less, + (_, DivergenceKind::ImplicitAttribute { .. }) => Ordering::Greater, + + ( + DivergenceKind::MemberLookupWithPolicy { + self_type: left_ty, + name: left_name, + policy: left_policy, + }, + DivergenceKind::MemberLookupWithPolicy { + self_type: right_ty, + name: right_name, + policy: right_policy, + }, + ) => union_or_intersection_elements_ordering(db, left_ty, right_ty) + .then_with(|| left_name.cmp(right_name)) + .then_with(|| left_policy.cmp(right_policy)), + (DivergenceKind::MemberLookupWithPolicy { .. }, _) => Ordering::Less, + (_, DivergenceKind::MemberLookupWithPolicy { .. }) => Ordering::Greater, + + ( + DivergenceKind::ClassLookupWithPolicy { + self_type: left_ty, + name: left_name, + policy: left_policy, + }, + DivergenceKind::ClassLookupWithPolicy { + self_type: right_ty, + name: right_name, + policy: right_policy, + }, + ) => union_or_intersection_elements_ordering(db, left_ty, right_ty) + .then_with(|| left_name.cmp(right_name)) + .then_with(|| left_policy.cmp(right_policy)), + (DivergenceKind::ClassLookupWithPolicy { .. }, _) => Ordering::Less, + (_, DivergenceKind::ClassLookupWithPolicy { .. }) => Ordering::Greater, + + (DivergenceKind::InferExpression(left), DivergenceKind::InferExpression(right)) => { + left.cmp(right) + } + (DivergenceKind::InferExpression(_), _) => Ordering::Less, + (_, DivergenceKind::InferExpression(_)) => Ordering::Greater, + + ( + DivergenceKind::InferExpressionTypes(left), + DivergenceKind::InferExpressionTypes(right), + ) => left.cmp(right), + (DivergenceKind::InferExpressionTypes(_), _) => Ordering::Less, + (_, DivergenceKind::InferExpressionTypes(_)) => Ordering::Greater, + + ( + DivergenceKind::InferDefinitionTypes(left), + DivergenceKind::InferDefinitionTypes(right), + ) => left.cmp(right), + (DivergenceKind::InferDefinitionTypes(_), _) => Ordering::Less, + (_, DivergenceKind::InferDefinitionTypes(_)) => Ordering::Greater, + + (DivergenceKind::InferScopeTypes(left), DivergenceKind::InferScopeTypes(right)) => { + left.cmp(right) + } + (DivergenceKind::InferScopeTypes(_), _) => Ordering::Less, + (_, DivergenceKind::InferScopeTypes(_)) => Ordering::Greater, + + (DivergenceKind::InferUnpackTypes(left), DivergenceKind::InferUnpackTypes(right)) => { + left.cmp(right) + } + } +} diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index e0c39c96557b7..e7bc145b8d6fb 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -8,10 +8,13 @@ use ruff_python_ast::{self as ast, AnyNodeRef}; use crate::Db; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::scope::ScopeId; -use crate::types::infer::{InferExpression, infer_expression_types_impl}; +use crate::types::infer::{InferExpression, infer_expression_types_impl, infer_unpack_types}; use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker}; -use crate::types::{DivergentType, Type, TypeCheckDiagnostics, TypeContext}; -use crate::unpack::{UnpackKind, UnpackValue}; +use crate::types::{ + DivergenceKind, DivergentType, NormalizedVisitor, Type, TypeCheckDiagnostics, TypeContext, + UnionType, +}; +use crate::unpack::{Unpack, UnpackKind, UnpackValue}; use super::context::InferContext; use super::diagnostic::INVALID_ASSIGNMENT; @@ -43,23 +46,13 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { } /// Unpack the value to the target expression. - pub(crate) fn unpack( - &mut self, - target: &ast::Expr, - value: UnpackValue<'db>, - cycle_recovery: DivergentType<'db>, - ) { + pub(crate) fn unpack(&mut self, target: &ast::Expr, value: UnpackValue<'db>) { debug_assert!( matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)), "Unpacking target must be a list or tuple expression" ); - let input = InferExpression::new( - self.db(), - value.expression(), - TypeContext::default(), - cycle_recovery, - ); + let input = InferExpression::new(self.db(), value.expression(), TypeContext::default()); let value_type = infer_expression_types_impl(self.db(), input) .expression_type(value.expression().node_ref(self.db(), self.module())); @@ -184,8 +177,21 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { } } - pub(crate) fn finish(mut self) -> UnpackResult<'db> { + pub(crate) fn finish(mut self, unpack: Unpack<'db>) -> UnpackResult<'db> { + let db = self.db(); self.targets.shrink_to_fit(); + let div = Type::divergent(DivergentType::new( + db, + DivergenceKind::InferUnpackTypes(unpack), + )); + let previous_cycle_value = infer_unpack_types(db, unpack); + let visitor = NormalizedVisitor::default().recursive(div); + for (expr, ty) in &mut self.targets { + let previous_ty = previous_cycle_value.expression_type(*expr); + *ty = UnionType::from_elements(db, [*ty, previous_ty]) + .recursive_type_normalized(db, &visitor); + } + UnpackResult { diagnostics: self.context.finish(), targets: self.targets, diff --git a/crates/ty_python_semantic/src/unpack.rs b/crates/ty_python_semantic/src/unpack.rs index cb07f2570a725..9da45ea3e1dbf 100644 --- a/crates/ty_python_semantic/src/unpack.rs +++ b/crates/ty_python_semantic/src/unpack.rs @@ -27,6 +27,7 @@ use crate::semantic_index::scope::{FileScopeId, ScopeId}; /// * a field of a type that is a return type of a cross-module query /// * an argument of a cross-module query #[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] +#[derive(PartialOrd, Ord)] pub(crate) struct Unpack<'db> { pub(crate) file: File, diff --git a/crates/ty_python_semantic/tests/corpus.rs b/crates/ty_python_semantic/tests/corpus.rs index 83ad7ae1ffdae..d91da4b91d9de 100644 --- a/crates/ty_python_semantic/tests/corpus.rs +++ b/crates/ty_python_semantic/tests/corpus.rs @@ -169,9 +169,6 @@ fn run_corpus_tests(pattern: &str) -> anyhow::Result<()> { /// Whether or not the .py/.pyi version of this file is expected to fail #[rustfmt::skip] const KNOWN_FAILURES: &[(&str, bool, bool)] = &[ - // Fails with too-many-cycle-iterations due to a self-referential - // type alias, see https://github.com/astral-sh/ty/issues/256 - ("crates/ruff_linter/resources/test/fixtures/pyflakes/F401_34.py", true, true), ]; #[salsa::db] From a26d6c635e658ab5822bb18631b55abd28001a94 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 16 Sep 2025 23:51:57 +0900 Subject: [PATCH 075/105] Revert "use `any_over_type` in `has_divergent_type`" This reverts commit f6cd44b8b9e984a56b1f31588671b6492b159929. --- crates/ty_python_semantic/src/types.rs | 131 +++++++++++++++++- crates/ty_python_semantic/src/types/class.rs | 24 +++- .../ty_python_semantic/src/types/generics.rs | 23 ++- .../ty_python_semantic/src/types/instance.rs | 16 ++- .../src/types/signatures.rs | 40 +++++- .../src/types/subclass_of.rs | 19 ++- crates/ty_python_semantic/src/types/tuple.rs | 13 +- 7 files changed, 244 insertions(+), 22 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index c7999c2987b83..8194ccff75eb3 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -66,7 +66,6 @@ use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signat use crate::types::tuple::TupleSpec; pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type}; use crate::types::variance::{TypeVarVariance, VarianceInferable}; -use crate::types::visitor::any_over_type; use crate::unpack::{EvaluationMode, Unpack}; pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic; use crate::{Db, FxOrderSet, Module, Program}; @@ -281,6 +280,10 @@ impl<'db> NormalizedVisitor<'db> { } } +/// A [`CycleDetector`] that is used in `has_divergent_type` methods. +pub(crate) type HasDivergentTypeVisitor<'db> = CycleDetector, bool>; +pub(crate) struct HasDivergentType; + /// How a generic type has been specialized. /// /// This matters only if there is at least one invariant type parameter. @@ -677,6 +680,19 @@ impl<'db> PropertyInstanceType<'db> { getter_equivalence.and(db, setter_equivalence) } + + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.setter(db) + .is_some_and(|setter| setter.has_divergent_type_impl(db, div, visitor)) + || self + .getter(db) + .is_some_and(|getter| getter.has_divergent_type_impl(db, div, visitor)) + } } bitflags! { @@ -6766,10 +6782,79 @@ impl<'db> Type<'db> { } pub(super) fn has_divergent_type(self, db: &'db dyn Db, div: Type<'db>) -> bool { - any_over_type(db, self, &|ty| match ty { - Type::Dynamic(DynamicType::Divergent(_)) => ty == div, - _ => false, - }) + let visitor = HasDivergentTypeVisitor::new(false); + self.has_divergent_type_impl(db, div, &visitor) + } + + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + match self { + Type::Dynamic(DynamicType::Divergent(_)) => self == div, + Type::Union(union) => { + visitor.visit(self, || union.has_divergent_type_impl(db, div, visitor)) + } + Type::Intersection(intersection) => visitor.visit(self, || { + intersection.has_divergent_type_impl(db, div, visitor) + }), + Type::GenericAlias(alias) => visitor.visit(self, || { + alias + .specialization(db) + .has_divergent_type_impl(db, div, visitor) + }), + Type::NominalInstance(instance) => visitor.visit(self, || { + instance.class(db).has_divergent_type_impl(db, div, visitor) + }), + Type::Callable(callable) => { + visitor.visit(self, || callable.has_divergent_type_impl(db, div, visitor)) + } + Type::ProtocolInstance(protocol) => { + visitor.visit(self, || protocol.has_divergent_type_impl(db, div, visitor)) + } + Type::PropertyInstance(property) => { + visitor.visit(self, || property.has_divergent_type_impl(db, div, visitor)) + } + Type::TypeIs(type_is) => visitor.visit(self, || { + type_is + .return_type(db) + .has_divergent_type_impl(db, div, visitor) + }), + Type::SubclassOf(subclass_of) => visitor.visit(self, || { + subclass_of.has_divergent_type_impl(db, div, visitor) + }), + Type::TypedDict(typed_dict) => visitor.visit(self, || { + typed_dict + .defining_class() + .has_divergent_type_impl(db, div, visitor) + }), + Type::Never + | Type::AlwaysTruthy + | Type::AlwaysFalsy + | Type::WrapperDescriptor(_) + | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) + | Type::ModuleLiteral(_) + | Type::ClassLiteral(_) + | Type::IntLiteral(_) + | Type::BooleanLiteral(_) + | Type::LiteralString + | Type::StringLiteral(_) + | Type::BytesLiteral(_) + | Type::EnumLiteral(_) + | Type::BoundSuper(_) + | Type::SpecialForm(_) + | Type::KnownInstance(_) + | Type::NonInferableTypeVar(_) + | Type::TypeVar(_) + | Type::FunctionLiteral(_) + | Type::KnownBoundMethod(_) + | Type::BoundMethod(_) + | Type::Dynamic(_) + | Type::TypeAlias(_) => false, + } } } @@ -9749,6 +9834,16 @@ impl<'db> CallableType<'db> { .is_equivalent_to_impl(db, other.signatures(db), visitor) }) } + + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.signatures(db) + .has_divergent_type_impl(db, div, visitor) + } } /// Represents a specific instance of a bound method type for a builtin class. @@ -10794,6 +10889,17 @@ impl<'db> UnionType<'db> { ConstraintSet::from(sorted_self == other.normalized(db)) } + + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.elements(db) + .iter() + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) + } } #[salsa::interned(debug, heap_size=IntersectionType::heap_size)] @@ -11032,6 +11138,21 @@ impl<'db> IntersectionType<'db> { ruff_memory_usage::order_set_heap_size(positive) + ruff_memory_usage::order_set_heap_size(negative) } + + fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.positive(db) + .iter() + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) + || self + .negative(db) + .iter() + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) + } } /// # Ordering diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 49d0c4629df07..336c9c0d5c0df 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -27,11 +27,11 @@ use crate::types::typed_dict::typed_dict_params_from_class_def; use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams, DeprecatedInstance, DivergenceKind, DivergentType, FindLegacyTypeVarsVisitor, - HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, - MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, - TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, - TypeVarKind, TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, - determine_upper_bound, infer_definition_types, + HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, + ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType, + StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, + TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, UnionBuilder, + VarianceInferable, declaration_type, determine_upper_bound, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -1245,6 +1245,20 @@ impl<'db> ClassType<'db> { } } + pub(super) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + match self { + ClassType::NonGeneric(_) => false, + ClassType::Generic(generic) => generic + .specialization(db) + .has_divergent_type_impl(db, div, visitor), + } + } + pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool { self.class_literal(db).0.is_protocol(db) } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 3207ad64afe38..195fde5f6783a 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -16,10 +16,11 @@ use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::{ - ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, - Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, - UnionType, binding_type, declaration_type, + ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, + HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, + KnownInstanceType, MaterializationKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, + TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, UnionType, binding_type, + declaration_type, }; use crate::{Db, FxOrderSet}; @@ -964,6 +965,20 @@ impl<'db> Specialization<'db> { // A tuple's specialization will include all of its element types, so we don't need to also // look in `self.tuple`. } + + pub(crate) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.types(db) + .iter() + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) + || self + .tuple_inner(db) + .is_some_and(|tuple| tuple.has_divergent_type_impl(db, div, visitor)) + } } /// A mapping between type variables and types. diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 135e55ef100b2..81b5b2b1105a7 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -13,8 +13,8 @@ use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ ApplyTypeMappingVisitor, ClassBase, ClassLiteral, FindLegacyTypeVarsVisitor, - HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, TypeMapping, - TypeRelation, VarianceInferable, + HasDivergentTypeVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, + NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -735,6 +735,18 @@ impl<'db> ProtocolInstanceType<'db> { pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { self.inner.interface(db) } + + pub(super) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.inner + .interface(db) + .members(db) + .any(|member| member.ty().has_divergent_type_impl(db, div, visitor)) + } } impl<'db> VarianceInferable<'db> for ProtocolInstanceType<'db> { diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 919b574bb19a8..d449872c4b9d1 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -21,8 +21,9 @@ use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, - HasRelationToVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, NormalizedVisitor, - TypeMapping, TypeRelation, VarianceInferable, todo_type, + HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, + MaterializationKind, NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, + todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -237,6 +238,17 @@ impl<'db> CallableSignature<'db> { } } } + + pub(super) fn has_divergent_type_impl( + &self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.overloads + .iter() + .any(|signature| signature.has_divergent_type_impl(db, div, visitor)) + } } impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> { @@ -1060,6 +1072,17 @@ impl<'db> Signature<'db> { pub(crate) fn with_definition(self, definition: Option>) -> Self { Self { definition, ..self } } + + fn has_divergent_type_impl( + &self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.return_ty + .is_some_and(|return_ty| return_ty.has_divergent_type_impl(db, div, visitor)) + || self.parameters.has_divergent_type_impl(db, div, visitor) + } } impl<'db> VarianceInferable<'db> for &Signature<'db> { @@ -1395,6 +1418,19 @@ impl<'db> Parameters<'db> { .enumerate() .rfind(|(_, parameter)| parameter.is_keyword_variadic()) } + + fn has_divergent_type_impl( + &self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.iter().any(|parameter| { + parameter + .annotated_type() + .is_some_and(|ty| ty.has_divergent_type_impl(db, div, visitor)) + }) + } } impl<'db, 'a> IntoIterator for &'a Parameters<'db> { diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index ba075766627f3..73bacfbe4d99d 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -4,9 +4,9 @@ use crate::types::constraints::ConstraintSet; use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, DynamicType, - FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, KnownClass, - MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, TypeMapping, - TypeRelation, + FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsDisjointVisitor, + KnownClass, MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, + TypeMapping, TypeRelation, }; use crate::{Db, FxOrderSet}; @@ -203,6 +203,19 @@ impl<'db> SubclassOfType<'db> { .into_class() .is_some_and(|class| class.class_literal(db).0.is_typed_dict(db)) } + + pub(super) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + match self.subclass_of { + SubclassOfInner::Dynamic(d @ DynamicType::Divergent(_)) => Type::Dynamic(d) == div, + SubclassOfInner::Dynamic(_) => false, + SubclassOfInner::Class(class) => class.has_divergent_type_impl(db, div, visitor), + } + } } impl<'db> VarianceInferable<'db> for SubclassOfType<'db> { diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 91a415552f245..6b651a4b7fe8a 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -22,7 +22,6 @@ use std::hash::Hash; use itertools::{Either, EitherOrBoth, Itertools}; use crate::semantic_index::definition::Definition; -use crate::types::Truthiness; use crate::types::class::{ClassType, KnownClass}; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::{ @@ -30,6 +29,7 @@ use crate::types::{ IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, TypeRelation, UnionBuilder, UnionType, }; +use crate::types::{HasDivergentTypeVisitor, Truthiness}; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; use crate::{Db, FxOrderSet, Program}; @@ -285,6 +285,17 @@ impl<'db> TupleType<'db> { pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool { self.tuple(db).is_single_valued(db) } + + pub(super) fn has_divergent_type_impl( + self, + db: &'db dyn Db, + div: Type<'db>, + visitor: &HasDivergentTypeVisitor<'db>, + ) -> bool { + self.tuple(db) + .all_elements() + .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) + } } fn to_class_type_cycle_recover<'db>( From c210b75f5df3c272a1f2e05bf0505d6a0b45c0d7 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 00:38:32 +0900 Subject: [PATCH 076/105] separate `NormalizedVisitor` and `RecursiveTypeNormalizedVisitor` --- crates/ty_python_semantic/src/types.rs | 197 +++++++++++------- crates/ty_python_semantic/src/types/class.rs | 13 +- .../src/types/class_base.rs | 12 +- .../ty_python_semantic/src/types/function.rs | 13 +- .../ty_python_semantic/src/types/generics.rs | 12 +- crates/ty_python_semantic/src/types/infer.rs | 8 +- .../src/types/infer/builder.rs | 15 +- .../ty_python_semantic/src/types/instance.rs | 7 +- .../src/types/signatures.rs | 10 +- .../src/types/subclass_of.rs | 14 +- crates/ty_python_semantic/src/types/tuple.rs | 20 +- .../ty_python_semantic/src/types/unpacker.rs | 6 +- 12 files changed, 187 insertions(+), 140 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 8194ccff75eb3..2571bb6fde117 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -240,30 +240,24 @@ pub(crate) type TryBoolVisitor<'db> = CycleDetector, Result>>; pub(crate) struct TryBool; -#[derive(Default, Copy, Clone, Debug)] -pub(crate) enum NormalizationKind<'db> { - #[default] - Normal, - Recursive(Type<'db>), -} +pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>; +pub(crate) struct Normalized; -/// A [`TypeTransformer`] that is used in `normalized` methods. -#[derive(Default)] -pub(crate) struct NormalizedVisitor<'db> { +/// A [`TypeTransformer`] that is used in `recursive_type_normalized` methods. +/// Calling [`Type::recursive_type_normalized`] will normalize the recursive type. +/// A recursive type here means a type that contains a `Divergent` type. +/// Normalizing recursive types allows recursive type inference for divergent functions to converge. +pub(crate) struct RecursiveTypeNormalizedVisitor<'db> { transformer: TypeTransformer<'db, Normalized>, - /// If this is [`NormalizationKind::Recursive`], calling [`Type::normalized_impl`] will normalize the recursive type. - /// A recursive type here means a type that contains a `Divergent` type. - /// Normalizing recursive types allows recursive type inference for divergent functions to converge. - kind: NormalizationKind<'db>, + div: Type<'db>, } -pub(crate) struct Normalized; -impl<'db> NormalizedVisitor<'db> { - fn recursive(self, div: Type<'db>) -> Self { +impl<'db> RecursiveTypeNormalizedVisitor<'db> { + fn new(div: Type<'db>) -> Self { debug_assert!(matches!(div, Type::Dynamic(DynamicType::Divergent(_)))); Self { - transformer: self.transformer, - kind: NormalizationKind::Recursive(div), + transformer: NormalizedVisitor::default(), + div, } } @@ -617,7 +611,11 @@ impl<'db> PropertyInstanceType<'db> { ) } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { Self::new( db, self.getter(db) @@ -1302,7 +1300,7 @@ impl<'db> Type<'db> { Type::TypeIs(type_is) => visitor.visit(self, || { type_is.with_type(db, type_is.return_type(db).normalized_impl(db, visitor)) }), - Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized_impl(visitor.kind)), + Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized_impl()), Type::EnumLiteral(enum_literal) if is_single_member_enum(db, enum_literal.enum_class(db)) => { @@ -1336,16 +1334,14 @@ impl<'db> Type<'db> { pub(crate) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { - if let NormalizationKind::Recursive(div) = visitor.kind { - if visitor.level() == 0 && self == div { - // int | Divergent = int | (int | (int | ...)) = int - return Type::Never; - } else if visitor.level() >= 1 && self.has_divergent_type(db, div) { - // G[G[Divergent]] = G[Divergent] - return div; - } + if visitor.level() == 0 && self == visitor.div { + // int | Divergent = int | (int | (int | ...)) = int + return Type::Never; + } else if visitor.level() >= 1 && self.has_divergent_type(db, visitor.div) { + // G[G[Divergent]] = G[Divergent] + return visitor.div; } match self { Type::Union(union) => { @@ -1403,9 +1399,7 @@ impl<'db> Type<'db> { .recursive_type_normalized(db, visitor), ) }), - Type::Dynamic(dynamic) => { - Type::Dynamic(dynamic.recursive_type_normalized(visitor.kind)) - } + Type::Dynamic(dynamic) => Type::Dynamic(dynamic.recursive_type_normalized()), Type::TypedDict(_) => { // TODO: Normalize TypedDicts self @@ -3135,14 +3129,16 @@ impl<'db> Type<'db> { ty }; - if let Place::Type(div @ Type::Dynamic(DynamicType::Divergent(_)), _) = - class_lookup_cycle_initial(db, self, name, policy).place - { - let visitor = NormalizedVisitor::default().recursive(div); - ty.recursive_type_normalized(db, &visitor) - } else { - ty - } + let div = Type::divergent(DivergentType::new( + db, + DivergenceKind::ClassLookupWithPolicy { + self_type: self, + name, + policy, + }, + )); + let visitor = RecursiveTypeNormalizedVisitor::new(div); + ty.recursive_type_normalized(db, &visitor) }) } @@ -3926,14 +3922,16 @@ impl<'db> Type<'db> { ty }; - if let Place::Type(div @ Type::Dynamic(DynamicType::Divergent(_)), _) = - member_lookup_cycle_initial(db, self, name, policy).place - { - let visotor = NormalizedVisitor::default().recursive(div); - ty.recursive_type_normalized(db, &visotor) - } else { - ty - } + let div = Type::divergent(DivergentType::new( + db, + DivergenceKind::MemberLookupWithPolicy { + self_type: self, + name, + policy, + }, + )); + let visotor = RecursiveTypeNormalizedVisitor::new(div); + ty.recursive_type_normalized(db, &visotor) }) } @@ -7090,7 +7088,11 @@ impl<'db> TypeMapping<'_, 'db> { } } - fn recursive_type_normalized(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { match self { TypeMapping::Specialization(specialization) => { TypeMapping::Specialization(specialization.recursive_type_normalized(db, visitor)) @@ -7268,7 +7270,11 @@ impl<'db> KnownInstanceType<'db> { } } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { match self { Self::SubscriptedProtocol(context) => { Self::SubscriptedProtocol(context.recursive_type_normalized(db, visitor)) @@ -7467,21 +7473,16 @@ pub enum DynamicType<'db> { Divergent(DivergentType<'db>), } -impl<'db> DynamicType<'db> { - fn normalized_impl(self, kind: NormalizationKind<'db>) -> Self { - match kind { - NormalizationKind::Recursive(_) => self, - NormalizationKind::Normal => { - if matches!(self, Self::Divergent(_)) { - self - } else { - Self::Any - } - } +impl DynamicType<'_> { + fn normalized_impl(self) -> Self { + if matches!(self, Self::Divergent(_)) { + self + } else { + Self::Any } } - fn recursive_type_normalized(self, _kind: NormalizationKind<'db>) -> Self { + fn recursive_type_normalized(self) -> Self { self } } @@ -7819,7 +7820,11 @@ impl<'db> FieldInstance<'db> { ) } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { FieldInstance::new( db, self.default_type(db).recursive_type_normalized(db, visitor), @@ -7999,7 +8004,7 @@ impl<'db> TypeVarInstance<'db> { fn recursive_type_normalized( self, _db: &'db dyn Db, - _visitor: &NormalizedVisitor<'db>, + _visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { self } @@ -8299,7 +8304,11 @@ impl<'db> BoundTypeVarInstance<'db> { ) } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { Self::new( db, self.typevar(db).recursive_type_normalized(db, visitor), @@ -9630,7 +9639,11 @@ impl<'db> BoundMethodType<'db> { ) } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { Self::new( db, self.function(db).recursive_type_normalized(db, visitor), @@ -9766,7 +9779,11 @@ impl<'db> CallableType<'db> { ) } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { CallableType::new( db, self.signatures(db).recursive_type_normalized(db, visitor), @@ -10004,7 +10021,11 @@ impl<'db> KnownBoundMethodType<'db> { } } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { match self { KnownBoundMethodType::FunctionTypeDunderGet(function) => { KnownBoundMethodType::FunctionTypeDunderGet( @@ -10437,7 +10458,7 @@ impl<'db> PEP695TypeAliasType<'db> { fn recursive_type_normalized( self, _db: &'db dyn Db, - _visitor: &NormalizedVisitor<'db>, + _visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { self } @@ -10508,7 +10529,11 @@ impl<'db> ManualPEP695TypeAliasType<'db> { ) } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { Self::new( db, self.name(db), @@ -10555,7 +10580,11 @@ impl<'db> TypeAliasType<'db> { } } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { match self { TypeAliasType::PEP695(type_alias) => { TypeAliasType::PEP695(type_alias.recursive_type_normalized(db, visitor)) @@ -10850,7 +10879,7 @@ impl<'db> UnionType<'db> { fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Type<'db> { self.elements(db) .iter() @@ -10978,12 +11007,12 @@ impl<'db> IntersectionType<'db> { pub(crate) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { fn normalized_set<'db>( db: &'db dyn Db, elements: &FxOrderSet>, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> FxOrderSet> { elements .iter() @@ -11278,9 +11307,7 @@ pub enum SuperOwnerKind<'db> { impl<'db> SuperOwnerKind<'db> { fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { - SuperOwnerKind::Dynamic(dynamic) => { - SuperOwnerKind::Dynamic(dynamic.normalized_impl(visitor.kind)) - } + SuperOwnerKind::Dynamic(dynamic) => SuperOwnerKind::Dynamic(dynamic.normalized_impl()), SuperOwnerKind::Class(class) => { SuperOwnerKind::Class(class.normalized_impl(db, visitor)) } @@ -11292,10 +11319,14 @@ impl<'db> SuperOwnerKind<'db> { } } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { match self { SuperOwnerKind::Dynamic(dynamic) => { - SuperOwnerKind::Dynamic(dynamic.recursive_type_normalized(visitor.kind)) + SuperOwnerKind::Dynamic(dynamic.recursive_type_normalized()) } SuperOwnerKind::Class(class) => { SuperOwnerKind::Class(class.recursive_type_normalized(db, visitor)) @@ -11581,7 +11612,11 @@ impl<'db> BoundSuperType<'db> { ) } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { Self::new( db, self.pivot_class(db).recursive_type_normalized(db, visitor), @@ -11802,7 +11837,7 @@ pub(crate) mod tests { let union = UnionType::from_elements(&db, [KnownClass::Object.to_instance(&db), div]); assert_eq!(union.display(&db).to_string(), "object"); - let visitor = NormalizedVisitor::default().recursive(div); + let visitor = RecursiveTypeNormalizedVisitor::new(div); let recursice = UnionType::from_elements( &db, [ diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 336c9c0d5c0df..36097054412d0 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -29,9 +29,10 @@ use crate::types::{ DataclassParams, DeprecatedInstance, DivergenceKind, DivergentType, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType, - StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, - TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, UnionBuilder, - VarianceInferable, declaration_type, determine_upper_bound, infer_definition_types, + RecursiveTypeNormalizedVisitor, StringLiteralType, TypeAliasType, TypeContext, TypeMapping, + TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, + UnionBuilder, VarianceInferable, declaration_type, determine_upper_bound, + infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -294,7 +295,7 @@ impl<'db> GenericAlias<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { Self::new( db, @@ -432,7 +433,7 @@ impl<'db> ClassType<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { match self { Self::NonGeneric(_) => self, @@ -2988,7 +2989,7 @@ impl<'db> ClassLiteral<'db> { target_method_decorator, }, ); - let visitor = NormalizedVisitor::default().recursive(Type::divergent(div)); + let visitor = RecursiveTypeNormalizedVisitor::new(Type::divergent(div)); let is_valid_scope = |method_scope: ScopeId<'db>| { if let Some(method_def) = method_scope.node(db).as_function() { diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 48dc8f6e5a6ae..65426b7351289 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -4,8 +4,8 @@ use crate::types::generics::Specialization; use crate::types::tuple::TupleType; use crate::types::{ ApplyTypeMappingVisitor, ClassLiteral, ClassType, DynamicType, KnownClass, KnownInstanceType, - MaterializationKind, MroError, MroIterator, NormalizedVisitor, SpecialFormType, Type, - TypeMapping, todo_type, + MaterializationKind, MroError, MroIterator, NormalizedVisitor, RecursiveTypeNormalizedVisitor, + SpecialFormType, Type, TypeMapping, todo_type, }; /// Enumeration of the possible kinds of types we allow in class bases. @@ -37,7 +37,7 @@ impl<'db> ClassBase<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { - Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized_impl(visitor.kind)), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized_impl()), Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), Self::Protocol | Self::Generic | Self::TypedDict => self, } @@ -46,12 +46,10 @@ impl<'db> ClassBase<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { match self { - Self::Dynamic(dynamic) => { - Self::Dynamic(dynamic.recursive_type_normalized(visitor.kind)) - } + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.recursive_type_normalized()), Self::Class(class) => Self::Class(class.recursive_type_normalized(db, visitor)), Self::Protocol | Self::Generic | Self::TypedDict => self, } diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index b1d20d42aef7d..90f5a78d011f0 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -81,8 +81,9 @@ use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, DeprecatedInstance, DivergenceKind, DivergentType, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, - SpecialFormType, TrackedConstraintSet, Truthiness, Type, TypeMapping, TypeRelation, - UnionBuilder, all_members, binding_type, todo_type, walk_generic_context, walk_type_mapping, + RecursiveTypeNormalizedVisitor, SpecialFormType, TrackedConstraintSet, Truthiness, Type, + TypeMapping, TypeRelation, UnionBuilder, all_members, binding_type, todo_type, + walk_generic_context, walk_type_mapping, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; @@ -700,7 +701,11 @@ impl<'db> FunctionLiteral<'db> { Self::new(db, self.last_definition(db), context) } - fn recursive_type_normalized(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { let context = self .inherited_generic_context(db) .map(|ctx| ctx.recursive_type_normalized(db, visitor)); @@ -1052,7 +1057,7 @@ impl<'db> FunctionType<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { let mappings: Box<_> = self .type_mappings(db) diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 195fde5f6783a..365e46a5f75cb 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -18,9 +18,9 @@ use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, - KnownInstanceType, MaterializationKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, - TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, UnionType, binding_type, - declaration_type, + KnownInstanceType, MaterializationKind, NormalizedVisitor, RecursiveTypeNormalizedVisitor, + Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, + UnionType, binding_type, declaration_type, }; use crate::{Db, FxOrderSet}; @@ -403,7 +403,7 @@ impl<'db> GenericContext<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { let variables = self .variables(db) @@ -746,7 +746,7 @@ impl<'db> Specialization<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { let types: Box<[_]> = self .types(db) @@ -1045,7 +1045,7 @@ impl<'db> PartialSpecialization<'_, 'db> { pub(super) fn recursive_type_normalized( &self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> PartialSpecialization<'db, 'db> { let generic_context = self.generic_context.recursive_type_normalized(db, visitor); let types: Cow<_> = self diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 29ed1a02dcc09..c62ba9013c91f 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -51,8 +51,8 @@ use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map}; use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - ClassLiteral, CycleRecoveryType, DivergenceKind, DivergentType, NormalizedVisitor, Truthiness, - Type, TypeAndQualifiers, UnionBuilder, UnionType, + ClassLiteral, CycleRecoveryType, DivergenceKind, DivergentType, RecursiveTypeNormalizedVisitor, + Truthiness, Type, TypeAndQualifiers, UnionBuilder, UnionType, }; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -285,7 +285,7 @@ fn infer_expression_type_impl<'db>(db: &'db dyn Db, input: InferExpression<'db>) DivergenceKind::InferExpression(input), )); let previous_cycle_value = infer_expression_type_impl(db, input); - let visitor = NormalizedVisitor::default().recursive(div); + let visitor = RecursiveTypeNormalizedVisitor::new(div); let result_ty = inference.expression_type(input.expression(db).node_ref(db, &module)); UnionType::from_elements(db, [result_ty, previous_cycle_value]) .recursive_type_normalized(db, &visitor) @@ -595,7 +595,7 @@ impl<'db> ScopeInference<'db> { if let Some(fall_back) = self.fallback_type() { union = union.add(fall_back); } - let visitor = NormalizedVisitor::default().recursive(div); + let visitor = RecursiveTypeNormalizedVisitor::new(div); // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. // For example, consider the following recursive function: // ```py diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 48c68c4fc6d96..7aa3d19517160 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -89,10 +89,11 @@ use crate::types::{ CallDunderError, CallableType, ClassLiteral, ClassType, CycleRecoveryType, DataclassParams, DivergenceKind, DivergentType, DynamicType, InferExpression, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, - NormalizedVisitor, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, - SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, - TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, - TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, + PEP695TypeAliasType, Parameter, ParameterForm, Parameters, RecursiveTypeNormalizedVisitor, + SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, + TypeAndQualifiers, TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, + TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, + todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -8920,7 +8921,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { db, DivergenceKind::InferExpressionTypes(input), )); - let visitor = NormalizedVisitor::default().recursive(div); + let visitor = RecursiveTypeNormalizedVisitor::new(div); let previous_cycle_value = infer_expression_types_impl(db, input); let extra = (cycle_recovery.is_some() || !bindings.is_empty() || !diagnostics.is_empty() || !all_definitely_bound).then(|| { @@ -9033,7 +9034,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { db, DivergenceKind::InferDefinitionTypes(definition), )); - let visitor = NormalizedVisitor::default().recursive(div); + let visitor = RecursiveTypeNormalizedVisitor::new(div); let previous_cycle_value = infer_definition_types(db, definition); for (expr, ty) in &mut expressions { let previous_ty = previous_cycle_value.expression_type(*expr); @@ -9118,7 +9119,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { db, DivergenceKind::InferScopeTypes(scope), )); - let visitor = NormalizedVisitor::default().recursive(div); + let visitor = RecursiveTypeNormalizedVisitor::new(div); let previous_cycle_value = infer_scope_types(db, scope); for (expr, ty) in &mut expressions { diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 81b5b2b1105a7..abb8aba6421de 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -14,7 +14,8 @@ use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ ApplyTypeMappingVisitor, ClassBase, ClassLiteral, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, - NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, + NormalizedVisitor, RecursiveTypeNormalizedVisitor, TypeMapping, TypeRelation, + VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -339,7 +340,7 @@ impl<'db> NominalInstanceType<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { match self.0 { NominalInstanceInner::ExactTuple(tuple) => Self(NominalInstanceInner::ExactTuple( @@ -654,7 +655,7 @@ impl<'db> ProtocolInstanceType<'db> { pub(super) fn recursive_type_normalized( self, _db: &'db dyn Db, - _visitor: &NormalizedVisitor<'db>, + _visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { self } diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index d449872c4b9d1..e758eb4d1142c 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -22,8 +22,8 @@ use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, - MaterializationKind, NormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, - todo_type, + MaterializationKind, NormalizedVisitor, RecursiveTypeNormalizedVisitor, TypeMapping, + TypeRelation, VarianceInferable, todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -78,7 +78,7 @@ impl<'db> CallableSignature<'db> { pub(super) fn recursive_type_normalized( &self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { Self::from_overloads( self.overloads @@ -476,7 +476,7 @@ impl<'db> Signature<'db> { pub(super) fn recursive_type_normalized( &self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { Self { generic_context: self @@ -1614,7 +1614,7 @@ impl<'db> Parameter<'db> { pub(super) fn recursive_type_normalized( &self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { let Parameter { annotated_type, diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 73bacfbe4d99d..94704aa6455b6 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -5,8 +5,8 @@ use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, DynamicType, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsDisjointVisitor, - KnownClass, MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, - TypeMapping, TypeRelation, + KnownClass, MaterializationKind, MemberLookupPolicy, NormalizedVisitor, + RecursiveTypeNormalizedVisitor, SpecialFormType, Type, TypeMapping, TypeRelation, }; use crate::{Db, FxOrderSet}; @@ -184,7 +184,7 @@ impl<'db> SubclassOfType<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { Self { subclass_of: self.subclass_of.recursive_type_normalized(db, visitor), @@ -273,20 +273,18 @@ impl<'db> SubclassOfInner<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), - Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized_impl(visitor.kind)), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized_impl()), } } pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { match self { Self::Class(class) => Self::Class(class.recursive_type_normalized(db, visitor)), - Self::Dynamic(dynamic) => { - Self::Dynamic(dynamic.recursive_type_normalized(visitor.kind)) - } + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.recursive_type_normalized()), } } diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 6b651a4b7fe8a..6103d0f521d24 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -26,8 +26,8 @@ use crate::types::class::{ClassType, KnownClass}; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, TypeRelation, - UnionBuilder, UnionType, + IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, RecursiveTypeNormalizedVisitor, + Type, TypeMapping, TypeRelation, UnionBuilder, UnionType, }; use crate::types::{HasDivergentTypeVisitor, Truthiness}; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; @@ -231,7 +231,7 @@ impl<'db> TupleType<'db> { pub(super) fn recursive_type_normalized( self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { Self::new_internal(db, self.tuple(db).recursive_type_normalized(db, visitor)) } @@ -405,7 +405,11 @@ impl<'db> FixedLengthTuple> { Self::from_elements(self.0.iter().map(|ty| ty.normalized_impl(db, visitor))) } - fn recursive_type_normalized(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { Self::from_elements( self.0 .iter() @@ -726,7 +730,11 @@ impl<'db> VariableLengthTuple> { }) } - fn recursive_type_normalized(&self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { + fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { let prefix = self .prefix .iter() @@ -1089,7 +1097,7 @@ impl<'db> Tuple> { pub(super) fn recursive_type_normalized( &self, db: &'db dyn Db, - visitor: &NormalizedVisitor<'db>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, ) -> Self { match self { Tuple::Fixed(tuple) => Tuple::Fixed(tuple.recursive_type_normalized(db, visitor)), diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index e7bc145b8d6fb..1735620d05831 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -11,8 +11,8 @@ use crate::semantic_index::scope::ScopeId; use crate::types::infer::{InferExpression, infer_expression_types_impl, infer_unpack_types}; use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker}; use crate::types::{ - DivergenceKind, DivergentType, NormalizedVisitor, Type, TypeCheckDiagnostics, TypeContext, - UnionType, + DivergenceKind, DivergentType, RecursiveTypeNormalizedVisitor, Type, TypeCheckDiagnostics, + TypeContext, UnionType, }; use crate::unpack::{Unpack, UnpackKind, UnpackValue}; @@ -185,7 +185,7 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { DivergenceKind::InferUnpackTypes(unpack), )); let previous_cycle_value = infer_unpack_types(db, unpack); - let visitor = NormalizedVisitor::default().recursive(div); + let visitor = RecursiveTypeNormalizedVisitor::new(div); for (expr, ty) in &mut self.targets { let previous_ty = previous_cycle_value.expression_type(*expr); *ty = UnionType::from_elements(db, [*ty, previous_ty]) From f40a29ec89a0f7127f7470202fe4350eb01f81b4 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 11:30:36 +0900 Subject: [PATCH 077/105] refactor --- crates/ty_python_semantic/src/semantic_index/symbol.rs | 2 +- crates/ty_python_semantic/src/types.rs | 6 +++--- crates/ty_python_semantic/src/types/class_base.rs | 2 +- crates/ty_python_semantic/src/types/infer/builder.rs | 6 +++--- crates/ty_python_semantic/src/types/subclass_of.rs | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/ty_python_semantic/src/semantic_index/symbol.rs b/crates/ty_python_semantic/src/semantic_index/symbol.rs index a0f3c33b090dd..400165485ed46 100644 --- a/crates/ty_python_semantic/src/semantic_index/symbol.rs +++ b/crates/ty_python_semantic/src/semantic_index/symbol.rs @@ -149,7 +149,7 @@ impl Symbol { /// /// Allows lookup by name and a symbol's ID. #[derive(Default, get_size2::GetSize)] -pub(crate) struct SymbolTable { +pub(super) struct SymbolTable { symbols: IndexVec, /// Map from symbol name to its ID. diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 2571bb6fde117..75098b0d6ceb8 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1300,7 +1300,7 @@ impl<'db> Type<'db> { Type::TypeIs(type_is) => visitor.visit(self, || { type_is.with_type(db, type_is.return_type(db).normalized_impl(db, visitor)) }), - Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized_impl()), + Type::Dynamic(dynamic) => Type::Dynamic(dynamic.normalized()), Type::EnumLiteral(enum_literal) if is_single_member_enum(db, enum_literal.enum_class(db)) => { @@ -7474,7 +7474,7 @@ pub enum DynamicType<'db> { } impl DynamicType<'_> { - fn normalized_impl(self) -> Self { + fn normalized(self) -> Self { if matches!(self, Self::Divergent(_)) { self } else { @@ -11307,7 +11307,7 @@ pub enum SuperOwnerKind<'db> { impl<'db> SuperOwnerKind<'db> { fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { - SuperOwnerKind::Dynamic(dynamic) => SuperOwnerKind::Dynamic(dynamic.normalized_impl()), + SuperOwnerKind::Dynamic(dynamic) => SuperOwnerKind::Dynamic(dynamic.normalized()), SuperOwnerKind::Class(class) => { SuperOwnerKind::Class(class.normalized_impl(db, visitor)) } diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 65426b7351289..5f82cae00a4c7 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -37,7 +37,7 @@ impl<'db> ClassBase<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { - Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized_impl()), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized()), Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), Self::Protocol | Self::Generic | Self::TypedDict => self, } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 7aa3d19517160..8b75fa901645d 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -301,7 +301,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.cycle_recovery } - fn merge_cycle_recovery(&mut self, other: Option>) { + fn extend_cycle_recovery(&mut self, other: Option>) { if let Some(other) = other { match self.cycle_recovery { Some(existing) => { @@ -327,7 +327,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } if let Some(extra) = &inference.extra { - self.merge_cycle_recovery(extra.cycle_recovery); + self.extend_cycle_recovery(extra.cycle_recovery); self.context.extend(&extra.diagnostics); self.deferred.extend(extra.deferred.iter().copied()); } @@ -345,7 +345,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if let Some(extra) = &inference.extra { self.context.extend(&extra.diagnostics); - self.merge_cycle_recovery(extra.cycle_recovery); + self.extend_cycle_recovery(extra.cycle_recovery); if !matches!(self.region, InferenceRegion::Scope(..)) { self.bindings.extend(extra.bindings.iter().copied()); diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 94704aa6455b6..d7aa20c4f16ee 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -273,7 +273,7 @@ impl<'db> SubclassOfInner<'db> { pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { match self { Self::Class(class) => Self::Class(class.normalized_impl(db, visitor)), - Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized_impl()), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.normalized()), } } From 61816b4655e1395e8c2cce0fb3bfa733e9246cc4 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 11:36:33 +0900 Subject: [PATCH 078/105] remove unnecessary code --- .../src/semantic_index/definition.rs | 1 - .../src/semantic_index/expression.rs | 1 - crates/ty_python_semantic/src/types.rs | 1 + crates/ty_python_semantic/src/types/infer.rs | 12 +- .../src/types/type_ordering.rs | 115 +----------------- 5 files changed, 8 insertions(+), 122 deletions(-) diff --git a/crates/ty_python_semantic/src/semantic_index/definition.rs b/crates/ty_python_semantic/src/semantic_index/definition.rs index 46065649aae4e..b06390fa8479e 100644 --- a/crates/ty_python_semantic/src/semantic_index/definition.rs +++ b/crates/ty_python_semantic/src/semantic_index/definition.rs @@ -23,7 +23,6 @@ use crate::unpack::{Unpack, UnpackPosition}; /// before this `Definition`. However, the ID can be considered stable and it is okay to use /// `Definition` in cross-module` salsa queries or as a field on other salsa tracked structs. #[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] -#[derive(PartialOrd, Ord)] pub struct Definition<'db> { /// The file in which the definition occurs. pub file: File, diff --git a/crates/ty_python_semantic/src/semantic_index/expression.rs b/crates/ty_python_semantic/src/semantic_index/expression.rs index 4c598153ddf48..3f6f159d179f9 100644 --- a/crates/ty_python_semantic/src/semantic_index/expression.rs +++ b/crates/ty_python_semantic/src/semantic_index/expression.rs @@ -32,7 +32,6 @@ pub(crate) enum ExpressionKind { /// * a field of a type that is a return type of a cross-module query /// * an argument of a cross-module query #[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] -#[derive(PartialOrd, Ord)] pub(crate) struct Expression<'db> { /// The file in which the expression occurs. pub(crate) file: File, diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 75098b0d6ceb8..61e6eb9524a4b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -7436,6 +7436,7 @@ pub(crate) type CycleRecoveryType<'db> = Type<'db>; /// Otherwise, type inference cannot converge properly. /// For detailed properties of this type, see the unit test at the end of the file. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] +#[derive(PartialOrd, Ord)] pub struct DivergentType<'db> { /// The kind of divergence. #[returns(ref)] diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index c62ba9013c91f..c6148e8aa937e 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -312,17 +312,7 @@ fn single_expression_cycle_initial<'db>(db: &'db dyn Db, input: InferExpression< /// This is a Salsa supertype used as the input to `infer_expression_types` to avoid /// interning an `ExpressionWithContext` unnecessarily when no type context is provided. #[derive( - Debug, - Clone, - Copy, - Eq, - Hash, - PartialEq, - PartialOrd, - Ord, - salsa::Supertype, - salsa::Update, - get_size2::GetSize, + Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update, get_size2::GetSize, )] pub(super) enum InferExpression<'db> { Bare(Expression<'db>), diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index d505a4a5dd5b6..e2c4b3fb22e16 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -1,7 +1,6 @@ use std::cmp::Ordering; use crate::db::Db; -use crate::types::{DivergenceKind, DivergentType}; use super::{ DynamicType, SuperOwnerKind, TodoType, Type, TypeIsType, class_base::ClassBase, @@ -119,7 +118,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (SubclassOfInner::Class(_), _) => Ordering::Less, (_, SubclassOfInner::Class(_)) => Ordering::Greater, (SubclassOfInner::Dynamic(left), SubclassOfInner::Dynamic(right)) => { - dynamic_elements_ordering(db, left, right) + dynamic_elements_ordering(left, right) } } } @@ -173,7 +172,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (_, ClassBase::TypedDict) => Ordering::Greater, (ClassBase::Dynamic(left), ClassBase::Dynamic(right)) => { - dynamic_elements_ordering(db, left, right) + dynamic_elements_ordering(left, right) } }) .then_with(|| match (left.owner(db), right.owner(db)) { @@ -186,7 +185,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (SuperOwnerKind::Instance(_), _) => Ordering::Less, (_, SuperOwnerKind::Instance(_)) => Ordering::Greater, (SuperOwnerKind::Dynamic(left), SuperOwnerKind::Dynamic(right)) => { - dynamic_elements_ordering(db, left, right) + dynamic_elements_ordering(left, right) } }) } @@ -205,7 +204,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::PropertyInstance(_), _) => Ordering::Less, (_, Type::PropertyInstance(_)) => Ordering::Greater, - (Type::Dynamic(left), Type::Dynamic(right)) => dynamic_elements_ordering(db, *left, *right), + (Type::Dynamic(left), Type::Dynamic(right)) => dynamic_elements_ordering(*left, *right), (Type::Dynamic(_), _) => Ordering::Less, (_, Type::Dynamic(_)) => Ordering::Greater, @@ -254,11 +253,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( } /// Determine a canonical order for two instances of [`DynamicType`]. -fn dynamic_elements_ordering<'db>( - db: &'db dyn Db, - left: DynamicType<'db>, - right: DynamicType<'db>, -) -> Ordering { +fn dynamic_elements_ordering<'db>(left: DynamicType<'db>, right: DynamicType<'db>) -> Ordering { match (left, right) { (DynamicType::Any, _) => Ordering::Less, (_, DynamicType::Any) => Ordering::Greater, @@ -281,9 +276,7 @@ fn dynamic_elements_ordering<'db>( (DynamicType::TodoTypeAlias, _) => Ordering::Less, (_, DynamicType::TodoTypeAlias) => Ordering::Greater, - (DynamicType::Divergent(left), DynamicType::Divergent(right)) => { - divergent_ordering(db, left, right) - } + (DynamicType::Divergent(left), DynamicType::Divergent(right)) => left.cmp(&right), (DynamicType::Divergent(_), _) => Ordering::Less, (_, DynamicType::Divergent(_)) => Ordering::Greater, } @@ -310,99 +303,3 @@ fn typeis_ordering(db: &dyn Db, left: TypeIsType, right: TypeIsType) -> Ordering }, } } - -fn divergent_ordering<'db>( - db: &'db dyn Db, - left: DivergentType<'db>, - right: DivergentType<'db>, -) -> Ordering { - match (left.kind(db), right.kind(db)) { - (DivergenceKind::InferReturnType(left), DivergenceKind::InferReturnType(right)) => { - left.cmp(right) - } - (DivergenceKind::InferReturnType(_), _) => Ordering::Less, - (_, DivergenceKind::InferReturnType(_)) => Ordering::Greater, - - ( - DivergenceKind::ImplicitAttribute { - class_body_scope: left_scope, - name: left_name, - target_method_decorator: left_deco, - }, - DivergenceKind::ImplicitAttribute { - class_body_scope: right_scope, - name: right_name, - target_method_decorator: right_deco, - }, - ) => left_scope - .cmp(right_scope) - .then_with(|| left_name.cmp(right_name)) - .then_with(|| left_deco.cmp(right_deco)), - (DivergenceKind::ImplicitAttribute { .. }, _) => Ordering::Less, - (_, DivergenceKind::ImplicitAttribute { .. }) => Ordering::Greater, - - ( - DivergenceKind::MemberLookupWithPolicy { - self_type: left_ty, - name: left_name, - policy: left_policy, - }, - DivergenceKind::MemberLookupWithPolicy { - self_type: right_ty, - name: right_name, - policy: right_policy, - }, - ) => union_or_intersection_elements_ordering(db, left_ty, right_ty) - .then_with(|| left_name.cmp(right_name)) - .then_with(|| left_policy.cmp(right_policy)), - (DivergenceKind::MemberLookupWithPolicy { .. }, _) => Ordering::Less, - (_, DivergenceKind::MemberLookupWithPolicy { .. }) => Ordering::Greater, - - ( - DivergenceKind::ClassLookupWithPolicy { - self_type: left_ty, - name: left_name, - policy: left_policy, - }, - DivergenceKind::ClassLookupWithPolicy { - self_type: right_ty, - name: right_name, - policy: right_policy, - }, - ) => union_or_intersection_elements_ordering(db, left_ty, right_ty) - .then_with(|| left_name.cmp(right_name)) - .then_with(|| left_policy.cmp(right_policy)), - (DivergenceKind::ClassLookupWithPolicy { .. }, _) => Ordering::Less, - (_, DivergenceKind::ClassLookupWithPolicy { .. }) => Ordering::Greater, - - (DivergenceKind::InferExpression(left), DivergenceKind::InferExpression(right)) => { - left.cmp(right) - } - (DivergenceKind::InferExpression(_), _) => Ordering::Less, - (_, DivergenceKind::InferExpression(_)) => Ordering::Greater, - - ( - DivergenceKind::InferExpressionTypes(left), - DivergenceKind::InferExpressionTypes(right), - ) => left.cmp(right), - (DivergenceKind::InferExpressionTypes(_), _) => Ordering::Less, - (_, DivergenceKind::InferExpressionTypes(_)) => Ordering::Greater, - - ( - DivergenceKind::InferDefinitionTypes(left), - DivergenceKind::InferDefinitionTypes(right), - ) => left.cmp(right), - (DivergenceKind::InferDefinitionTypes(_), _) => Ordering::Less, - (_, DivergenceKind::InferDefinitionTypes(_)) => Ordering::Greater, - - (DivergenceKind::InferScopeTypes(left), DivergenceKind::InferScopeTypes(right)) => { - left.cmp(right) - } - (DivergenceKind::InferScopeTypes(_), _) => Ordering::Less, - (_, DivergenceKind::InferScopeTypes(_)) => Ordering::Greater, - - (DivergenceKind::InferUnpackTypes(left), DivergenceKind::InferUnpackTypes(right)) => { - left.cmp(right) - } - } -} From 8593f11170298855490196e7593eed44b6cc6f39 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 12:16:02 +0900 Subject: [PATCH 079/105] prevent unintended `Divergent` types from appearing in the inference results --- crates/ty_python_semantic/resources/mdtest/cycle.md | 6 ++---- crates/ty_python_semantic/src/types/infer.rs | 2 +- crates/ty_python_semantic/src/types/unpacker.rs | 11 +++++++++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/cycle.md b/crates/ty_python_semantic/resources/mdtest/cycle.md index ffd107aa0ce4f..0606f14e21caa 100644 --- a/crates/ty_python_semantic/resources/mdtest/cycle.md +++ b/crates/ty_python_semantic/resources/mdtest/cycle.md @@ -28,10 +28,8 @@ class Point: self.x, self.y = other.x, other.y p = Point() -# TODO: should be `Unknown | int` -reveal_type(p.x) # revealed: Unknown | int | Divergent -# TODO: should be `Unknown | int` -reveal_type(p.y) # revealed: Unknown | int | Divergent +reveal_type(p.x) # revealed: Unknown | int +reveal_type(p.y) # revealed: Unknown | int ``` ## Self-referential bare type alias diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index c6148e8aa937e..018104ab03a10 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -833,7 +833,7 @@ impl<'db> ExpressionInference<'db> { .unwrap_or_else(Type::unknown) } - fn fallback_type(&self) -> Option> { + pub(super) fn fallback_type(&self) -> Option> { self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index 1735620d05831..95bfcf89c2bc0 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -53,8 +53,15 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { ); let input = InferExpression::new(self.db(), value.expression(), TypeContext::default()); - let value_type = infer_expression_types_impl(self.db(), input) - .expression_type(value.expression().node_ref(self.db(), self.module())); + let inference = infer_expression_types_impl(self.db(), input); + let value_type = if let Some(cycle_recovery) = inference.fallback_type() { + let visitor = RecursiveTypeNormalizedVisitor::new(cycle_recovery); + inference + .expression_type(value.expression().node_ref(self.db(), self.module())) + .recursive_type_normalized(self.db(), &visitor) + } else { + inference.expression_type(value.expression().node_ref(self.db(), self.module())) + }; let value_type = match value.kind() { UnpackKind::Assign => { From 6b9c889c127dc23d4e1c882697b6f936f83ec029 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 16:43:24 +0900 Subject: [PATCH 080/105] fix functions that use `{Scope, Expression, Definition}Inference` to be cycle-safe --- crates/ty_python_semantic/src/dunder_all.rs | 16 ++--- .../ty_python_semantic/src/semantic_model.rs | 4 +- crates/ty_python_semantic/src/types.rs | 52 ++++++++++++----- crates/ty_python_semantic/src/types/class.rs | 6 +- .../ty_python_semantic/src/types/generics.rs | 6 +- crates/ty_python_semantic/src/types/infer.rs | 58 +++++++++++++++---- .../src/types/infer/builder.rs | 32 ++++++---- .../ty_python_semantic/src/types/unpacker.rs | 2 +- 8 files changed, 120 insertions(+), 56 deletions(-) diff --git a/crates/ty_python_semantic/src/dunder_all.rs b/crates/ty_python_semantic/src/dunder_all.rs index 10eab9321a919..83d8ec6a3418b 100644 --- a/crates/ty_python_semantic/src/dunder_all.rs +++ b/crates/ty_python_semantic/src/dunder_all.rs @@ -6,8 +6,8 @@ use ruff_python_ast::name::Name; use ruff_python_ast::statement_visitor::{StatementVisitor, walk_stmt}; use ruff_python_ast::{self as ast}; -use crate::semantic_index::{SemanticIndex, semantic_index}; -use crate::types::{Truthiness, Type, TypeContext, infer_expression_types}; +use crate::semantic_index::global_scope; +use crate::types::{Truthiness, Type, infer_scope_expression_type}; use crate::{Db, ModuleName, resolve_module}; #[allow(clippy::ref_option)] @@ -31,8 +31,7 @@ pub(crate) fn dunder_all_names(db: &dyn Db, file: File) -> Option { db: &'db dyn Db, file: File, - /// The semantic index for the module. - index: &'db SemanticIndex<'db>, - /// The origin of the `__all__` variable in the current module, [`None`] if it is not defined. origin: Option, @@ -57,11 +53,10 @@ struct DunderAllNamesCollector<'db> { } impl<'db> DunderAllNamesCollector<'db> { - fn new(db: &'db dyn Db, file: File, index: &'db SemanticIndex<'db>) -> Self { + fn new(db: &'db dyn Db, file: File) -> Self { Self { db, file, - index, origin: None, invalid: false, names: FxHashSet::default(), @@ -182,8 +177,7 @@ impl<'db> DunderAllNamesCollector<'db> { /// /// This function panics if `expr` was not marked as a standalone expression during semantic indexing. fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> { - infer_expression_types(self.db, self.index.expression(expr), TypeContext::default()) - .expression_type(expr) + infer_scope_expression_type(self.db, global_scope(self.db, self.file), expr) } /// Evaluate the given expression and return its truthiness. diff --git a/crates/ty_python_semantic/src/semantic_model.rs b/crates/ty_python_semantic/src/semantic_model.rs index 489929c9ba626..6ed95dba03074 100644 --- a/crates/ty_python_semantic/src/semantic_model.rs +++ b/crates/ty_python_semantic/src/semantic_model.rs @@ -11,7 +11,7 @@ use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::semantic_index; use crate::types::ide_support::all_declarations_and_bindings; -use crate::types::{CycleDetector, Type, binding_type, infer_scope_types}; +use crate::types::{CycleDetector, Type, binding_type, infer_scope_expression_type}; pub struct SemanticModel<'db> { db: &'db dyn Db, @@ -437,7 +437,7 @@ impl HasType for ast::ExprRef<'_> { let file_scope = index.expression_scope_id(self); let scope = file_scope.to_scope_id(model.db, model.file); - infer_scope_types(model.db, scope).expression_type(*self) + infer_scope_expression_type(model.db, scope, *self) } } diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 61e6eb9524a4b..5ea3704acf1e7 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -23,9 +23,10 @@ pub(crate) use self::cyclic::{CycleDetector, PairVisitor, TypeTransformer}; pub use self::diagnostic::TypeCheckDiagnostics; pub(crate) use self::diagnostic::register_lints; pub(crate) use self::infer::{ - TypeContext, infer_deferred_types, infer_definition_types, infer_expression_type, - infer_expression_types, infer_scope_types, static_expression_truthiness, + TypeContext, infer_deferred_types, infer_expression_type, infer_scope_expression_type, + static_expression_truthiness, }; +use self::infer::{infer_definition_types, infer_expression_types, infer_scope_types}; pub(crate) use self::signatures::{CallableSignature, Signature}; pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType}; use crate::module_name::ModuleName; @@ -160,7 +161,11 @@ pub fn check_types(db: &dyn Db, file: File) -> Vec { /// Infer the type of a binding. pub(crate) fn binding_type<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> { let inference = infer_definition_types(db, definition); - inference.binding_type(definition) + if let Some(cycle_recovery) = inference.cycle_recovery() { + UnionType::from_elements(db, [inference.binding_type(definition), cycle_recovery]) + } else { + inference.binding_type(definition) + } } /// Infer the type of a declaration. @@ -169,7 +174,28 @@ pub(crate) fn declaration_type<'db>( definition: Definition<'db>, ) -> TypeAndQualifiers<'db> { let inference = infer_definition_types(db, definition); - inference.declaration_type(definition) + if let Some(cycle_recovery) = inference.cycle_recovery() { + let decl_ty = inference.declaration_type(definition); + let union = UnionType::from_elements(db, [decl_ty.inner_type(), cycle_recovery]); + TypeAndQualifiers::new(union, decl_ty.qualifiers()) + } else { + inference.declaration_type(definition) + } +} + +pub(crate) fn undecorated_type<'db>( + db: &'db dyn Db, + definition: Definition<'db>, +) -> Option> { + let inference = infer_definition_types(db, definition); + if let Some(cycle_recovery) = inference.cycle_recovery() { + Some(UnionType::from_elements( + db, + [inference.undecorated_type()?, cycle_recovery], + )) + } else { + inference.undecorated_type() + } } /// Infer the type of a (possibly deferred) sub-expression of a [`Definition`]. @@ -191,13 +217,17 @@ fn definition_expression_type<'db>( // expression is in the definition scope let inference = infer_definition_types(db, definition); if let Some(ty) = inference.try_expression_type(expression) { - ty + if let Some(cycle_recovery) = inference.cycle_recovery() { + UnionType::from_elements(db, [ty, cycle_recovery]) + } else { + ty + } } else { infer_deferred_types(db, definition).expression_type(expression) } } else { // expression is in a type-params sub-scope - infer_scope_types(db, scope).expression_type(expression) + infer_scope_expression_type(db, scope, expression) } } @@ -8242,13 +8272,9 @@ impl<'db> BoundTypeVarInstance<'db> { match self.typevar(db).explicit_variance(db) { Some(explicit_variance) => explicit_variance.compose(polarity), None => match self.binding_context(db) { - BindingContext::Definition(definition) => { - let type_inference = infer_definition_types(db, definition); - type_inference - .binding_type(definition) - .with_polarity(polarity) - .variance_of(db, self) - } + BindingContext::Definition(definition) => binding_type(db, definition) + .with_polarity(polarity) + .variance_of(db, self), BindingContext::Synthetic => TypeVarVariance::Invariant, }, } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 36097054412d0..6e8eb541a8729 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -31,8 +31,7 @@ use crate::types::{ ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType, RecursiveTypeNormalizedVisitor, StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, - UnionBuilder, VarianceInferable, declaration_type, determine_upper_bound, - infer_definition_types, + UnionBuilder, VarianceInferable, binding_type, declaration_type, determine_upper_bound, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -4994,8 +4993,7 @@ impl KnownClass { }; let definition = index.expect_single_definition(first_param); - let first_param = - infer_definition_types(db, definition).binding_type(definition); + let first_param = binding_type(db, definition); let bound_super = BoundSuperType::build( db, diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 365e46a5f75cb..9bc26df5e849a 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -11,7 +11,6 @@ use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::{FileScopeId, NodeWithScopeKind}; use crate::types::class::ClassType; use crate::types::class_base::ClassBase; -use crate::types::infer::infer_definition_types; use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; @@ -20,7 +19,7 @@ use crate::types::{ HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, RecursiveTypeNormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, - UnionType, binding_type, declaration_type, + UnionType, binding_type, declaration_type, undecorated_type, }; use crate::{Db, FxOrderSet}; @@ -43,8 +42,7 @@ fn enclosing_generic_contexts<'db>( } NodeWithScopeKind::Function(function) => { let definition = index.expect_single_definition(function.node(module)); - infer_definition_types(db, definition) - .undecorated_type() + undecorated_type(db, definition) .expect("function should have undecorated type") .into_function_literal()? .last_definition_signature(db) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 018104ab03a10..8c26569e41fd7 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -52,7 +52,7 @@ use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ ClassLiteral, CycleRecoveryType, DivergenceKind, DivergentType, RecursiveTypeNormalizedVisitor, - Truthiness, Type, TypeAndQualifiers, UnionBuilder, UnionType, + Truthiness, Type, TypeAndQualifiers, UnionBuilder, UnionType, declaration_type, }; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -64,8 +64,10 @@ mod tests; /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. +/// When using types ​​in [`ScopeInference`], you must use [`ScopeInference::cycle_recovery`]. +/// Alternatively, consider using a cycle-safe function such as [`infer_scope_expression_type`]. #[salsa::tracked(returns(ref), cycle_fn=scope_cycle_recover, cycle_initial=scope_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { +pub(super) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { let file = scope.file(db); let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered(); @@ -97,10 +99,25 @@ fn scope_cycle_initial<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInfere ) } +pub(crate) fn infer_scope_expression_type<'db>( + db: &'db dyn Db, + scope: ScopeId<'db>, + expr: impl Into, +) -> Type<'db> { + let inference = infer_scope_types(db, scope); + if let Some(cycle_recovery) = inference.cycle_recovery() { + UnionType::from_elements(db, [inference.expression_type(expr), cycle_recovery]) + } else { + inference.expression_type(expr) + } +} + /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a place use or public type of a place. +/// When using types ​​in [`DefinitionInference`], you must use [`DefinitionInference::cycle_recovery`]. +/// Alternatively, consider using a cycle-safe function such as [`crate::types::binding_type`]. #[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -pub(crate) fn infer_definition_types<'db>( +pub(super) fn infer_definition_types<'db>( db: &'db dyn Db, definition: Definition<'db>, ) -> DefinitionInference<'db> { @@ -186,7 +203,9 @@ fn deferred_cycle_initial<'db>( /// Use rarely; only for cases where we'd otherwise risk double-inferring an expression: RHS of an /// assignment, which might be unpacking/multi-target and thus part of multiple definitions, or a /// type narrowing guard expression (e.g. if statement test node). -pub(crate) fn infer_expression_types<'db>( +/// When using types ​​in [`ExpressionInference`], you must use [`ExpressionInference::cycle_recovery`]. +/// Alternatively, consider using a cycle-safe function such as [`infer_expression_type`]. +pub(super) fn infer_expression_types<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, @@ -194,6 +213,7 @@ pub(crate) fn infer_expression_types<'db>( infer_expression_types_impl(db, InferExpression::new(db, expression, tcx)) } +/// When using types ​​in [`ExpressionInference`], you must use [`ExpressionInference::cycle_recovery`]. #[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(super) fn infer_expression_types_impl<'db>( db: &'db dyn Db, @@ -247,7 +267,7 @@ fn expression_cycle_initial<'db>( /// This is a small helper around [`infer_expression_types()`] to reduce the boilerplate. /// Use [`infer_expression_type()`] if it isn't guaranteed that `expression` is in the same file to /// avoid cross-file query dependencies. -pub(super) fn infer_same_file_expression_type<'db>( +pub(crate) fn infer_same_file_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, @@ -473,8 +493,7 @@ pub(crate) fn nearest_enclosing_class<'db>( .find_map(|(_, ancestor_scope)| { let class = ancestor_scope.node().as_class()?; let definition = semantic.expect_single_definition(class); - infer_definition_types(db, definition) - .declaration_type(definition) + declaration_type(db, definition) .inner_type() .into_class_literal() }) @@ -566,6 +585,12 @@ impl<'db> ScopeInference<'db> { self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } + /// When using `ScopeInference` during type inference, + /// use this method to get the cycle recovery type so that divergent types are propagated. + pub(super) fn cycle_recovery(&self) -> Option> { + self.fallback_type() + } + /// Returns the inferred return type of this function body (union of all possible return types), /// or `None` if the region is not a function body. /// In the case of methods, the return type of the superclass method is further unioned. @@ -582,8 +607,8 @@ impl<'db> ScopeInference<'db> { db, DivergenceKind::InferReturnType(self.scope), )); - if let Some(fall_back) = self.fallback_type() { - union = union.add(fall_back); + if let Some(cycle_recovery) = self.cycle_recovery() { + union = union.add(cycle_recovery); } let visitor = RecursiveTypeNormalizedVisitor::new(div); // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. @@ -767,6 +792,13 @@ impl<'db> DefinitionInference<'db> { self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } + /// When using `DefinitionInference` during type inference, + /// use this method to get the cycle recovery type so that divergent types are propagated. + #[allow(unused)] + pub(super) fn cycle_recovery(&self) -> Option> { + self.fallback_type() + } + pub(crate) fn undecorated_type(&self) -> Option> { self.extra.as_ref().and_then(|extra| extra.undecorated_type) } @@ -833,10 +865,16 @@ impl<'db> ExpressionInference<'db> { .unwrap_or_else(Type::unknown) } - pub(super) fn fallback_type(&self) -> Option> { + fn fallback_type(&self) -> Option> { self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } + /// When using `ExpressionInference` during type inference, + /// use this method to get the cycle recovery type so that divergent types are propagated. + pub(super) fn cycle_recovery(&self) -> Option> { + self.fallback_type() + } + /// Returns true if all places in this expression are definitely bound. pub(crate) fn all_places_definitely_bound(&self) -> bool { self.extra diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 8b75fa901645d..8b2d3d9e60b8b 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -74,7 +74,7 @@ use crate::types::function::{ }; use crate::types::generics::LegacyGenericBase; use crate::types::generics::{GenericContext, bind_typevar}; -use crate::types::infer::infer_expression_types_impl; +use crate::types::infer::{infer_expression_types_impl, infer_scope_expression_type}; use crate::types::instance::SliceLiteral; use crate::types::mro::MroErrorKind; use crate::types::signatures::Signature; @@ -93,7 +93,7 @@ use crate::types::{ SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, - todo_type, + infer_expression_type, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -441,7 +441,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { InferenceRegion::Scope(scope) if scope == expr_scope => { self.expression_type(expression) } - _ => infer_scope_types(self.db(), expr_scope).expression_type(expression), + _ => infer_scope_expression_type(self.db(), expr_scope, expression), } } @@ -1822,10 +1822,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let definition_types = infer_definition_types(self.db(), definition); - function - .decorator_list - .iter() - .map(move |decorator| definition_types.expression_type(&decorator.expression)) + function.decorator_list.iter().map(move |decorator| { + let decorator_ty = definition_types.expression_type(&decorator.expression); + if let Some(cycle_recovery) = definition_types.cycle_recovery() { + UnionType::from_elements(self.db(), [decorator_ty, cycle_recovery]) + } else { + decorator_ty + } + }) } /// Returns `true` if the current scope is the function body scope of a function overload (that @@ -4516,7 +4520,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Check non-star imports for deprecations if definition.kind(self.db()).as_star_import().is_none() { for ty in inferred.declaration_types() { - self.check_deprecated(alias, ty.inner); + let ty = if let Some(cycle_recovery) = inferred.cycle_recovery() { + UnionType::from_elements(self.db(), [ty.inner, cycle_recovery]) + } else { + ty.inner + }; + self.check_deprecated(alias, ty); } } self.extend_definition(inferred); @@ -5499,6 +5508,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut infer_iterable_type = || { let expression = self.index.expression(iterable); let result = infer_expression_types(self.db(), expression, TypeContext::default()); + let iterable = infer_expression_type(self.db(), expression, TypeContext::default()); // Two things are different if it's the first comprehension: // (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope, @@ -5507,10 +5517,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // because `ScopedExpressionId`s are only meaningful within their own scope, so // we'd add types for random wrong expressions in the current scope if comprehension.is_first() && target.is_name_expr() { - result.expression_type(iterable) + iterable } else { self.extend_expression_unchecked(result); - result.expression_type(iterable) + iterable } }; @@ -5549,7 +5559,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let definition = self.index.expect_single_definition(named); let result = infer_definition_types(self.db(), definition); self.extend_definition(result); - result.binding_type(definition) + binding_type(self.db(), definition) } else { // For syntactically invalid targets, we still need to run type inference: self.infer_expression(&named.target, TypeContext::default()); diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index 95bfcf89c2bc0..3a778b2dfa14b 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -54,7 +54,7 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { let input = InferExpression::new(self.db(), value.expression(), TypeContext::default()); let inference = infer_expression_types_impl(self.db(), input); - let value_type = if let Some(cycle_recovery) = inference.fallback_type() { + let value_type = if let Some(cycle_recovery) = inference.cycle_recovery() { let visitor = RecursiveTypeNormalizedVisitor::new(cycle_recovery); inference .expression_type(value.expression().node_ref(self.db(), self.module())) From 151fa0d1d1dbe6f24fbec15c311dd9c84a463d6f Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 16:55:34 +0900 Subject: [PATCH 081/105] Update cycle.md --- crates/ty_python_semantic/resources/mdtest/cycle.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/cycle.md b/crates/ty_python_semantic/resources/mdtest/cycle.md index 0606f14e21caa..09f14efac92e3 100644 --- a/crates/ty_python_semantic/resources/mdtest/cycle.md +++ b/crates/ty_python_semantic/resources/mdtest/cycle.md @@ -43,3 +43,16 @@ def f(x: A): # TODO: should be `A | None`? reveal_type(x[0]) # revealed: Divergent ``` + +## Self-referential type variables + +```py +from typing import Generic, TypeVar + +B = TypeVar("B", bound="Base") + +# TODO: no error +# error: [invalid-argument-type] "`typing.TypeVar | typing.TypeVar` is not a valid argument to `Generic`" +class Base(Generic[B]): + pass +``` From 81a5247ae43f3cc10123b759e51714e780fa3ef5 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 17:18:58 +0900 Subject: [PATCH 082/105] Update function.rs --- .../ty_python_semantic/src/types/function.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index a599d03c136e7..40e8bf286f233 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -594,7 +594,7 @@ impl<'db> FunctionLiteral<'db> { self.last_definition(db).spans(db) } - #[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked(returns(ref), cycle_fn=overloads_and_implementation_cycle_recover, cycle_initial=overloads_and_implementation_cycle_initial, heap_size=ruff_memory_usage::heap_size)] fn overloads_and_implementation( self, db: &'db dyn Db, @@ -713,6 +713,22 @@ impl<'db> FunctionLiteral<'db> { } } +fn overloads_and_implementation_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &(Box<[OverloadLiteral<'db>]>, Option>), + _count: u32, + _function: FunctionLiteral<'db>, +) -> salsa::CycleRecoveryAction<(Box<[OverloadLiteral<'db>]>, Option>)> { + salsa::CycleRecoveryAction::Iterate +} + +fn overloads_and_implementation_cycle_initial<'db>( + _db: &'db dyn Db, + _function: FunctionLiteral<'db>, +) -> (Box<[OverloadLiteral<'db>]>, Option>) { + (Box::new([]), None) +} + /// Represents a function type, which might be a non-generic function, or a specialization of a /// generic function. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] From caf569135091136b10a8ff4efb447a2804cea2ee Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 17:30:03 +0900 Subject: [PATCH 083/105] Update infer.rs --- crates/ty_python_semantic/src/types/infer.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 8c26569e41fd7..e9f4c50ff7a1c 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -628,10 +628,11 @@ impl<'db> ScopeInference<'db> { // 0th: Divergent // 1st: tuple[Divergent] | None // 2nd: tuple[tuple[Divergent] | None] | None => tuple[Divergent] | None - let previous_cycle_value = callee_ty.infer_return_type(db).unwrap(); - // In fixed-point iteration of return type inference, the return type must be monotonically widened and not "oscillate". - // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. - union = union.add(previous_cycle_value.recursive_type_normalized(db, &visitor)); + if let Some(previous_cycle_value) = callee_ty.infer_return_type(db) { + // In fixed-point iteration of return type inference, the return type must be monotonically widened and not "oscillate". + // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. + union = union.add(previous_cycle_value.recursive_type_normalized(db, &visitor)); + } let Some(extra) = &self.extra else { unreachable!( From 443253281e08a2c22c9f89f35a1a65c0d3a5ecdb Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Wed, 17 Sep 2025 23:13:37 +0900 Subject: [PATCH 084/105] Update crates/ty_python_semantic/resources/mdtest/attributes.md Co-authored-by: Carl Meyer --- crates/ty_python_semantic/resources/mdtest/attributes.md | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/ty_python_semantic/resources/mdtest/attributes.md b/crates/ty_python_semantic/resources/mdtest/attributes.md index 68139d03e29e4..747a153ae8a63 100644 --- a/crates/ty_python_semantic/resources/mdtest/attributes.md +++ b/crates/ty_python_semantic/resources/mdtest/attributes.md @@ -2287,6 +2287,7 @@ class Base: return Sub() class Sub(Base): + # TODO invalid override error def flip(self) -> "Base": return Base() From 43d306b20fe8f12b57536c968c52dcf9a75623ef Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 23:20:58 +0900 Subject: [PATCH 085/105] Update type.md --- .../resources/mdtest/narrow/type.md | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type.md b/crates/ty_python_semantic/resources/mdtest/narrow/type.md index f80736e1320d7..96a1bcbc847a3 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type.md @@ -146,23 +146,19 @@ def _(x: A | B): ## No special narrowing for custom `type` callable -`stub.pyi`: - -```pyi -from ty_extensions import TypeOf - -def type(x: object) -> TypeOf[int]: ... -``` - ```py -from stub import type +def type(x: object): + return int class A: ... class B: ... def _(x: A | B): + # The custom `type` function always returns `int`, + # so any branch other than `type(...) is int` is unreachable. if type(x) is A: reveal_type(x) # revealed: Never + # And the condition here is always `True` and has no effect on the narrowing of `x`. elif type(x) is int: reveal_type(x) # revealed: A | B else: From a778047c06269f20b9d28dcf364173a528486a31 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 17 Sep 2025 23:40:01 +0900 Subject: [PATCH 086/105] Update return_type.md --- .../resources/mdtest/function/return_type.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 59b20d1dafa6e..5dd5b812e3139 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -383,6 +383,15 @@ def divergent2(value): reveal_type(divergent2((1,))) # revealed: tuple[Divergent] | list[Divergent] | None +def list_int(x: int): + if x > 0: + return list1(list_int(x - 1)) + else: + return list1(x) + +# TODO: should be `list[int]` +reveal_type(list_int(1)) # revealed: list[Divergent] | list[int] + def tuple_obj(cond: bool): if cond: x = object() @@ -421,6 +430,7 @@ class C: return D() class D(C): + # TODO invalid override error def flip(self) -> "C": return C() From e7024ae67a5b6ad822f6edd6867a8393c22fff1f Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 24 Sep 2025 23:47:58 +0900 Subject: [PATCH 087/105] revert unnecessary changes for the purpose of this PR --- .../resources/corpus/divergent.py | 6 + .../resources/mdtest/attributes.md | 18 -- .../resources/mdtest/cycle.md | 25 -- crates/ty_python_semantic/src/dunder_all.rs | 16 +- .../ty_python_semantic/src/semantic_model.rs | 4 +- crates/ty_python_semantic/src/types.rs | 286 ++++-------------- crates/ty_python_semantic/src/types/class.rs | 113 +++---- .../ty_python_semantic/src/types/function.rs | 9 +- .../ty_python_semantic/src/types/generics.rs | 5 +- crates/ty_python_semantic/src/types/infer.rs | 253 +++++++++------- .../src/types/infer/builder.rs | 179 +++-------- .../types/infer/builder/type_expression.rs | 6 +- .../src/types/infer/tests.rs | 41 +-- .../src/types/type_ordering.rs | 6 +- .../ty_python_semantic/src/types/unpacker.rs | 46 +-- crates/ty_python_semantic/src/unpack.rs | 1 - crates/ty_python_semantic/tests/corpus.rs | 3 + 17 files changed, 327 insertions(+), 690 deletions(-) diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py index bc3c0e1fff66d..6632a7d834bdf 100644 --- a/crates/ty_python_semantic/resources/corpus/divergent.py +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -57,6 +57,9 @@ def unwrap(value): else: raise TypeError() +# TODO: If this is commented out, that is, if `infer_scope_types` is called before `infer_return_type`, it will panic. +reveal_type(unwrap(Foo())) + def descent(x: int, y: int): if x > y: y, x = descent(y, x) @@ -67,3 +70,6 @@ def descent(x: int, y: int): return (0, 1) else: return descent(x-1, y-1) + +# TODO: If this is commented out, that is, if `infer_scope_types` is called before `infer_return_type`, it will panic. +reveal_type(descent(5, 3)) diff --git a/crates/ty_python_semantic/resources/mdtest/attributes.md b/crates/ty_python_semantic/resources/mdtest/attributes.md index c583c1a20a835..8f5742bcc7348 100644 --- a/crates/ty_python_semantic/resources/mdtest/attributes.md +++ b/crates/ty_python_semantic/resources/mdtest/attributes.md @@ -2281,24 +2281,6 @@ class B: reveal_type(B().x) # revealed: Unknown | Literal[1] reveal_type(A().x) # revealed: Unknown | Literal[1] - -class Base: - def flip(self) -> "Sub": - return Sub() - -class Sub(Base): - # TODO invalid override error - def flip(self) -> "Base": - return Base() - -class C2: - def __init__(self, x: Sub): - self.x = x - - def replace_with(self, other: "C2"): - self.x = other.x.flip() - -reveal_type(C2(Sub()).x) # revealed: Unknown | Base ``` This case additionally tests our union/intersection simplification logic: diff --git a/crates/ty_python_semantic/resources/mdtest/cycle.md b/crates/ty_python_semantic/resources/mdtest/cycle.md index 09f14efac92e3..0bd3b5b2a6df4 100644 --- a/crates/ty_python_semantic/resources/mdtest/cycle.md +++ b/crates/ty_python_semantic/resources/mdtest/cycle.md @@ -31,28 +31,3 @@ p = Point() reveal_type(p.x) # revealed: Unknown | int reveal_type(p.y) # revealed: Unknown | int ``` - -## Self-referential bare type alias - -```py -A = list["A" | None] - -def f(x: A): - # TODO: should be `list[A | None]`? - reveal_type(x) # revealed: list[Divergent] - # TODO: should be `A | None`? - reveal_type(x[0]) # revealed: Divergent -``` - -## Self-referential type variables - -```py -from typing import Generic, TypeVar - -B = TypeVar("B", bound="Base") - -# TODO: no error -# error: [invalid-argument-type] "`typing.TypeVar | typing.TypeVar` is not a valid argument to `Generic`" -class Base(Generic[B]): - pass -``` diff --git a/crates/ty_python_semantic/src/dunder_all.rs b/crates/ty_python_semantic/src/dunder_all.rs index 83d8ec6a3418b..10eab9321a919 100644 --- a/crates/ty_python_semantic/src/dunder_all.rs +++ b/crates/ty_python_semantic/src/dunder_all.rs @@ -6,8 +6,8 @@ use ruff_python_ast::name::Name; use ruff_python_ast::statement_visitor::{StatementVisitor, walk_stmt}; use ruff_python_ast::{self as ast}; -use crate::semantic_index::global_scope; -use crate::types::{Truthiness, Type, infer_scope_expression_type}; +use crate::semantic_index::{SemanticIndex, semantic_index}; +use crate::types::{Truthiness, Type, TypeContext, infer_expression_types}; use crate::{Db, ModuleName, resolve_module}; #[allow(clippy::ref_option)] @@ -31,7 +31,8 @@ pub(crate) fn dunder_all_names(db: &dyn Db, file: File) -> Option { db: &'db dyn Db, file: File, + /// The semantic index for the module. + index: &'db SemanticIndex<'db>, + /// The origin of the `__all__` variable in the current module, [`None`] if it is not defined. origin: Option, @@ -53,10 +57,11 @@ struct DunderAllNamesCollector<'db> { } impl<'db> DunderAllNamesCollector<'db> { - fn new(db: &'db dyn Db, file: File) -> Self { + fn new(db: &'db dyn Db, file: File, index: &'db SemanticIndex<'db>) -> Self { Self { db, file, + index, origin: None, invalid: false, names: FxHashSet::default(), @@ -177,7 +182,8 @@ impl<'db> DunderAllNamesCollector<'db> { /// /// This function panics if `expr` was not marked as a standalone expression during semantic indexing. fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> { - infer_scope_expression_type(self.db, global_scope(self.db, self.file), expr) + infer_expression_types(self.db, self.index.expression(expr), TypeContext::default()) + .expression_type(expr) } /// Evaluate the given expression and return its truthiness. diff --git a/crates/ty_python_semantic/src/semantic_model.rs b/crates/ty_python_semantic/src/semantic_model.rs index 3ff4cdd79574c..11ea272246f85 100644 --- a/crates/ty_python_semantic/src/semantic_model.rs +++ b/crates/ty_python_semantic/src/semantic_model.rs @@ -12,7 +12,7 @@ use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::semantic_index; use crate::types::ide_support::all_declarations_and_bindings; -use crate::types::{Type, binding_type, infer_scope_expression_type}; +use crate::types::{Type, binding_type, infer_scope_types}; pub struct SemanticModel<'db> { db: &'db dyn Db, @@ -363,7 +363,7 @@ impl HasType for ast::ExprRef<'_> { let file_scope = index.expression_scope_id(self); let scope = file_scope.to_scope_id(model.db, model.file); - infer_scope_expression_type(model.db, scope, *self) + infer_scope_types(model.db, scope).expression_type(*self) } } diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index e040cacf25706..65285f5766a6b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -24,11 +24,11 @@ pub(crate) use self::cyclic::{PairVisitor, TypeTransformer}; pub use self::diagnostic::TypeCheckDiagnostics; pub(crate) use self::diagnostic::register_lints; pub(crate) use self::infer::{ - TypeContext, infer_deferred_types, infer_expression_type, infer_isolated_expression, - infer_scope_expression_type, static_expression_truthiness, + TypeContext, infer_deferred_types, infer_definition_types, infer_expression_type, + infer_expression_types, infer_isolated_expression, infer_scope_types, + static_expression_truthiness, }; -use self::infer::{infer_definition_types, infer_expression_types, infer_scope_types}; -pub(crate) use self::signatures::{CallableSignature, Signature}; +pub(crate) use self::signatures::{CallableSignature, Parameter, Parameters, Signature}; pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType}; use crate::module_name::ModuleName; use crate::module_resolver::{KnownModule, resolve_module}; @@ -39,7 +39,6 @@ use crate::semantic_index::scope::ScopeId; use crate::semantic_index::{imported_modules, place_table, semantic_index}; use crate::suppression::check_suppressions; use crate::types::call::{Binding, Bindings, CallArguments, CallableBinding}; -use crate::types::class::MethodDecorator; pub(crate) use crate::types::class_base::ClassBase; use crate::types::constraints::{ ConstraintSet, IteratorConstraintsExtension, OptionConstraintsExtension, @@ -62,15 +61,13 @@ pub use crate::types::ide_support::{ definitions_for_keyword_argument, definitions_for_name, find_active_signature_from_details, inlay_hint_function_argument_details, }; -use crate::types::infer::InferExpression; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; -pub(crate) use crate::types::signatures::{Parameter, Parameters}; use crate::types::signatures::{ParameterForm, walk_signature}; use crate::types::tuple::TupleSpec; pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type}; use crate::types::variance::{TypeVarVariance, VarianceInferable}; -use crate::unpack::{EvaluationMode, Unpack}; +use crate::unpack::EvaluationMode; pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic; use crate::{Db, FxOrderSet, Module, Program}; pub(crate) use class::{ClassLiteral, ClassType, GenericAlias, KnownClass}; @@ -121,16 +118,13 @@ fn return_type_cycle_recover<'db>( } fn return_type_cycle_initial<'db>(db: &'db dyn Db, method: BoundMethodType<'db>) -> Type<'db> { - Type::divergent(DivergentType::new( - db, - DivergenceKind::InferReturnType( - method - .function(db) - .literal(db) - .last_definition(db) - .body_scope(db), - ), - )) + Type::divergent( + method + .function(db) + .literal(db) + .last_definition(db) + .body_scope(db), + ) } pub fn check_types(db: &dyn Db, file: File) -> Vec { @@ -164,11 +158,7 @@ pub fn check_types(db: &dyn Db, file: File) -> Vec { /// Infer the type of a binding. pub(crate) fn binding_type<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> { let inference = infer_definition_types(db, definition); - if let Some(cycle_recovery) = inference.cycle_recovery() { - UnionType::from_elements(db, [inference.binding_type(definition), cycle_recovery]) - } else { - inference.binding_type(definition) - } + inference.binding_type(definition) } /// Infer the type of a declaration. @@ -177,28 +167,7 @@ pub(crate) fn declaration_type<'db>( definition: Definition<'db>, ) -> TypeAndQualifiers<'db> { let inference = infer_definition_types(db, definition); - if let Some(cycle_recovery) = inference.cycle_recovery() { - let decl_ty = inference.declaration_type(definition); - let union = UnionType::from_elements(db, [decl_ty.inner_type(), cycle_recovery]); - TypeAndQualifiers::new(union, decl_ty.qualifiers()) - } else { - inference.declaration_type(definition) - } -} - -pub(crate) fn undecorated_type<'db>( - db: &'db dyn Db, - definition: Definition<'db>, -) -> Option> { - let inference = infer_definition_types(db, definition); - if let Some(cycle_recovery) = inference.cycle_recovery() { - Some(UnionType::from_elements( - db, - [inference.undecorated_type()?, cycle_recovery], - )) - } else { - inference.undecorated_type() - } + inference.declaration_type(definition) } /// Infer the type of a (possibly deferred) sub-expression of a [`Definition`]. @@ -220,17 +189,13 @@ fn definition_expression_type<'db>( // expression is in the definition scope let inference = infer_definition_types(db, definition); if let Some(ty) = inference.try_expression_type(expression) { - if let Some(cycle_recovery) = inference.cycle_recovery() { - UnionType::from_elements(db, [ty, cycle_recovery]) - } else { - ty - } + ty } else { infer_deferred_types(db, definition).expression_type(expression) } } else { // expression is in a type-params sub-scope - infer_scope_expression_type(db, scope, expression) + infer_scope_types(db, scope).expression_type(expression) } } @@ -273,6 +238,7 @@ pub(crate) type TryBoolVisitor<'db> = CycleDetector, Result>>; pub(crate) struct TryBool; +/// A [`TypeTransformer`] that is used in `normalized` methods. pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>; pub(crate) struct Normalized; @@ -287,7 +253,11 @@ pub(crate) struct RecursiveTypeNormalizedVisitor<'db> { impl<'db> RecursiveTypeNormalizedVisitor<'db> { fn new(div: Type<'db>) -> Self { - debug_assert!(matches!(div, Type::Dynamic(DynamicType::Divergent(_)))); + // TODO: Divergent only + debug_assert!(matches!( + div, + Type::Never | Type::Dynamic(DynamicType::Divergent(_)) + )); Self { transformer: NormalizedVisitor::default(), div, @@ -370,7 +340,7 @@ enum InstanceFallbackShadowsNonDataDescriptor { } bitflags! { - #[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] + #[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)] pub(crate) struct MemberLookupPolicy: u8 { /// Dunder methods are looked up on the meta-type of a type without potentially falling /// back on attributes on the type itself. For example, when implicitly invoked on an @@ -433,8 +403,6 @@ impl Default for MemberLookupPolicy { } } -impl get_size2::GetSize for MemberLookupPolicy {} - fn member_lookup_cycle_recover<'db>( _db: &'db dyn Db, _value: &PlaceAndQualifiers<'db>, @@ -446,22 +414,13 @@ fn member_lookup_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -#[allow(clippy::needless_pass_by_value)] fn member_lookup_cycle_initial<'db>( - db: &'db dyn Db, - self_type: Type<'db>, - name: Name, - policy: MemberLookupPolicy, + _db: &'db dyn Db, + _self: Type<'db>, + _name: Name, + _policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { - Place::bound(Type::divergent(DivergentType::new( - db, - DivergenceKind::MemberLookupWithPolicy { - self_type, - name, - policy, - }, - ))) - .into() + Place::bound(Type::Never).into() } fn class_lookup_cycle_recover<'db>( @@ -475,22 +434,13 @@ fn class_lookup_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -#[allow(clippy::needless_pass_by_value)] fn class_lookup_cycle_initial<'db>( - db: &'db dyn Db, - self_type: Type<'db>, - name: Name, - policy: MemberLookupPolicy, + _db: &'db dyn Db, + _self: Type<'db>, + _name: Name, + _policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { - Place::bound(Type::divergent(DivergentType::new( - db, - DivergenceKind::ClassLookupWithPolicy { - self_type, - name, - policy, - }, - ))) - .into() + Place::bound(Type::Never).into() } #[allow(clippy::trivially_copy_pass_by_ref)] @@ -909,8 +859,8 @@ impl<'db> Type<'db> { Self::Dynamic(DynamicType::Unknown) } - pub(crate) fn divergent(divergent: DivergentType<'db>) -> Self { - Self::Dynamic(DynamicType::Divergent(divergent)) + pub(crate) fn divergent(scope: ScopeId<'db>) -> Self { + Self::Dynamic(DynamicType::Divergent(DivergentType { scope })) } pub const fn is_unknown(&self) -> bool { @@ -3169,7 +3119,7 @@ impl<'db> Type<'db> { policy: MemberLookupPolicy, ) -> PlaceAndQualifiers<'db> { tracing::trace!("class_member: {}.{}", self.display(db), name); - let result = match self { + match self { Type::Union(union) => union.map_with_boundness_and_qualifiers(db, |elem| { elem.class_member_with_policy(db, name.clone(), policy) }), @@ -3187,30 +3137,7 @@ impl<'db> Type<'db> { .expect( "`Type::find_name_in_mro()` should return `Some()` when called on a meta-type", ), - }; - result.map_type(|ty| { - // In fixed-point iteration of type inference, the member type must be monotonically widened and not "oscillate". - // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. - let previous_cycle_value = self.class_member_with_policy(db, name.clone(), policy); - - let ty = if let Some(previous_ty) = previous_cycle_value.place.ignore_possibly_unbound() - { - UnionType::from_elements(db, [ty, previous_ty]) - } else { - ty - }; - - let div = Type::divergent(DivergentType::new( - db, - DivergenceKind::ClassLookupWithPolicy { - self_type: self, - name, - policy, - }, - )); - let visitor = RecursiveTypeNormalizedVisitor::new(div); - ty.recursive_type_normalized(db, &visitor) - }) + } } /// This function roughly corresponds to looking up an attribute in the `__dict__` of an object. @@ -3664,7 +3591,7 @@ impl<'db> Type<'db> { let name_str = name.as_str(); - let result = match self { + match self { Type::Union(union) => union.map_with_boundness_and_qualifiers(db, |elem| { elem.member_lookup_with_policy(db, name_str.into(), policy) }), @@ -3732,20 +3659,20 @@ impl<'db> Type<'db> { // If an attribute is not available on the bound method object, // it will be looked up on the underlying function object: Type::FunctionLiteral(bound_method.function(db)) - .member_lookup_with_policy(db, name.clone(), policy) + .member_lookup_with_policy(db, name, policy) }) } }, Type::KnownBoundMethod(method) => method .class() .to_instance(db) - .member_lookup_with_policy(db, name.clone(), policy), + .member_lookup_with_policy(db, name, policy), Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType .to_instance(db) - .member_lookup_with_policy(db, name.clone(), policy), + .member_lookup_with_policy(db, name, policy), Type::DataclassDecorator(_) => KnownClass::FunctionType .to_instance(db) - .member_lookup_with_policy(db, name.clone(), policy), + .member_lookup_with_policy(db, name, policy), Type::Callable(_) | Type::DataclassTransformer(_) if name_str == "__call__" => { Place::bound(self).into() @@ -3753,10 +3680,10 @@ impl<'db> Type<'db> { Type::Callable(callable) if callable.is_function_like(db) => KnownClass::FunctionType .to_instance(db) - .member_lookup_with_policy(db, name.clone(), policy), + .member_lookup_with_policy(db, name, policy), Type::Callable(_) | Type::DataclassTransformer(_) => { - Type::object().member_lookup_with_policy(db, name.clone(), policy) + Type::object().member_lookup_with_policy(db, name, policy) } Type::NominalInstance(instance) @@ -3814,11 +3741,9 @@ impl<'db> Type<'db> { policy, ), - Type::TypeAlias(alias) => { - alias - .value_type(db) - .member_lookup_with_policy(db, name.clone(), policy) - } + Type::TypeAlias(alias) => alias + .value_type(db) + .member_lookup_with_policy(db, name, policy), Type::EnumLiteral(enum_literal) if matches!(name_str, "name" | "_name_") @@ -4003,30 +3928,7 @@ impl<'db> Type<'db> { .try_call_dunder_get_on_attribute(db, owner_attr.clone()) .unwrap_or(owner_attr) } - }; - result.map_type(|ty| { - // In fixed-point iteration of type inference, the member type must be monotonically widened and not "oscillate". - // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. - let previous_cycle_value = self.member_lookup_with_policy(db, name.clone(), policy); - - let ty = if let Some(previous_ty) = previous_cycle_value.place.ignore_possibly_unbound() - { - UnionType::from_elements(db, [ty, previous_ty]) - } else { - ty - }; - - let div = Type::divergent(DivergentType::new( - db, - DivergenceKind::MemberLookupWithPolicy { - self_type: self, - name, - policy, - }, - )); - let visotor = RecursiveTypeNormalizedVisitor::new(div); - ty.recursive_type_normalized(db, &visotor) - }) + } } /// Resolves the boolean value of the type and falls back to [`Truthiness::Ambiguous`] if the type doesn't implement `__bool__` correctly. @@ -7485,59 +7387,17 @@ impl<'db> KnownInstanceType<'db> { } } -#[allow(private_interfaces)] -#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] -pub enum DivergenceKind<'db> { - /// Divergence from `{FunctionLiteral, BoundMethodType}::infer_return_type`. - InferReturnType(ScopeId<'db>), - /// Divergence from `ClassLiteral::implicit_attribute_inner`. - ImplicitAttribute { - class_body_scope: ScopeId<'db>, - name: String, - target_method_decorator: MethodDecorator, - }, - /// Divergence from `Type::member_lookup_with_policy`. - MemberLookupWithPolicy { - self_type: Type<'db>, - name: Name, - policy: MemberLookupPolicy, - }, - /// Divergence from `Type::class_lookup_with_policy`. - ClassLookupWithPolicy { - self_type: Type<'db>, - name: Name, - policy: MemberLookupPolicy, - }, - /// Divergence from `infer_expression_type_impl`. - InferExpression(InferExpression<'db>), - /// Divergence from `infer_expression_types_impl`. - InferExpressionTypes(InferExpression<'db>), - /// Divergence from `infer_definition_types`. - InferDefinitionTypes(Definition<'db>), - /// Divergence from `infer_scope_types`. - InferScopeTypes(ScopeId<'db>), - /// Divergence from `infer_unpack_types`. - InferUnpackTypes(Unpack<'db>), -} - -pub(crate) type CycleRecoveryType<'db> = Type<'db>; - /// A type that is determined to be divergent during recursive type inference. /// This type must never be eliminated by dynamic type reduction /// (e.g. `Divergent` is assignable to `@Todo`, but `@Todo | Divergent` must not be reducted to `@Todo`). /// Otherwise, type inference cannot converge properly. /// For detailed properties of this type, see the unit test at the end of the file. -#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -#[derive(PartialOrd, Ord)] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] pub struct DivergentType<'db> { - /// The kind of divergence. - #[returns(ref)] - kind: DivergenceKind<'db>, + /// The scope where this divergence was detected. + scope: ScopeId<'db>, } -// The Salsa heap is tracked separately. -impl get_size2::GetSize for DivergentType<'_> {} - #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq, salsa::Update, get_size2::GetSize)] pub enum DynamicType<'db> { /// An explicitly annotated `typing.Any` @@ -8230,7 +8090,7 @@ impl<'db> TypeVarInstance<'db> { Some(TypeVarBoundOrConstraints::UpperBound(ty)) } - #[salsa::tracked(cycle_fn=lazy_constraint_cycle_recover, cycle_initial=lazy_constraint_cycle_initial)] + #[salsa::tracked] fn lazy_constraints(self, db: &'db dyn Db) -> Option> { let definition = self.definition(db)?; let module = parsed_module(db, definition.file(db)).load(db); @@ -8240,7 +8100,7 @@ impl<'db> TypeVarInstance<'db> { Some(TypeVarBoundOrConstraints::Constraints(ty)) } - #[salsa::tracked(cycle_fn=lazy_default_cycle_recover, cycle_initial=lazy_default_cycle_initial)] + #[salsa::tracked] fn lazy_default(self, db: &'db dyn Db) -> Option> { let definition = self.definition(db)?; let module = parsed_module(db, definition.file(db)).load(db); @@ -8270,40 +8130,6 @@ fn lazy_bound_cycle_initial<'db>( None } -#[allow(clippy::ref_option)] -fn lazy_constraint_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &Option>, - _count: u32, - _self: TypeVarInstance<'db>, -) -> salsa::CycleRecoveryAction>> { - salsa::CycleRecoveryAction::Iterate -} - -fn lazy_constraint_cycle_initial<'db>( - _db: &'db dyn Db, - _self: TypeVarInstance<'db>, -) -> Option> { - None -} - -#[allow(clippy::ref_option)] -fn lazy_default_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &Option>, - _count: u32, - _self: TypeVarInstance<'db>, -) -> salsa::CycleRecoveryAction>> { - salsa::CycleRecoveryAction::Iterate -} - -fn lazy_default_cycle_initial<'db>( - _db: &'db dyn Db, - _self: TypeVarInstance<'db>, -) -> Option> { - None -} - /// Where a type variable is bound and usable. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub enum BindingContext<'db> { @@ -12000,11 +11826,7 @@ pub(crate) mod tests { let file = system_path_to_file(&db, "src/foo.py").unwrap(); let file_scope_id = FileScopeId::global(); let scope = file_scope_id.to_scope_id(&db, file); - - let div = Type::divergent(DivergentType::new( - &db, - DivergenceKind::InferReturnType(scope), - )); + let div = Type::divergent(scope); // The `Divergent` type must not be eliminated in union with other dynamic types, // as this would prevent detection of divergent type inference using `Divergent`. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 349036bb160b6..f4f55fe6af689 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -26,12 +26,12 @@ use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, - DataclassParams, DeprecatedInstance, DivergenceKind, DivergentType, FindLegacyTypeVarsVisitor, - HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, - ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType, - RecursiveTypeNormalizedVisitor, StringLiteralType, TypeAliasType, TypeContext, TypeMapping, - TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, - UnionBuilder, VarianceInferable, binding_type, declaration_type, determine_upper_bound, + DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, + HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, + MaterializationKind, NormalizedVisitor, PropertyInstanceType, RecursiveTypeNormalizedVisitor, + StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, + TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, UnionBuilder, + VarianceInferable, declaration_type, determine_upper_bound, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -106,22 +106,13 @@ fn implicit_attribute_recover<'db>( salsa::CycleRecoveryAction::Iterate } -#[allow(clippy::needless_pass_by_value)] fn implicit_attribute_initial<'db>( - db: &'db dyn Db, - class_body_scope: ScopeId<'db>, - name: String, - target_method_decorator: MethodDecorator, + _db: &'db dyn Db, + _class_body_scope: ScopeId<'db>, + _name: String, + _target_method_decorator: MethodDecorator, ) -> PlaceAndQualifiers<'db> { - Place::bound(Type::divergent(DivergentType::new( - db, - DivergenceKind::ImplicitAttribute { - class_body_scope, - name, - target_method_decorator, - }, - ))) - .into() + Place::Unbound.into() } fn try_mro_cycle_recover<'db>( @@ -181,19 +172,6 @@ fn try_metaclass_cycle_initial<'db>( }) } -fn into_callable_cycle_recover<'db>( - _db: &'db dyn Db, - _value: &Type<'db>, - _count: u32, - _self: ClassType<'db>, -) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate -} - -fn into_callable_cycle_initial<'db>(db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { - Type::Callable(CallableType::bottom(db)) -} - /// A category of classes with code generation capabilities (with synthesized methods). #[derive(Clone, Copy, Debug, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) enum CodeGeneratorKind { @@ -1268,6 +1246,19 @@ impl<'db> ClassType<'db> { } } +fn into_callable_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: ClassType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn into_callable_cycle_initial<'db>(db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { + Type::Callable(CallableType::bottom(db)) +} + impl<'db> From> for ClassType<'db> { fn from(generic: GenericAlias<'db>) -> ClassType<'db> { ClassType::Generic(generic) @@ -1294,7 +1285,7 @@ impl<'db> VarianceInferable<'db> for ClassType<'db> { /// A filter that describes which methods are considered when looking for implicit attribute assignments /// in [`ClassLiteral::implicit_attribute`]. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, get_size2::GetSize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub(super) enum MethodDecorator { None, ClassMethod, @@ -2960,7 +2951,7 @@ impl<'db> ClassLiteral<'db> { cycle_initial=implicit_attribute_initial, heap_size=ruff_memory_usage::heap_size, )] - pub(super) fn implicit_attribute_inner( + fn implicit_attribute_inner( db: &'db dyn Db, class_body_scope: ScopeId<'db>, name: String, @@ -2980,15 +2971,6 @@ impl<'db> ClassLiteral<'db> { let index = semantic_index(db, file); let class_map = use_def_map(db, class_body_scope); let class_table = place_table(db, class_body_scope); - let div = DivergentType::new( - db, - DivergenceKind::ImplicitAttribute { - class_body_scope, - name: name.clone(), - target_method_decorator, - }, - ); - let visitor = RecursiveTypeNormalizedVisitor::new(Type::divergent(div)); let is_valid_scope = |method_scope: &Scope| { if let Some(method_def) = method_scope.node().as_function() { @@ -3046,8 +3028,7 @@ impl<'db> ClassLiteral<'db> { index.expression(value), TypeContext::default(), ); - return Place::bound(inferred_ty.recursive_type_normalized(db, &visitor)) - .with_qualifiers(all_qualifiers); + return Place::bound(inferred_ty).with_qualifiers(all_qualifiers); } // If there is no right-hand side, just record that we saw a `Final` qualifier @@ -3062,19 +3043,6 @@ impl<'db> ClassLiteral<'db> { if !qualifiers.contains(TypeQualifiers::FINAL) { union_of_inferred_types = union_of_inferred_types.add(Type::unknown()); } - if let Place::Type(previous_cycle_type, _) = Self::implicit_attribute_inner( - db, - class_body_scope, - name.clone(), - target_method_decorator, - ) - .place - { - // In fixed-point iteration of type inference, the attribute type must be monotonically widened and not "oscillate". - // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. - union_of_inferred_types = union_of_inferred_types - .add(previous_cycle_type.recursive_type_normalized(db, &visitor)); - } for (attribute_assignments, method_scope_id) in attribute_assignments(db, class_body_scope, &name) @@ -3136,8 +3104,7 @@ impl<'db> ClassLiteral<'db> { let inferred_ty = unpacked.expression_type(assign.target(&module)); - union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.recursive_type_normalized(db, &visitor)); + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } TargetKind::Single => { // We found an un-annotated attribute assignment of the form: @@ -3150,8 +3117,7 @@ impl<'db> ClassLiteral<'db> { TypeContext::default(), ); - union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.recursive_type_normalized(db, &visitor)); + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } } } @@ -3166,8 +3132,7 @@ impl<'db> ClassLiteral<'db> { let inferred_ty = unpacked.expression_type(for_stmt.target(&module)); - union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.recursive_type_normalized(db, &visitor)); + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } TargetKind::Single => { // We found an attribute assignment like: @@ -3183,8 +3148,7 @@ impl<'db> ClassLiteral<'db> { let inferred_ty = iterable_ty.iterate(db).homogeneous_element_type(db); - union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.recursive_type_normalized(db, &visitor)); + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } } } @@ -3199,8 +3163,7 @@ impl<'db> ClassLiteral<'db> { let inferred_ty = unpacked.expression_type(with_item.target(&module)); - union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.recursive_type_normalized(db, &visitor)); + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } TargetKind::Single => { // We found an attribute assignment like: @@ -3218,8 +3181,7 @@ impl<'db> ClassLiteral<'db> { context_ty.enter(db) }; - union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.recursive_type_normalized(db, &visitor)); + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } } } @@ -3235,8 +3197,7 @@ impl<'db> ClassLiteral<'db> { let inferred_ty = unpacked.expression_type(comprehension.target(&module)); - union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.recursive_type_normalized(db, &visitor)); + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } TargetKind::Single => { // We found an attribute assignment like: @@ -3252,8 +3213,7 @@ impl<'db> ClassLiteral<'db> { let inferred_ty = iterable_ty.iterate(db).homogeneous_element_type(db); - union_of_inferred_types = union_of_inferred_types - .add(inferred_ty.recursive_type_normalized(db, &visitor)); + union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } } } @@ -5019,7 +4979,8 @@ impl KnownClass { }; let definition = index.expect_single_definition(first_param); - let first_param = binding_type(db, definition); + let first_param = + infer_definition_types(db, definition).binding_type(definition); let bound_super = BoundSuperType::build( db, diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index f88c783afe9f9..222fc6f0590cc 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -79,8 +79,8 @@ use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, - DeprecatedInstance, DivergenceKind, DivergentType, DynamicType, FindLegacyTypeVarsVisitor, - HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, + DeprecatedInstance, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, + IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, RecursiveTypeNormalizedVisitor, SpecialFormType, TrackedConstraintSet, Truthiness, Type, TypeMapping, TypeRelation, UnionBuilder, all_members, binding_type, todo_type, walk_generic_context, walk_type_mapping, @@ -97,10 +97,7 @@ fn return_type_cycle_recover<'db>( } fn return_type_cycle_initial<'db>(db: &'db dyn Db, function: FunctionType<'db>) -> Type<'db> { - Type::Dynamic(DynamicType::Divergent(DivergentType::new( - db, - DivergenceKind::InferReturnType(function.literal(db).last_definition(db).body_scope(db)), - ))) + Type::divergent(function.literal(db).last_definition(db).body_scope(db)) } /// A collection of useful spans for annotating functions. diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 63887ec18da50..4a972c208d7aa 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -20,7 +20,7 @@ use crate::types::{ HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, RecursiveTypeNormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, - TypeVarVariance, UnionType, binding_type, declaration_type, undecorated_type, + TypeVarVariance, UnionType, binding_type, declaration_type, infer_definition_types, }; use crate::{Db, FxOrderSet}; @@ -43,7 +43,8 @@ fn enclosing_generic_contexts<'db>( } NodeWithScopeKind::Function(function) => { let definition = index.expect_single_definition(function.node(module)); - undecorated_type(db, definition) + infer_definition_types(db, definition) + .undecorated_type() .expect("function should have undecorated type") .into_function_literal()? .last_definition_signature(db) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 5a464f57d97e6..c5e62aa02b6fc 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -31,7 +31,7 @@ //! //! Many of our type inference Salsa queries implement cycle recovery via fixed-point iteration. In //! general, they initiate fixed-point iteration by returning an `Inference` type that returns -//! the `Divergent` type for all expressions, bindings, and declarations, and then they continue iterating +//! `Type::Never` for all expressions, bindings, and declarations, and then they continue iterating //! the query cycle until a fixed-point is reached. Salsa has a built-in fixed limit on the number //! of iterations, so if we fail to converge, Salsa will eventually panic. (This should of course //! be considered a bug.) @@ -53,9 +53,8 @@ use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::generics::Specialization; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ - ClassLiteral, CycleRecoveryType, DivergenceKind, DivergentType, KnownClass, - RecursiveTypeNormalizedVisitor, Truthiness, Type, TypeAndQualifiers, UnionBuilder, UnionType, - declaration_type, + ClassLiteral, KnownClass, RecursiveTypeNormalizedVisitor, Truthiness, Type, TypeAndQualifiers, + UnionBuilder, }; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -64,13 +63,14 @@ mod builder; #[cfg(test)] mod tests; +/// How many fixpoint iterations to allow before falling back to Divergent type. +const ITERATIONS_BEFORE_FALLBACK: u32 = 10; + /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. -/// When using types ​​in [`ScopeInference`], you must use [`ScopeInference::cycle_recovery`]. -/// Alternatively, consider using a cycle-safe function such as [`infer_scope_expression_type`]. #[salsa::tracked(returns(ref), cycle_fn=scope_cycle_recover, cycle_initial=scope_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -pub(super) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { +pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { let file = scope.file(db); let _span = tracing::trace_span!("infer_scope_types", scope=?scope.as_id(), ?file).entered(); @@ -92,35 +92,14 @@ fn scope_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -fn scope_cycle_initial<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { - ScopeInference::cycle_initial( - Type::divergent(DivergentType::new( - db, - DivergenceKind::InferScopeTypes(scope), - )), - scope, - ) -} - -pub(crate) fn infer_scope_expression_type<'db>( - db: &'db dyn Db, - scope: ScopeId<'db>, - expr: impl Into, -) -> Type<'db> { - let inference = infer_scope_types(db, scope); - if let Some(cycle_recovery) = inference.cycle_recovery() { - UnionType::from_elements(db, [inference.expression_type(expr), cycle_recovery]) - } else { - inference.expression_type(expr) - } +fn scope_cycle_initial<'db>(_db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { + ScopeInference::cycle_initial(scope) } /// Infer all types for a [`Definition`] (including sub-expressions). /// Use when resolving a place use or public type of a place. -/// When using types ​​in [`DefinitionInference`], you must use [`DefinitionInference::cycle_recovery`]. -/// Alternatively, consider using a cycle-safe function such as [`crate::types::binding_type`]. #[salsa::tracked(returns(ref), cycle_fn=definition_cycle_recover, cycle_initial=definition_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -pub(super) fn infer_definition_types<'db>( +pub(crate) fn infer_definition_types<'db>( db: &'db dyn Db, definition: Definition<'db>, ) -> DefinitionInference<'db> { @@ -140,25 +119,25 @@ pub(super) fn infer_definition_types<'db>( } fn definition_cycle_recover<'db>( - _db: &'db dyn Db, + db: &'db dyn Db, _value: &DefinitionInference<'db>, - _count: u32, - _definition: Definition<'db>, + count: u32, + definition: Definition<'db>, ) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate + if count == ITERATIONS_BEFORE_FALLBACK { + salsa::CycleRecoveryAction::Fallback(DefinitionInference::cycle_fallback( + definition.scope(db), + )) + } else { + salsa::CycleRecoveryAction::Iterate + } } fn definition_cycle_initial<'db>( db: &'db dyn Db, definition: Definition<'db>, ) -> DefinitionInference<'db> { - DefinitionInference::cycle_initial( - definition.scope(db), - Type::divergent(DivergentType::new( - db, - DivergenceKind::InferDefinitionTypes(definition), - )), - ) + DefinitionInference::cycle_initial(definition.scope(db)) } /// Infer types for all deferred type expressions in a [`Definition`]. @@ -199,16 +178,14 @@ fn deferred_cycle_initial<'db>( db: &'db dyn Db, definition: Definition<'db>, ) -> DefinitionInference<'db> { - DefinitionInference::cycle_initial(definition.scope(db), Type::Never) + DefinitionInference::cycle_initial(definition.scope(db)) } /// Infer all types for an [`Expression`] (including sub-expressions). /// Use rarely; only for cases where we'd otherwise risk double-inferring an expression: RHS of an /// assignment, which might be unpacking/multi-target and thus part of multiple definitions, or a /// type narrowing guard expression (e.g. if statement test node). -/// When using types ​​in [`ExpressionInference`], you must use [`ExpressionInference::cycle_recovery`]. -/// Alternatively, consider using a cycle-safe function such as [`infer_expression_type`]. -pub(super) fn infer_expression_types<'db>( +pub(crate) fn infer_expression_types<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, @@ -218,7 +195,7 @@ pub(super) fn infer_expression_types<'db>( /// When using types ​​in [`ExpressionInference`], you must use [`ExpressionInference::cycle_recovery`]. #[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] -pub(super) fn infer_expression_types_impl<'db>( +fn infer_expression_types_impl<'db>( db: &'db dyn Db, input: InferExpression<'db>, ) -> ExpressionInference<'db> { @@ -242,7 +219,7 @@ pub(super) fn infer_expression_types_impl<'db>( index, &module, ) - .finish_expression(input) + .finish_expression() } /// Infer the type of an expression in isolation. @@ -264,23 +241,25 @@ pub(crate) fn infer_isolated_expression<'db>( } fn expression_cycle_recover<'db>( - _db: &'db dyn Db, + db: &'db dyn Db, _value: &ExpressionInference<'db>, - _count: u32, - _input: InferExpression<'db>, + count: u32, + input: InferExpression<'db>, ) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate + if count == ITERATIONS_BEFORE_FALLBACK { + salsa::CycleRecoveryAction::Fallback(ExpressionInference::cycle_fallback( + input.expression(db).scope(db), + )) + } else { + salsa::CycleRecoveryAction::Iterate + } } fn expression_cycle_initial<'db>( db: &'db dyn Db, input: InferExpression<'db>, ) -> ExpressionInference<'db> { - let cycle_recovery = Type::divergent(DivergentType::new( - db, - DivergenceKind::InferExpressionTypes(input), - )); - ExpressionInference::cycle_initial(input.expression(db).scope(db), cycle_recovery) + ExpressionInference::cycle_initial(input.expression(db).scope(db)) } /// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query. @@ -288,7 +267,7 @@ fn expression_cycle_initial<'db>( /// This is a small helper around [`infer_expression_types()`] to reduce the boilerplate. /// Use [`infer_expression_type()`] if it isn't guaranteed that `expression` is in the same file to /// avoid cross-file query dependencies. -pub(crate) fn infer_same_file_expression_type<'db>( +pub(super) fn infer_same_file_expression_type<'db>( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, @@ -320,16 +299,7 @@ fn infer_expression_type_impl<'db>(db: &'db dyn Db, input: InferExpression<'db>) // It's okay to call the "same file" version here because we're inside a salsa query. let inference = infer_expression_types_impl(db, input); - - let div = Type::divergent(DivergentType::new( - db, - DivergenceKind::InferExpression(input), - )); - let previous_cycle_value = infer_expression_type_impl(db, input); - let visitor = RecursiveTypeNormalizedVisitor::new(div); - let result_ty = inference.expression_type(input.expression(db).node_ref(db, &module)); - UnionType::from_elements(db, [result_ty, previous_cycle_value]) - .recursive_type_normalized(db, &visitor) + inference.expression_type(input.expression(db).node_ref(db, &module)) } fn single_expression_cycle_recover<'db>( @@ -341,27 +311,25 @@ fn single_expression_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -fn single_expression_cycle_initial<'db>(db: &'db dyn Db, input: InferExpression<'db>) -> Type<'db> { - Type::divergent(DivergentType::new( - db, - DivergenceKind::InferExpression(input), - )) +fn single_expression_cycle_initial<'db>( + _db: &'db dyn Db, + _input: InferExpression<'db>, +) -> Type<'db> { + Type::Never } /// An `Expression` with an optional `TypeContext`. /// /// This is a Salsa supertype used as the input to `infer_expression_types` to avoid /// interning an `ExpressionWithContext` unnecessarily when no type context is provided. -#[derive( - Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update, get_size2::GetSize, -)] -pub(super) enum InferExpression<'db> { +#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update)] +enum InferExpression<'db> { Bare(Expression<'db>), WithContext(ExpressionWithContext<'db>), } impl<'db> InferExpression<'db> { - pub(super) fn new( + fn new( db: &'db dyn Db, expression: Expression<'db>, tcx: TypeContext<'db>, @@ -376,7 +344,7 @@ impl<'db> InferExpression<'db> { fn expression(self, db: &'db dyn Db) -> Expression<'db> { match self { - InferExpression::Bare(bare) => bare, + InferExpression::Bare(expression) => expression, InferExpression::WithContext(expression_with_context) => { expression_with_context.expression(db) } @@ -393,17 +361,13 @@ impl<'db> InferExpression<'db> { } } -/// An [`Expression`] with a [`TypeContext`]. +/// An `Expression` with a `TypeContext`. #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] -#[derive(PartialOrd, Ord)] -pub(super) struct ExpressionWithContext<'db> { +struct ExpressionWithContext<'db> { expression: Expression<'db>, tcx: TypeContext<'db>, } -/// The Salsa heap is tracked separately. -impl get_size2::GetSize for ExpressionWithContext<'_> {} - /// The type context for a given expression, namely the type annotation /// in an annotated assignment. /// @@ -461,6 +425,7 @@ pub(crate) fn static_expression_truthiness<'db>( let file = expression.file(db); let module = parsed_module(db, file).load(db); let node = expression.node_ref(db, &module); + inference.expression_type(node).bool(db) } @@ -496,7 +461,7 @@ pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> U let mut unpacker = Unpacker::new(db, unpack.target_scope(db), &module); unpacker.unpack(unpack.target(db, &module), unpack.value(db)); - unpacker.finish(unpack) + unpacker.finish() } fn unpack_cycle_recover<'db>( @@ -508,12 +473,8 @@ fn unpack_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -fn unpack_cycle_initial<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> { - let cycle_recovery = Type::divergent(DivergentType::new( - db, - DivergenceKind::InferUnpackTypes(unpack), - )); - UnpackResult::cycle_initial(cycle_recovery) +fn unpack_cycle_initial<'db>(_db: &'db dyn Db, _unpack: Unpack<'db>) -> UnpackResult<'db> { + UnpackResult::cycle_initial(Type::Never) } /// Returns the type of the nearest enclosing class for the given scope. @@ -535,7 +496,8 @@ pub(crate) fn nearest_enclosing_class<'db>( .find_map(|(_, ancestor_scope)| { let class = ancestor_scope.node().as_class()?; let definition = semantic.expect_single_definition(class); - declaration_type(db, definition) + infer_definition_types(db, definition) + .declaration_type(definition) .inner_type() .into_class_literal() }) @@ -566,6 +528,35 @@ impl<'db> InferenceRegion<'db> { } } +#[derive(Debug, Clone, Copy, Eq, PartialEq, get_size2::GetSize, salsa::Update)] +enum CycleRecovery<'db> { + /// An initial-value for fixpoint iteration; all types are `Type::Never`. + Initial, + /// A divergence-fallback value for fixpoint iteration; all types are `Divergent`. + Divergent(ScopeId<'db>), +} + +impl<'db> CycleRecovery<'db> { + fn merge(self, other: Option>) -> Self { + if let Some(other) = other { + match (self, other) { + // It's important here that we keep the scope of `self` if merging two `Divergent`. + (Self::Divergent(scope), _) | (_, Self::Divergent(scope)) => Self::Divergent(scope), + _ => Self::Initial, + } + } else { + self + } + } + + fn fallback_type(self) -> Type<'db> { + match self { + Self::Initial => Type::Never, + Self::Divergent(scope) => Type::divergent(scope), + } + } +} + /// The inferred types for a scope region. #[derive(Debug, Eq, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) struct ScopeInference<'db> { @@ -580,8 +571,8 @@ pub(crate) struct ScopeInference<'db> { #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] struct ScopeInferenceExtra<'db> { - /// The fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_recovery: Option>, + /// Is this a cycle-recovery inference result, and if so, what kind? + cycle_recovery: Option>, /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, @@ -593,14 +584,14 @@ struct ScopeInferenceExtra<'db> { } impl<'db> ScopeInference<'db> { - fn cycle_initial(cycle_recovery: CycleRecoveryType<'db>, scope: ScopeId<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>) -> Self { Self { + scope, extra: Some(Box::new(ScopeInferenceExtra { - cycle_recovery: Some(cycle_recovery), + cycle_recovery: Some(CycleRecovery::Initial), ..ScopeInferenceExtra::default() })), expressions: FxHashMap::default(), - scope, } } @@ -624,7 +615,9 @@ impl<'db> ScopeInference<'db> { } fn fallback_type(&self) -> Option> { - self.extra.as_ref().and_then(|extra| extra.cycle_recovery) + self.extra + .as_ref() + .and_then(|extra| extra.cycle_recovery.map(CycleRecovery::fallback_type)) } /// When using `ScopeInference` during type inference, @@ -645,10 +638,7 @@ impl<'db> ScopeInference<'db> { } let mut union = UnionBuilder::new(db); - let div = Type::divergent(DivergentType::new( - db, - DivergenceKind::InferReturnType(self.scope), - )); + let div = Type::divergent(self.scope); if let Some(cycle_recovery) = self.cycle_recovery() { union = union.add(cycle_recovery); } @@ -738,8 +728,8 @@ pub(crate) struct DefinitionInference<'db> { #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] struct DefinitionInferenceExtra<'db> { - /// The fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_recovery: Option>, + /// Is this a cycle-recovery inference result, and if so, what kind? + cycle_recovery: Option>, /// The definitions that are deferred. deferred: Box<[Definition<'db>]>, @@ -752,7 +742,7 @@ struct DefinitionInferenceExtra<'db> { } impl<'db> DefinitionInference<'db> { - fn cycle_initial(scope: ScopeId<'db>, cycle_recovery: CycleRecoveryType<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>) -> Self { let _ = scope; Self { @@ -762,7 +752,23 @@ impl<'db> DefinitionInference<'db> { #[cfg(debug_assertions)] scope, extra: Some(Box::new(DefinitionInferenceExtra { - cycle_recovery: Some(cycle_recovery), + cycle_recovery: Some(CycleRecovery::Initial), + ..DefinitionInferenceExtra::default() + })), + } + } + + fn cycle_fallback(scope: ScopeId<'db>) -> Self { + let _ = scope; + + Self { + expressions: FxHashMap::default(), + bindings: Box::default(), + declarations: Box::default(), + #[cfg(debug_assertions)] + scope, + extra: Some(Box::new(DefinitionInferenceExtra { + cycle_recovery: Some(CycleRecovery::Divergent(scope)), ..DefinitionInferenceExtra::default() })), } @@ -832,7 +838,9 @@ impl<'db> DefinitionInference<'db> { } fn fallback_type(&self) -> Option> { - self.extra.as_ref().and_then(|extra| extra.cycle_recovery) + self.extra + .as_ref() + .and_then(|extra| extra.cycle_recovery.map(CycleRecovery::fallback_type)) } /// When using `DefinitionInference` during type inference, @@ -871,23 +879,39 @@ struct ExpressionInferenceExtra<'db> { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, - /// The fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_recovery: Option>, + /// Is this a cycle recovery inference result, and if so, what kind? + cycle_recovery: Option>, /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, } impl<'db> ExpressionInference<'db> { - fn cycle_initial(scope: ScopeId<'db>, cycle_recovery: CycleRecoveryType<'db>) -> Self { + fn cycle_initial(scope: ScopeId<'db>) -> Self { let _ = scope; Self { extra: Some(Box::new(ExpressionInferenceExtra { - cycle_recovery: Some(cycle_recovery), + cycle_recovery: Some(CycleRecovery::Initial), all_definitely_bound: true, ..ExpressionInferenceExtra::default() })), expressions: FxHashMap::default(), + + #[cfg(debug_assertions)] + scope, + } + } + + fn cycle_fallback(scope: ScopeId<'db>) -> Self { + let _ = scope; + Self { + extra: Some(Box::new(ExpressionInferenceExtra { + cycle_recovery: Some(CycleRecovery::Divergent(scope)), + all_definitely_bound: true, + ..ExpressionInferenceExtra::default() + })), + expressions: FxHashMap::default(), + #[cfg(debug_assertions)] scope, } @@ -909,11 +933,14 @@ impl<'db> ExpressionInference<'db> { } fn fallback_type(&self) -> Option> { - self.extra.as_ref().and_then(|extra| extra.cycle_recovery) + self.extra + .as_ref() + .and_then(|extra| extra.cycle_recovery.map(CycleRecovery::fallback_type)) } /// When using `ExpressionInference` during type inference, /// use this method to get the cycle recovery type so that divergent types are propagated. + #[allow(unused)] pub(super) fn cycle_recovery(&self) -> Option> { self.fallback_type() } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 161a5ee783c12..e8fbb7dade750 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -12,7 +12,7 @@ use super::{ DefinitionInference, DefinitionInferenceExtra, ExpressionInference, ExpressionInferenceExtra, InferenceRegion, ScopeInference, ScopeInferenceExtra, infer_deferred_types, infer_definition_types, infer_expression_types, infer_same_file_expression_type, - infer_scope_types, infer_unpack_types, + infer_unpack_types, }; use crate::module_name::{ModuleName, ModuleNameResolutionError}; use crate::module_resolver::{ @@ -76,7 +76,7 @@ use crate::types::function::{ use crate::types::generics::{ GenericContext, LegacyGenericBase, SpecializationBuilder, bind_typevar, }; -use crate::types::infer::{infer_expression_types_impl, infer_scope_expression_type}; +use crate::types::infer::CycleRecovery; use crate::types::instance::SliceLiteral; use crate::types::mro::MroErrorKind; use crate::types::signatures::{Parameter, Parameters, Signature}; @@ -88,14 +88,13 @@ use crate::types::typed_dict::{ }; use crate::types::visitor::any_over_type; use crate::types::{ - CallDunderError, CallableType, ClassLiteral, ClassType, CycleRecoveryType, DataclassParams, - DivergenceKind, DivergentType, DynamicType, InferExpression, IntersectionBuilder, - IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, MetaclassCandidate, - PEP695TypeAliasType, ParameterForm, RecursiveTypeNormalizedVisitor, SpecialFormType, - SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, - TypeContext, TypeMapping, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, - TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, - infer_expression_type, todo_type, + CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType, + IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, + MetaclassCandidate, PEP695TypeAliasType, ParameterForm, SpecialFormType, SubclassOfType, + TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeContext, + TypeMapping, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, + TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, infer_scope_types, + todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -264,8 +263,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// For function definitions, the undecorated type of the function. undecorated_type: Option>, - /// The fallback type for missing expressions/bindings/declarations or recursive type inference. - cycle_recovery: Option>, + /// Did we merge in a sub-region with a cycle-recovery fallback, and if so, what kind? + cycle_recovery: Option>, /// `true` if all places in this expression are definitely bound all_definitely_bound: bool, @@ -306,22 +305,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - fn fallback_type(&self) -> Option> { - self.cycle_recovery + fn extend_cycle_recovery(&mut self, other_recovery: Option>) { + match &mut self.cycle_recovery { + Some(recovery) => *recovery = recovery.merge(other_recovery), + recovery @ None => *recovery = other_recovery, + } } - fn extend_cycle_recovery(&mut self, other: Option>) { - if let Some(other) = other { - match self.cycle_recovery { - Some(existing) => { - self.cycle_recovery = - Some(UnionType::from_elements(self.db(), [existing, other])); - } - None => { - self.cycle_recovery = Some(other); - } - } - } + fn fallback_type(&self) -> Option> { + self.cycle_recovery.map(CycleRecovery::fallback_type) } fn extend_definition(&mut self, inference: &DefinitionInference<'db>) { @@ -450,7 +442,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { InferenceRegion::Scope(scope) if scope == expr_scope => { self.expression_type(expression) } - _ => infer_scope_expression_type(self.db(), expr_scope, expression), + _ => infer_scope_types(self.db(), expr_scope).expression_type(expression), } } @@ -1835,14 +1827,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let definition_types = infer_definition_types(self.db(), definition); - function.decorator_list.iter().map(move |decorator| { - let decorator_ty = definition_types.expression_type(&decorator.expression); - if let Some(cycle_recovery) = definition_types.cycle_recovery() { - UnionType::from_elements(self.db(), [decorator_ty, cycle_recovery]) - } else { - decorator_ty - } - }) + function + .decorator_list + .iter() + .map(move |decorator| definition_types.expression_type(&decorator.expression)) } /// Returns `true` if the current scope is the function body scope of a function overload (that @@ -3921,8 +3909,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) => { self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown())); - let object_ty = - self.infer_maybe_standalone_expression(object, TypeContext::default()); + let object_ty = self.infer_expression(object, TypeContext::default()); if let Some(assigned_ty) = assigned_ty { self.validate_attribute_assignment( @@ -4535,12 +4522,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Check non-star imports for deprecations if definition.kind(self.db()).as_star_import().is_none() { for ty in inferred.declaration_types() { - let ty = if let Some(cycle_recovery) = inferred.cycle_recovery() { - UnionType::from_elements(self.db(), [ty.inner, cycle_recovery]) - } else { - ty.inner - }; - self.check_deprecated(alias, ty); + self.check_deprecated(alias, ty.inner); } } self.extend_definition(inferred); @@ -5292,9 +5274,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut annotated_elt_tys = annotated_tuple.as_ref().map(Tuple::all_elements); let db = self.db(); + let divergent = Type::divergent(self.scope()); let element_types = elts.iter().map(|element| { let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied(); - self.infer_expression(element, TypeContext::new(annotated_elt_ty)) + let element_type = self.infer_expression(element, TypeContext::new(annotated_elt_ty)); + + if element_type.has_divergent_type(self.db(), divergent) { + divergent + } else { + element_type + } }); Type::heterogeneous_tuple(db, element_types) @@ -5595,7 +5584,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let mut infer_iterable_type = || { let expression = self.index.expression(iterable); let result = infer_expression_types(self.db(), expression, TypeContext::default()); - let iterable = infer_expression_type(self.db(), expression, TypeContext::default()); // Two things are different if it's the first comprehension: // (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope, @@ -5604,10 +5592,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // because `ScopedExpressionId`s are only meaningful within their own scope, so // we'd add types for random wrong expressions in the current scope if comprehension.is_first() && target.is_name_expr() { - iterable + result.expression_type(iterable) } else { self.extend_expression_unchecked(result); - iterable + result.expression_type(iterable) } }; @@ -5646,7 +5634,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let definition = self.index.expect_single_definition(named); let result = infer_definition_types(self.db(), definition); self.extend_definition(result); - binding_type(self.db(), definition) + result.binding_type(definition) } else { // For syntactically invalid targets, we still need to run type inference: self.infer_expression(&named.target, TypeContext::default()); @@ -6854,7 +6842,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match ctx { ExprContext::Load => self.infer_attribute_load(attribute), ExprContext::Store => { - self.infer_maybe_standalone_expression(value, TypeContext::default()); + self.infer_expression(value, TypeContext::default()); Type::Never } ExprContext::Del => { @@ -6862,7 +6850,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { Type::Never } ExprContext::Invalid => { - self.infer_maybe_standalone_expression(value, TypeContext::default()); + self.infer_expression(value, TypeContext::default()); Type::unknown() } } @@ -9044,18 +9032,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { expr_ty } - pub(super) fn finish_expression( - mut self, - input: InferExpression<'db>, - ) -> ExpressionInference<'db> { + pub(super) fn finish_expression(mut self) -> ExpressionInference<'db> { self.infer_region(); - let db = self.db(); let Self { context, mut expressions, scope, - mut bindings, + bindings, declarations, deferred, cycle_recovery, @@ -9085,12 +9069,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { "Expression region can't have deferred types" ); - let div = Type::divergent(DivergentType::new( - db, - DivergenceKind::InferExpressionTypes(input), - )); - let visitor = RecursiveTypeNormalizedVisitor::new(div); - let previous_cycle_value = infer_expression_types_impl(db, input); let extra = (cycle_recovery.is_some() || !bindings.is_empty() || !diagnostics.is_empty() || !all_definitely_bound).then(|| { if bindings.len() > 20 { @@ -9100,17 +9078,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { bindings.len() ); } - for (binding, binding_ty) in bindings.iter_mut() { - if let Some((_, previous_binding)) = previous_cycle_value.extra.as_deref() - .and_then(|extra| extra.bindings.iter().find(|(previous_binding, _)| previous_binding == binding)) { - *binding_ty = UnionType::from_elements( - db, - [*binding_ty, *previous_binding], - ).recursive_type_normalized(db, &visitor); - } else { - *binding_ty = binding_ty.recursive_type_normalized(db, &visitor); - } - } Box::new(ExpressionInferenceExtra { bindings: bindings.into_boxed_slice(), @@ -9121,11 +9088,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }); expressions.shrink_to_fit(); - for (expr, ty) in &mut expressions { - let previous_ty = previous_cycle_value.expression_type(*expr); - *ty = UnionType::from_elements(db, [*ty, previous_ty]) - .recursive_type_normalized(db, &visitor); - } ExpressionInference { expressions, @@ -9137,33 +9099,26 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { pub(super) fn finish_definition(mut self) -> DefinitionInference<'db> { self.infer_region(); - let db = self.db(); let Self { context, mut expressions, scope, - mut bindings, - mut declarations, + bindings, + declarations, deferred, cycle_recovery, undecorated_type, all_definitely_bound: _, // builder only state typevar_binding_context: _, - deferred_state, + deferred_state: _, called_functions: _, index: _, - region, + region: _, returnees: _, } = self; - let (InferenceRegion::Definition(definition) | InferenceRegion::Deferred(definition)) = - region - else { - panic!("expected definition/deferred region"); - }; - let _ = scope; let diagnostics = context.finish(); @@ -9197,32 +9152,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } expressions.shrink_to_fit(); - if !matches!(region, InferenceRegion::Deferred(_)) && !deferred_state.is_deferred() { - let div = Type::divergent(DivergentType::new( - db, - DivergenceKind::InferDefinitionTypes(definition), - )); - let visitor = RecursiveTypeNormalizedVisitor::new(div); - let previous_cycle_value = infer_definition_types(db, definition); - for (expr, ty) in &mut expressions { - let previous_ty = previous_cycle_value.expression_type(*expr); - *ty = UnionType::from_elements(db, [*ty, previous_ty]) - .recursive_type_normalized(db, &visitor); - } - - for (binding, binding_ty) in bindings.iter_mut() { - let previous_ty = previous_cycle_value.binding_type(*binding); - *binding_ty = UnionType::from_elements(db, [*binding_ty, previous_ty]) - .recursive_type_normalized(db, &visitor); - } - for (declaration, TypeAndQualifiers { inner, .. }) in declarations.iter_mut() { - let previous_ty = previous_cycle_value - .declaration_type(*declaration) - .inner_type(); - *inner = UnionType::from_elements(db, [*inner, previous_ty]) - .recursive_type_normalized(db, &visitor); - } - } DefinitionInference { expressions, @@ -9280,22 +9209,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }); expressions.shrink_to_fit(); - if let NodeWithScopeKind::TypeAlias(_) = scope.node(db) { - // Don't perform recursive type normalization on type aliases - } else { - let div = Type::divergent(DivergentType::new( - db, - DivergenceKind::InferScopeTypes(scope), - )); - let visitor = RecursiveTypeNormalizedVisitor::new(div); - let previous_cycle_value = infer_scope_types(db, scope); - - for (expr, ty) in &mut expressions { - let previous_ty = previous_cycle_value.expression_type(*expr); - *ty = UnionType::from_elements(db, [*ty, previous_ty]) - .recursive_type_normalized(db, &visitor); - } - } ScopeInference { expressions, @@ -9561,10 +9474,6 @@ where self.0.iter().map(|(k, v)| (k, v)) } - fn iter_mut(&mut self) -> impl ExactSizeIterator { - self.0.iter_mut().map(|(k, v)| (&*k, v)) - } - fn insert(&mut self, key: K, value: V) { debug_assert!( !self.0.iter().any(|(existing, _)| existing == &key), diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index 4ffee16332620..e48d0582017a6 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -21,7 +21,11 @@ use crate::types::{ impl<'db> TypeInferenceBuilder<'db, '_> { /// Infer the type of a type expression. pub(super) fn infer_type_expression(&mut self, expression: &ast::Expr) -> Type<'db> { - let ty = self.infer_type_expression_no_store(expression); + let mut ty = self.infer_type_expression_no_store(expression); + let divergent = Type::divergent(self.scope()); + if ty.has_divergent_type(self.db(), divergent) { + ty = divergent; + } self.store_expression_type(expression, ty); ty } diff --git a/crates/ty_python_semantic/src/types/infer/tests.rs b/crates/ty_python_semantic/src/types/infer/tests.rs index 6eef61de10ef9..ec4a0d46c8e0f 100644 --- a/crates/ty_python_semantic/src/types/infer/tests.rs +++ b/crates/ty_python_semantic/src/types/infer/tests.rs @@ -5,44 +5,14 @@ use crate::place::{ConsideredDefinitions, Place, global_symbol}; use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::{global_scope, place_table, semantic_index, use_def_map}; -use crate::types::function::FunctionType; -use crate::types::{BoundMethodType, KnownClass, KnownInstanceType, UnionType, check_types}; +use crate::types::{KnownClass, KnownInstanceType, UnionType, check_types}; use ruff_db::diagnostic::Diagnostic; use ruff_db::files::{File, system_path_to_file}; use ruff_db::system::DbWithWritableSystem as _; use ruff_db::testing::{assert_function_query_was_not_run, assert_function_query_was_run}; -use salsa::Database; use super::*; -fn __() { - let _ = &Type::member_lookup_with_policy; - let _ = &Type::class_member_with_policy; - let _ = &FunctionType::infer_return_type; - let _ = &BoundMethodType::infer_return_type; - let _ = &ClassLiteral::implicit_attribute_inner; - let _ = &infer_expression_type_impl; - let _ = &infer_expression_types_impl; - let _ = &infer_definition_types; - let _ = &infer_scope_types; - let _ = &infer_unpack_types; -} -/// These queries refer to a value ​​from the previous cycle to ensure convergence. -/// Therefore, even when convergence is apparent, they will cycle at least once. -/// TODO: Is it possible to use the salsa API to get the value from the previous cycle (without doing anything if called for the first time)? -const QUERIES_USE_PREVIOUS_CYCLE_VALUE: [&str; 10] = [ - "Type < 'db >::member_lookup_with_policy_", - "Type < 'db >::class_member_with_policy_", - "FunctionType < 'db >::infer_return_type_", - "BoundMethodType < 'db >::infer_return_type_", - "ClassLiteral < 'db >::implicit_attribute_inner_", - "infer_expression_type_impl", - "infer_expression_types_impl", - "infer_definition_types", - "infer_scope_types", - "infer_unpack_types", -]; - #[track_caller] fn get_symbol<'db>( db: &'db TestDb, @@ -303,12 +273,6 @@ fn unbound_symbol_no_reachability_constraint_check() { .iter() .filter_map(|event| { if let salsa::EventKind::WillIterateCycle { database_key, .. } = event.kind { - if QUERIES_USE_PREVIOUS_CYCLE_VALUE.contains( - &db.ingredient_debug_name(database_key.ingredient_index()) - .as_ref(), - ) { - return None; - } Some(format!("{database_key:?}")) } else { None @@ -502,6 +466,7 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; + assert_function_query_was_not_run( &db, infer_expression_types_impl, @@ -595,6 +560,7 @@ fn dependency_own_instance_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); db.take_salsa_events() }; + assert_function_query_was_not_run( &db, infer_expression_types_impl, @@ -693,6 +659,7 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> { assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); db.take_salsa_events() }; + assert_function_query_was_not_run( &db, infer_expression_types_impl, diff --git a/crates/ty_python_semantic/src/types/type_ordering.rs b/crates/ty_python_semantic/src/types/type_ordering.rs index e2c4b3fb22e16..3a0f1cd251113 100644 --- a/crates/ty_python_semantic/src/types/type_ordering.rs +++ b/crates/ty_python_semantic/src/types/type_ordering.rs @@ -253,7 +253,7 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( } /// Determine a canonical order for two instances of [`DynamicType`]. -fn dynamic_elements_ordering<'db>(left: DynamicType<'db>, right: DynamicType<'db>) -> Ordering { +fn dynamic_elements_ordering(left: DynamicType, right: DynamicType) -> Ordering { match (left, right) { (DynamicType::Any, _) => Ordering::Less, (_, DynamicType::Any) => Ordering::Greater, @@ -276,7 +276,9 @@ fn dynamic_elements_ordering<'db>(left: DynamicType<'db>, right: DynamicType<'db (DynamicType::TodoTypeAlias, _) => Ordering::Less, (_, DynamicType::TodoTypeAlias) => Ordering::Greater, - (DynamicType::Divergent(left), DynamicType::Divergent(right)) => left.cmp(&right), + (DynamicType::Divergent(left), DynamicType::Divergent(right)) => { + left.scope.cmp(&right.scope) + } (DynamicType::Divergent(_), _) => Ordering::Less, (_, DynamicType::Divergent(_)) => Ordering::Greater, } diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index 3a778b2dfa14b..b448053295f82 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -8,13 +8,9 @@ use ruff_python_ast::{self as ast, AnyNodeRef}; use crate::Db; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::scope::ScopeId; -use crate::types::infer::{InferExpression, infer_expression_types_impl, infer_unpack_types}; use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker}; -use crate::types::{ - DivergenceKind, DivergentType, RecursiveTypeNormalizedVisitor, Type, TypeCheckDiagnostics, - TypeContext, UnionType, -}; -use crate::unpack::{Unpack, UnpackKind, UnpackValue}; +use crate::types::{Type, TypeCheckDiagnostics, TypeContext, infer_expression_types}; +use crate::unpack::{UnpackKind, UnpackValue}; use super::context::InferContext; use super::diagnostic::INVALID_ASSIGNMENT; @@ -52,16 +48,9 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { "Unpacking target must be a list or tuple expression" ); - let input = InferExpression::new(self.db(), value.expression(), TypeContext::default()); - let inference = infer_expression_types_impl(self.db(), input); - let value_type = if let Some(cycle_recovery) = inference.cycle_recovery() { - let visitor = RecursiveTypeNormalizedVisitor::new(cycle_recovery); - inference - .expression_type(value.expression().node_ref(self.db(), self.module())) - .recursive_type_normalized(self.db(), &visitor) - } else { - inference.expression_type(value.expression().node_ref(self.db(), self.module())) - }; + let value_type = + infer_expression_types(self.db(), value.expression(), TypeContext::default()) + .expression_type(value.expression().node_ref(self.db(), self.module())); let value_type = match value.kind() { UnpackKind::Assign => { @@ -184,25 +173,12 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { } } - pub(crate) fn finish(mut self, unpack: Unpack<'db>) -> UnpackResult<'db> { - let db = self.db(); + pub(crate) fn finish(mut self) -> UnpackResult<'db> { self.targets.shrink_to_fit(); - let div = Type::divergent(DivergentType::new( - db, - DivergenceKind::InferUnpackTypes(unpack), - )); - let previous_cycle_value = infer_unpack_types(db, unpack); - let visitor = RecursiveTypeNormalizedVisitor::new(div); - for (expr, ty) in &mut self.targets { - let previous_ty = previous_cycle_value.expression_type(*expr); - *ty = UnionType::from_elements(db, [*ty, previous_ty]) - .recursive_type_normalized(db, &visitor); - } - UnpackResult { diagnostics: self.context.finish(), targets: self.targets, - cycle_recovery: None, + cycle_fallback_type: None, } } } @@ -215,7 +191,7 @@ pub(crate) struct UnpackResult<'db> { /// The fallback type for missing expressions. /// /// This is used only when constructing a cycle-recovery `UnpackResult`. - cycle_recovery: Option>, + cycle_fallback_type: Option>, } impl<'db> UnpackResult<'db> { @@ -241,7 +217,7 @@ impl<'db> UnpackResult<'db> { self.targets .get(&expr.into()) .copied() - .or(self.cycle_recovery) + .or(self.cycle_fallback_type) } /// Returns the diagnostics in this unpacking assignment. @@ -249,11 +225,11 @@ impl<'db> UnpackResult<'db> { &self.diagnostics } - pub(crate) fn cycle_initial(cycle_recovery: Type<'db>) -> Self { + pub(crate) fn cycle_initial(cycle_fallback_type: Type<'db>) -> Self { Self { targets: FxHashMap::default(), diagnostics: TypeCheckDiagnostics::default(), - cycle_recovery: Some(cycle_recovery), + cycle_fallback_type: Some(cycle_fallback_type), } } } diff --git a/crates/ty_python_semantic/src/unpack.rs b/crates/ty_python_semantic/src/unpack.rs index 9da45ea3e1dbf..cb07f2570a725 100644 --- a/crates/ty_python_semantic/src/unpack.rs +++ b/crates/ty_python_semantic/src/unpack.rs @@ -27,7 +27,6 @@ use crate::semantic_index::scope::{FileScopeId, ScopeId}; /// * a field of a type that is a return type of a cross-module query /// * an argument of a cross-module query #[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)] -#[derive(PartialOrd, Ord)] pub(crate) struct Unpack<'db> { pub(crate) file: File, diff --git a/crates/ty_python_semantic/tests/corpus.rs b/crates/ty_python_semantic/tests/corpus.rs index d91da4b91d9de..83ad7ae1ffdae 100644 --- a/crates/ty_python_semantic/tests/corpus.rs +++ b/crates/ty_python_semantic/tests/corpus.rs @@ -169,6 +169,9 @@ fn run_corpus_tests(pattern: &str) -> anyhow::Result<()> { /// Whether or not the .py/.pyi version of this file is expected to fail #[rustfmt::skip] const KNOWN_FAILURES: &[(&str, bool, bool)] = &[ + // Fails with too-many-cycle-iterations due to a self-referential + // type alias, see https://github.com/astral-sh/ty/issues/256 + ("crates/ruff_linter/resources/test/fixtures/pyflakes/F401_34.py", true, true), ]; #[salsa::db] From 01321642409c4326c0c65c382f10a26faf2d0f19 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 25 Sep 2025 00:46:17 +0900 Subject: [PATCH 088/105] Update types.rs --- crates/ty_python_semantic/src/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 8ef6fc2a67961..b68d96fdce52b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -10075,7 +10075,7 @@ impl<'db> KnownBoundMethodType<'db> { property.recursive_type_normalized(db, visitor), ) } - KnownBoundMethodType::StrStartswith(_) => self, + KnownBoundMethodType::StrStartswith(_) | KnownBoundMethodType::PathOpen => self, } } From c58434db9791f8c375cdc29495d04a0e2467f733 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 25 Sep 2025 01:51:21 +0900 Subject: [PATCH 089/105] use `any_over_type` in `has_divergent_type` --- crates/ty_python_semantic/src/types.rs | 128 +----------------- crates/ty_python_semantic/src/types/class.rs | 26 +--- .../ty_python_semantic/src/types/generics.rs | 24 +--- .../ty_python_semantic/src/types/instance.rs | 17 +-- .../src/types/signatures.rs | 40 +----- .../src/types/subclass_of.rs | 19 +-- crates/ty_python_semantic/src/types/tuple.rs | 13 +- 7 files changed, 21 insertions(+), 246 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index b68d96fdce52b..18b760b4a2c33 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -67,6 +67,7 @@ use crate::types::signatures::{ParameterForm, walk_signature}; use crate::types::tuple::TupleSpec; pub(crate) use crate::types::typed_dict::{TypedDictParams, TypedDictType, walk_typed_dict_type}; use crate::types::variance::{TypeVarVariance, VarianceInferable}; +use crate::types::visitor::any_over_type; use crate::unpack::EvaluationMode; pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic; use crate::{Db, FxOrderSet, Module, Program}; @@ -277,10 +278,6 @@ impl<'db> RecursiveTypeNormalizedVisitor<'db> { } } -/// A [`CycleDetector`] that is used in `has_divergent_type` methods. -pub(crate) type HasDivergentTypeVisitor<'db> = CycleDetector, bool>; -pub(crate) struct HasDivergentType; - /// How a generic type has been specialized. /// /// This matters only if there is at least one invariant type parameter. @@ -661,19 +658,6 @@ impl<'db> PropertyInstanceType<'db> { getter_equivalence.and(db, setter_equivalence) } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.setter(db) - .is_some_and(|setter| setter.has_divergent_type_impl(db, div, visitor)) - || self - .getter(db) - .is_some_and(|getter| getter.has_divergent_type_impl(db, div, visitor)) - } } bitflags! { @@ -6772,79 +6756,7 @@ impl<'db> Type<'db> { } pub(super) fn has_divergent_type(self, db: &'db dyn Db, div: Type<'db>) -> bool { - let visitor = HasDivergentTypeVisitor::new(false); - self.has_divergent_type_impl(db, div, &visitor) - } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - match self { - Type::Dynamic(DynamicType::Divergent(_)) => self == div, - Type::Union(union) => { - visitor.visit(self, || union.has_divergent_type_impl(db, div, visitor)) - } - Type::Intersection(intersection) => visitor.visit(self, || { - intersection.has_divergent_type_impl(db, div, visitor) - }), - Type::GenericAlias(alias) => visitor.visit(self, || { - alias - .specialization(db) - .has_divergent_type_impl(db, div, visitor) - }), - Type::NominalInstance(instance) => visitor.visit(self, || { - instance.class(db).has_divergent_type_impl(db, div, visitor) - }), - Type::Callable(callable) => { - visitor.visit(self, || callable.has_divergent_type_impl(db, div, visitor)) - } - Type::ProtocolInstance(protocol) => { - visitor.visit(self, || protocol.has_divergent_type_impl(db, div, visitor)) - } - Type::PropertyInstance(property) => { - visitor.visit(self, || property.has_divergent_type_impl(db, div, visitor)) - } - Type::TypeIs(type_is) => visitor.visit(self, || { - type_is - .return_type(db) - .has_divergent_type_impl(db, div, visitor) - }), - Type::SubclassOf(subclass_of) => visitor.visit(self, || { - subclass_of.has_divergent_type_impl(db, div, visitor) - }), - Type::TypedDict(typed_dict) => visitor.visit(self, || { - typed_dict - .defining_class() - .has_divergent_type_impl(db, div, visitor) - }), - Type::Never - | Type::AlwaysTruthy - | Type::AlwaysFalsy - | Type::WrapperDescriptor(_) - | Type::DataclassDecorator(_) - | Type::DataclassTransformer(_) - | Type::ModuleLiteral(_) - | Type::ClassLiteral(_) - | Type::IntLiteral(_) - | Type::BooleanLiteral(_) - | Type::LiteralString - | Type::StringLiteral(_) - | Type::BytesLiteral(_) - | Type::EnumLiteral(_) - | Type::BoundSuper(_) - | Type::SpecialForm(_) - | Type::KnownInstance(_) - | Type::NonInferableTypeVar(_) - | Type::TypeVar(_) - | Type::FunctionLiteral(_) - | Type::KnownBoundMethod(_) - | Type::BoundMethod(_) - | Type::Dynamic(_) - | Type::TypeAlias(_) => false, - } + any_over_type(db, self, &|ty| ty == div, false) } } @@ -9864,16 +9776,6 @@ impl<'db> CallableType<'db> { .is_equivalent_to_impl(db, other.signatures(db), visitor) }) } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.signatures(db) - .has_divergent_type_impl(db, div, visitor) - } } /// Represents a specific instance of a bound method type for a builtin class. @@ -10953,17 +10855,6 @@ impl<'db> UnionType<'db> { ConstraintSet::from(sorted_self == other.normalized(db)) } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.elements(db) - .iter() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - } } #[salsa::interned(debug, heap_size=IntersectionType::heap_size)] @@ -11202,21 +11093,6 @@ impl<'db> IntersectionType<'db> { ruff_memory_usage::order_set_heap_size(positive) + ruff_memory_usage::order_set_heap_size(negative) } - - fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.positive(db) - .iter() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - || self - .negative(db) - .iter() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - } } /// # Ordering diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 7d7ea5ce4cd37..bc107019e2ae1 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -26,12 +26,12 @@ use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, - DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, - HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, - MaterializationKind, NormalizedVisitor, PropertyInstanceType, RecursiveTypeNormalizedVisitor, - StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, - TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, UnionBuilder, - VarianceInferable, declaration_type, determine_upper_bound, infer_definition_types, + DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, + IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, + NormalizedVisitor, PropertyInstanceType, RecursiveTypeNormalizedVisitor, StringLiteralType, + TypeAliasType, TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, + TypeVarInstance, TypeVarKind, TypedDictParams, UnionBuilder, VarianceInferable, + declaration_type, determine_upper_bound, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -1223,20 +1223,6 @@ impl<'db> ClassType<'db> { } } - pub(super) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - match self { - ClassType::NonGeneric(_) => false, - ClassType::Generic(generic) => generic - .specialization(db) - .has_divergent_type_impl(db, div, visitor), - } - } - pub(super) fn is_protocol(self, db: &'db dyn Db) -> bool { self.class_literal(db).0.is_protocol(db) } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 4a972c208d7aa..918ac2c758030 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -16,11 +16,11 @@ use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::{ - ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, - HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, - KnownInstanceType, MaterializationKind, NormalizedVisitor, RecursiveTypeNormalizedVisitor, - Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, - TypeVarVariance, UnionType, binding_type, declaration_type, infer_definition_types, + ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, + IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, + RecursiveTypeNormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, + TypeVarInstance, TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type, + infer_definition_types, }; use crate::{Db, FxOrderSet}; @@ -988,20 +988,6 @@ impl<'db> Specialization<'db> { // A tuple's specialization will include all of its element types, so we don't need to also // look in `self.tuple`. } - - pub(crate) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.types(db) - .iter() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - || self - .tuple_inner(db) - .is_some_and(|tuple| tuple.has_divergent_type_impl(db, div, visitor)) - } } /// A mapping between type variables and types. diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index fdfdd6e272978..e634c1eac174f 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -14,9 +14,8 @@ use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ ApplyTypeMappingVisitor, ClassBase, ClassLiteral, FindLegacyTypeVarsVisitor, - HasDivergentTypeVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, - NormalizedVisitor, RecursiveTypeNormalizedVisitor, TypeMapping, TypeRelation, - VarianceInferable, + HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, + RecursiveTypeNormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -750,18 +749,6 @@ impl<'db> ProtocolInstanceType<'db> { pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { self.inner.interface(db) } - - pub(super) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.inner - .interface(db) - .members(db) - .any(|member| member.ty().has_divergent_type_impl(db, div, visitor)) - } } impl<'db> VarianceInferable<'db> for ProtocolInstanceType<'db> { diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 49bbc7a41d6af..711ec9b23f3b9 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -21,9 +21,8 @@ use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, - HasDivergentTypeVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, - MaterializationKind, NormalizedVisitor, RecursiveTypeNormalizedVisitor, TypeMapping, - TypeRelation, VarianceInferable, todo_type, + HasRelationToVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, NormalizedVisitor, + RecursiveTypeNormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -238,17 +237,6 @@ impl<'db> CallableSignature<'db> { } } } - - pub(super) fn has_divergent_type_impl( - &self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.overloads - .iter() - .any(|signature| signature.has_divergent_type_impl(db, div, visitor)) - } } impl<'a, 'db> IntoIterator for &'a CallableSignature<'db> { @@ -1074,17 +1062,6 @@ impl<'db> Signature<'db> { pub(crate) fn with_definition(self, definition: Option>) -> Self { Self { definition, ..self } } - - fn has_divergent_type_impl( - &self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.return_ty - .is_some_and(|return_ty| return_ty.has_divergent_type_impl(db, div, visitor)) - || self.parameters.has_divergent_type_impl(db, div, visitor) - } } impl<'db> VarianceInferable<'db> for &Signature<'db> { @@ -1431,19 +1408,6 @@ impl<'db> Parameters<'db> { .enumerate() .rfind(|(_, parameter)| parameter.is_keyword_variadic()) } - - fn has_divergent_type_impl( - &self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.iter().any(|parameter| { - parameter - .annotated_type() - .is_some_and(|ty| ty.has_divergent_type_impl(db, div, visitor)) - }) - } } impl<'db, 'a> IntoIterator for &'a Parameters<'db> { diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index d7aa20c4f16ee..29e906ef0ff2d 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -4,9 +4,9 @@ use crate::types::constraints::ConstraintSet; use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, DynamicType, - FindLegacyTypeVarsVisitor, HasDivergentTypeVisitor, HasRelationToVisitor, IsDisjointVisitor, - KnownClass, MaterializationKind, MemberLookupPolicy, NormalizedVisitor, - RecursiveTypeNormalizedVisitor, SpecialFormType, Type, TypeMapping, TypeRelation, + FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, KnownClass, + MaterializationKind, MemberLookupPolicy, NormalizedVisitor, RecursiveTypeNormalizedVisitor, + SpecialFormType, Type, TypeMapping, TypeRelation, }; use crate::{Db, FxOrderSet}; @@ -203,19 +203,6 @@ impl<'db> SubclassOfType<'db> { .into_class() .is_some_and(|class| class.class_literal(db).0.is_typed_dict(db)) } - - pub(super) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - match self.subclass_of { - SubclassOfInner::Dynamic(d @ DynamicType::Divergent(_)) => Type::Dynamic(d) == div, - SubclassOfInner::Dynamic(_) => false, - SubclassOfInner::Class(class) => class.has_divergent_type_impl(db, div, visitor), - } - } } impl<'db> VarianceInferable<'db> for SubclassOfType<'db> { diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index dc4595365b847..48094c005e9de 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -22,6 +22,7 @@ use std::hash::Hash; use itertools::{Either, EitherOrBoth, Itertools}; use crate::semantic_index::definition::Definition; +use crate::types::Truthiness; use crate::types::class::{ClassType, KnownClass}; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::{ @@ -29,7 +30,6 @@ use crate::types::{ IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, RecursiveTypeNormalizedVisitor, Type, TypeMapping, TypeRelation, UnionBuilder, UnionType, }; -use crate::types::{HasDivergentTypeVisitor, Truthiness}; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; use crate::{Db, FxOrderSet, Program}; @@ -285,17 +285,6 @@ impl<'db> TupleType<'db> { pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool { self.tuple(db).is_single_valued(db) } - - pub(super) fn has_divergent_type_impl( - self, - db: &'db dyn Db, - div: Type<'db>, - visitor: &HasDivergentTypeVisitor<'db>, - ) -> bool { - self.tuple(db) - .all_elements() - .any(|ty| ty.has_divergent_type_impl(db, div, visitor)) - } } fn to_class_type_cycle_recover<'db>( From 112599f96f4ec671d416f302b7a5b5e7eb3785de Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 25 Sep 2025 02:19:50 +0900 Subject: [PATCH 090/105] revert unnecessary changes for the purpose of this PR --- crates/ty_python_semantic/src/types.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 18b760b4a2c33..b7f57642d0317 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -8133,9 +8133,13 @@ impl<'db> BoundTypeVarInstance<'db> { match self.typevar(db).explicit_variance(db) { Some(explicit_variance) => explicit_variance.compose(polarity), None => match self.binding_context(db) { - BindingContext::Definition(definition) => binding_type(db, definition) - .with_polarity(polarity) - .variance_of(db, self), + BindingContext::Definition(definition) => { + let type_inference = infer_definition_types(db, definition); + type_inference + .binding_type(definition) + .with_polarity(polarity) + .variance_of(db, self) + } BindingContext::Synthetic => TypeVarVariance::Invariant, }, } From 0e63890b19e11d59d6193500319ad3a55ede603b Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 25 Sep 2025 02:28:59 +0900 Subject: [PATCH 091/105] use `CycleRecoveryAction::Fallback` even within `infer_scope_types` --- crates/ty_python_semantic/src/types/infer.rs | 21 +++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index c5e62aa02b6fc..ac646ec4c65b3 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -86,10 +86,14 @@ pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Sc fn scope_cycle_recover<'db>( _db: &'db dyn Db, _value: &ScopeInference<'db>, - _count: u32, - _scope: ScopeId<'db>, + count: u32, + scope: ScopeId<'db>, ) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate + if count == ITERATIONS_BEFORE_FALLBACK { + salsa::CycleRecoveryAction::Fallback(ScopeInference::cycle_fallback(scope)) + } else { + salsa::CycleRecoveryAction::Iterate + } } fn scope_cycle_initial<'db>(_db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { @@ -595,6 +599,17 @@ impl<'db> ScopeInference<'db> { } } + fn cycle_fallback(scope: ScopeId<'db>) -> Self { + Self { + scope, + extra: Some(Box::new(ScopeInferenceExtra { + cycle_recovery: Some(CycleRecovery::Divergent(scope)), + ..ScopeInferenceExtra::default() + })), + expressions: FxHashMap::default(), + } + } + pub(crate) fn diagnostics(&self) -> Option<&TypeCheckDiagnostics> { self.extra.as_deref().map(|extra| &extra.diagnostics) } From 92f3e4cbeaa60d3e7d41040fcb2e5d34ed5da1ed Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 25 Sep 2025 14:11:36 +0900 Subject: [PATCH 092/105] lower the limit for `MAX_UNION_LITERALS` --- crates/ty_python_semantic/resources/corpus/divergent.py | 7 ++----- crates/ty_python_semantic/src/types/builder.rs | 4 +++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py index 6632a7d834bdf..1ef6726cf2563 100644 --- a/crates/ty_python_semantic/resources/corpus/divergent.py +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -57,9 +57,6 @@ def unwrap(value): else: raise TypeError() -# TODO: If this is commented out, that is, if `infer_scope_types` is called before `infer_return_type`, it will panic. -reveal_type(unwrap(Foo())) - def descent(x: int, y: int): if x > y: y, x = descent(y, x) @@ -71,5 +68,5 @@ def descent(x: int, y: int): else: return descent(x-1, y-1) -# TODO: If this is commented out, that is, if `infer_scope_types` is called before `infer_return_type`, it will panic. -reveal_type(descent(5, 3)) +def count_set_bits(n): + return 1 + count_set_bits(n & n - 1) if n else 0 diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index cc1c35790a90b..a8e24bc57f08d 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -206,7 +206,9 @@ enum ReduceResult<'db> { // // For now (until we solve https://github.com/astral-sh/ty/issues/957), keep this number // below 200, which is the salsa fixpoint iteration limit. -const MAX_UNION_LITERALS: usize = 199; +// +// If we can handle fixed-point iterations properly, we should be able to reset the limit to 199. +const MAX_UNION_LITERALS: usize = 188; pub(crate) struct UnionBuilder<'db> { elements: Vec>, From 706eff81951949f0c9a4363fe7242b8b5e7cca98 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 25 Sep 2025 14:43:03 +0900 Subject: [PATCH 093/105] Update ty_walltime.rs --- crates/ruff_benchmark/benches/ty_walltime.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index 154f75511e81a..bf0ce69d5f870 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -199,7 +199,8 @@ static SYMPY: std::sync::LazyLock> = std::sync::LazyLock::new max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 75000, + // TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably. + 70000, ) }); From 505dcc81ac1961ed18a25533be29de39c24720be Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Mon, 15 Dec 2025 19:53:06 +0900 Subject: [PATCH 094/105] revert unnecessary changes --- crates/ruff_benchmark/benches/ty_walltime.rs | 28 +++++++++---------- crates/ty_python_semantic/src/types.rs | 4 +-- crates/ty_python_semantic/src/types/infer.rs | 15 ---------- .../src/types/infer/builder.rs | 4 +-- .../types/infer/builder/type_expression.rs | 2 +- .../src/types/protocol_class.rs | 2 +- 6 files changed, 19 insertions(+), 36 deletions(-) diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index e92d5295e114a..f3ef925e19360 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -184,21 +184,19 @@ static PYDANTIC: Benchmark = Benchmark::new( 7000, ); -static SYMPY: std::sync::LazyLock> = std::sync::LazyLock::new(|| { - Benchmark::new( - RealWorldProject { - name: "sympy", - repository: "https://github.com/sympy/sympy", - commit: "22fc107a94eaabc4f6eb31470b39db65abb7a394", - paths: &["sympy"], - dependencies: &["mpmath"], - max_dep_date: "2025-06-17", - python_version: PythonVersion::PY312, - }, - // TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably. - 70000, - ) -}); +static SYMPY: Benchmark = Benchmark::new( + RealWorldProject { + name: "sympy", + repository: "https://github.com/sympy/sympy", + commit: "22fc107a94eaabc4f6eb31470b39db65abb7a394", + paths: &["sympy"], + dependencies: &["mpmath"], + max_dep_date: "2025-06-17", + python_version: PythonVersion::PY312, + }, + // TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably. + 70000, +); static TANJUN: Benchmark = Benchmark::new( RealWorldProject { diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index b0f0b93e95824..6dfe53611b440 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -2,8 +2,6 @@ use compact_str::{CompactString, ToCompactString}; use infer::nearest_enclosing_class; use itertools::{Either, Itertools}; use ruff_diagnostics::{Edit, Fix}; -use ruff_python_ast::name::Name; -use smallvec::{SmallVec, smallvec}; use std::borrow::Cow; use std::time::Duration; @@ -17,7 +15,9 @@ use ruff_db::diagnostic::{Annotation, Diagnostic, Span, SubDiagnostic, SubDiagno use ruff_db::files::File; use ruff_db::parsed::parsed_module; use ruff_python_ast as ast; +use ruff_python_ast::name::Name; use ruff_text_size::{Ranged, TextRange}; +use smallvec::{SmallVec, smallvec}; use type_ordering::union_or_intersection_elements_ordering; diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 2f83063ed9380..e565b84622d4f 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -193,7 +193,6 @@ pub(crate) fn infer_expression_types<'db>( infer_expression_types_impl(db, InferExpression::new(db, expression, tcx)) } -/// When using types ​​in [`ExpressionInference`], you must use [`ExpressionInference::cycle_recovery`]. #[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] pub(super) fn infer_expression_types_impl<'db>( db: &'db dyn Db, @@ -824,13 +823,6 @@ impl<'db> DefinitionInference<'db> { self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } - /// When using `DefinitionInference` during type inference, - /// use this method to get the cycle recovery type so that divergent types are propagated. - #[allow(unused)] - pub(super) fn cycle_recovery(&self) -> Option> { - self.fallback_type() - } - pub(crate) fn undecorated_type(&self) -> Option> { self.extra.as_ref().and_then(|extra| extra.undecorated_type) } @@ -933,13 +925,6 @@ impl<'db> ExpressionInference<'db> { self.extra.as_ref().and_then(|extra| extra.cycle_recovery) } - /// When using `ExpressionInference` during type inference, - /// use this method to get the cycle recovery type so that divergent types are propagated. - #[allow(unused)] - pub(super) fn cycle_recovery(&self) -> Option> { - self.fallback_type() - } - /// Returns true if all places in this expression are definitely bound. pub(crate) fn all_places_definitely_bound(&self) -> bool { self.extra diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 18c273895e38b..815a546bf3299 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -130,7 +130,7 @@ enum IntersectionOn { } #[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub(super) struct TypeAndRange<'db> { +struct TypeAndRange<'db> { ty: Type<'db>, range: TextRange, } @@ -8432,7 +8432,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Special handling for `TypedDict` method calls if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() { - let value_type = self.expression_type(value.as_ref()); + let value_type = self.expression_type(value); if let Type::TypedDict(typed_dict_ty) = value_type { if matches!(attr.id.as_str(), "pop" | "setdefault") && !arguments.args.is_empty() { // Validate the key argument for `TypedDict` methods diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index 823dfc0d73fe8..d4c5701f1d2a9 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -580,7 +580,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // we do not store types for sub-expressions. Re-infer the type here. builder.infer_expression(value, TypeContext::default()) } else { - builder.expression_type(value.as_ref()) + builder.expression_type(value) }; value_ty == Type::SpecialForm(SpecialFormType::Unpack) diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 130055cf3b488..5d3719b8d9865 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -682,7 +682,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> { self.qualifiers } - pub(super) fn ty(&self) -> Type<'db> { + fn ty(&self) -> Type<'db> { match &self.kind { ProtocolMemberKind::Method(callable) => Type::Callable(*callable), ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property), From 9fd7a6f9d930a8293b3ffe906e810715365f030f Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 18 Dec 2025 00:41:32 +0900 Subject: [PATCH 095/105] fix too-many-cycle panics --- .../resources/mdtest/function/return_type.md | 8 +++--- .../ty_python_semantic/src/types/builder.rs | 27 +++++++++++++++++-- crates/ty_python_semantic/src/types/infer.rs | 19 +++++++++++++ 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 3c6fbc839e4a8..3c22efb826997 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -429,7 +429,7 @@ def list_int(x: int): return list1(x) # TODO: should be `list[int]` -reveal_type(list_int(1)) # revealed: list[Divergent] | list[int] +reveal_type(list_int(1)) # revealed: list[Divergent] | list[int] | list[Divergent] def tuple_obj(cond: bool): if cond: @@ -454,7 +454,7 @@ def nested_scope(): return nested_scope() return inner() -reveal_type(nested_scope()) # revealed: Never +reveal_type(nested_scope()) # revealed: Divergent def eager_nested_scope(): class A: @@ -469,7 +469,7 @@ class C: return D() class D(C): - # TODO invalid override error + # error: [invalid-method-override] def flip(self) -> "C": return C() @@ -582,7 +582,7 @@ reveal_type(C().h(1)) # revealed: Literal[1] reveal_type(D().h(1)) # revealed: Literal[2] | Unknown reveal_type(C().h(True)) # revealed: Literal[True] reveal_type(D().h(True)) # revealed: Literal[2] | Unknown -reveal_type(C().i(1)) # revealed: list[Literal[1]] +reveal_type(C().i(1)) # revealed: list[int] # TODO: better type for list elements reveal_type(D().i(1)) # revealed: list[Unknown | int] | list[Unknown] diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index ae2c21759b21f..cc76d643d5c63 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -649,8 +649,31 @@ impl<'db> UnionBuilder<'db> { types.sort_unstable_by(|l, r| union_or_intersection_elements_ordering(self.db, l, r)); } match types.len() { - 0 => None, - 1 => Some(types[0]), + 0 => { + if self.recursively_defined.is_yes() { + // See the comment below for why this is necessary. + Some(Type::Union(UnionType::new( + self.db, + Box::from([Type::Never]), + self.recursively_defined, + ))) + } else { + None + } + } + 1 => { + if self.recursively_defined.is_yes() { + // We need to mark this type with a "recursively-defined" marker, so build it as a single-element recursively-defined union type. + // This will only happen very early in the fixed-point iteration, and a single-element union type should never appear in the final converged type. + Some(Type::Union(UnionType::new( + self.db, + Box::from([types[0]]), + self.recursively_defined, + ))) + } else { + Some(types[0]) + } + } _ => Some(Type::Union(UnionType::new( self.db, types.into_boxed_slice(), diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index e565b84622d4f..178d15df92d48 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -48,6 +48,7 @@ use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::scope::ScopeId; use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map}; +use crate::types::builder::RecursivelyDefined; use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::function::FunctionType; use crate::types::generics::Specialization; @@ -573,6 +574,23 @@ impl<'db> ScopeInference<'db> { *ty = ty.cycle_normalized(db, previous_ty, cycle); } + if let Some(extra) = &mut self.extra { + for (i, return_ty) in extra.return_types.iter_mut().enumerate() { + match previous_inference.extra.as_ref() { + Some(previous_extra) => { + if let Some(previous_return_ty) = previous_extra.return_types.get(i) { + *return_ty = return_ty.cycle_normalized(db, *previous_return_ty, cycle); + } else { + *return_ty = return_ty.recursive_type_normalized(db, cycle); + } + } + None => { + *return_ty = return_ty.recursive_type_normalized(db, cycle); + } + } + } + } + self } @@ -627,6 +645,7 @@ impl<'db> ScopeInference<'db> { let mut union = UnionBuilder::new(db); if let Some(cycle_recovery) = self.fallback_type() { + union = union.recursively_defined(RecursivelyDefined::Yes); union = union.add(cycle_recovery); } From d02c834fbc4cc2531319b7104884701e40bd8369 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 18 Dec 2025 00:53:59 +0900 Subject: [PATCH 096/105] refactor --- crates/ty_python_semantic/src/types.rs | 42 +++++++------------ .../ty_python_semantic/src/types/function.rs | 2 +- crates/ty_python_semantic/src/types/infer.rs | 15 +------ 3 files changed, 17 insertions(+), 42 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 6dfe53611b440..5d84f095f7497 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -10704,7 +10704,6 @@ pub struct UnionTypeInstance<'db> { /// ``. For `Union[int, str]`, this field is `None`, as we infer /// the elements as type expressions. Use `value_expression_types` to get the /// corresponding value expression types. - #[expect(clippy::ref_option)] #[returns(ref)] _value_expr_types: Option]>>, @@ -12158,15 +12157,22 @@ impl<'db> BoundMethodType<'db> { } /// Infers this method scope's types and returns the inferred return type. - #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { - let scope = self - .function(db) - .literal(db) - .last_definition(db) - .body_scope(db); - let inference = infer_scope_types(db, scope); - inference.infer_return_type(db, scope, Type::BoundMethod(self)) + let inferred_return_type = self.function(db).infer_return_type(db); + // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. + // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. + // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. + if !self.is_final(db) { + UnionType::from_elements( + db, + [ + inferred_return_type, + self.base_return_type(db).unwrap_or(Type::unknown()), + ], + ) + } else { + inferred_return_type + } } #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] @@ -12305,24 +12311,6 @@ impl<'db> BoundMethodType<'db> { } } -fn return_type_cycle_recover<'db>( - db: &'db dyn Db, - cycle: &salsa::Cycle, - previous_return_type: &Type<'db>, - return_type: Type<'db>, - _self: BoundMethodType<'db>, -) -> Type<'db> { - return_type.cycle_normalized(db, *previous_return_type, cycle) -} - -fn return_type_cycle_initial<'db>( - _db: &'db dyn Db, - id: salsa::Id, - _method: BoundMethodType<'db>, -) -> Type<'db> { - Type::divergent(id) -} - #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, get_size2::GetSize)] pub enum CallableTypeKind { /// Represents regular callable objects. diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index b05aa95b82eb6..3652f9c1488b2 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -1203,7 +1203,7 @@ impl<'db> FunctionType<'db> { pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { let scope = self.literal(db).last_definition(db).body_scope(db); let inference = infer_scope_types(db, scope); - inference.infer_return_type(db, scope, Type::FunctionLiteral(self)) + inference.infer_return_type(db, scope) } } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 178d15df92d48..530465bc2566d 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -631,12 +631,7 @@ impl<'db> ScopeInference<'db> { /// or `None` if the region is not a function body. /// In the case of methods, the return type of the superclass method is further unioned. /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. - pub(crate) fn infer_return_type( - &self, - db: &'db dyn Db, - scope: ScopeId<'db>, - callee_ty: Type<'db>, - ) -> Type<'db> { + pub(crate) fn infer_return_type(&self, db: &'db dyn Db, scope: ScopeId<'db>) -> Type<'db> { // TODO: coroutine function type inference // TODO: generator function type inference if scope.is_coroutine_function(db) || scope.is_generator_function(db) { @@ -661,14 +656,6 @@ impl<'db> ScopeInference<'db> { if use_def.can_implicitly_return_none(db) { union = union.add(Type::none(db)); } - if let Type::BoundMethod(method_ty) = callee_ty { - // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. - // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. - // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. - if !method_ty.is_final(db) { - union = union.add(method_ty.base_return_type(db).unwrap_or(Type::unknown())); - } - } union.build() } From c5bd2cde5557bf5b2264d05292e66cc40cea0d22 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 18 Dec 2025 02:03:54 +0900 Subject: [PATCH 097/105] increase max_diagnostics --- crates/ruff_benchmark/benches/ty.rs | 2 +- crates/ruff_benchmark/benches/ty_walltime.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/ruff_benchmark/benches/ty.rs b/crates/ruff_benchmark/benches/ty.rs index 9ae6e9c40bd6a..57f6eab5a8dad 100644 --- a/crates/ruff_benchmark/benches/ty.rs +++ b/crates/ruff_benchmark/benches/ty.rs @@ -667,7 +667,7 @@ fn attrs(criterion: &mut Criterion) { max_dep_date: "2025-06-17", python_version: PythonVersion::PY313, }, - 120, + 136, ); bench_project(&benchmark, criterion); diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index 70afcaeb9f5ca..12abe9b11fcba 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -163,7 +163,7 @@ static PANDAS: Benchmark = Benchmark::new( max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 4000, + 5430, ); static PYDANTIC: Benchmark = Benchmark::new( From 1d1afaa517836c1974ce4d780dcee47a7c258387 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 18 Dec 2025 20:14:45 +0900 Subject: [PATCH 098/105] fix `Type::cycle_normalized` --- .../resources/corpus/divergent.py | 21 +++++++++++++++++++ .../resources/mdtest/function/return_type.md | 2 +- crates/ty_python_semantic/src/types.rs | 3 ++- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py index 1ef6726cf2563..2bacad7f72b9e 100644 --- a/crates/ty_python_semantic/resources/corpus/divergent.py +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -70,3 +70,24 @@ def descent(x: int, y: int): def count_set_bits(n): return 1 + count_set_bits(n & n - 1) if n else 0 + +class Literal: + def __invert__(self): + return Literal() + +class OR: + def __invert__(self): + return AND() + +class AND: + def __invert__(self): + return OR() + +def to_NNF(cond): + if cond: + return ~to_NNF(cond) + if cond: + return OR() + if cond: + return AND() + return Literal() diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 47aa1097234f4..dd071fa5efc11 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -429,7 +429,7 @@ def list_int(x: int): return list1(x) # TODO: should be `list[int]` -reveal_type(list_int(1)) # revealed: list[Divergent] | list[int] | list[Divergent] +reveal_type(list_int(1)) # revealed: list[Divergent] | list[Divergent] | list[int] def tuple_obj(cond: bool): if cond: diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 0d96ff5be2aee..6f7848daf29e4 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -966,7 +966,8 @@ impl<'db> Type<'db> { if has_divergent_type_in_cycle(previous) && !has_divergent_type_in_cycle(self) { self } else { - UnionType::from_elements_cycle_recovery(db, [self, previous]) + // The current type is unioned to the previous type. Unioning in the reverse order can cause the fixed-point iterations to converge slowly or even fail. + UnionType::from_elements_cycle_recovery(db, [previous, self]) } } } From 7093727fb505cd284163f9afd8c2c5b47b1f11ed Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Thu, 18 Dec 2025 20:34:40 +0900 Subject: [PATCH 099/105] Update ty_walltime.rs --- crates/ruff_benchmark/benches/ty_walltime.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index 12abe9b11fcba..3759c0febe0da 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -143,7 +143,7 @@ static FREQTRADE: Benchmark = Benchmark::new( max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 600, + 626, ); static PANDAS: Benchmark = Benchmark::new( @@ -195,7 +195,7 @@ static SYMPY: Benchmark = Benchmark::new( python_version: PythonVersion::PY312, }, // TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably. - 70000, + 58000, ); static TANJUN: Benchmark = Benchmark::new( From af8b75b3fd3a4c6cd59e0c07348d8561171abc87 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 19 Dec 2025 02:15:24 +0900 Subject: [PATCH 100/105] add `IntersectionBuilder::intersection_hashes` --- .../ty_python_semantic/src/types/builder.rs | 59 +++++++++++++++++-- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 4aa0f9f4a4e19..fc48dcf05bfef 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -44,7 +44,7 @@ use crate::types::{ TypeVarBoundOrConstraints, UnionType, }; use crate::{Db, FxOrderSet}; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashSet, FxHasher}; use smallvec::SmallVec; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -691,16 +691,22 @@ pub(crate) struct IntersectionBuilder<'db> { // but if a union is added to the intersection, we'll distribute ourselves over that union and // create a union of intersections. intersections: Vec>, + /// Stores hash values ​​of `intersections` to prevent adding identical `InnerIntersectionBuilder`s. + intersection_hashes: FxHashSet, order_elements: bool, db: &'db dyn Db, } impl<'db> IntersectionBuilder<'db> { pub(crate) fn new(db: &'db dyn Db) -> Self { + let inner = InnerIntersectionBuilder::default(); + let mut intersection_hashes = FxHashSet::default(); + intersection_hashes.insert(inner.hash_value()); Self { db, order_elements: false, - intersections: vec![InnerIntersectionBuilder::default()], + intersections: vec![inner], + intersection_hashes, } } @@ -709,6 +715,22 @@ impl<'db> IntersectionBuilder<'db> { db, order_elements: false, intersections: vec![], + intersection_hashes: FxHashSet::default(), + } + } + + fn update_hashes(&mut self) { + self.intersection_hashes.clear(); + for intersection in &self.intersections { + self.intersection_hashes.insert(intersection.hash_value()); + } + } + + fn extend(&mut self, other: IntersectionBuilder<'db>) { + for intersection in other.intersections { + if self.intersection_hashes.insert(intersection.hash_value()) { + self.intersections.push(intersection); + } } } @@ -728,6 +750,7 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.positive.insert(ty); } + self.update_hashes(); return self; } seen_aliases.push(ty); @@ -748,7 +771,7 @@ impl<'db> IntersectionBuilder<'db> { .iter() .map(|elem| self.clone().add_positive_impl(*elem, seen_aliases)) .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.intersections.extend(sub.intersections); + builder.extend(sub); builder }) } @@ -800,6 +823,7 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.add_positive(self.db, ty); } + self.update_hashes(); self } } @@ -809,6 +833,7 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.add_positive(self.db, ty); } + self.update_hashes(); self } } @@ -838,6 +863,7 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.negative.insert(ty); } + self.update_hashes(); return self; } seen_aliases.push(ty); @@ -879,7 +905,7 @@ impl<'db> IntersectionBuilder<'db> { positive_side.chain(negative_side).fold( IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.intersections.extend(sub.intersections); + builder.extend(sub); builder }, ) @@ -905,6 +931,7 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.add_negative(self.db, ty); } + self.update_hashes(); self } } @@ -1318,6 +1345,30 @@ impl<'db> InnerIntersectionBuilder<'db> { } } } + + /// An element-order-independent hash value, unrelated to the hash value of the actual `IntersectionType`. + fn hash_value(&self) -> u64 { + use std::hash::{Hash, Hasher}; + + let mut sum = 0u64; + for ty in &self.positive { + sum = sum.wrapping_add({ + let mut hasher = FxHasher::default(); + ty.hash(&mut hasher); + hasher.finish() + }); + } + for ty in &self.negative { + sum = sum.wrapping_add({ + let mut hasher = FxHasher::default(); + ty.hash(&mut hasher); + // We add a constant to the negative types to differentiate them from positive types. + 1.hash(&mut hasher); + hasher.finish() + }); + } + sum + } } #[cfg(test)] From bbad56bfd33dcc8828ec56feb38496cd30360c73 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 19 Dec 2025 09:43:43 +0900 Subject: [PATCH 101/105] Revert "add `IntersectionBuilder::intersection_hashes`" This reverts commit af8b75b3fd3a4c6cd59e0c07348d8561171abc87. --- .../ty_python_semantic/src/types/builder.rs | 59 ++----------------- 1 file changed, 4 insertions(+), 55 deletions(-) diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index fc48dcf05bfef..4aa0f9f4a4e19 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -44,7 +44,7 @@ use crate::types::{ TypeVarBoundOrConstraints, UnionType, }; use crate::{Db, FxOrderSet}; -use rustc_hash::{FxHashSet, FxHasher}; +use rustc_hash::FxHashSet; use smallvec::SmallVec; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -691,22 +691,16 @@ pub(crate) struct IntersectionBuilder<'db> { // but if a union is added to the intersection, we'll distribute ourselves over that union and // create a union of intersections. intersections: Vec>, - /// Stores hash values ​​of `intersections` to prevent adding identical `InnerIntersectionBuilder`s. - intersection_hashes: FxHashSet, order_elements: bool, db: &'db dyn Db, } impl<'db> IntersectionBuilder<'db> { pub(crate) fn new(db: &'db dyn Db) -> Self { - let inner = InnerIntersectionBuilder::default(); - let mut intersection_hashes = FxHashSet::default(); - intersection_hashes.insert(inner.hash_value()); Self { db, order_elements: false, - intersections: vec![inner], - intersection_hashes, + intersections: vec![InnerIntersectionBuilder::default()], } } @@ -715,22 +709,6 @@ impl<'db> IntersectionBuilder<'db> { db, order_elements: false, intersections: vec![], - intersection_hashes: FxHashSet::default(), - } - } - - fn update_hashes(&mut self) { - self.intersection_hashes.clear(); - for intersection in &self.intersections { - self.intersection_hashes.insert(intersection.hash_value()); - } - } - - fn extend(&mut self, other: IntersectionBuilder<'db>) { - for intersection in other.intersections { - if self.intersection_hashes.insert(intersection.hash_value()) { - self.intersections.push(intersection); - } } } @@ -750,7 +728,6 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.positive.insert(ty); } - self.update_hashes(); return self; } seen_aliases.push(ty); @@ -771,7 +748,7 @@ impl<'db> IntersectionBuilder<'db> { .iter() .map(|elem| self.clone().add_positive_impl(*elem, seen_aliases)) .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.extend(sub); + builder.intersections.extend(sub.intersections); builder }) } @@ -823,7 +800,6 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.add_positive(self.db, ty); } - self.update_hashes(); self } } @@ -833,7 +809,6 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.add_positive(self.db, ty); } - self.update_hashes(); self } } @@ -863,7 +838,6 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.negative.insert(ty); } - self.update_hashes(); return self; } seen_aliases.push(ty); @@ -905,7 +879,7 @@ impl<'db> IntersectionBuilder<'db> { positive_side.chain(negative_side).fold( IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.extend(sub); + builder.intersections.extend(sub.intersections); builder }, ) @@ -931,7 +905,6 @@ impl<'db> IntersectionBuilder<'db> { for inner in &mut self.intersections { inner.add_negative(self.db, ty); } - self.update_hashes(); self } } @@ -1345,30 +1318,6 @@ impl<'db> InnerIntersectionBuilder<'db> { } } } - - /// An element-order-independent hash value, unrelated to the hash value of the actual `IntersectionType`. - fn hash_value(&self) -> u64 { - use std::hash::{Hash, Hasher}; - - let mut sum = 0u64; - for ty in &self.positive { - sum = sum.wrapping_add({ - let mut hasher = FxHasher::default(); - ty.hash(&mut hasher); - hasher.finish() - }); - } - for ty in &self.negative { - sum = sum.wrapping_add({ - let mut hasher = FxHasher::default(); - ty.hash(&mut hasher); - // We add a constant to the negative types to differentiate them from positive types. - 1.hash(&mut hasher); - hasher.finish() - }); - } - sum - } } #[cfg(test)] From cc06d3e0d6d0eb34fb3207e3c0c2db15c856194f Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 19 Dec 2025 12:05:48 +0900 Subject: [PATCH 102/105] simply filter `InnerIntersectionBuilder`s using equality check --- crates/ty_python_semantic/src/types/builder.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index 4aa0f9f4a4e19..5bda8a1e5aed1 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -712,6 +712,14 @@ impl<'db> IntersectionBuilder<'db> { } } + fn extend(&mut self, sub: Self) { + for inner in sub.intersections { + if !self.intersections.contains(&inner) { + self.intersections.push(inner); + } + } + } + pub(crate) fn add_positive(self, ty: Type<'db>) -> Self { self.add_positive_impl(ty, &mut vec![]) } @@ -748,7 +756,7 @@ impl<'db> IntersectionBuilder<'db> { .iter() .map(|elem| self.clone().add_positive_impl(*elem, seen_aliases)) .fold(IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.intersections.extend(sub.intersections); + builder.extend(sub); builder }) } @@ -879,7 +887,7 @@ impl<'db> IntersectionBuilder<'db> { positive_side.chain(negative_side).fold( IntersectionBuilder::empty(self.db), |mut builder, sub| { - builder.intersections.extend(sub.intersections); + builder.extend(sub); builder }, ) @@ -939,7 +947,7 @@ impl<'db> IntersectionBuilder<'db> { } } -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq, Eq)] struct InnerIntersectionBuilder<'db> { positive: FxOrderSet>, negative: FxOrderSet>, From 274bae84c863f4a2846b0eca0d37036fe4d07ec0 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 20 Dec 2025 01:13:11 +0900 Subject: [PATCH 103/105] refactor --- crates/ty_python_semantic/src/types.rs | 4 ++-- crates/ty_python_semantic/src/types/infer.rs | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 7f56b905e833c..2c96efa25b44b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -12324,7 +12324,7 @@ impl<'db> BoundMethodType<'db> { Some(index.expect_single_definition(definition_scope.node(db).as_class()?)) } - pub(crate) fn is_final(self, db: &'db dyn Db) -> bool { + fn is_final(self, db: &'db dyn Db) -> bool { if self .function(db) .has_known_decorator(db, FunctionDecorators::FINAL) @@ -12342,7 +12342,7 @@ impl<'db> BoundMethodType<'db> { .any(|deco| deco == KnownFunction::Final) } - pub(super) fn base_return_type(self, db: &'db dyn Db) -> Option> { + fn base_return_type(self, db: &'db dyn Db) -> Option> { let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?; let name = self.function(db).name(db); diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 530465bc2566d..825abc6b50a96 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -639,9 +639,11 @@ impl<'db> ScopeInference<'db> { } let mut union = UnionBuilder::new(db); - if let Some(cycle_recovery) = self.fallback_type() { + // If this method is called early in the query cycle of `infer_scope_types`, `extra.return_types` will be empty. + // To properly propagate divergence, we must add `Divergent` to the union type. + if let Some(divergent) = self.fallback_type() { union = union.recursively_defined(RecursivelyDefined::Yes); - union = union.add(cycle_recovery); + union = union.add(divergent); } let Some(extra) = &self.extra else { From 0a80f9852225eff94ec2f9480c4a782615ea483a Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 23 Dec 2025 17:58:56 +0900 Subject: [PATCH 104/105] don't use return type context to specialize `__init__` type variables ref: https://github.com/astral-sh/ruff/pull/22068 --- .../resources/mdtest/assignment/annotations.md | 8 ++++++++ crates/ty_python_semantic/src/types.rs | 11 +++++++++++ crates/ty_python_semantic/src/types/call/bind.rs | 15 ++++++++++++--- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index d8595041fc828..4c79e5002718c 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -608,6 +608,14 @@ class X[T]: x1: X[int | None] = X() reveal_type(x1) # revealed: X[None] + +class Y[T]: + def __init__(self: Y[None]) -> None: ... + def pop(self) -> T: + raise NotImplementedError + +y1: Y[int | None] = Y() +reveal_type(y1) # revealed: Y[None] ``` ## Narrow generic unions diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 5a20564779de2..28ad0ae11b5db 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1390,6 +1390,13 @@ impl<'db> Type<'db> { matches!(self, Type::FunctionLiteral(..)) } + pub(crate) const fn as_bound_method(self) -> Option> { + match self { + Type::BoundMethod(bound_method_type) => Some(bound_method_type), + _ => None, + } + } + /// Detects types which are valid to appear inside a `Literal[…]` type annotation. pub(crate) fn is_literal_or_union_of_literals(&self, db: &'db dyn Db) -> bool { match self { @@ -12357,6 +12364,10 @@ impl<'db> BoundMethodType<'db> { .any(|deco| deco == KnownFunction::Final) } + fn is_init(self, db: &'db dyn Db) -> bool { + self.function(db).name(db) == "__init__" + } + fn base_return_type(self, db: &'db dyn Db) -> Option> { let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?; let name = self.function(db).name(db); diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 673d84c36d201..c2dcdd3627578 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3155,9 +3155,18 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { return None; } - // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an - // annotated assignment, to closer match the order of any unions written in the type annotation. - builder.infer(return_ty, call_expression_tcx).ok()?; + // For `__init__`, do not the use type context to widen the return type, + // as it can lead to argument assignability errors if the type variable + // is constrained by a narrower parameter type. + if self + .signature_type + .as_bound_method() + .is_none_or(|method| !method.is_init(self.db)) + { + // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an + // annotated assignment, to closer match the order of any unions written in the type annotation. + builder.infer(return_ty, call_expression_tcx).ok()?; + } // Otherwise, build the specialization again after inferring the complete type context. let specialization = builder From f5e5165910c9b52364badab313137e1871477cdb Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 31 Dec 2025 12:21:41 +0900 Subject: [PATCH 105/105] Update types.rs --- crates/ty_python_semantic/src/types.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index e839e29e11773..a72086f60e580 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -12529,7 +12529,7 @@ impl<'db> BoundMethodType<'db> { .nth(1) .and_then(class_base::ClassBase::into_class)?; let base_member = base.class_member(db, name, MemberLookupPolicy::default()); - if let Place::Defined(Type::FunctionLiteral(base_func), _, _) = base_member.place { + if let Place::Defined(Type::FunctionLiteral(base_func), _, _, _) = base_member.place { if let [signature] = base_func.signature(db).overloads.as_slice() { let unspecialized_return_ty = signature.return_ty.unwrap_or_else(|| { let base_method_ty =