diff --git a/tests/core/test_live_streaming.py b/tests/core/test_live_streaming.py index d012a43..ef0e893 100644 --- a/tests/core/test_live_streaming.py +++ b/tests/core/test_live_streaming.py @@ -1,4 +1,4 @@ -"""Deterministic tests for SignalR live streaming helpers.""" +"""Deterministic tests for the live streaming session abstraction.""" from __future__ import annotations @@ -8,7 +8,7 @@ import types from typing import Any, Callable, Dict, List, Tuple -if "httpx" not in sys.modules: # pragma: no cover - import shim for optional dependency +if "httpx" not in sys.modules: # pragma: no cover - optional dependency shim httpx_stub = types.ModuleType("httpx") class _StubResponse: @@ -44,6 +44,27 @@ def close(self) -> None: # noqa: D401 - interface compatibility from toptek.core import live +class DummyGateway: + """Gateway double that only tracks auth header refreshes.""" + + def __init__(self) -> None: + self._counter = 0 + self.base_url = "https://example.com/api" + self.auth_requests: List[str] = [] + + def auth_headers(self) -> Dict[str, str]: + token = f"token-{self._counter}" + self._counter += 1 + self.auth_requests.append(token) + return {"Authorization": token} + + def search_open_orders(self, payload: Dict[str, Any]) -> Dict[str, Any]: + return {"orders": payload} + + def search_positions(self, payload: Dict[str, Any]) -> Dict[str, Any]: + return {"positions": payload} + + class DummySignalRConnection: """Test double that mimics the minimal SignalR hub API surface.""" @@ -85,7 +106,7 @@ def off(self, event: str, identifier: Any | None = None) -> None: self.remove_listener(event, identifier) def send(self, method: str, args: List[Any]) -> None: - self.sent.append((method, args)) + self.sent.append((method, list(args))) def on_open(self, callback: Callable[[], None]) -> None: self._open_callbacks.append(callback) @@ -107,7 +128,7 @@ def emit(self, event: str, payload: Any) -> None: class DummyHubConnectionBuilder: - """Builder double compatible with :func:`connect_market_hub`.""" + """Builder double compatible with :class:`GatewayStreamingSession`.""" instances: List["DummyHubConnectionBuilder"] = [] @@ -119,7 +140,7 @@ def __init__(self) -> None: def with_url(self, url: str, options: Dict[str, Any] | None = None) -> "DummyHubConnectionBuilder": self.url = url - self.options = options + self.options = options or {} return self def build(self) -> DummySignalRConnection: @@ -131,88 +152,114 @@ def reset_builder_instances() -> None: DummyHubConnectionBuilder.instances.clear() -def test_connect_market_hub_merges_headers_and_closes(monkeypatch: pytest.MonkeyPatch) -> None: +def test_streaming_session_fanout_and_resubscribe(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(live, "_require_signalr_builder", lambda: DummyHubConnectionBuilder) - opened: List[bool] = [] - closed: List[Any] = [] + gateway = DummyGateway() + session = live.GatewayStreamingSession(gateway) - handle = live.connect_market_hub( - "https://example.com/api", - hub_path="stream", - headers={"Authorization": "token"}, - options={"headers": {"User-Agent": "ProjectX"}}, - on_open=lambda: opened.append(True), - on_close=lambda exc: closed.append(exc), + ticker_events: List[Tuple[str, Any]] = [] + bar_events: List[Tuple[str, str, Any]] = [] + depth_events: List[Tuple[str, Any]] = [] + order_events: List[Tuple[str, Any]] = [] + position_events: List[Tuple[str, Any]] = [] + trade_events: List[Tuple[str, Any]] = [] + account_events: List[Any] = [] + + ticker_handle = session.market.subscribe_ticker( + "ES=F", lambda symbol, payload: ticker_events.append((symbol, payload)) + ) + session.market.subscribe_bars( + "NQ=F", + "1m", + lambda symbol, timeframe, payload: bar_events.append((symbol, timeframe, payload)), + ) + session.market.subscribe_depth( + "CL=F", + lambda symbol, payload: depth_events.append((symbol, payload)), ) - assert isinstance(handle, live.HubConnectionHandle) - builder = DummyHubConnectionBuilder.instances[-1] - assert builder.url == "https://example.com/api/stream" - assert builder.options == { - "headers": {"User-Agent": "ProjectX", "Authorization": "token"} - } + session.user.subscribe_orders( + "ACCT1", lambda account, payload: order_events.append((account, payload)) + ) + session.user.subscribe_positions( + "ACCT1", lambda account, payload: position_events.append((account, payload)) + ) + session.user.subscribe_trades( + "ACCT1", lambda account, payload: trade_events.append((account, payload)) + ) + session.user.subscribe_accounts(lambda payload: account_events.append(payload)) - connection = handle.connection - assert connection.started is True + assert len(DummyHubConnectionBuilder.instances) == 2 - connection.trigger_open() - assert opened == [True] + market_builder, user_builder = DummyHubConnectionBuilder.instances + assert market_builder.url == "https://example.com/api/marketHub" + assert user_builder.url == "https://example.com/api/userHub" - connection.trigger_close(None) - assert closed == [None] + market_connection = market_builder.connection + user_connection = user_builder.connection + assert market_connection.started is True + assert user_connection.started is True - handle.close() - assert connection.stopped is True + market_connection.trigger_open() + user_connection.trigger_open() + market_connection.emit("ticker_update", {"bid": 1}) + market_connection.emit("bar_update", {"close": 4100}) + market_connection.emit("depth_update", {"levels": []}) + user_connection.emit("order_update", {"id": 1}) + user_connection.emit("position_update", {"symbol": "ES=F"}) + user_connection.emit("trade_update", {"qty": 2}) + user_connection.emit("account_update", {"margin": 1000}) -def test_subscribe_ticker_dispatch_and_unsubscribe() -> None: - connection = DummySignalRConnection() - events: List[Tuple[str, Any]] = [] + assert ticker_events == [("ES=F", {"bid": 1})] + assert bar_events == [("NQ=F", "1m", {"close": 4100})] + assert depth_events == [("CL=F", {"levels": []})] + assert order_events == [("ACCT1", {"id": 1})] + assert position_events == [("ACCT1", {"symbol": "ES=F"})] + assert trade_events == [("ACCT1", {"qty": 2})] + assert account_events == [{"margin": 1000}] - handle = live.subscribe_ticker( - connection, - "ES=F", - lambda symbol, payload: events.append((symbol, payload)), - event="ticker", - ) + sent_methods = [method for method, _ in market_connection.sent] + assert sent_methods.count("SubscribeTicker") == 1 + assert sent_methods.count("SubscribeBars") == 1 + assert sent_methods.count("SubscribeDepth") == 1 - assert handle.connection is connection - assert connection.sent == [("SubscribeTicker", ["ES=F"])] + market_connection.trigger_close(None) - connection.emit("ticker", {"bid": 1}) - assert events == [("ES=F", {"bid": 1})] + assert len(DummyHubConnectionBuilder.instances) == 3 + reconnect_builder = DummyHubConnectionBuilder.instances[-1] + new_market_connection = reconnect_builder.connection + assert reconnect_builder.options == {"headers": {"Authorization": "token-3"}} - handle.unsubscribe() - assert ("UnsubscribeTicker", ["ES=F"]) in connection.sent + new_market_connection.trigger_open() + resubscribe_calls = [method for method, _ in new_market_connection.sent] + assert resubscribe_calls.count("SubscribeTicker") == 1 + assert resubscribe_calls.count("SubscribeBars") == 1 + assert resubscribe_calls.count("SubscribeDepth") == 1 - connection.emit("ticker", {"bid": 2}) - assert events == [("ES=F", {"bid": 1})] + new_market_connection.emit("ticker_update", {"bid": 2}) + assert ticker_events[-1] == ("ES=F", {"bid": 2}) + ticker_handle.close() + assert ("UnsubscribeTicker", ["ES=F"]) in new_market_connection.sent -def test_subscribe_bars_uses_handle_and_timeframe() -> None: - handle = live.HubConnectionHandle(DummySignalRConnection()) - events: List[Tuple[str, str, Any]] = [] + session.close() + assert new_market_connection.stopped is True + assert user_connection.stopped is True - subscription = live.subscribe_bars( - handle, - "NQ=F", - "1m", - lambda symbol, timeframe, payload: events.append((symbol, timeframe, payload)), - event="bars", - ) + assert gateway.auth_requests[:2] == ["token-0", "token-1"] - connection = handle.connection - assert connection.sent == [("SubscribeBars", ["NQ=F", "1m"])] - connection.emit("bars", {"close": 4100}) - assert events == [("NQ=F", "1m", {"close": 4100})] +def test_poll_helpers_use_gateway() -> None: + gateway = DummyGateway() + context = live.ExecutionContext(gateway=gateway, account_id="ACCT2") - subscription.unsubscribe() - assert ("UnsubscribeBars", ["NQ=F", "1m"]) in connection.sent + orders = live.poll_open_orders(context) + positions = live.poll_positions(context) - connection.emit("bars", {"close": 4200}) - assert events == [("NQ=F", "1m", {"close": 4100})] + assert orders == {"orders": {"accountId": "ACCT2"}} + assert positions == {"positions": {"accountId": "ACCT2"}} def test_utils_module_behaviour(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: @@ -289,3 +336,4 @@ def _version_with_error(package: str) -> str: message = exc.value.args[0] assert "Missing packages" in message assert "Version mismatches" in message + diff --git a/tests/gui/test_ui_live_tab_state.py b/tests/gui/test_ui_live_tab_state.py index e04a465..fa810d0 100644 --- a/tests/gui/test_ui_live_tab_state.py +++ b/tests/gui/test_ui_live_tab_state.py @@ -7,6 +7,7 @@ import pytest tk = pytest.importorskip("tkinter") +pytest.importorskip("pandas") from tkinter import ttk # noqa: E402 from toptek.ui.live_tab import LiveTab # noqa: E402 diff --git a/toptek/core/gateway.py b/toptek/core/gateway.py index dc75c08..baf6713 100644 --- a/toptek/core/gateway.py +++ b/toptek/core/gateway.py @@ -83,6 +83,21 @@ def _headers(self) -> Dict[str, str]: "Content-Type": "application/json", } + @property + def base_url(self) -> str: + """Expose the configured base URL for downstream consumers.""" + + return self._config.base_url + + def auth_headers(self) -> Dict[str, str]: + """Return authorization headers, refreshing the JWT if required.""" + + if not self._token: + self.login() + else: + self._validate() + return dict(self._headers) + def _request(self, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]: """Send a POST request with automatic token validation.""" diff --git a/toptek/core/live.py b/toptek/core/live.py index a115f4f..ae91f96 100644 --- a/toptek/core/live.py +++ b/toptek/core/live.py @@ -1,9 +1,11 @@ -"""Live trading utilities with optional SignalR streaming helpers.""" +"""Live trading streaming abstractions for ProjectX hubs.""" from __future__ import annotations +from collections import defaultdict from dataclasses import dataclass -from typing import Any, Callable, Dict, MutableMapping, Optional, Sequence +import threading +from typing import Any, Callable, Dict, Iterable, MutableMapping, Optional, Sequence, Tuple from .gateway import ProjectXGateway @@ -23,35 +25,19 @@ class ExecutionContext: @dataclass -class HubConnectionHandle: - """Wrapper around a SignalR hub connection with a close helper.""" +class HubSubscriptionHandle: + """Represents an active hub subscription that can be torn down.""" - connection: Any + _subscription: "_Subscription | None" def close(self) -> None: - """Stop the underlying connection if it exposes a stop method.""" + """Close the underlying subscription.""" - if hasattr(self.connection, "stop"): - self.connection.stop() + if self._subscription is not None: + self._subscription.close() + self._subscription = None - -@dataclass -class SubscriptionHandle: - """Represents an active SignalR subscription that can be torn down.""" - - connection: Any - event: str - handler: Callable[[Any], None] - handler_token: Any - unsubscribe_method: Optional[str] - unsubscribe_payload: Sequence[Any] - - def unsubscribe(self) -> None: - """Detach the handler and propagate an unsubscribe message.""" - - _remove_listener(self.connection, self.event, self.handler, self.handler_token) - if self.unsubscribe_method: - self.connection.send(self.unsubscribe_method, list(self.unsubscribe_payload)) + unsubscribe = close def poll_open_orders(context: ExecutionContext) -> Dict[str, object]: @@ -66,96 +52,448 @@ def poll_positions(context: ExecutionContext) -> Dict[str, object]: return context.gateway.search_positions({"accountId": context.account_id}) -def connect_market_hub( - base_url: str, - *, - hub_path: str = "marketHub", - headers: Optional[MutableMapping[str, str]] = None, - options: Optional[MutableMapping[str, Any]] = None, - auto_start: bool = True, - on_open: Optional[Callable[[], None]] = None, - on_close: Optional[Callable[[Optional[Exception]], None]] = None, -) -> HubConnectionHandle: - """Create and optionally start a SignalR hub connection.""" +class GatewayStreamingSession: + """Manage live SignalR hubs bound to a :class:`ProjectXGateway`.""" + + def __init__( + self, + gateway: ProjectXGateway, + *, + market_hub_path: str = "marketHub", + user_hub_path: str = "userHub", + reconnect_delay: float = 0.0, + connection_builder: Optional[Callable[[], Any]] = None, + ) -> None: + self._gateway = gateway + self._reconnect_delay = reconnect_delay + self._builder_factory = connection_builder or _require_signalr_builder + self.market = MarketHub(self, market_hub_path) + self.user = UserHub(self, user_hub_path) - builder_cls = _require_signalr_builder() - normalized_url = _join_url(base_url, hub_path) - connection_options = _merge_options(options, headers) + def close(self) -> None: + """Close both hub connections.""" + + self.market.close() + self.user.close() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _headers(self) -> Dict[str, str]: + return self._gateway.auth_headers() + + def _build_url(self, hub_path: str) -> str: + return _join_url(self._gateway.base_url, hub_path) + + def _connection_builder(self) -> Any: + builder_cls = self._builder_factory() + return builder_cls() + + def _refresh_token(self) -> None: + self._gateway.auth_headers() + + +class MarketHub: + """High level helpers for market hub subscriptions.""" + + def __init__(self, session: GatewayStreamingSession, hub_path: str) -> None: + self._session = session + self._hub = _StreamingHub(session, hub_path) + + def subscribe_ticker( + self, + symbol: str, + callback: Callable[[str, Any], None], + *, + event: str = "ticker_update", + subscribe_method: str = "SubscribeTicker", + unsubscribe_method: str = "UnsubscribeTicker", + ) -> HubSubscriptionHandle: + return self._hub.subscribe( + event=event, + subscribe_method=subscribe_method, + unsubscribe_method=unsubscribe_method, + args=(symbol,), + callback=callback, + prefix_args=(symbol,), + ) + + def subscribe_bars( + self, + symbol: str, + timeframe: str, + callback: Callable[[str, str, Any], None], + *, + event: str = "bar_update", + subscribe_method: str = "SubscribeBars", + unsubscribe_method: str = "UnsubscribeBars", + ) -> HubSubscriptionHandle: + return self._hub.subscribe( + event=event, + subscribe_method=subscribe_method, + unsubscribe_method=unsubscribe_method, + args=(symbol, timeframe), + callback=callback, + prefix_args=(symbol, timeframe), + ) + + def subscribe_depth( + self, + symbol: str, + callback: Callable[[str, Any], None], + *, + event: str = "depth_update", + subscribe_method: str = "SubscribeDepth", + unsubscribe_method: str = "UnsubscribeDepth", + ) -> HubSubscriptionHandle: + return self._hub.subscribe( + event=event, + subscribe_method=subscribe_method, + unsubscribe_method=unsubscribe_method, + args=(symbol,), + callback=callback, + prefix_args=(symbol,), + ) - connection = builder_cls().with_url(normalized_url, options=connection_options).build() + def close(self) -> None: + self._hub.close() + + +class UserHub: + """Helpers for user/account hub subscriptions.""" + + def __init__(self, session: GatewayStreamingSession, hub_path: str) -> None: + self._session = session + self._hub = _StreamingHub(session, hub_path) + + def subscribe_accounts( + self, + callback: Callable[[Any], None], + account_id: Optional[str] = None, + *, + event: str = "account_update", + subscribe_method: str = "SubscribeAccounts", + unsubscribe_method: str = "UnsubscribeAccounts", + ) -> HubSubscriptionHandle: + args = tuple(filter(None, (account_id,))) + prefix: Tuple[str, ...] = (account_id,) if account_id else tuple() + return self._hub.subscribe( + event=event, + subscribe_method=subscribe_method, + unsubscribe_method=unsubscribe_method, + args=args, + callback=callback, + prefix_args=prefix, + ) + + def subscribe_orders( + self, + account_id: str, + callback: Callable[[str, Any], None], + *, + event: str = "order_update", + subscribe_method: str = "SubscribeOrders", + unsubscribe_method: str = "UnsubscribeOrders", + ) -> HubSubscriptionHandle: + return self._account_subscription( + event, + subscribe_method, + unsubscribe_method, + account_id, + callback, + ) + + def subscribe_positions( + self, + account_id: str, + callback: Callable[[str, Any], None], + *, + event: str = "position_update", + subscribe_method: str = "SubscribePositions", + unsubscribe_method: str = "UnsubscribePositions", + ) -> HubSubscriptionHandle: + return self._account_subscription( + event, + subscribe_method, + unsubscribe_method, + account_id, + callback, + ) + + def subscribe_trades( + self, + account_id: str, + callback: Callable[[str, Any], None], + *, + event: str = "trade_update", + subscribe_method: str = "SubscribeTrades", + unsubscribe_method: str = "UnsubscribeTrades", + ) -> HubSubscriptionHandle: + return self._account_subscription( + event, + subscribe_method, + unsubscribe_method, + account_id, + callback, + ) - if on_open and hasattr(connection, "on_open"): - connection.on_open(on_open) - if on_close and hasattr(connection, "on_close"): - connection.on_close(on_close) + def close(self) -> None: + self._hub.close() + + def _account_subscription( + self, + event: str, + subscribe_method: str, + unsubscribe_method: str, + account_id: str, + callback: Callable[[str, Any], None], + ) -> HubSubscriptionHandle: + return self._hub.subscribe( + event=event, + subscribe_method=subscribe_method, + unsubscribe_method=unsubscribe_method, + args=(account_id,), + callback=callback, + prefix_args=(account_id,), + ) + + +class _StreamingHub: + """Manage a single SignalR hub connection and its subscriptions.""" + + def __init__(self, session: GatewayStreamingSession, hub_path: str) -> None: + self._session = session + self._hub_path = hub_path + self._lock = threading.RLock() + self._connection: Any | None = None + self._dispatchers: Dict[str, Callable[[Any], None]] = {} + self._tokens: Dict[str, Any] = {} + self._subscriptions: Dict[str, list[_Subscription]] = defaultdict(list) + self._stopping = False + self._reconnect_delay = session._reconnect_delay + self._pending_timer: threading.Timer | None = None + + def subscribe( + self, + *, + event: str, + subscribe_method: Optional[str], + unsubscribe_method: Optional[str], + args: Sequence[Any], + callback: Callable[..., None], + prefix_args: Sequence[Any] = (), + adapter: Optional[Callable[[Any], Any]] = None, + ) -> HubSubscriptionHandle: + subscription = _Subscription( + hub=self, + event=event, + subscribe_method=subscribe_method, + unsubscribe_method=unsubscribe_method, + args=tuple(args), + callback=callback, + prefix_args=tuple(prefix_args), + adapter=adapter, + ) + + connection: Any + should_start = False + with self._lock: + connection, should_start = self._ensure_connection_locked() + self._register_dispatcher_locked(event, connection) + self._subscriptions[event].append(subscription) + + if should_start: + _start_connection(connection) + + if subscribe_method: + connection.send(subscribe_method, list(args)) + subscription.synced = True + + return HubSubscriptionHandle(subscription) - if auto_start and hasattr(connection, "start"): - connection.start() + def close(self) -> None: + with self._lock: + self._stopping = True + if self._pending_timer is not None: + self._pending_timer.cancel() + self._pending_timer = None + connection = self._connection + self._connection = None + self._dispatchers.clear() + self._tokens.clear() + self._subscriptions.clear() + if connection is not None and hasattr(connection, "stop"): + connection.stop() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _ensure_connection_locked(self) -> tuple[Any, bool]: + if self._connection is not None: + return self._connection, False + + builder = self._session._connection_builder() + headers = self._session._headers() + options: MutableMapping[str, Any] = _merge_options({}, headers) + connection = builder.with_url( + self._session._build_url(self._hub_path), options=options + ).build() + + if hasattr(connection, "on_open"): + connection.on_open(self._handle_open) + if hasattr(connection, "on_close"): + connection.on_close(self._handle_close) + + for event, dispatcher in self._dispatchers.items(): + _, token = _register_handler(connection, event, dispatcher) + self._tokens[event] = token + + self._connection = connection + return connection, True + + def _register_dispatcher_locked(self, event: str, connection: Any) -> None: + if event in self._dispatchers: + if event not in self._tokens: + handler = self._dispatchers[event] + _, token = _register_handler(connection, event, handler) + self._tokens[event] = token + return - return HubConnectionHandle(connection) - - -def subscribe_ticker( - connection: HubConnectionHandle | Any, - symbol: str, - callback: Callable[[str, Any], None], - *, - event: str = "ticker_update", - subscribe_method: Optional[str] = "SubscribeTicker", - unsubscribe_method: Optional[str] = "UnsubscribeTicker", -) -> SubscriptionHandle: - """Attach a ticker listener and optionally send subscribe/unsubscribe calls.""" - - signalr_connection = _unwrap_connection(connection) - handler, token = _register_handler( - signalr_connection, - event, - _wrap_payload(callback, symbol), - ) - - if subscribe_method: - signalr_connection.send(subscribe_method, [symbol]) - - return SubscriptionHandle( - signalr_connection, - event, - handler, - token, - unsubscribe_method, - [symbol], - ) - - -def subscribe_bars( - connection: HubConnectionHandle | Any, - symbol: str, - timeframe: str, - callback: Callable[[str, str, Any], None], - *, - event: str = "bar_update", - subscribe_method: Optional[str] = "SubscribeBars", - unsubscribe_method: Optional[str] = "UnsubscribeBars", -) -> SubscriptionHandle: - """Attach a bar listener for the provided symbol and timeframe.""" - - signalr_connection = _unwrap_connection(connection) - handler, token = _register_handler( - signalr_connection, - event, - _wrap_payload(callback, symbol, timeframe), - ) - - if subscribe_method: - signalr_connection.send(subscribe_method, [symbol, timeframe]) - - return SubscriptionHandle( - signalr_connection, - event, - handler, - token, - unsubscribe_method, - [symbol, timeframe], - ) + def _dispatcher(message: Any) -> None: + self._fan_out(event, message) + + self._dispatchers[event] = _dispatcher + _, token = _register_handler(connection, event, _dispatcher) + self._tokens[event] = token + + def _fan_out(self, event: str, message: Any) -> None: + for subscription in list(self._subscriptions.get(event, [])): + subscription.deliver(message) + + def _handle_open(self) -> None: + subscriptions: Iterable[_Subscription] + with self._lock: + subscriptions = [ + subscription + for subs in self._subscriptions.values() + for subscription in subs + if not subscription.synced and subscription.subscribe_method + ] + if not subscriptions: + return + connection = self._connection + if connection is None: + return + for subscription in subscriptions: + connection.send(subscription.subscribe_method, list(subscription.args)) + with self._lock: + for subscription in subscriptions: + subscription.synced = True + + def _handle_close(self, error: Any = None) -> None: + with self._lock: + if self._stopping: + return + for subscriptions in self._subscriptions.values(): + for subscription in subscriptions: + subscription.synced = False + self._connection = None + self._tokens.clear() + if self._pending_timer is not None: + self._pending_timer.cancel() + self._pending_timer = None + self._session._refresh_token() + if self._reconnect_delay <= 0: + self._ensure_connection() + else: + timer = threading.Timer(self._reconnect_delay, self._ensure_connection) + timer.daemon = True + with self._lock: + if self._stopping: + return + self._pending_timer = timer + timer.start() + + def _ensure_connection(self) -> None: + connection: Any + should_start = False + with self._lock: + if self._stopping: + return + if self._pending_timer is not None: + self._pending_timer.cancel() + self._pending_timer = None + connection, should_start = self._ensure_connection_locked() + if should_start: + _start_connection(connection) + + def _remove_subscription(self, subscription: "_Subscription") -> None: + connection: Any | None = None + send_unsubscribe = False + unsubscribe_method: Optional[str] = None + args: Tuple[Any, ...] = () + + with self._lock: + subscriptions = self._subscriptions.get(subscription.event) + if not subscriptions or subscription not in subscriptions: + return + subscriptions.remove(subscription) + if not subscriptions: + self._subscriptions.pop(subscription.event, None) + handler = self._dispatchers.pop(subscription.event, None) + token = self._tokens.pop(subscription.event, None) + connection = self._connection + if connection is not None and handler is not None: + _remove_listener(connection, subscription.event, handler, token) + if subscription.unsubscribe_method and subscription.synced: + connection = connection or self._connection + send_unsubscribe = True + unsubscribe_method = subscription.unsubscribe_method + args = subscription.args + + if send_unsubscribe and connection is not None and unsubscribe_method: + connection.send(unsubscribe_method, list(args)) + + +class _Subscription: + """Track individual subscriber state.""" + + def __init__( + self, + *, + hub: _StreamingHub, + event: str, + subscribe_method: Optional[str], + unsubscribe_method: Optional[str], + args: Tuple[Any, ...], + callback: Callable[..., None], + prefix_args: Tuple[Any, ...], + adapter: Optional[Callable[[Any], Any]], + ) -> None: + self._hub = hub + self.event = event + self.subscribe_method = subscribe_method + self.unsubscribe_method = unsubscribe_method + self.args = tuple(arg for arg in args if arg is not None) + self.callback = callback + self.prefix_args = prefix_args + self.adapter = adapter + self.synced = False + self._closed = False + + def deliver(self, message: Any) -> None: + if self._closed: + return + payload = _extract_payload(message) + if self.adapter is not None: + payload = self.adapter(payload) + self.callback(*self.prefix_args, payload) + + def close(self) -> None: + if self._closed: + return + self._closed = True + self._hub._remove_subscription(self) def _require_signalr_builder(): @@ -167,24 +505,9 @@ def _require_signalr_builder(): return HubConnectionBuilder -def _unwrap_connection(connection: HubConnectionHandle | Any) -> Any: - return connection.connection if isinstance(connection, HubConnectionHandle) else connection - - -def _register_handler(connection: Any, event: str, handler: Callable[[Any], None]) -> tuple[Callable[[Any], None], Any]: - token = connection.on(event, handler) - return handler, token - - -def _wrap_payload( - callback: Callable[..., None], - *prefix_args: str, -) -> Callable[[Any], None]: - def _inner(message: Any) -> None: - payload = _extract_payload(message) - callback(*prefix_args, payload) - - return _inner +def _start_connection(connection: Any) -> None: + if hasattr(connection, "start"): + connection.start() def _extract_payload(message: Any) -> Any: @@ -215,6 +538,11 @@ def _remove_listener(connection: Any, event: str, handler: Callable[[Any], None] connection.off(event) +def _register_handler(connection: Any, event: str, handler: Callable[[Any], None]) -> tuple[Callable[[Any], None], Any]: + token = connection.on(event, handler) + return handler, token + + def _merge_options( options: Optional[MutableMapping[str, Any]], headers: Optional[MutableMapping[str, str]], @@ -235,11 +563,11 @@ def _join_url(base_url: str, hub_path: str) -> str: __all__ = [ "ExecutionContext", - "HubConnectionHandle", - "SubscriptionHandle", + "GatewayStreamingSession", + "HubSubscriptionHandle", + "MarketHub", + "UserHub", "poll_open_orders", "poll_positions", - "connect_market_hub", - "subscribe_ticker", - "subscribe_bars", ] +