From ebf5e738c2b28bc8bdd3e1f8d700947f8b69ac64 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Wed, 28 Dec 2022 17:37:00 +0100 Subject: [PATCH 01/20] Apply some runtime types, start fixing errors --- lib/sqlalchemy/sql/elements.py | 174 +++++++++++++++++++-------------- 1 file changed, 98 insertions(+), 76 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index a416b6ac096..7fb2fd94794 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-defs, allow-untyped-calls """Core SQL expression elements, including :class:`_expression.ClauseElement`, :class:`_expression.ColumnElement`, and derived classes. @@ -113,6 +112,20 @@ from ..engine.interfaces import CoreExecuteOptionsParameter from ..engine.interfaces import SchemaTranslateMapType from ..engine.result import Result + from re import Match + from .selectable import Select + from .sqltypes import NullType + from .sqltypes import ARRAY + from mypy_extensions import NoReturn + from ._py_util import cache_anon_map + from .annotation import AnnotatedBindParameter + from typing import Tuple + from .annotation import AnnotatedClauseList + from .annotation import AnnotatedBooleanClauseList + from .annotation import AnnotatedLabel + from sqlalchemy.util._py_collections import immutabledict + from ._py_util import prefix_anon_map + from .sqltypes import Integer _NUMERIC = Union[float, Decimal] _NUMBER = Union[float, int, Decimal] @@ -408,7 +421,11 @@ def _clone(self, **kw: Any) -> Self: c._is_clone_of = cc if cc is not None else self return c - def _negate_in_binary(self, negated_op, original_op): + def _negate_in_binary( + self: SelfClauseElement, + negated_op: Callable[..., Any], + original_op: Callable[..., Any] + ) -> SelfClauseElement: """a hook to allow the right side of a binary expression to respond to a negation of the binary expression. @@ -417,7 +434,7 @@ def _negate_in_binary(self, negated_op, original_op): """ return self - def _with_binary_element_type(self, type_): + def _with_binary_element_type(self, type_: Any) -> Any: """in the context of binary expression, convert the type of this object to the one given. @@ -427,7 +444,7 @@ def _with_binary_element_type(self, type_): return self @property - def _constructor(self): + def _constructor(self) -> Any: """return the 'constructor' for this ClauseElement. This is for the purposes for creating a new object of @@ -439,7 +456,7 @@ def _constructor(self): return self.__class__ @HasMemoized.memoized_attribute - def _cloned_set(self): + def _cloned_set(self) -> Any: """Return the set consisting all cloned ancestors of this ClauseElement. @@ -462,13 +479,13 @@ def _cloned_set(self): return s @property - def entity_namespace(self): + def entity_namespace(self) -> NoReturn: raise AttributeError( "This SQL expression has no entity namespace " "with which to filter from." ) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: d = self.__dict__.copy() d.pop("_is_clone_of", None) d.pop("_generate_cache_key", None) @@ -680,7 +697,7 @@ def _compile_w_cache( return compiled_sql, extracted_params, cache_hit - def __invert__(self): + def __invert__(self) -> Union[ColumnElement[bool], ClauseElement]: # undocumented element currently used by the ORM for # relationship.contains() if hasattr(self, "negation_clause"): @@ -693,10 +710,10 @@ def _negate(self) -> ClauseElement: assert isinstance(grouped, ColumnElement) return UnaryExpression(grouped, operator=operators.inv) - def __bool__(self): + def __bool__(self) -> None: raise TypeError("Boolean value of this clause is not defined") - def __repr__(self): + def __repr__(self) -> str: friendly = self.description if friendly is None: return object.__repr__(self) @@ -1858,7 +1875,7 @@ def _dedupe_anon_label_idx(self, idx: int) -> str: return self._dedupe_anon_tq_label_idx(idx) @property - def _proxy_key(self): + def _proxy_key(self) -> Union[None, quoted_name, str]: wce = self.wrapped_column_expression if not wce._is_text_clause: @@ -2008,7 +2025,12 @@ def __init__( else: self.type = type_ - def _with_value(self, value, maintain_key=False, required=NO_ARG): + def _with_value( + self, + value: int, + maintain_key: bool = False, + required: Union[bool, _NoArg] = NO_ARG + ) -> BindParameter: """Return a copy of this :class:`.BindParameter` with the given value set. """ @@ -2064,7 +2086,7 @@ def render_literal_execute(self) -> BindParameter[_T]: literal_execute=True, ) - def _negate_in_binary(self, negated_op, original_op): + def _negate_in_binary(self, negated_op: Callable, original_op: Callable) -> BindParameter: if self.expand_op is original_op: bind = self._clone() bind.expand_op = negated_op @@ -2072,7 +2094,7 @@ def _negate_in_binary(self, negated_op, original_op): else: return self - def _with_binary_element_type(self, type_): + def _with_binary_element_type(self, type_: Any) -> BindParameter: c = ClauseElement._clone(self) c.type = type_ return c @@ -2094,7 +2116,7 @@ def _clone(self, maintain_key: bool = False, **kw: Any) -> Self: ) return c - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key(self, anon_map: cache_anon_map, bindparams: Union[List[AnnotatedBindParameter], List[BindParameter]]) -> Tuple[str, type, Union[Tuple[type, Tuple[str, int]], Tuple[type]], Union[_truncated_label, quoted_name, str], bool]: _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False) if not _gen_cache_ok: @@ -2117,14 +2139,14 @@ def _gen_cache_key(self, anon_map, bindparams): self.literal_execute, ) - def _convert_to_unique(self): + def _convert_to_unique(self) -> None: if not self.unique: self.unique = True self.key = _anonymous_label.safe_construct( id(self), self._orig_key or "param", sanitize_key=True ) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: """execute a deferred value for serialization purposes.""" d = self.__dict__.copy() @@ -2142,7 +2164,7 @@ def __setstate__(self, state): ) self.__dict__.update(state) - def __repr__(self): + def __repr__(self) -> str: return "%s(%r, %r, type_=%r)" % ( self.__class__.__name__, self.key, @@ -2164,7 +2186,7 @@ class TypeClause(DQLDMLClauseElement): ("type", InternalTraversal.dp_type) ] - def __init__(self, type_): + def __init__(self, type_: Any) -> None: self.type = type_ @@ -2245,7 +2267,7 @@ def _is_star(self): def __init__(self, text: str): self._bindparams: Dict[str, BindParameter[Any]] = {} - def repl(m): + def repl(m: Match) -> str: self._bindparams[m.group(1)] = BindParameter(m.group(1)) return ":%s" % m.group(1) @@ -2539,7 +2561,7 @@ def comparator(self): # be using this method. return self.type.comparator_factory(self) # type: ignore - def self_group(self, against=None): + def self_group(self, against: Union[Callable, None, builtin_function_or_method] = None) -> TextClause: if against is operators.in_op: return Grouping(self) else: @@ -2738,7 +2760,7 @@ def append(self, clause): def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) - def self_group(self, against=None): + def self_group(self, against: Callable = None) -> Union[AnnotatedClauseList, Grouping]: if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2761,7 +2783,7 @@ class OperatorExpression(ColumnElement[_T]): def is_comparison(self): return operators.is_comparison(self.operator) - def self_group(self, against=None): + def self_group(self, against: Union[Callable, builtin_function_or_method] = None) -> Any: if ( self.group and operators.is_precedent(self.operator, against) @@ -2923,7 +2945,7 @@ class BooleanClauseList(ExpressionClauseList[bool]): __visit_name__ = "expression_clauselist" inherit_cache = True - def __init__(self, *arg, **kw): + def __init__(self, *arg, **kw) -> None: raise NotImplementedError( "BooleanClauseList has a private constructor" ) @@ -3116,7 +3138,7 @@ def or_( def _select_iterable(self) -> _SelectIterable: return (self,) - def self_group(self, against=None): + def self_group(self, against: Union[Callable, None, builtin_function_or_method] = None) -> Union[AnnotatedBooleanClauseList, BooleanClauseList, Grouping]: if not self.clauses: return self else: @@ -3173,7 +3195,7 @@ def __init__( def _select_iterable(self) -> _SelectIterable: return (self,) - def _bind_param(self, operator, obj, type_=None, expanding=False): + def _bind_param(self, operator: Callable, obj: List[Tuple[int, str]], type_: Optional[Any] = None, expanding: bool = False) -> BindParameter: if expanding: return BindParameter( None, @@ -3199,7 +3221,7 @@ def _bind_param(self, operator, obj, type_=None, expanding=False): ] ) - def self_group(self, against=None): + def self_group(self, against: Optional[Callable] = None) -> Tuple: # Tuple is parenthesized by definition. return self @@ -3352,7 +3374,7 @@ def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @property - def wrapped_column_expression(self): + def wrapped_column_expression(self) -> Union[BindParameter, ColumnClause, Column]: return self.clause @@ -3402,7 +3424,7 @@ def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @HasMemoized.memoized_attribute - def typed_expression(self): + def typed_expression(self) -> Any: if isinstance(self.clause, BindParameter): bp = self.clause._clone() bp.type = self.type @@ -3411,10 +3433,10 @@ def typed_expression(self): return self.clause @property - def wrapped_column_expression(self): + def wrapped_column_expression(self) -> BinaryExpression: return self.clause - def self_group(self, against=None): + def self_group(self, against: Callable = None) -> TypeCoerce: grouped = self.clause.self_group(against=against) if grouped is not self.clause: return TypeCoerce(grouped, self.type) @@ -3618,7 +3640,7 @@ def _order_by_label_element(self) -> Optional[Label[Any]]: def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def _negate(self): + def _negate(self) -> UnaryExpression: if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return UnaryExpression( self.self_group(against=operators.inv), @@ -3629,7 +3651,7 @@ def _negate(self): else: return ClauseElement._negate(self) - def self_group(self, against=None): + def self_group(self, against: Union[Callable, builtin_function_or_method] = None) -> Any: if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -3702,7 +3724,7 @@ def reverse_operate(self, op, other, **kwargs): class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): inherit_cache = True - def __init__(self, element, operator, negate): + def __init__(self, element: Any, operator: Callable, negate: Callable) -> None: self.element = element self.type = type_api.BOOLEANTYPE self.operator = operator @@ -3712,10 +3734,10 @@ def __init__(self, element, operator, negate): self._is_implicitly_boolean = element._is_implicitly_boolean @property - def wrapped_column_expression(self): + def wrapped_column_expression(self) -> BindParameter: return self.element - def self_group(self, against=None): + def self_group(self, against: Callable = None) -> AsBoolean: return self def _negate(self): @@ -3816,7 +3838,7 @@ def _flattened_operator_clauses( ) -> typing_Tuple[ColumnElement[Any], ...]: return (self.left, self.right) - def __bool__(self): + def __bool__(self) -> bool: """Implement Python-side "bool" for BinaryExpression as a simple "identity" check for the left and right attributes, if the operator is "eq" or "ne". Otherwise the expression @@ -3865,7 +3887,7 @@ def __invert__( def _from_objects(self) -> List[FromClause]: return self.left._from_objects + self.right._from_objects - def _negate(self): + def _negate(self) -> Union[BinaryExpression, UnaryExpression]: if self.negate is not None: return BinaryExpression( self.left, @@ -3895,7 +3917,7 @@ class Slice(ColumnElement[Any]): ("step", InternalTraversal.dp_clauseelement), ] - def __init__(self, start, stop, step, _name=None): + def __init__(self, start: int, stop: int, step: Optional[int], _name: Union[None, quoted_name, str] = None) -> None: self.start = coercions.expect( roles.ExpressionElementRole, start, @@ -3916,7 +3938,7 @@ def __init__(self, start, stop, step, _name=None): ) self.type = type_api.NULLTYPE - def self_group(self, against=None): + def self_group(self, against: builtin_function_or_method = None) -> Slice: assert against is operator.getitem return self @@ -3935,10 +3957,10 @@ class GroupedElement(DQLDMLClauseElement): element: ClauseElement - def self_group(self, against=None): + def self_group(self, against: Union[Callable, builtin_function_or_method] = None) -> Grouping: return self - def _ungroup(self): + def _ungroup(self) -> Select: return self.element._ungroup() @@ -3964,7 +3986,7 @@ def __init__( # nulltype assignment issue self.type = getattr(element, "type", type_api.NULLTYPE) # type: ignore - def _with_binary_element_type(self, type_): + def _with_binary_element_type(self, type_: Union[String, TupleType]) -> Grouping: return self.__class__(self.element._with_binary_element_type(type_)) @util.memoized_property @@ -3988,7 +4010,7 @@ def _proxies(self) -> List[ColumnElement[Any]]: def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Union[Callable, builtin_function_or_method]: return getattr(self.element, attr) def __getstate__(self): @@ -4127,7 +4149,7 @@ def _interpret_range( return lower, upper @util.memoized_property - def type(self): + def type(self) -> Union[NullType, String]: return self.element.type @util.ro_non_memoized_property @@ -4185,7 +4207,7 @@ def __reduce__(self): tuple(self.order_by) if self.order_by is not None else () ) - def over(self, partition_by=None, order_by=None, range_=None, rows=None): + def over(self, partition_by: ColumnClause = None, order_by: ColumnClause = None, range_: Optional[Tuple[int, int]] = None, rows: Optional[Tuple[int, int]] = None) -> Over: """Produce an OVER clause against this :class:`.WithinGroup` construct. @@ -4202,7 +4224,7 @@ def over(self, partition_by=None, order_by=None, range_=None, rows=None): ) @util.memoized_property - def type(self): + def type(self) -> String: wgt = self.element.within_group_type(self) if wgt is not None: return wgt @@ -4259,7 +4281,7 @@ def __init__( self.func = func self.filter(*criterion) - def filter(self, *criterion): + def filter(self, *criterion: BinaryExpression) -> FunctionFilter: """Produce an additional FILTER against the function. This method adds additional criteria to the initial criteria @@ -4323,14 +4345,14 @@ def over( rows=rows, ) - def self_group(self, against=None): + def self_group(self, against: builtin_function_or_method = None) -> Grouping: if operators.is_precedent(operators.filter_op, against): return Grouping(self) else: return self @util.memoized_property - def type(self): + def type(self) -> ARRAY: return self.func.type @util.ro_non_memoized_property @@ -4362,7 +4384,7 @@ def description(self) -> str: return self.name @HasMemoized.memoized_attribute - def _tq_key_label(self): + def _tq_key_label(self) -> Union[_anonymous_label, _truncated_label]: """table qualified label based on column key. for table-bound columns this is _; @@ -4387,11 +4409,11 @@ def _tq_label(self) -> Optional[str]: return self._gen_tq_label(self.name) @HasMemoized.memoized_attribute - def _render_label_in_columns_clause(self): + def _render_label_in_columns_clause(self) -> bool: return True @HasMemoized.memoized_attribute - def _non_anon_label(self): + def _non_anon_label(self) -> Any: return self.name def _gen_tq_label( @@ -4516,10 +4538,10 @@ def __reduce__(self): return self.__class__, (self.name, self._element, self.type) @HasMemoized.memoized_attribute - def _render_label_in_columns_clause(self): + def _render_label_in_columns_clause(self) -> bool: return True - def _bind_param(self, operator, obj, type_=None, expanding=False): + def _bind_param(self, operator: Union[Callable, builtin_function_or_method], obj: Union[List[str], int, str], type_: Optional[Any] = None, expanding: bool = False) -> BindParameter: return BindParameter( None, obj, @@ -4535,24 +4557,24 @@ def _is_implicitly_boolean(self): return self.element._is_implicitly_boolean @HasMemoized.memoized_attribute - def _allow_label_resolve(self): + def _allow_label_resolve(self) -> bool: return self.element._allow_label_resolve @property - def _order_by_label_element(self): + def _order_by_label_element(self) -> Label: return self @HasMemoized.memoized_attribute def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) - def self_group(self, against=None): + def self_group(self, against: Union[Callable, None, builtin_function_or_method] = None) -> Union[AnnotatedLabel, Label]: return self._apply_to_inner(self._element.self_group, against=against) def _negate(self): return self._apply_to_inner(self._element._negate) - def _apply_to_inner(self, fn, *arg, **kw): + def _apply_to_inner(self, fn: Callable, *arg: Any, **kw: Any) -> Union[AnnotatedLabel, Label]: sub_element = fn(*arg, **kw) if sub_element is not self._element: return Label(self.name, sub_element, type_=self.type) @@ -4560,11 +4582,11 @@ def _apply_to_inner(self, fn, *arg, **kw): return self @property - def primary_key(self): + def primary_key(self) -> bool: return self.element.primary_key @property - def foreign_keys(self): + def foreign_keys(self) -> Set: return self.element.foreign_keys def _copy_internals( @@ -4691,7 +4713,7 @@ class is usable by itself in those cases where behavioral requirements _is_multiparam_column = False @property - def _is_star(self): + def _is_star(self) -> bool: return self.is_literal and self.name == "*" def __init__( @@ -4710,7 +4732,7 @@ def __init__( self.is_literal = is_literal - def get_children(self, *, column_tables=False, **kw): + def get_children(self, *, column_tables: bool = False, **kw: Any) -> List: # override base get_children() to not return the Table # or selectable that is parent to this column. Traversals # expect the columns of tables and subqueries to be leaf nodes. @@ -4723,7 +4745,7 @@ def entity_namespace(self): else: return super().entity_namespace - def _clone(self, detect_subquery_cols=False, **kw): + def _clone(self, detect_subquery_cols: bool = False, **kw: Any) -> Union[ColumnClause, Column]: if ( detect_subquery_cols and self.table is not None @@ -4745,11 +4767,11 @@ def _from_objects(self) -> List[FromClause]: return [] @HasMemoized.memoized_attribute - def _render_label_in_columns_clause(self): + def _render_label_in_columns_clause(self) -> bool: return self.table is not None @property - def _ddl_label(self): + def _ddl_label(self) -> _truncated_label: return self._gen_tq_label(self.name, dedupe_on_key=False) def _compare_name_for_result(self, other): @@ -4922,7 +4944,7 @@ def _create_collation_expression( type_=expr.type, ) - def __init__(self, collation): + def __init__(self, collation: str) -> None: self.collation = collation @@ -5033,10 +5055,10 @@ def __new__(cls, value: str, quote: Optional[bool]) -> quoted_name: self.quote = quote return self - def __reduce__(self): + def __reduce__(self) -> Tuple[type, Tuple[str, None]]: return quoted_name, (str(self), self.quote) - def _memoized_method_lower(self): + def _memoized_method_lower(self) -> str: if self.quote: return self else: @@ -5057,7 +5079,7 @@ def _find_columns(clause: ClauseElement) -> Set[ColumnClause[Any]]: return cols -def _type_from_args(args): +def _type_from_args(args: Any) -> Union[Integer, NullType, String]: for a in args: if not a.type._isnull: return a.type @@ -5081,7 +5103,7 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False): class AnnotatedColumnElement(Annotated): _Annotated__element: ColumnElement[Any] - def __init__(self, element, values): + def __init__(self, element: Any, values: Union[Dict[str, Any], immutabledict]) -> None: Annotated.__init__(self, element, values) for attr in ( "comparator", @@ -5095,7 +5117,7 @@ def __init__(self, element, values): if self.__dict__.get(attr, False) is None: self.__dict__.pop(attr) - def _with_annotations(self, values): + def _with_annotations(self, values: immutabledict) -> Any: clone = super()._with_annotations(values) clone.__dict__.pop("comparator", None) return clone @@ -5106,7 +5128,7 @@ def name(self): return self._Annotated__element.name @util.memoized_property - def table(self): + def table(self) -> NoReturn: """pull 'table' from parent, if not present""" return self._Annotated__element.table @@ -5227,7 +5249,7 @@ def safe_construct( return _anonymous_label(label) - def __add__(self, other): + def __add__(self, other: Any) -> _anonymous_label: if "%" in other and not isinstance(other, _anonymous_label): other = str(other).replace("%", "%%") else: @@ -5240,7 +5262,7 @@ def __add__(self, other): ) ) - def __radd__(self, other): + def __radd__(self, other: str) -> _anonymous_label: if "%" in other and not isinstance(other, _anonymous_label): other = str(other).replace("%", "%%") else: @@ -5253,7 +5275,7 @@ def __radd__(self, other): ) ) - def apply_map(self, map_): + def apply_map(self, map_: Union[cache_anon_map, prefix_anon_map]) -> str: if self.quote is not None: # preserve quoting only if necessary return quoted_name(self % map_, self.quote) From 5ec44bb62809cb432b667440dcf56ed10ccc5202 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Fri, 30 Dec 2022 12:02:52 +0100 Subject: [PATCH 02/20] Allow untyped calls; fix some more types --- lib/sqlalchemy/sql/elements.py | 175 +++++++++++++++++++++++---------- 1 file changed, 121 insertions(+), 54 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 7fb2fd94794..261f05c2aa3 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4,6 +4,7 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-calls """Core SQL expression elements, including :class:`_expression.ClauseElement`, :class:`_expression.ColumnElement`, and derived classes. @@ -79,11 +80,22 @@ from ..util.typing import Self if typing.TYPE_CHECKING: + from re import Match + + from mypy_extensions import NoReturn + + from sqlalchemy.util._py_collections import immutabledict + from ._py_util import cache_anon_map + from ._py_util import prefix_anon_map from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument + from .annotation import AnnotatedBindParameter + from .annotation import AnnotatedBooleanClauseList + from .annotation import AnnotatedClauseList + from .annotation import AnnotatedLabel from .cache_key import _CacheKeyTraversalType from .cache_key import CacheKey from .compiler import Compiled @@ -99,6 +111,11 @@ from .selectable import FromClause from .selectable import NamedFromClause from .selectable import TextualSelect + from .selectable import Select + from .sqltypes import ARRAY + from .sqltypes import Integer + from .sqltypes import NullType + from .sqltypes import String from .sqltypes import TupleType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -112,20 +129,6 @@ from ..engine.interfaces import CoreExecuteOptionsParameter from ..engine.interfaces import SchemaTranslateMapType from ..engine.result import Result - from re import Match - from .selectable import Select - from .sqltypes import NullType - from .sqltypes import ARRAY - from mypy_extensions import NoReturn - from ._py_util import cache_anon_map - from .annotation import AnnotatedBindParameter - from typing import Tuple - from .annotation import AnnotatedClauseList - from .annotation import AnnotatedBooleanClauseList - from .annotation import AnnotatedLabel - from sqlalchemy.util._py_collections import immutabledict - from ._py_util import prefix_anon_map - from .sqltypes import Integer _NUMERIC = Union[float, Decimal] _NUMBER = Union[float, int, Decimal] @@ -424,7 +427,7 @@ def _clone(self, **kw: Any) -> Self: def _negate_in_binary( self: SelfClauseElement, negated_op: Callable[..., Any], - original_op: Callable[..., Any] + original_op: Callable[..., Any], ) -> SelfClauseElement: """a hook to allow the right side of a binary expression to respond to a negation of the binary expression. @@ -710,7 +713,7 @@ def _negate(self) -> ClauseElement: assert isinstance(grouped, ColumnElement) return UnaryExpression(grouped, operator=operators.inv) - def __bool__(self) -> None: + def __bool__(self) -> NoReturn: raise TypeError("Boolean value of this clause is not defined") def __repr__(self) -> str: @@ -1545,6 +1548,7 @@ def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: @util.memoized_property def _expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + # type: ignore [no-untyped-call] return frozenset(_expand_cloned(self.proxy_set)) def _uncached_proxy_list(self) -> List[ColumnElement[Any]]: @@ -1875,7 +1879,7 @@ def _dedupe_anon_label_idx(self, idx: int) -> str: return self._dedupe_anon_tq_label_idx(idx) @property - def _proxy_key(self) -> Union[None, quoted_name, str]: + def _proxy_key(self) -> Optional[str]: wce = self.wrapped_column_expression if not wce._is_text_clause: @@ -2027,12 +2031,13 @@ def __init__( def _with_value( self, - value: int, + value: Optional[_T], maintain_key: bool = False, - required: Union[bool, _NoArg] = NO_ARG - ) -> BindParameter: + required: Union[bool, _NoArg] = NO_ARG, + ) -> BindParameter[_T]: """Return a copy of this :class:`.BindParameter` with the given value set. + """ cloned = self._clone(maintain_key=maintain_key) cloned.value = value @@ -2086,7 +2091,9 @@ def render_literal_execute(self) -> BindParameter[_T]: literal_execute=True, ) - def _negate_in_binary(self, negated_op: Callable, original_op: Callable) -> BindParameter: + def _negate_in_binary( + self, negated_op: Callable[..., Any], original_op: Callable[..., Any] + ) -> BindParameter[_T]: if self.expand_op is original_op: bind = self._clone() bind.expand_op = negated_op @@ -2094,7 +2101,7 @@ def _negate_in_binary(self, negated_op: Callable, original_op: Callable) -> Bind else: return self - def _with_binary_element_type(self, type_: Any) -> BindParameter: + def _with_binary_element_type(self, type_: Any) -> ClauseElement: c = ClauseElement._clone(self) c.type = type_ return c @@ -2116,7 +2123,17 @@ def _clone(self, maintain_key: bool = False, **kw: Any) -> Self: ) return c - def _gen_cache_key(self, anon_map: cache_anon_map, bindparams: Union[List[AnnotatedBindParameter], List[BindParameter]]) -> Tuple[str, type, Union[Tuple[type, Tuple[str, int]], Tuple[type]], Union[_truncated_label, quoted_name, str], bool]: + def _gen_cache_key( + self, + anon_map: cache_anon_map, + bindparams: Union[List[AnnotatedBindParameter], List[BindParameter]], + ) -> typing_Tuple[ + str, + type, + Union[Tuple[type, Tuple[str, int]], Tuple[type]], + Union[_truncated_label, quoted_name, str], + bool, + ]: _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False) if not _gen_cache_ok: @@ -2245,7 +2262,9 @@ class TextClause( def _hide_froms(self) -> Iterable[FromClause]: return () - def __and__(self, other): + def __and__( + self, other: _ColumnExpressionArgument[bool] + ) -> ColumnElement[bool]: # support use in select.where(), query.filter() return and_(self, other) @@ -2261,13 +2280,13 @@ def _select_iterable(self) -> _SelectIterable: _allow_label_resolve = False @property - def _is_star(self): + def _is_star(self) -> bool: return self.text == "*" def __init__(self, text: str): self._bindparams: Dict[str, BindParameter[Any]] = {} - def repl(m: Match) -> str: + def repl(m: Match[Any]) -> str: self._bindparams[m.group(1)] = BindParameter(m.group(1)) return ":%s" % m.group(1) @@ -2561,7 +2580,9 @@ def comparator(self): # be using this method. return self.type.comparator_factory(self) # type: ignore - def self_group(self, against: Union[Callable, None, builtin_function_or_method] = None) -> TextClause: + def self_group( + self, against: Union[Callable[..., Any], OperatorType, None] = None + ) -> Union[Grouping[TextClause], TextClause]: if against is operators.in_op: return Grouping(self) else: @@ -2760,7 +2781,9 @@ def append(self, clause): def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) - def self_group(self, against: Callable = None) -> Union[AnnotatedClauseList, Grouping]: + def self_group( + self, against: Callable = None + ) -> Union[AnnotatedClauseList, Grouping]: if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2780,10 +2803,10 @@ class OperatorExpression(ColumnElement[_T]): group: bool = True @property - def is_comparison(self): + def is_comparison(self) -> bool: return operators.is_comparison(self.operator) - def self_group(self, against: Union[Callable, builtin_function_or_method] = None) -> Any: + def self_group(self, against: Optional[Callable[..., Any]] = None) -> Any: if ( self.group and operators.is_precedent(self.operator, against) @@ -3138,7 +3161,9 @@ def or_( def _select_iterable(self) -> _SelectIterable: return (self,) - def self_group(self, against: Union[Callable, None, builtin_function_or_method] = None) -> Union[AnnotatedBooleanClauseList, BooleanClauseList, Grouping]: + def self_group( + self, against: Optional[Callable[..., Any]] = None + ) -> Union[AnnotatedBooleanClauseList, BooleanClauseList, Grouping]: if not self.clauses: return self else: @@ -3195,7 +3220,13 @@ def __init__( def _select_iterable(self) -> _SelectIterable: return (self,) - def _bind_param(self, operator: Callable, obj: List[Tuple[int, str]], type_: Optional[Any] = None, expanding: bool = False) -> BindParameter: + def _bind_param( + self, + operator: Callable, + obj: List[Tuple[int, str]], + type_: Optional[Any] = None, + expanding: bool = False, + ) -> BindParameter: if expanding: return BindParameter( None, @@ -3221,7 +3252,7 @@ def _bind_param(self, operator: Callable, obj: List[Tuple[int, str]], type_: Opt ] ) - def self_group(self, against: Optional[Callable] = None) -> Tuple: + def self_group(self, against: Optional[OperatorType] = None) -> Tuple: # Tuple is parenthesized by definition. return self @@ -3374,7 +3405,9 @@ def _from_objects(self) -> List[FromClause]: return self.clause._from_objects @property - def wrapped_column_expression(self) -> Union[BindParameter, ColumnClause, Column]: + def wrapped_column_expression( + self, + ) -> Union[BindParameter, ColumnClause, Column]: return self.clause @@ -3433,10 +3466,10 @@ def typed_expression(self) -> Any: return self.clause @property - def wrapped_column_expression(self) -> BinaryExpression: + def wrapped_column_expression(self) -> ColumnElement[Any]: return self.clause - def self_group(self, against: Callable = None) -> TypeCoerce: + def self_group(self, against: Optional[OperatorType] = None) -> TypeCoerce: grouped = self.clause.self_group(against=against) if grouped is not self.clause: return TypeCoerce(grouped, self.type) @@ -3651,7 +3684,7 @@ def _negate(self) -> UnaryExpression: else: return ClauseElement._negate(self) - def self_group(self, against: Union[Callable, builtin_function_or_method] = None) -> Any: + def self_group(self, against: Optional[Callable[..., Any]] = None) -> Any: if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -3724,7 +3757,9 @@ def reverse_operate(self, op, other, **kwargs): class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): inherit_cache = True - def __init__(self, element: Any, operator: Callable, negate: Callable) -> None: + def __init__( + self, element: Any, operator: Callable, negate: Callable + ) -> None: self.element = element self.type = type_api.BOOLEANTYPE self.operator = operator @@ -3734,10 +3769,10 @@ def __init__(self, element: Any, operator: Callable, negate: Callable) -> None: self._is_implicitly_boolean = element._is_implicitly_boolean @property - def wrapped_column_expression(self) -> BindParameter: + def wrapped_column_expression(self) -> ColumnElement[Any]: return self.element - def self_group(self, against: Callable = None) -> AsBoolean: + def self_group(self, against: Optional[OperatorType] = None) -> AsBoolean: return self def _negate(self): @@ -3838,7 +3873,7 @@ def _flattened_operator_clauses( ) -> typing_Tuple[ColumnElement[Any], ...]: return (self.left, self.right) - def __bool__(self) -> bool: + def __bool__(self) -> bool: # type: ignore [override] """Implement Python-side "bool" for BinaryExpression as a simple "identity" check for the left and right attributes, if the operator is "eq" or "ne". Otherwise the expression @@ -3887,7 +3922,7 @@ def __invert__( def _from_objects(self) -> List[FromClause]: return self.left._from_objects + self.right._from_objects - def _negate(self) -> Union[BinaryExpression, UnaryExpression]: + def _negate(self) -> Any: if self.negate is not None: return BinaryExpression( self.left, @@ -3917,7 +3952,13 @@ class Slice(ColumnElement[Any]): ("step", InternalTraversal.dp_clauseelement), ] - def __init__(self, start: int, stop: int, step: Optional[int], _name: Union[None, quoted_name, str] = None) -> None: + def __init__( + self, + start: int, + stop: int, + step: Optional[int], + _name: Union[None, quoted_name, str] = None, + ) -> None: self.start = coercions.expect( roles.ExpressionElementRole, start, @@ -3938,7 +3979,9 @@ def __init__(self, start: int, stop: int, step: Optional[int], _name: Union[None ) self.type = type_api.NULLTYPE - def self_group(self, against: builtin_function_or_method = None) -> Slice: + def self_group( + self, against: Optional[Callable[..., Any]] = None + ) -> Slice: assert against is operator.getitem return self @@ -3957,7 +4000,7 @@ class GroupedElement(DQLDMLClauseElement): element: ClauseElement - def self_group(self, against: Union[Callable, builtin_function_or_method] = None) -> Grouping: + def self_group(self, against: Union[Callable] = None) -> Grouping: return self def _ungroup(self) -> Select: @@ -3986,7 +4029,9 @@ def __init__( # nulltype assignment issue self.type = getattr(element, "type", type_api.NULLTYPE) # type: ignore - def _with_binary_element_type(self, type_: Union[String, TupleType]) -> Grouping: + def _with_binary_element_type( + self, type_: Union[String, TupleType] + ) -> Grouping: return self.__class__(self.element._with_binary_element_type(type_)) @util.memoized_property @@ -4010,7 +4055,7 @@ def _proxies(self) -> List[ColumnElement[Any]]: def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def __getattr__(self, attr: str) -> Union[Callable, builtin_function_or_method]: + def __getattr__(self, attr: str) -> Union[Callable]: return getattr(self.element, attr) def __getstate__(self): @@ -4207,7 +4252,13 @@ def __reduce__(self): tuple(self.order_by) if self.order_by is not None else () ) - def over(self, partition_by: ColumnClause = None, order_by: ColumnClause = None, range_: Optional[Tuple[int, int]] = None, rows: Optional[Tuple[int, int]] = None) -> Over: + def over( + self, + partition_by: ColumnClause = None, + order_by: ColumnClause = None, + range_: Optional[Tuple[int, int]] = None, + rows: Optional[Tuple[int, int]] = None, + ) -> Over: """Produce an OVER clause against this :class:`.WithinGroup` construct. @@ -4345,7 +4396,9 @@ def over( rows=rows, ) - def self_group(self, against: builtin_function_or_method = None) -> Grouping: + def self_group( + self, against: Optional[Callable[..., Any]] = None + ) -> Grouping: if operators.is_precedent(operators.filter_op, against): return Grouping(self) else: @@ -4541,7 +4594,13 @@ def __reduce__(self): def _render_label_in_columns_clause(self) -> bool: return True - def _bind_param(self, operator: Union[Callable, builtin_function_or_method], obj: Union[List[str], int, str], type_: Optional[Any] = None, expanding: bool = False) -> BindParameter: + def _bind_param( + self, + operator: Union[Callable], + obj: Union[List[str], int, str], + type_: Optional[Any] = None, + expanding: bool = False, + ) -> BindParameter: return BindParameter( None, obj, @@ -4568,13 +4627,17 @@ def _order_by_label_element(self) -> Label: def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) - def self_group(self, against: Union[Callable, None, builtin_function_or_method] = None) -> Union[AnnotatedLabel, Label]: + def self_group( + self, against: Union[Callable, None] = None + ) -> Union[AnnotatedLabel, Label]: return self._apply_to_inner(self._element.self_group, against=against) def _negate(self): return self._apply_to_inner(self._element._negate) - def _apply_to_inner(self, fn: Callable, *arg: Any, **kw: Any) -> Union[AnnotatedLabel, Label]: + def _apply_to_inner( + self, fn: Callable, *arg: Any, **kw: Any + ) -> Union[AnnotatedLabel, Label]: sub_element = fn(*arg, **kw) if sub_element is not self._element: return Label(self.name, sub_element, type_=self.type) @@ -4745,7 +4808,9 @@ def entity_namespace(self): else: return super().entity_namespace - def _clone(self, detect_subquery_cols: bool = False, **kw: Any) -> Union[ColumnClause, Column]: + def _clone( + self, detect_subquery_cols: bool = False, **kw: Any + ) -> Union[ColumnClause, Column]: if ( detect_subquery_cols and self.table is not None @@ -5103,7 +5168,9 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False): class AnnotatedColumnElement(Annotated): _Annotated__element: ColumnElement[Any] - def __init__(self, element: Any, values: Union[Dict[str, Any], immutabledict]) -> None: + def __init__( + self, element: Any, values: Union[Dict[str, Any], immutabledict] + ) -> None: Annotated.__init__(self, element, values) for attr in ( "comparator", From 710c1781a34de2d851726583a87bd133a7900381 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Mon, 2 Jan 2023 18:39:10 +0100 Subject: [PATCH 03/20] Fix more types --- lib/sqlalchemy/sql/elements.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 261f05c2aa3..5b9ff9edd7d 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -96,6 +96,7 @@ from .annotation import AnnotatedBooleanClauseList from .annotation import AnnotatedClauseList from .annotation import AnnotatedLabel + from .base import _EntityNamespace from .cache_key import _CacheKeyTraversalType from .cache_key import CacheKey from .compiler import Compiled @@ -482,7 +483,7 @@ def _cloned_set(self) -> Any: return s @property - def entity_namespace(self) -> NoReturn: + def entity_namespace(self) -> Union[_EntityNamespace, NoReturn]: raise AttributeError( "This SQL expression has no entity namespace " "with which to filter from." @@ -1879,7 +1880,7 @@ def _dedupe_anon_label_idx(self, idx: int) -> str: return self._dedupe_anon_tq_label_idx(idx) @property - def _proxy_key(self) -> Optional[str]: + def _proxy_key(self) -> Optional[str]: # type: ignore [override] wce = self.wrapped_column_expression if not wce._is_text_clause: @@ -2174,7 +2175,7 @@ def __getstate__(self) -> Dict[str, Any]: d["value"] = v return d - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: if state.get("unique", False): state["key"] = _anonymous_label.safe_construct( id(self), state.get("_orig_key", "param"), sanitize_key=True @@ -2765,7 +2766,7 @@ def _select_iterable(self) -> _SelectIterable: [elem._select_iterable for elem in self.clauses] ) - def append(self, clause): + def append(self, clause: Any) -> None: if self.group_contents: self.clauses.append( coercions.expect(self._text_converter_role, clause).self_group( @@ -2782,8 +2783,8 @@ def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) def self_group( - self, against: Callable = None - ) -> Union[AnnotatedClauseList, Grouping]: + self, against: OperatorType + ) -> Union[AnnotatedClauseList, Grouping[ClauseList]]: if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2968,7 +2969,7 @@ class BooleanClauseList(ExpressionClauseList[bool]): __visit_name__ = "expression_clauselist" inherit_cache = True - def __init__(self, *arg, **kw) -> None: + def __init__(self, *arg: Any, **kw: Any) -> NoReturn: raise NotImplementedError( "BooleanClauseList has a private constructor" ) @@ -3222,11 +3223,11 @@ def _select_iterable(self) -> _SelectIterable: def _bind_param( self, - operator: Callable, + operator: OperatorType, obj: List[Tuple[int, str]], - type_: Optional[Any] = None, + type_: Optional[TypeEngine[_T]] = None, expanding: bool = False, - ) -> BindParameter: + ) -> Union[BindParameter[_T], Tuple]: if expanding: return BindParameter( None, @@ -3407,7 +3408,9 @@ def _from_objects(self) -> List[FromClause]: @property def wrapped_column_expression( self, - ) -> Union[BindParameter, ColumnClause, Column]: + ) -> Union[ + BindParameter[Any], ColumnClause[Any], ColumnElement[Any], Column[Any] + ]: return self.clause @@ -4802,7 +4805,7 @@ def get_children(self, *, column_tables: bool = False, **kw: Any) -> List: return [] @property - def entity_namespace(self): + def entity_namespace(self) -> Union[_EntityNamespace, NoReturn]: if self.table is not None: return self.table.entity_namespace else: @@ -4810,7 +4813,7 @@ def entity_namespace(self): def _clone( self, detect_subquery_cols: bool = False, **kw: Any - ) -> Union[ColumnClause, Column]: + ) -> Union[ColumnClause[_T], Column[_T]]: if ( detect_subquery_cols and self.table is not None @@ -5016,7 +5019,7 @@ def __init__(self, collation: str) -> None: class _IdentifiedClause(Executable, ClauseElement): __visit_name__ = "identified" - def __init__(self, ident): + def __init__(self, ident: Any) -> None: self.ident = ident @@ -5184,7 +5187,7 @@ def __init__( if self.__dict__.get(attr, False) is None: self.__dict__.pop(attr) - def _with_annotations(self, values: immutabledict) -> Any: + def _with_annotations(self, values: immutabledict[str, Any]) -> Any: clone = super()._with_annotations(values) clone.__dict__.pop("comparator", None) return clone From 2ca8fc9b7d784102b19ed2af7e6847169fa51401 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Tue, 3 Jan 2023 17:56:35 +0100 Subject: [PATCH 04/20] Fix some more types --- lib/sqlalchemy/sql/elements.py | 152 +++++++++++++++++++++------------ 1 file changed, 96 insertions(+), 56 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 5b9ff9edd7d..b99d10a7c33 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -84,7 +84,6 @@ from mypy_extensions import NoReturn - from sqlalchemy.util._py_collections import immutabledict from ._py_util import cache_anon_map from ._py_util import prefix_anon_map from ._typing import _ColumnExpressionArgument @@ -96,8 +95,10 @@ from .annotation import AnnotatedBooleanClauseList from .annotation import AnnotatedClauseList from .annotation import AnnotatedLabel + from .annotation import SupportsAnnotations from .base import _EntityNamespace from .cache_key import _CacheKeyTraversalType + from .cache_key import CacheConst from .cache_key import CacheKey from .compiler import Compiled from .compiler import SQLCompiler @@ -114,7 +115,7 @@ from .selectable import TextualSelect from .selectable import Select from .sqltypes import ARRAY - from .sqltypes import Integer + from .sqltypes import Boolean from .sqltypes import NullType from .sqltypes import String from .sqltypes import TupleType @@ -135,9 +136,9 @@ _NUMBER = Union[float, int, Decimal] _T = TypeVar("_T", bound="Any") +_O = TypeVar("_O", bound="object") _OPT = TypeVar("_OPT", bound="Any") _NT = TypeVar("_NT", bound="_NUMERIC") - _NMT = TypeVar("_NMT", bound="_NUMBER") @@ -1570,7 +1571,9 @@ def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool: return bool(self.proxy_set.intersection(othercolumn.proxy_set)) - def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: + def _compare_name_for_result( + self, other: ColumnElement[Any] + ) -> Union[bool, FrozenSet[ColumnElement[_T]]]: """Return True if the given column element compares to this one when targeting within a result row.""" @@ -2127,13 +2130,20 @@ def _clone(self, maintain_key: bool = False, **kw: Any) -> Self: def _gen_cache_key( self, anon_map: cache_anon_map, - bindparams: Union[List[AnnotatedBindParameter], List[BindParameter]], - ) -> typing_Tuple[ - str, - type, - Union[Tuple[type, Tuple[str, int]], Tuple[type]], - Union[_truncated_label, quoted_name, str], - bool, + bindparams: Union[ + List[AnnotatedBindParameter], List[BindParameter[_T]] + ], + ) -> Optional[ + Union[ + typing_Tuple[str, Type[BindParameter[_T]]], + typing_Tuple[ + str, + Type[BindParameter[_T]], + Union[CacheConst, typing_Tuple[Any, ...]], + str, + bool, + ], + ] ]: _gen_cache_ok = self.__class__.__dict__.get("inherit_cache", False) @@ -2284,7 +2294,7 @@ def _select_iterable(self) -> _SelectIterable: def _is_star(self) -> bool: return self.text == "*" - def __init__(self, text: str): + def __init__(self, text: str) -> None: self._bindparams: Dict[str, BindParameter[Any]] = {} def repl(m: Match[Any]) -> str: @@ -2590,7 +2600,9 @@ def self_group( return self -class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]): +class Null( # type: ignore [misc] + SingletonConstant, roles.ConstExprRole[None], ColumnElement[None] +): """Represent the NULL keyword in a SQL statement. :class:`.Null` is accessed as a constant via the @@ -2604,7 +2616,7 @@ class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]): _singleton: Null @util.memoized_property - def type(self): + def type(self) -> NullType: return type_api.NULLTYPE @classmethod @@ -2617,7 +2629,7 @@ def _instance(cls) -> Null: Null._create_singleton() -class False_( +class False_( # type: ignore [misc] SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool] ): """Represent the ``false`` keyword, or equivalent, in a SQL statement. @@ -2632,7 +2644,7 @@ class False_( _singleton: False_ @util.memoized_property - def type(self): + def type(self) -> Boolean: return type_api.BOOLEANTYPE def _negate(self) -> True_: @@ -2646,7 +2658,9 @@ def _instance(cls) -> False_: False_._create_singleton() -class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]): +class True_( # type: ignore [misc] + SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool] +): """Represent the ``true`` keyword, or equivalent, in a SQL statement. :class:`.True_` is accessed as a constant via the @@ -2660,7 +2674,7 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]): _singleton: True_ @util.memoized_property - def type(self): + def type(self) -> Boolean: return type_api.BOOLEANTYPE def _negate(self) -> False_: @@ -2716,7 +2730,7 @@ def __init__( group: bool = True, group_contents: bool = True, _literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole, - ): + ) -> None: self.operator = operator self.group = group self.group_contents = group_contents @@ -2807,7 +2821,9 @@ class OperatorExpression(ColumnElement[_T]): def is_comparison(self) -> bool: return operators.is_comparison(self.operator) - def self_group(self, against: Optional[Callable[..., Any]] = None) -> Any: + def self_group( + self, against: OperatorType + ) -> Union[Grouping[OperatorExpression[_T]], OperatorExpression[_T]]: if ( self.group and operators.is_precedent(self.operator, against) @@ -3163,7 +3179,7 @@ def _select_iterable(self) -> _SelectIterable: return (self,) def self_group( - self, against: Optional[Callable[..., Any]] = None + self, against: Optional[OperatorType] = None ) -> Union[AnnotatedBooleanClauseList, BooleanClauseList, Grouping]: if not self.clauses: return self @@ -3175,7 +3191,9 @@ def self_group( or_ = BooleanClauseList.or_ -class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): +class Tuple( # type: ignore [misc] + ClauseList, ColumnElement[typing_Tuple[Any, ...]] +): """Represent a SQL tuple.""" __visit_name__ = "tuple" @@ -3472,7 +3490,9 @@ def typed_expression(self) -> Any: def wrapped_column_expression(self) -> ColumnElement[Any]: return self.clause - def self_group(self, against: Optional[OperatorType] = None) -> TypeCoerce: + def self_group( + self, against: Optional[OperatorType] = None + ) -> TypeCoerce[_T]: grouped = self.clause.self_group(against=against) if grouped is not self.clause: return TypeCoerce(grouped, self.type) @@ -3757,11 +3777,13 @@ def reverse_operate(self, op, other, **kwargs): ) -class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): +class AsBoolean( # type: ignore [misc] + WrapsColumnExpression[bool], UnaryExpression[bool] +): inherit_cache = True def __init__( - self, element: Any, operator: Callable, negate: Callable + self, element: Any, operator: OperatorType, negate: OperatorType ) -> None: self.element = element self.type = type_api.BOOLEANTYPE @@ -3778,7 +3800,7 @@ def wrapped_column_expression(self) -> ColumnElement[Any]: def self_group(self, against: Optional[OperatorType] = None) -> AsBoolean: return self - def _negate(self): + def _negate(self) -> Union[AsBoolean, False_, True_]: if isinstance(self.element, (True_, False_)): return self.element._negate() else: @@ -4003,14 +4025,16 @@ class GroupedElement(DQLDMLClauseElement): element: ClauseElement - def self_group(self, against: Union[Callable] = None) -> Grouping: + def self_group( + self, against: Optional[OperatorType] = None + ) -> GroupedElement: return self def _ungroup(self) -> Select: return self.element._ungroup() -class Grouping(GroupedElement, ColumnElement[_T]): +class Grouping(GroupedElement, ColumnElement[_T]): # type: ignore [misc] """Represent a grouping within a column expression""" _traverse_internals: _TraverseInternalsType = [ @@ -4034,11 +4058,11 @@ def __init__( def _with_binary_element_type( self, type_: Union[String, TupleType] - ) -> Grouping: + ) -> Grouping[_T]: return self.__class__(self.element._with_binary_element_type(type_)) @util.memoized_property - def _is_implicitly_boolean(self): + def _is_implicitly_boolean(self) -> bool: return self.element._is_implicitly_boolean @util.non_memoized_property @@ -4058,13 +4082,13 @@ def _proxies(self) -> List[ColumnElement[Any]]: def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def __getattr__(self, attr: str) -> Union[Callable]: + def __getattr__(self, attr: str) -> Any: return getattr(self.element, attr) - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: return {"element": self.element, "type": self.type} - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: self.element = state["element"] self.type = state["type"] @@ -4430,7 +4454,7 @@ class NamedColumn(KeyedColumnElement[_T]): name: str key: str - def _compare_name_for_result(self, other): + def _compare_name_for_result(self, other: NamedColumn[_T]) -> bool: return (hasattr(other, "name") and self.name == other.name) or ( hasattr(other, "_label") and self._label == other._label ) @@ -4527,7 +4551,9 @@ def _make_proxy( return c.key, c -class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): +class Label( # type: ignore [misc] + roles.LabeledColumnExprRole[_T], NamedColumn[_T] +): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -4590,7 +4616,11 @@ def __init__( self._proxies = [element] - def __reduce__(self): + def __reduce__( + self, + ) -> typing_Tuple[ + Type[Any], typing_Tuple[str, ColumnElement[_T], TypeEngine[_T]] + ]: return self.__class__, (self.name, self._element, self.type) @HasMemoized.memoized_attribute @@ -4615,7 +4645,7 @@ def _bind_param( ) @util.memoized_property - def _is_implicitly_boolean(self): + def _is_implicitly_boolean(self) -> bool: return self.element._is_implicitly_boolean @HasMemoized.memoized_attribute @@ -4631,16 +4661,16 @@ def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) def self_group( - self, against: Union[Callable, None] = None - ) -> Union[AnnotatedLabel, Label]: + self, against: Optional[OperatorType] = None + ) -> Union[AnnotatedLabel, Label[_T]]: return self._apply_to_inner(self._element.self_group, against=against) def _negate(self): return self._apply_to_inner(self._element._negate) def _apply_to_inner( - self, fn: Callable, *arg: Any, **kw: Any - ) -> Union[AnnotatedLabel, Label]: + self, fn: Callable[..., Any], *arg: Any, **kw: Any + ) -> Union[AnnotatedLabel, Label[_T]]: sub_element = fn(*arg, **kw) if sub_element is not self._element: return Label(self.name, sub_element, type_=self.type) @@ -4652,7 +4682,7 @@ def primary_key(self) -> bool: return self.element.primary_key @property - def foreign_keys(self) -> Set: + def foreign_keys(self) -> AbstractSet[ForeignKey]: return self.element.foreign_keys def _copy_internals( @@ -4715,7 +4745,7 @@ def _make_proxy( return self.key, e -class ColumnClause( +class ColumnClause( # type: ignore [misc] roles.DDLReferredColumnRole, roles.LabeledColumnExprRole[_T], roles.StrAsPlainColumnRole, @@ -4798,7 +4828,9 @@ def __init__( self.is_literal = is_literal - def get_children(self, *, column_tables: bool = False, **kw: Any) -> List: + def get_children( + self, *, column_tables: bool = False, **kw: Any + ) -> List[Any]: # override base get_children() to not return the Table # or selectable that is parent to this column. Traversals # expect the columns of tables and subqueries to be leaf nodes. @@ -4842,7 +4874,9 @@ def _render_label_in_columns_clause(self) -> bool: def _ddl_label(self) -> _truncated_label: return self._gen_tq_label(self.name, dedupe_on_key=False) - def _compare_name_for_result(self, other): + def _compare_name_for_result( + self, other: ColumnClause[_T] + ) -> Union[bool, FrozenSet[ColumnElement[_T]]]: if ( self.is_literal or self.table is None @@ -5123,7 +5157,9 @@ def __new__(cls, value: str, quote: Optional[bool]) -> quoted_name: self.quote = quote return self - def __reduce__(self) -> Tuple[type, Tuple[str, None]]: + def __reduce__( + self, + ) -> typing_Tuple[Type[quoted_name], typing_Tuple[str, Optional[bool]]]: return quoted_name, (str(self), self.quote) def _memoized_method_lower(self) -> str: @@ -5132,7 +5168,7 @@ def _memoized_method_lower(self) -> str: else: return str(self).lower() - def _memoized_method_upper(self): + def _memoized_method_upper(self) -> Union[quoted_name, str]: if self.quote: return self else: @@ -5147,7 +5183,7 @@ def _find_columns(clause: ClauseElement) -> Set[ColumnClause[Any]]: return cols -def _type_from_args(args: Any) -> Union[Integer, NullType, String]: +def _type_from_args(args: Iterable[Any]) -> Union[Any, NullType]: for a in args: if not a.type._isnull: return a.type @@ -5155,7 +5191,11 @@ def _type_from_args(args: Any) -> Union[Integer, NullType, String]: return type_api.NULLTYPE -def _corresponding_column_or_error(fromclause, column, require_embedded=False): +def _corresponding_column_or_error( + fromclause: FromClause, + column: KeyedColumnElement[_T], + require_embedded: bool = False, +) -> Optional[KeyedColumnElement[_T]]: c = fromclause.corresponding_column( column, require_embedded=require_embedded ) @@ -5171,9 +5211,7 @@ def _corresponding_column_or_error(fromclause, column, require_embedded=False): class AnnotatedColumnElement(Annotated): _Annotated__element: ColumnElement[Any] - def __init__( - self, element: Any, values: Union[Dict[str, Any], immutabledict] - ) -> None: + def __init__(self, element: Any, values: Mapping[str, Any]) -> None: Annotated.__init__(self, element, values) for attr in ( "comparator", @@ -5187,23 +5225,25 @@ def __init__( if self.__dict__.get(attr, False) is None: self.__dict__.pop(attr) - def _with_annotations(self, values: immutabledict[str, Any]) -> Any: + def _with_annotations( + self, values: Mapping[str, Any] + ) -> SupportsAnnotations: clone = super()._with_annotations(values) clone.__dict__.pop("comparator", None) return clone @util.memoized_property - def name(self): + def name(self) -> Any: """pull 'name' from parent, if not present""" return self._Annotated__element.name @util.memoized_property - def table(self) -> NoReturn: + def table(self) -> Any: """pull 'table' from parent, if not present""" return self._Annotated__element.table @util.memoized_property - def key(self): + def key(self) -> Union[str, None]: """pull 'key' from parent, if not present""" return self._Annotated__element.key From 7e6036fac211687a64d3f72e99c5dad0b7d080cf Mon Sep 17 00:00:00 2001 From: jazzthief Date: Thu, 5 Jan 2023 17:30:58 +0100 Subject: [PATCH 05/20] Fix some more types; focus on method overrides --- lib/sqlalchemy/sql/elements.py | 73 ++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index b99d10a7c33..0ebba07308c 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -91,10 +91,7 @@ from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument - from .annotation import AnnotatedBindParameter from .annotation import AnnotatedBooleanClauseList - from .annotation import AnnotatedClauseList - from .annotation import AnnotatedLabel from .annotation import SupportsAnnotations from .base import _EntityNamespace from .cache_key import _CacheKeyTraversalType @@ -2130,9 +2127,7 @@ def _clone(self, maintain_key: bool = False, **kw: Any) -> Self: def _gen_cache_key( self, anon_map: cache_anon_map, - bindparams: Union[ - List[AnnotatedBindParameter], List[BindParameter[_T]] - ], + bindparams: Union[List[Annotated], List[BindParameter[_T]]], ) -> Optional[ Union[ typing_Tuple[str, Type[BindParameter[_T]]], @@ -2592,7 +2587,7 @@ def comparator(self): return self.type.comparator_factory(self) # type: ignore def self_group( - self, against: Union[Callable[..., Any], OperatorType, None] = None + self, against: Optional[OperatorType] = None ) -> Union[Grouping[TextClause], TextClause]: if against is operators.in_op: return Grouping(self) @@ -2796,9 +2791,8 @@ def append(self, clause: Any) -> None: def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) - def self_group( - self, against: OperatorType - ) -> Union[AnnotatedClauseList, Grouping[ClauseList]]: + def self_group(self, against: Optional[OperatorType]) -> ClauseElement: + assert against is not None if self.group and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -2822,8 +2816,9 @@ def is_comparison(self) -> bool: return operators.is_comparison(self.operator) def self_group( - self, against: OperatorType + self, against: Optional[OperatorType] ) -> Union[Grouping[OperatorExpression[_T]], OperatorExpression[_T]]: + assert against is not None if ( self.group and operators.is_precedent(self.operator, against) @@ -3707,7 +3702,7 @@ def _negate(self) -> UnaryExpression: else: return ClauseElement._negate(self) - def self_group(self, against: Optional[Callable[..., Any]] = None) -> Any: + def self_group(self, against: OperatorType) -> Any: if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -4004,9 +3999,7 @@ def __init__( ) self.type = type_api.NULLTYPE - def self_group( - self, against: Optional[Callable[..., Any]] = None - ) -> Slice: + def self_group(self, against: Optional[OperatorType] = None) -> Slice: assert against is operator.getitem return self @@ -4030,7 +4023,7 @@ def self_group( ) -> GroupedElement: return self - def _ungroup(self) -> Select: + def _ungroup(self) -> ClauseElement: return self.element._ungroup() @@ -4174,7 +4167,18 @@ def __init__( else: self.rows = self.range_ = None - def __reduce__(self): + def __reduce__( + self, + ) -> typing_Tuple( + Type[Over[_T]], + typing_Tuple( + ColumnElement[_T], + Optional[ClauseList], + Optional[ClauseList], + Optional[typing_Tuple[int, int]], + Optional[typing_Tuple[int, int]], + ), + ): return self.__class__, ( self.element, self.partition_by, @@ -4359,7 +4363,9 @@ def __init__( self.func = func self.filter(*criterion) - def filter(self, *criterion: BinaryExpression) -> FunctionFilter: + def filter( + self, *criterion: _ColumnExpressionArgument[bool] + ) -> FunctionFilter[_T]: """Produce an additional FILTER against the function. This method adds additional criteria to the initial criteria @@ -4423,16 +4429,14 @@ def over( rows=rows, ) - def self_group( - self, against: Optional[Callable[..., Any]] = None - ) -> Grouping: + def self_group(self, against: OperatorType) -> ClauseElement: if operators.is_precedent(operators.filter_op, against): return Grouping(self) else: return self @util.memoized_property - def type(self) -> ARRAY: + def type(self) -> TypeEngine[_T]: return self.func.type @util.ro_non_memoized_property @@ -4463,8 +4467,9 @@ def _compare_name_for_result(self, other: NamedColumn[_T]) -> bool: def description(self) -> str: return self.name + # QUESTION: Union[_anonymous_label, _truncated_label] @HasMemoized.memoized_attribute - def _tq_key_label(self) -> Union[_anonymous_label, _truncated_label]: + def _tq_key_label(self) -> Optional[str]: """table qualified label based on column key. for table-bound columns this is _; @@ -4493,7 +4498,7 @@ def _render_label_in_columns_clause(self) -> bool: return True @HasMemoized.memoized_attribute - def _non_anon_label(self) -> Any: + def _non_anon_label(self) -> Optional[str]: return self.name def _gen_tq_label( @@ -4629,11 +4634,11 @@ def _render_label_in_columns_clause(self) -> bool: def _bind_param( self, - operator: Union[Callable], + operator: OperatorType, obj: Union[List[str], int, str], - type_: Optional[Any] = None, + type_: Optional[_TypeEngineArgument[_T]] = None, expanding: bool = False, - ) -> BindParameter: + ) -> BindParameter[_T]: return BindParameter( None, obj, @@ -4653,24 +4658,24 @@ def _allow_label_resolve(self) -> bool: return self.element._allow_label_resolve @property - def _order_by_label_element(self) -> Label: + def _order_by_label_element(self) -> Label[_T]: return self @HasMemoized.memoized_attribute def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) - def self_group( - self, against: Optional[OperatorType] = None - ) -> Union[AnnotatedLabel, Label[_T]]: + # QUESTION: AnnotatedLabel + def self_group(self, against: Optional[OperatorType] = None) -> Label[_T]: return self._apply_to_inner(self._element.self_group, against=against) - def _negate(self): + def _negate(self) -> Label[_T]: return self._apply_to_inner(self._element._negate) + # QUESTION: AnnotatedLabel def _apply_to_inner( self, fn: Callable[..., Any], *arg: Any, **kw: Any - ) -> Union[AnnotatedLabel, Label[_T]]: + ) -> Label[_T]: sub_element = fn(*arg, **kw) if sub_element is not self._element: return Label(self.name, sub_element, type_=self.type) @@ -4871,7 +4876,7 @@ def _render_label_in_columns_clause(self) -> bool: return self.table is not None @property - def _ddl_label(self) -> _truncated_label: + def _ddl_label(self) -> Optional[str]: return self._gen_tq_label(self.name, dedupe_on_key=False) def _compare_name_for_result( From 0de1535faea509336735b878cb8c89d38b6d2b81 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Mon, 9 Jan 2023 19:52:28 +0100 Subject: [PATCH 06/20] Disallow untyped calls; fix more types --- lib/sqlalchemy/sql/elements.py | 98 +++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 0ebba07308c..cda13ea71da 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4,7 +4,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: allow-untyped-calls """Core SQL expression elements, including :class:`_expression.ClauseElement`, :class:`_expression.ColumnElement`, and derived classes. @@ -85,13 +84,11 @@ from mypy_extensions import NoReturn from ._py_util import cache_anon_map - from ._py_util import prefix_anon_map from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrStrLabelArgument from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument - from .annotation import AnnotatedBooleanClauseList from .annotation import SupportsAnnotations from .base import _EntityNamespace from .cache_key import _CacheKeyTraversalType @@ -1568,9 +1565,8 @@ def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool: return bool(self.proxy_set.intersection(othercolumn.proxy_set)) - def _compare_name_for_result( - self, other: ColumnElement[Any] - ) -> Union[bool, FrozenSet[ColumnElement[_T]]]: + # QUESTION: an override returns FrozenSet[ColumnElement[_T]] + def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: """Return True if the given column element compares to this one when targeting within a result row.""" @@ -2127,7 +2123,7 @@ def _clone(self, maintain_key: bool = False, **kw: Any) -> Self: def _gen_cache_key( self, anon_map: cache_anon_map, - bindparams: Union[List[Annotated], List[BindParameter[_T]]], + bindparams: List[BindParameter[_T]], ) -> Optional[ Union[ typing_Tuple[str, Type[BindParameter[_T]]], @@ -2285,6 +2281,7 @@ def _select_iterable(self) -> _SelectIterable: _allow_label_resolve = False + # QUESTION: mypy bug @property def _is_star(self) -> bool: return self.text == "*" @@ -2791,7 +2788,9 @@ def append(self, clause: Any) -> None: def _from_objects(self) -> List[FromClause]: return list(itertools.chain(*[c._from_objects for c in self.clauses])) - def self_group(self, against: Optional[OperatorType]) -> ClauseElement: + def self_group( + self, against: Optional[OperatorType] = None + ) -> ClauseElement: assert against is not None if self.group and operators.is_precedent(self.operator, against): return Grouping(self) @@ -2816,8 +2815,8 @@ def is_comparison(self) -> bool: return operators.is_comparison(self.operator) def self_group( - self, against: Optional[OperatorType] - ) -> Union[Grouping[OperatorExpression[_T]], OperatorExpression[_T]]: + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: assert against is not None if ( self.group @@ -3173,9 +3172,10 @@ def or_( def _select_iterable(self) -> _SelectIterable: return (self,) + # caught AnnotatedBooleanClauseList at runtime def self_group( self, against: Optional[OperatorType] = None - ) -> Union[AnnotatedBooleanClauseList, BooleanClauseList, Grouping]: + ) -> ColumnElement[Any]: if not self.clauses: return self else: @@ -3237,10 +3237,10 @@ def _select_iterable(self) -> _SelectIterable: def _bind_param( self, operator: OperatorType, - obj: List[Tuple[int, str]], - type_: Optional[TypeEngine[_T]] = None, + obj: List[typing_Tuple[int, str]], + type_: Optional[_TypeEngineArgument[_T]] = None, expanding: bool = False, - ) -> Union[BindParameter[_T], Tuple]: + ) -> BindParameter[_T]: if expanding: return BindParameter( None, @@ -3591,7 +3591,7 @@ def __init__( element: ColumnElement[Any], operator: Optional[OperatorType] = None, modifier: Optional[OperatorType] = None, - type_: Optional[_TypeEngineArgument[_T]] = None, + type_: Optional[Union[_TypeEngineArgument[_T], Boolean]] = None, wraps_column_expression: bool = False, ): self.operator = operator @@ -3691,7 +3691,7 @@ def _order_by_label_element(self) -> Optional[Label[Any]]: def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def _negate(self) -> UnaryExpression: + def _negate(self) -> ClauseElement: if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return UnaryExpression( self.self_group(against=operators.inv), @@ -3702,7 +3702,10 @@ def _negate(self) -> UnaryExpression: else: return ClauseElement._negate(self) - def self_group(self, against: OperatorType) -> Any: + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + assert against is not None if self.operator and operators.is_precedent(self.operator, against): return Grouping(self) else: @@ -3756,7 +3759,9 @@ def _create_all( # operate and reverse_operate are hardwired to # dispatch onto the type comparator directly, so that we can # ensure "reversed" behavior. - def operate(self, op, *other, **kwargs): + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[_T]: if not operators.is_comparison(op): raise exc.ArgumentError( "Only comparison operators may be used with ANY/ALL" @@ -3764,7 +3769,9 @@ def operate(self, op, *other, **kwargs): kwargs["reverse"] = kwargs["_any_all_expr"] = True return self.comparator.operate(operators.mirror(op), *other, **kwargs) - def reverse_operate(self, op, other, **kwargs): + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> NoReturn: # comparison operators should never call reverse_operate assert not operators.is_comparison(op) raise exc.ArgumentError( @@ -4169,16 +4176,16 @@ def __init__( def __reduce__( self, - ) -> typing_Tuple( + ) -> typing_Tuple[ Type[Over[_T]], - typing_Tuple( + typing_Tuple[ ColumnElement[_T], Optional[ClauseList], Optional[ClauseList], Optional[typing_Tuple[int, int]], Optional[typing_Tuple[int, int]], - ), - ): + ], + ]: return self.__class__, ( self.element, self.partition_by, @@ -4225,7 +4232,7 @@ def _interpret_range( return lower, upper @util.memoized_property - def type(self) -> Union[NullType, String]: + def type(self) -> TypeEngine[_T]: return self.element.type @util.ro_non_memoized_property @@ -4278,18 +4285,23 @@ def __init__( *util.to_list(order_by), _literal_as_text_role=roles.ByOfRole ) - def __reduce__(self): + def __reduce__( + self, + ) -> typing_Tuple[ + Type[WithinGroup[_T]], + typing_Tuple[Union[FunctionElement[_T], ColumnElement[Any]], ...], + ]: return self.__class__, (self.element,) + ( tuple(self.order_by) if self.order_by is not None else () ) def over( self, - partition_by: ColumnClause = None, - order_by: ColumnClause = None, - range_: Optional[Tuple[int, int]] = None, - rows: Optional[Tuple[int, int]] = None, - ) -> Over: + partition_by: Optional[ColumnClause[_T]] = None, + order_by: Optional[ColumnClause[_T]] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + ) -> Over[_T]: """Produce an OVER clause against this :class:`.WithinGroup` construct. @@ -4306,7 +4318,7 @@ def over( ) @util.memoized_property - def type(self) -> String: + def type(self) -> TypeEngine[_T]: wgt = self.element.within_group_type(self) if wgt is not None: return wgt @@ -4429,7 +4441,10 @@ def over( rows=rows, ) - def self_group(self, against: OperatorType) -> ClauseElement: + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[_T]: + assert against is not None if operators.is_precedent(operators.filter_op, against): return Grouping(self) else: @@ -4458,7 +4473,7 @@ class NamedColumn(KeyedColumnElement[_T]): name: str key: str - def _compare_name_for_result(self, other: NamedColumn[_T]) -> bool: + def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: return (hasattr(other, "name") and self.name == other.name) or ( hasattr(other, "_label") and self._label == other._label ) @@ -4665,14 +4680,14 @@ def _order_by_label_element(self) -> Label[_T]: def element(self) -> ColumnElement[_T]: return self._element.self_group(against=operators.as_) - # QUESTION: AnnotatedLabel + # caught AnnotatedLabel at runtime def self_group(self, against: Optional[OperatorType] = None) -> Label[_T]: return self._apply_to_inner(self._element.self_group, against=against) - def _negate(self) -> Label[_T]: + def _negate(self) -> ColumnElement[_T]: return self._apply_to_inner(self._element._negate) - # QUESTION: AnnotatedLabel + # caught AnnotatedLabel at runtime def _apply_to_inner( self, fn: Callable[..., Any], *arg: Any, **kw: Any ) -> Label[_T]: @@ -4848,9 +4863,7 @@ def entity_namespace(self) -> Union[_EntityNamespace, NoReturn]: else: return super().entity_namespace - def _clone( - self, detect_subquery_cols: bool = False, **kw: Any - ) -> Union[ColumnClause[_T], Column[_T]]: + def _clone(self, detect_subquery_cols: bool = False, **kw: Any) -> Any: if ( detect_subquery_cols and self.table is not None @@ -4880,8 +4893,8 @@ def _ddl_label(self) -> Optional[str]: return self._gen_tq_label(self.name, dedupe_on_key=False) def _compare_name_for_result( - self, other: ColumnClause[_T] - ) -> Union[bool, FrozenSet[ColumnElement[_T]]]: + self, other: ColumnElement[Any] + ) -> Union[bool, FrozenSet[ColumnElement[Any]]]: if ( self.is_literal or self.table is None @@ -5390,7 +5403,8 @@ def __radd__(self, other: str) -> _anonymous_label: ) ) - def apply_map(self, map_: Union[cache_anon_map, prefix_anon_map]) -> str: + # caught map_: Union[cache_anon_map, prefix_anon_map] + def apply_map(self, map_: Mapping[str, Any]) -> str: if self.quote is not None: # preserve quoting only if necessary return quoted_name(self % map_, self.quote) From b119b9e3b308595a57bcd46d6ec3d114586335a5 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Mon, 9 Jan 2023 19:54:37 +0100 Subject: [PATCH 07/20] Add return annotation to fix untyped calls from sql.elements --- lib/sqlalchemy/sql/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 1752a4dc1ab..2eaf0bcde77 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -185,7 +185,7 @@ def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: raise NotImplementedError() @classmethod - def _create_singleton(cls): + def _create_singleton(cls) -> None: obj = object.__new__(cls) obj.__init__() # type: ignore From ef6ca74f675d0c69b49834225f76c6532e9a99b6 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Mon, 9 Jan 2023 20:03:35 +0100 Subject: [PATCH 08/20] Clean up --- lib/sqlalchemy/sql/elements.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index cda13ea71da..376244202a5 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -130,7 +130,6 @@ _NUMBER = Union[float, int, Decimal] _T = TypeVar("_T", bound="Any") -_O = TypeVar("_O", bound="object") _OPT = TypeVar("_OPT", bound="Any") _NT = TypeVar("_NT", bound="_NUMERIC") _NMT = TypeVar("_NMT", bound="_NUMBER") From 3a135f7714997b5c7cbc6cb91c9d67b184984310 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Mon, 9 Jan 2023 20:09:56 +0100 Subject: [PATCH 09/20] Clean up comments --- lib/sqlalchemy/sql/elements.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 376244202a5..4aab76e4ae7 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -132,6 +132,7 @@ _T = TypeVar("_T", bound="Any") _OPT = TypeVar("_OPT", bound="Any") _NT = TypeVar("_NT", bound="_NUMERIC") + _NMT = TypeVar("_NMT", bound="_NUMBER") @@ -1564,7 +1565,6 @@ def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool: return bool(self.proxy_set.intersection(othercolumn.proxy_set)) - # QUESTION: an override returns FrozenSet[ColumnElement[_T]] def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: """Return True if the given column element compares to this one when targeting within a result row.""" @@ -2280,7 +2280,6 @@ def _select_iterable(self) -> _SelectIterable: _allow_label_resolve = False - # QUESTION: mypy bug @property def _is_star(self) -> bool: return self.text == "*" @@ -4481,7 +4480,7 @@ def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: def description(self) -> str: return self.name - # QUESTION: Union[_anonymous_label, _truncated_label] + # caught Union[_anonymous_label, _truncated_label] at runtime @HasMemoized.memoized_attribute def _tq_key_label(self) -> Optional[str]: """table qualified label based on column key. From 28f43c9c7f7d8ea972a7329dba3a3eb3c07cc6c9 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Tue, 10 Jan 2023 13:03:53 +0100 Subject: [PATCH 10/20] Refine some types; clean up --- lib/sqlalchemy/sql/elements.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 4aab76e4ae7..5e6e5500546 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -455,7 +455,7 @@ def _constructor(self) -> Any: return self.__class__ @HasMemoized.memoized_attribute - def _cloned_set(self) -> Any: + def _cloned_set(self) -> Set[ClauseElement]: """Return the set consisting all cloned ancestors of this ClauseElement. @@ -1544,7 +1544,6 @@ def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: @util.memoized_property def _expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]: - # type: ignore [no-untyped-call] return frozenset(_expand_cloned(self.proxy_set)) def _uncached_proxy_list(self) -> List[ColumnElement[Any]]: From ccad2f082b3b0937fd4528bdab9d4b01b07ac3dd Mon Sep 17 00:00:00 2001 From: jazzthief Date: Tue, 10 Jan 2023 15:48:44 +0100 Subject: [PATCH 11/20] Remove remaining ignores; refine some types --- lib/sqlalchemy/sql/elements.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 5e6e5500546..e07caf4a42f 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -696,7 +696,7 @@ def _compile_w_cache( return compiled_sql, extracted_params, cache_hit - def __invert__(self) -> Union[ColumnElement[bool], ClauseElement]: + def __invert__(self) -> ClauseElement: # undocumented element currently used by the ORM for # relationship.contains() if hasattr(self, "negation_clause"): @@ -709,7 +709,7 @@ def _negate(self) -> ClauseElement: assert isinstance(grouped, ColumnElement) return UnaryExpression(grouped, operator=operators.inv) - def __bool__(self) -> NoReturn: + def __bool__(self) -> bool: raise TypeError("Boolean value of this clause is not defined") def __repr__(self) -> str: @@ -1874,7 +1874,7 @@ def _dedupe_anon_label_idx(self, idx: int) -> str: return self._dedupe_anon_tq_label_idx(idx) @property - def _proxy_key(self) -> Optional[str]: # type: ignore [override] + def _proxy_key(self) -> Optional[str]: wce = self.wrapped_column_expression if not wce._is_text_clause: @@ -2589,9 +2589,7 @@ def self_group( return self -class Null( # type: ignore [misc] - SingletonConstant, roles.ConstExprRole[None], ColumnElement[None] -): +class Null(SingletonConstant, roles.ConstExprRole[None], ColumnElement[None]): """Represent the NULL keyword in a SQL statement. :class:`.Null` is accessed as a constant via the @@ -2618,7 +2616,7 @@ def _instance(cls) -> Null: Null._create_singleton() -class False_( # type: ignore [misc] +class False_( SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool] ): """Represent the ``false`` keyword, or equivalent, in a SQL statement. @@ -2647,9 +2645,7 @@ def _instance(cls) -> False_: False_._create_singleton() -class True_( # type: ignore [misc] - SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool] -): +class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]): """Represent the ``true`` keyword, or equivalent, in a SQL statement. :class:`.True_` is accessed as a constant via the @@ -3183,9 +3179,7 @@ def self_group( or_ = BooleanClauseList.or_ -class Tuple( # type: ignore [misc] - ClauseList, ColumnElement[typing_Tuple[Any, ...]] -): +class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): """Represent a SQL tuple.""" __visit_name__ = "tuple" @@ -3776,9 +3770,7 @@ def reverse_operate( ) -class AsBoolean( # type: ignore [misc] - WrapsColumnExpression[bool], UnaryExpression[bool] -): +class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): inherit_cache = True def __init__( @@ -3897,7 +3889,7 @@ def _flattened_operator_clauses( ) -> typing_Tuple[ColumnElement[Any], ...]: return (self.left, self.right) - def __bool__(self) -> bool: # type: ignore [override] + def __bool__(self) -> bool: """Implement Python-side "bool" for BinaryExpression as a simple "identity" check for the left and right attributes, if the operator is "eq" or "ne". Otherwise the expression @@ -4031,7 +4023,7 @@ def _ungroup(self) -> ClauseElement: return self.element._ungroup() -class Grouping(GroupedElement, ColumnElement[_T]): # type: ignore [misc] +class Grouping(GroupedElement, ColumnElement[_T]): """Represent a grouping within a column expression""" _traverse_internals: _TraverseInternalsType = [ @@ -4568,9 +4560,7 @@ def _make_proxy( return c.key, c -class Label( # type: ignore [misc] - roles.LabeledColumnExprRole[_T], NamedColumn[_T] -): +class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): """Represents a column label (AS). Represent a label, as typically applied to any column-level @@ -4762,7 +4752,7 @@ def _make_proxy( return self.key, e -class ColumnClause( # type: ignore [misc] +class ColumnClause( roles.DDLReferredColumnRole, roles.LabeledColumnExprRole[_T], roles.StrAsPlainColumnRole, From 85aeecc81e309d4a7026ac71f3fe1c5c1fda9020 Mon Sep 17 00:00:00 2001 From: jazzthief Date: Tue, 10 Jan 2023 16:24:51 +0100 Subject: [PATCH 12/20] Refine more types --- lib/sqlalchemy/sql/elements.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e07caf4a42f..cec2a72ce07 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -696,7 +696,7 @@ def _compile_w_cache( return compiled_sql, extracted_params, cache_hit - def __invert__(self) -> ClauseElement: + def __invert__(self) -> Union[operators.Operators, ClauseElement]: # undocumented element currently used by the ORM for # relationship.contains() if hasattr(self, "negation_clause"): @@ -4016,7 +4016,7 @@ class GroupedElement(DQLDMLClauseElement): def self_group( self, against: Optional[OperatorType] = None - ) -> GroupedElement: + ) -> ClauseElement: return self def _ungroup(self) -> ClauseElement: From 85b8101e38896693d37864c973fcf36308b415bf Mon Sep 17 00:00:00 2001 From: jazzthief Date: Fri, 13 Jan 2023 11:07:48 +0100 Subject: [PATCH 13/20] More type fixes --- lib/sqlalchemy/sql/elements.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index cec2a72ce07..e543338660f 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1381,7 +1381,7 @@ def _non_anon_label(self) -> Optional[str]: """ return getattr(self, "name", None) - _render_label_in_columns_clause = True + _render_label_in_columns_clause: bool = True """A flag used by select._columns_plus_names that helps to determine we are actually going to render in terms of "SELECT AS