Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pandas_dataclasses/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
AnyField,
DataClass,
FType,
deannotate,
get_annotated,
get_dtype,
get_ftype,
get_name,
Expand Down Expand Up @@ -154,5 +154,5 @@ def get_fieldspec(field: AnyField) -> Optional[AnyFieldSpec]:
return ScalarFieldSpec(
type=ftype.value,
name=name,
data=ScalarSpec(deannotate(field.type), field.default),
data=ScalarSpec(get_annotated(field.type), field.default),
)
81 changes: 41 additions & 40 deletions pandas_dataclasses/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Hashable,
Iterator,
Optional,
Type,
Tuple,
TypeVar,
Union,
)
Expand Down Expand Up @@ -42,12 +42,6 @@
THashable = TypeVar("THashable", bound=Hashable)


class Collection(Collection[TCovariant], Protocol):
"""Type hint equivalent to typing.Collection."""

pass


class DataClass(Protocol):
"""Type hint for dataclass objects."""

Expand All @@ -72,6 +66,14 @@ class FType(Enum):
OTHER = "other"
"""Annotation for other fields."""

@classmethod
def annotates(cls, tp: Any) -> bool:
"""Check if any ftype annotates a type hint."""
if get_origin(tp) is not Annotated:
return False

return any(isinstance(arg, cls) for arg in get_args(tp))


# type hints (public)
Attr = Annotated[T, FType.ATTR]
Expand All @@ -92,44 +94,48 @@ class FType(Enum):

# runtime functions
def deannotate(tp: Any) -> Any:
"""Recursively remove annotations from a type hint."""
"""Recursively remove annotations in a type hint."""

class Temporary:
__annotations__ = dict(type=tp)

return get_type_hints(Temporary)["type"]


def get_annotations(tp: Any) -> Iterator[Any]:
"""Extract all annotations from a type hint."""
def find_annotated(tp: Any) -> Iterator[Any]:
"""Generate all annotated types in a type hint."""
args = get_args(tp)

if get_origin(tp) is Annotated:
yield from get_annotations(args[0])
yield from args[1:]
yield tp
yield from find_annotated(args[0])
else:
yield from chain(*map(get_annotations, args))
yield from chain(*map(find_annotated, args))


def get_collections(tp: Any) -> Iterator[Type[Collection[Any]]]:
"""Extract all collection types from a type hint."""
args = get_args(tp)
def get_annotated(tp: Any) -> Any:
"""Extract the first ftype-annotated type."""
for annotated in filter(FType.annotates, find_annotated(tp)):
return deannotate(annotated)

if get_origin(tp) is Collection:
yield tp
else:
yield from chain(*map(get_collections, args))
raise TypeError("Could not find any ftype-annotated type.")


def get_annotations(tp: Any) -> Tuple[Any, ...]:
"""Extract annotations of the first ftype-annotated type."""
for annotated in filter(FType.annotates, find_annotated(tp)):
return get_args(annotated)[1:]

raise TypeError("Could not find any ftype-annotated type.")


def get_dtype(tp: Any) -> Optional[AnyDType]:
"""Extract a dtype (most outer data type) from a type hint."""
"""Extract a NumPy or pandas data type."""
try:
collection = list(get_collections(tp))[-1]
except IndexError:
dtype = get_args(get_annotated(tp))[1]
except TypeError:
raise TypeError(f"Could not find any dtype in {tp!r}.")

dtype = get_args(collection)[0]

if dtype is Any or dtype is type(None):
return

Expand All @@ -140,21 +146,16 @@ def get_dtype(tp: Any) -> Optional[AnyDType]:


def get_ftype(tp: Any, default: FType = FType.OTHER) -> FType:
"""Extract an ftype (most outer FType) from a type hint."""
for annotation in reversed(list(get_annotations(tp))):
if isinstance(annotation, FType):
return annotation

return default
"""Extract an ftype if found or return given default."""
try:
return get_annotations(tp)[0]
except (IndexError, TypeError):
return default


def get_name(tp: Any, default: Hashable = None) -> Hashable:
"""Extract a name (most outer hashable) from a type hint."""
for annotation in reversed(list(get_annotations(tp))):
if isinstance(annotation, FType):
continue

if isinstance(annotation, Hashable):
return annotation

return default
"""Extract a name if found or return given default."""
try:
return get_annotations(tp)[1]
except (IndexError, TypeError):
return default
65 changes: 43 additions & 22 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# standard library
from typing import Any, Optional, Union
from typing import Any, Union


