Skip to content

Commit

Permalink
Merge pull request #4659 from Textualize/type-checking-overload
Browse files Browse the repository at this point in the history
wrap overloads
  • Loading branch information
willmcgugan committed Jun 17, 2024
2 parents 2c97dec + 68706b0 commit 0d78a58
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 206 deletions.
12 changes: 7 additions & 5 deletions src/textual/_immutable_sequence_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from sys import maxsize
from typing import Generic, Iterator, Sequence, TypeVar, overload
from typing import TYPE_CHECKING, Generic, Iterator, Sequence, TypeVar, overload

T = TypeVar("T")

Expand All @@ -19,11 +19,13 @@ def __init__(self, wrap: Sequence[T]) -> None:
"""
self._wrap = wrap

@overload
def __getitem__(self, index: int) -> T: ...
if TYPE_CHECKING:

@overload
def __getitem__(self, index: slice) -> ImmutableSequenceView[T]: ...
@overload
def __getitem__(self, index: int) -> T: ...

@overload
def __getitem__(self, index: slice) -> ImmutableSequenceView[T]: ...

def __getitem__(self, index: int | slice) -> T | ImmutableSequenceView[T]:
return (
Expand Down
10 changes: 6 additions & 4 deletions src/textual/_node_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,13 @@ def __iter__(self) -> Iterator[Widget]:
def __reversed__(self) -> Iterator[Widget]:
return reversed(self._nodes)

@overload
def __getitem__(self, index: int) -> Widget: ...
if TYPE_CHECKING:

@overload
def __getitem__(self, index: slice) -> list[Widget]: ...
@overload
def __getitem__(self, index: int) -> Widget: ...

@overload
def __getitem__(self, index: slice) -> list[Widget]: ...

def __getitem__(self, index: int | slice) -> Widget | list[Widget]:
return self._nodes[index]
72 changes: 36 additions & 36 deletions src/textual/_work_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,42 +33,42 @@ class WorkerDeclarationError(Exception):
"""An error in the declaration of a worker method."""


@overload
def work(
method: Callable[FactoryParamSpec, Coroutine[None, None, ReturnType]],
*,
name: str = "",
group: str = "default",
exit_on_error: bool = True,
exclusive: bool = False,
description: str | None = None,
thread: bool = False,
) -> Callable[FactoryParamSpec, "Worker[ReturnType]"]: ...


@overload
def work(
method: Callable[FactoryParamSpec, ReturnType],
*,
name: str = "",
group: str = "default",
exit_on_error: bool = True,
exclusive: bool = False,
description: str | None = None,
thread: bool = False,
) -> Callable[FactoryParamSpec, "Worker[ReturnType]"]: ...

if TYPE_CHECKING:

@overload
def work(
*,
name: str = "",
group: str = "default",
exit_on_error: bool = True,
exclusive: bool = False,
description: str | None = None,
thread: bool = False,
) -> Decorator[..., ReturnType]: ...
@overload
def work(
method: Callable[FactoryParamSpec, Coroutine[None, None, ReturnType]],
*,
name: str = "",
group: str = "default",
exit_on_error: bool = True,
exclusive: bool = False,
description: str | None = None,
thread: bool = False,
) -> Callable[FactoryParamSpec, "Worker[ReturnType]"]: ...

@overload
def work(
method: Callable[FactoryParamSpec, ReturnType],
*,
name: str = "",
group: str = "default",
exit_on_error: bool = True,
exclusive: bool = False,
description: str | None = None,
thread: bool = False,
) -> Callable[FactoryParamSpec, "Worker[ReturnType]"]: ...

@overload
def work(
*,
name: str = "",
group: str = "default",
exit_on_error: bool = True,
exclusive: bool = False,
description: str | None = None,
thread: bool = False,
) -> Decorator[..., ReturnType]: ...


def work(
Expand Down Expand Up @@ -103,7 +103,7 @@ def decorator(
method: (
Callable[DecoratorParamSpec, ReturnType]
| Callable[DecoratorParamSpec, Coroutine[None, None, ReturnType]]
)
),
) -> Callable[DecoratorParamSpec, Worker[ReturnType]]:
"""The decorator."""

Expand Down
70 changes: 40 additions & 30 deletions src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,11 +1680,15 @@ def render(self) -> RenderResult:

ExpectType = TypeVar("ExpectType", bound=Widget)

@overload
def get_child_by_id(self, id: str) -> Widget: ...
if TYPE_CHECKING:

@overload
def get_child_by_id(self, id: str, expect_type: type[ExpectType]) -> ExpectType: ...
@overload
def get_child_by_id(self, id: str) -> Widget: ...

@overload
def get_child_by_id(
self, id: str, expect_type: type[ExpectType]
) -> ExpectType: ...

def get_child_by_id(
self, id: str, expect_type: type[ExpectType] | None = None
Expand All @@ -1709,13 +1713,15 @@ def get_child_by_id(
else self.screen.get_child_by_id(id, expect_type)
)

@overload
def get_widget_by_id(self, id: str) -> Widget: ...
if TYPE_CHECKING:

@overload
def get_widget_by_id(
self, id: str, expect_type: type[ExpectType]
) -> ExpectType: ...
@overload
def get_widget_by_id(self, id: str) -> Widget: ...

@overload
def get_widget_by_id(
self, id: str, expect_type: type[ExpectType]
) -> ExpectType: ...

def get_widget_by_id(
self, id: str, expect_type: type[ExpectType] | None = None
Expand Down Expand Up @@ -2044,21 +2050,23 @@ def _replace_screen(self, screen: Screen) -> Screen:
self.log.system(f"{screen} REMOVED")
return screen

@overload
def push_screen(
self,
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: Literal[False] = False,
) -> AwaitMount: ...
if TYPE_CHECKING:

@overload
def push_screen(
self,
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: Literal[True] = True,
) -> asyncio.Future[ScreenResultType]: ...
@overload
def push_screen(
self,
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: Literal[False] = False,
) -> AwaitMount: ...

@overload
def push_screen(
self,
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: Literal[True] = True,
) -> asyncio.Future[ScreenResultType]: ...

def push_screen(
self,
Expand Down Expand Up @@ -2120,13 +2128,15 @@ def push_screen(
else:
return await_mount

@overload
async def push_screen_wait(
self, screen: Screen[ScreenResultType]
) -> ScreenResultType: ...
if TYPE_CHECKING:

@overload
async def push_screen_wait(
self, screen: Screen[ScreenResultType]
) -> ScreenResultType: ...

@overload
async def push_screen_wait(self, screen: str) -> Any: ...
@overload
async def push_screen_wait(self, screen: str) -> Any: ...

async def push_screen_wait(
self, screen: Screen[ScreenResultType] | str
Expand Down
30 changes: 17 additions & 13 deletions src/textual/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from __future__ import annotations

from typing import Dict, Generic, KeysView, TypeVar, overload
from typing import TYPE_CHECKING, Dict, Generic, KeysView, TypeVar, overload

CacheKey = TypeVar("CacheKey")
CacheValue = TypeVar("CacheValue")
Expand Down Expand Up @@ -127,13 +127,15 @@ def set(self, key: CacheKey, value: CacheValue) -> None:

__setitem__ = set

@overload
def get(self, key: CacheKey) -> CacheValue | None: ...
if TYPE_CHECKING:

@overload
def get(
self, key: CacheKey, default: DefaultValue
) -> CacheValue | DefaultValue: ...
@overload
def get(self, key: CacheKey) -> CacheValue | None: ...

@overload
def get(
self, key: CacheKey, default: DefaultValue
) -> CacheValue | DefaultValue: ...

def get(
self, key: CacheKey, default: DefaultValue | None = None
Expand Down Expand Up @@ -267,13 +269,15 @@ def set(self, key: CacheKey, value: CacheValue) -> None:

__setitem__ = set

@overload
def get(self, key: CacheKey) -> CacheValue | None: ...
if TYPE_CHECKING:

@overload
def get(
self, key: CacheKey, default: DefaultValue
) -> CacheValue | DefaultValue: ...
@overload
def get(self, key: CacheKey) -> CacheValue | None: ...

@overload
def get(
self, key: CacheKey, default: DefaultValue
) -> CacheValue | DefaultValue: ...

def get(
self, key: CacheKey, default: DefaultValue | None = None
Expand Down
50 changes: 30 additions & 20 deletions src/textual/css/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,13 @@ def __iter__(self) -> Iterator[QueryType]:
def __reversed__(self) -> Iterator[QueryType]:
return reversed(self.nodes)

@overload
def __getitem__(self, index: int) -> QueryType: ...
if TYPE_CHECKING:

@overload
def __getitem__(self, index: slice) -> list[QueryType]: ...
@overload
def __getitem__(self, index: int) -> QueryType: ...

@overload
def __getitem__(self, index: slice) -> list[QueryType]: ...

def __getitem__(self, index: int | slice) -> QueryType | list[QueryType]:
return self.nodes[index]
Expand Down Expand Up @@ -208,11 +210,13 @@ def exclude(self, selector: str) -> DOMQuery[QueryType]:
parent=self,
)

@overload
def first(self) -> QueryType: ...
if TYPE_CHECKING:

@overload
def first(self) -> QueryType: ...

@overload
def first(self, expect_type: type[ExpectType]) -> ExpectType: ...
@overload
def first(self, expect_type: type[ExpectType]) -> ExpectType: ...

def first(
self, expect_type: type[ExpectType] | None = None
Expand Down Expand Up @@ -242,11 +246,13 @@ def first(
else:
raise NoMatches(f"No nodes match {self!r} on {self.node!r}")

@overload
def only_one(self) -> QueryType: ...
if TYPE_CHECKING:

@overload
def only_one(self) -> QueryType: ...

@overload
def only_one(self, expect_type: type[ExpectType]) -> ExpectType: ...
@overload
def only_one(self, expect_type: type[ExpectType]) -> ExpectType: ...

def only_one(
self, expect_type: type[ExpectType] | None = None
Expand Down Expand Up @@ -287,11 +293,13 @@ def only_one(
pass
return the_one

@overload
def last(self) -> QueryType: ...
if TYPE_CHECKING:

@overload
def last(self, expect_type: type[ExpectType]) -> ExpectType: ...
@overload
def last(self) -> QueryType: ...

@overload
def last(self, expect_type: type[ExpectType]) -> ExpectType: ...

def last(
self, expect_type: type[ExpectType] | None = None
Expand All @@ -318,11 +326,13 @@ def last(
)
return last

@overload
def results(self) -> Iterator[QueryType]: ...
if TYPE_CHECKING:

@overload
def results(self) -> Iterator[QueryType]: ...

@overload
def results(self, filter_type: type[ExpectType]) -> Iterator[ExpectType]: ...
@overload
def results(self, filter_type: type[ExpectType]) -> Iterator[ExpectType]: ...

def results(
self, filter_type: type[ExpectType] | None = None
Expand Down
10 changes: 6 additions & 4 deletions src/textual/document/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,13 @@ def start(self) -> Location:
def end(self) -> Location:
"""Returns the location of the end of the document."""

@overload
def __getitem__(self, line_index: int) -> str: ...
if TYPE_CHECKING:

@overload
def __getitem__(self, line_index: slice) -> list[str]: ...
@overload
def __getitem__(self, line_index: int) -> str: ...

@overload
def __getitem__(self, line_index: slice) -> list[str]: ...

@abstractmethod
def __getitem__(self, line_index: int | slice) -> str | list[str]:
Expand Down
Loading

0 comments on commit 0d78a58

Please sign in to comment.