Skip to content

Commit

Permalink
TYP: simple return types (pandas-dev#54786)
Browse files Browse the repository at this point in the history
* Return None

* Return simple types

* ruff false positive

* isort+mypy

* typo, use " for cast

* SingleArrayManager.dtype can also be a numpy dtype

* comments + test assert on CI

* wider return types at the cost of one fewer mypy ignore

* DatetimeArray reaches IntervalArray._combined

* avoid some ignores

* remove assert False
  • Loading branch information
twoertwein committed Aug 31, 2023
1 parent eafceae commit c74a071
Show file tree
Hide file tree
Showing 40 changed files with 237 additions and 145 deletions.
2 changes: 1 addition & 1 deletion pandas/_testing/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def round_trip_localpath(writer, reader, path: str | None = None):
return obj


def write_to_compressed(compression, path, data, dest: str = "test"):
def write_to_compressed(compression, path, data, dest: str = "test") -> None:
"""
Write data to a compressed file.
Expand Down
2 changes: 1 addition & 1 deletion pandas/arrays/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
]


def __getattr__(name: str):
def __getattr__(name: str) -> type[NumpyExtensionArray]:
if name == "PandasArray":
# GH#53694
import warnings
Expand Down
2 changes: 1 addition & 1 deletion pandas/compat/pickle_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from collections.abc import Generator


def load_reduce(self):
def load_reduce(self) -> None:
stack = self.stack
args = stack.pop()
func = stack[-1]
Expand Down
4 changes: 2 additions & 2 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def box_with_array(request):


@pytest.fixture
def dict_subclass():
def dict_subclass() -> type[dict]:
"""
Fixture for a dictionary subclass.
"""
Expand All @@ -504,7 +504,7 @@ def __init__(self, *args, **kwargs) -> None:


@pytest.fixture
def non_dict_mapping_subclass():
def non_dict_mapping_subclass() -> type[abc.Mapping]:
"""
Fixture for a non-mapping dictionary subclass.
"""
Expand Down
20 changes: 11 additions & 9 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@

if TYPE_CHECKING:
from collections.abc import (
Generator,
Hashable,
Iterable,
Iterator,
Sequence,
)

Expand Down Expand Up @@ -253,7 +253,7 @@ def transform(self) -> DataFrame | Series:

return result

def transform_dict_like(self, func):
def transform_dict_like(self, func) -> DataFrame:
"""
Compute transform in the case of a dict-like func
"""
Expand Down Expand Up @@ -315,7 +315,7 @@ def compute_list_like(
op_name: Literal["agg", "apply"],
selected_obj: Series | DataFrame,
kwargs: dict[str, Any],
) -> tuple[list[Hashable], list[Any]]:
) -> tuple[list[Hashable] | Index, list[Any]]:
"""
Compute agg/apply results for like-like input.
Expand All @@ -330,7 +330,7 @@ def compute_list_like(
Returns
-------
keys : list[hashable]
keys : list[Hashable] or Index
Index labels for result.
results : list
Data for result. When aggregating with a Series, this can contain any
Expand Down Expand Up @@ -370,12 +370,14 @@ def compute_list_like(
new_res = getattr(colg, op_name)(func, *args, **kwargs)
results.append(new_res)
indices.append(index)
keys = selected_obj.columns.take(indices)
# error: Incompatible types in assignment (expression has type "Any |
# Index", variable has type "list[Any | Callable[..., Any] | str]")
keys = selected_obj.columns.take(indices) # type: ignore[assignment]

return keys, results

def wrap_results_list_like(
self, keys: list[Hashable], results: list[Series | DataFrame]
self, keys: Iterable[Hashable], results: list[Series | DataFrame]
):
from pandas.core.reshape.concat import concat

Expand Down Expand Up @@ -772,7 +774,7 @@ def result_columns(self) -> Index:

@property
@abc.abstractmethod
def series_generator(self) -> Iterator[Series]:
def series_generator(self) -> Generator[Series, None, None]:
pass

@abc.abstractmethod
Expand Down Expand Up @@ -1014,7 +1016,7 @@ class FrameRowApply(FrameApply):
axis: AxisInt = 0

@property
def series_generator(self):
def series_generator(self) -> Generator[Series, None, None]:
return (self.obj._ixs(i, axis=1) for i in range(len(self.columns)))

@property
Expand Down Expand Up @@ -1075,7 +1077,7 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame:
return result.T

@property
def series_generator(self):
def series_generator(self) -> Generator[Series, None, None]:
values = self.values
values = ensure_wrapped_if_datetimelike(values)
assert len(values) > 0
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/arrow/extension_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __ne__(self, other) -> bool:
def __hash__(self) -> int:
return hash((str(self), self.freq))

def to_pandas_dtype(self):
def to_pandas_dtype(self) -> PeriodDtype:
return PeriodDtype(freq=self.freq)


Expand Down Expand Up @@ -105,7 +105,7 @@ def __ne__(self, other) -> bool:
def __hash__(self) -> int:
return hash((str(self), str(self.subtype), self.closed))

def to_pandas_dtype(self):
def to_pandas_dtype(self) -> IntervalDtype:
return IntervalDtype(self.subtype.to_pandas_dtype(), self.closed)


Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,7 +2410,7 @@ def _mode(self, dropna: bool = True) -> Categorical:
# ------------------------------------------------------------------
# ExtensionArray Interface

def unique(self):
def unique(self) -> Self:
"""
Return the ``Categorical`` which ``categories`` and ``codes`` are
unique.
Expand Down
38 changes: 22 additions & 16 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
)


IntervalSideT = Union[TimeArrayLike, np.ndarray]
IntervalSide = Union[TimeArrayLike, np.ndarray]
IntervalOrNA = Union[Interval, float]

_interval_shared_docs: dict[str, str] = {}
Expand Down Expand Up @@ -219,8 +219,8 @@ def ndim(self) -> Literal[1]:
return 1

# To make mypy recognize the fields
_left: IntervalSideT
_right: IntervalSideT
_left: IntervalSide
_right: IntervalSide
_dtype: IntervalDtype

# ---------------------------------------------------------------------
Expand All @@ -237,8 +237,8 @@ def __new__(
data = extract_array(data, extract_numpy=True)

if isinstance(data, cls):
left: IntervalSideT = data._left
right: IntervalSideT = data._right
left: IntervalSide = data._left
right: IntervalSide = data._right
closed = closed or data.closed
dtype = IntervalDtype(left.dtype, closed=closed)
else:
Expand Down Expand Up @@ -280,8 +280,8 @@ def __new__(
@classmethod
def _simple_new(
cls,
left: IntervalSideT,
right: IntervalSideT,
left: IntervalSide,
right: IntervalSide,
dtype: IntervalDtype,
) -> Self:
result = IntervalMixin.__new__(cls)
Expand All @@ -299,7 +299,7 @@ def _ensure_simple_new_inputs(
closed: IntervalClosedType | None = None,
copy: bool = False,
dtype: Dtype | None = None,
) -> tuple[IntervalSideT, IntervalSideT, IntervalDtype]:
) -> tuple[IntervalSide, IntervalSide, IntervalDtype]:
"""Ensure correctness of input parameters for cls._simple_new."""
from pandas.core.indexes.base import ensure_index

Expand Down Expand Up @@ -1031,8 +1031,8 @@ def _concat_same_type(cls, to_concat: Sequence[IntervalArray]) -> Self:
raise ValueError("Intervals must all be closed on the same side.")
closed = closed_set.pop()

left = np.concatenate([interval.left for interval in to_concat])
right = np.concatenate([interval.right for interval in to_concat])
left: IntervalSide = np.concatenate([interval.left for interval in to_concat])
right: IntervalSide = np.concatenate([interval.right for interval in to_concat])

left, right, dtype = cls._ensure_simple_new_inputs(left, right, closed=closed)

Expand Down Expand Up @@ -1283,7 +1283,7 @@ def _format_space(self) -> str:
# Vectorized Interval Properties/Attributes

@property
def left(self):
def left(self) -> Index:
"""
Return the left endpoints of each Interval in the IntervalArray as an Index.
Expand All @@ -1303,7 +1303,7 @@ def left(self):
return Index(self._left, copy=False)

@property
def right(self):
def right(self) -> Index:
"""
Return the right endpoints of each Interval in the IntervalArray as an Index.
Expand Down Expand Up @@ -1855,11 +1855,17 @@ def isin(self, values) -> npt.NDArray[np.bool_]:
return isin(self.astype(object), values.astype(object))

@property
def _combined(self) -> IntervalSideT:
left = self.left._values.reshape(-1, 1)
right = self.right._values.reshape(-1, 1)
def _combined(self) -> IntervalSide:
# error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]"
# has no attribute "reshape" [union-attr]
left = self.left._values.reshape(-1, 1) # type: ignore[union-attr]
right = self.right._values.reshape(-1, 1) # type: ignore[union-attr]
if needs_i8_conversion(left.dtype):
comb = left._concat_same_type([left, right], axis=1)
# error: Item "ndarray[Any, Any]" of "Any | ndarray[Any, Any]" has
# no attribute "_concat_same_type"
comb = left._concat_same_type( # type: ignore[union-attr]
[left, right], axis=1
)
else:
comb = np.concatenate([left, right], axis=1)
return comb
Expand Down
9 changes: 3 additions & 6 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def _check_timedeltalike_freq_compat(self, other):
return lib.item_from_zerodim(delta)


def raise_on_incompatible(left, right):
def raise_on_incompatible(left, right) -> IncompatibleFrequency:
"""
Helper function to render a consistent error message when raising
IncompatibleFrequency.
Expand Down Expand Up @@ -1089,7 +1089,7 @@ def validate_dtype_freq(dtype, freq: timedelta | str | None) -> BaseOffset:


def validate_dtype_freq(
dtype, freq: BaseOffsetT | timedelta | str | None
dtype, freq: BaseOffsetT | BaseOffset | timedelta | str | None
) -> BaseOffsetT:
"""
If both a dtype and a freq are available, ensure they match. If only
Expand All @@ -1110,10 +1110,7 @@ def validate_dtype_freq(
IncompatibleFrequency : mismatch between dtype and freq
"""
if freq is not None:
# error: Incompatible types in assignment (expression has type
# "BaseOffset", variable has type "Union[BaseOffsetT, timedelta,
# str, None]")
freq = to_offset(freq) # type: ignore[assignment]
freq = to_offset(freq)

if dtype is not None:
dtype = pandas_dtype(dtype)
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,9 @@ def npoints(self) -> int:
"""
return self.sp_index.npoints

def isna(self):
# error: Return type "SparseArray" of "isna" incompatible with return type
# "ndarray[Any, Any] | ExtensionArraySupportsAnyAll" in supertype "ExtensionArray"
def isna(self) -> Self: # type: ignore[override]
# If null fill value, we want SparseDtype[bool, true]
# to preserve the same memory usage.
dtype = SparseDtype(bool, self._null_fill_value)
Expand Down Expand Up @@ -1421,7 +1423,7 @@ def all(self, axis=None, *args, **kwargs):

return values.all()

def any(self, axis: AxisInt = 0, *args, **kwargs):
def any(self, axis: AxisInt = 0, *args, **kwargs) -> bool:
"""
Tests whether at least one of elements evaluate True
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
NumpySorter,
NumpyValueArrayLike,
Scalar,
Self,
npt,
type_t,
)
Expand Down Expand Up @@ -135,7 +136,7 @@ def type(self) -> type[str]:
return str

@classmethod
def construct_from_string(cls, string):
def construct_from_string(cls, string) -> Self:
"""
Construct a StringDtype from a string.
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
npt,
)

from pandas import Series


ArrowStringScalarOrNAT = Union[str, libmissing.NAType]

Expand Down Expand Up @@ -547,7 +549,7 @@ def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True):
def value_counts(self, dropna: bool = True) -> Series:
from pandas import Series

result = super().value_counts(dropna)
Expand Down
Loading

0 comments on commit c74a071

Please sign in to comment.