Skip to content

Commit

Permalink
Enhanced typing of subscribe()
Browse files Browse the repository at this point in the history
This PR adds type var `T_Event` and uses it to type the `subscribe()` interfaces.

Fixes an issue where subscribing a handler would raise a type error if the handler is typed to receive a subclass of `Event`.

Allows type checkers to verify that the handler passed to `subscribe()` can support the event types it is assigned to handle.

Closes #846
  • Loading branch information
peterschutt committed Jan 13, 2024
1 parent 5bb4d81 commit dec885a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
4 changes: 3 additions & 1 deletion src/apscheduler/_events.py
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime, timezone
from functools import partial
from traceback import format_tb
from typing import Any
from typing import Any, TypeVar
from uuid import UUID

import attrs
Expand All @@ -14,6 +14,8 @@
from ._structures import Job, JobResult
from ._utils import qualified_name

T_Event = TypeVar("T_Event", bound="Event")


@attrs.define(kw_only=True, frozen=True)
class Event:
Expand Down
5 changes: 3 additions & 2 deletions src/apscheduler/_schedulers/async_.py
Expand Up @@ -38,6 +38,7 @@
SchedulerStarted,
SchedulerStopped,
ScheduleUpdated,
T_Event,
)
from .._exceptions import (
CallableLookupError,
Expand Down Expand Up @@ -216,8 +217,8 @@ async def cleanup(self) -> None:

def subscribe(
self,
callback: Callable[[Event], Any],
event_types: type[Event] | Iterable[type[Event]] | None = None,
callback: Callable[[T_Event], Any],
event_types: type[T_Event] | Iterable[type[T_Event]] | None = None,
*,
one_shot: bool = False,
is_async: bool = True,
Expand Down
7 changes: 4 additions & 3 deletions src/apscheduler/_schedulers/sync.py
Expand Up @@ -15,7 +15,8 @@

from anyio.from_thread import BlockingPortal, start_blocking_portal

from .. import Event, current_scheduler
from .. import current_scheduler
from .._events import T_Event
from .._enums import CoalescePolicy, ConflictPolicy, RunState, SchedulerRole
from .._structures import Job, JobResult, Schedule, Task
from .._utils import UnsetValue, unset
Expand Down Expand Up @@ -157,8 +158,8 @@ def cleanup(self) -> None:

def subscribe(
self,
callback: Callable[[Event], Any],
event_types: Iterable[type[Event]] | None = None,
callback: Callable[[T_Event], Any],
event_types: Iterable[type[T_Event]] | None = None,
*,
one_shot: bool = False,
) -> Subscription:
Expand Down
6 changes: 3 additions & 3 deletions src/apscheduler/abc.py
Expand Up @@ -15,7 +15,7 @@

if TYPE_CHECKING:
from ._enums import ConflictPolicy
from ._events import Event
from ._events import Event, T_Event
from ._structures import Job, JobResult, Schedule, Task


Expand Down Expand Up @@ -135,8 +135,8 @@ async def publish_local(self, event: Event) -> None:
@abstractmethod
def subscribe(
self,
callback: Callable[[Event], Any],
event_types: Iterable[type[Event]] | None = None,
callback: Callable[[T_Event], Any],
event_types: Iterable[type[T_Event]] | None = None,
*,
is_async: bool = True,
one_shot: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions src/apscheduler/eventbrokers/base.py
Expand Up @@ -11,7 +11,7 @@
from anyio.abc import TaskGroup

from .. import _events
from .._events import Event
from .._events import Event, T_Event
from .._exceptions import DeserializationError
from .._retry import RetryMixin
from ..abc import EventBroker, Serializer, Subscription
Expand Down Expand Up @@ -47,8 +47,8 @@ async def start(self, exit_stack: AsyncExitStack, logger: Logger) -> None:

def subscribe(
self,
callback: Callable[[Event], Any],
event_types: Iterable[type[Event]] | None = None,
callback: Callable[[T_Event], Any],
event_types: Iterable[type[T_Event]] | None = None,
*,
is_async: bool = True,
one_shot: bool = False,
Expand Down

0 comments on commit dec885a

Please sign in to comment.