# dependencies
import numpy as np
import pandas as pd
from pytest import mark
from typing_extensions import Annotated, Literal
from typing_extensions import Annotated as Ann
from typing_extensions import Literal as L
from pandas_dataclasses.typing import (
Attr,
Data,
FType,
Index,
Name,
get_dtype,
Expand All @@ -22,37 +24,56 @@
testdata_dtype = [
(Data[Any], None),
(Data[None], None),
(Data[int], np.dtype("int64")),
(Data[Literal["i8"]], np.dtype("int64")),
(Data[Literal["boolean"]], pd.BooleanDtype()),
(Data[int], np.dtype("i8")),
(Data[L["i8"]], np.dtype("i8")),
(Data[L["boolean"]], pd.BooleanDtype()),
(Data[L["category"]], pd.CategoricalDtype()),
(Index[Any], None),
(Index[None], None),
(Index[int], np.dtype("int64")),
(Index[Literal["i8"]], np.dtype("int64")),
(Index[Literal["boolean"]], pd.BooleanDtype()),
(Optional[Data[float]], np.dtype("float64")),
(Optional[Index[float]], np.dtype("float64")),
(Union[Data[float], str], np.dtype("float64")),
(Union[Index[float], str], np.dtype("float64")),
(Index[int], np.dtype("i8")),
(Index[L["i8"]], np.dtype("i8")),
(Index[L["boolean"]], pd.BooleanDtype()),
(Index[L["category"]], pd.CategoricalDtype()),
(Ann[Data[float], "data"], np.dtype("f8")),
(Ann[Index[float], "index"], np.dtype("f8")),
(Union[Ann[Data[float], "data"], Ann[Any, "any"]], np.dtype("f8")),
(Union[Ann[Index[float], "index"], Ann[Any, "any"]], np.dtype("f8")),
]

testdata_ftype = [
(Attr[Any], "attr"),
(Data[Any], "data"),
(Index[Any], "index"),
(Name[Any], "name"),
(Any, "other"),
(Attr[Any], FType.ATTR),
(Data[Any], FType.DATA),
(Index[Any], FType.INDEX),
(Name[Any], FType.NAME),
(Any, FType.OTHER),
(Ann[Attr[Any], "attr"], FType.ATTR),
(Ann[Data[Any], "data"], FType.DATA),
(Ann[Index[Any], "index"], FType.INDEX),
(Ann[Name[Any], "name"], FType.NAME),
(Ann[Any, "other"], FType.OTHER),
(Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], FType.ATTR),
(Union[Ann[Data[Any], "data"], Ann[Any, "any"]], FType.DATA),
(Union[Ann[Index[Any], "index"], Ann[Any, "any"]], FType.INDEX),
(Union[Ann[Name[Any], "name"], Ann[Any, "any"]], FType.NAME),
(Union[Ann[Any, "other"], Ann[Any, "any"]], FType.OTHER),
]

testdata_name = [
(Attr[Any], None),
(Data[Any], None),
(Index[Any], None),
(Name[Any], None),
(Annotated[Attr[Any], "attr"], "attr"),
(Annotated[Data[Any], "data"], "data"),
(Annotated[Index[Any], "index"], "index"),
(Annotated[Name[Any], "name"], "name"),
(Any, None),
(Ann[Attr[Any], "attr"], "attr"),
(Ann[Data[Any], "data"], "data"),
(Ann[Index[Any], "index"], "index"),
(Ann[Name[Any], "name"], "name"),
(Ann[Any, "other"], None),
(Union[Ann[Attr[Any], "attr"], Ann[Any, "any"]], "attr"),
(Union[Ann[Data[Any], "data"], Ann[Any, "any"]], "data"),
(Union[Ann[Index[Any], "index"], Ann[Any, "any"]], "index"),
(Union[Ann[Name[Any], "name"], Ann[Any, "any"]], "name"),
(Union[Ann[Any, "other"], Ann[Any, "any"]], None),
]


Expand All @@ -64,7 +85,7 @@ def test_get_dtype(tp: Any, dtype: Any) -> None:

@mark.parametrize("tp, ftype", testdata_ftype)
def test_get_ftype(tp: Any, ftype: Any) -> None:
assert get_ftype(tp).value == ftype
assert get_ftype(tp) is ftype


@mark.parametrize("tp, name", testdata_name)
Expand Down