From 5359fe63510debc7f593e4c1ae2d72f2d277adee Mon Sep 17 00:00:00 2001 From: Max Kahan Date: Sat, 23 Aug 2025 03:18:59 +0100 Subject: [PATCH] refactor: location and core methods --- getstream/video/rtc/__init__.py | 3 +- getstream/video/rtc/connection_manager.py | 9 +----- getstream/video/rtc/location_discovery.py | 37 +++++++++-------------- getstream/video/rtc/pc.py | 2 +- tests/rtc/test_location_discovery.py | 3 +- 5 files changed, 19 insertions(+), 35 deletions(-) diff --git a/getstream/video/rtc/__init__.py b/getstream/video/rtc/__init__.py index b7d29b20..e07e0bdf 100644 --- a/getstream/video/rtc/__init__.py +++ b/getstream/video/rtc/__init__.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from getstream.video.call import Call from getstream.video.rtc.location_discovery import ( @@ -44,7 +45,7 @@ async def discover_location(): async def join( - call: Call, user_id: str = None, create=True, **kwargs + call: Call, user_id: Optional[str] = None, create=True, **kwargs ) -> ConnectionManager: """ Join a call. This method will: diff --git a/getstream/video/rtc/connection_manager.py b/getstream/video/rtc/connection_manager.py index 963dd942..033cc90a 100644 --- a/getstream/video/rtc/connection_manager.py +++ b/getstream/video/rtc/connection_manager.py @@ -355,7 +355,7 @@ async def leave(self): await self._network_monitor.stop_monitoring() await self._peer_manager.close() if self._ws_client: - await self._ws_client.close() + self._ws_client.close() self._ws_client = None if self._coordinator_ws_client: await self._coordinator_ws_client.disconnect() @@ -385,13 +385,6 @@ async def add_tracks(self, audio=None, video=None): """Add multiple audio and video tracks in a single negotiation.""" await self._peer_manager.add_tracks(audio, video) - async def addTrack(self, track, track_info=None): - """Add a single track (backward compatibility).""" - if track.kind == "video": - await self.add_tracks(video=track) - else: - await self.add_tracks(audio=track) - async def start_recording( self, recording_types, user_ids=None, output_dir="recordings" ): diff --git a/getstream/video/rtc/location_discovery.py b/getstream/video/rtc/location_discovery.py index 4baaa400..af3051b6 100644 --- a/getstream/video/rtc/location_discovery.py +++ b/getstream/video/rtc/location_discovery.py @@ -8,7 +8,7 @@ import logging import http.client import functools -from typing import Optional, Protocol +from typing import Optional, Protocol, ContextManager from contextlib import contextmanager # Constants matching the Go implementation @@ -23,12 +23,11 @@ class HTTPClient(Protocol): """Protocol defining the HTTP client interface.""" - def request(self, method: str, url: str, body=None, headers=None, **kwargs): + def request(self, method: str, url: str, body=None, headers=None, **kwargs) -> None: """Make an HTTP request.""" ... - @contextmanager - def response(self): + def response(self) -> ContextManager[http.client.HTTPResponse]: """Get the HTTP response.""" ... @@ -69,33 +68,25 @@ def discover(self, context=None) -> str: Returns: The 3-character location code (e.g. "IAD", "FRA") """ + # Basic validation to match previous behavior and provide fast-fail parsed_url = self.url.split("://", 1) if len(parsed_url) != 2: self.logger.warning("Invalid URL format: %s", self.url) return FALLBACK_LOCATION_NAME - protocol, host_path = parsed_url - host = host_path.split("/", 1)[0] - path = "/" + host_path.split("/", 1)[1] if "/" in host_path else "/" - for i in range(self.max_retries): self.logger.info("Discovering location, attempt %d", i + 1) try: - if protocol.lower() == "https": - conn = http.client.HTTPSConnection(host, timeout=1) - else: - conn = http.client.HTTPConnection(host, timeout=1) - - conn.request("HEAD", path) - response = conn.getresponse() - - if response.status != 200: - self.logger.warning("Unexpected status code: %d", response.status) - continue - - pop_name = response.getheader(HEADER_CLOUDFRONT_POP, "") - response.read() # Read and discard the response body - conn.close() + # Use injected HTTP client (or default) for requests + self.client.request("HEAD", self.url) + with self.client.response() as response: + if response.status != 200: + self.logger.warning( + "Unexpected status code: %d", response.status + ) + continue + + pop_name = response.getheader(HEADER_CLOUDFRONT_POP, "") if len(pop_name) < 3: self.logger.warning("Invalid pop name: %s", pop_name) diff --git a/getstream/video/rtc/pc.py b/getstream/video/rtc/pc.py index e298817e..c1801d5e 100644 --- a/getstream/video/rtc/pc.py +++ b/getstream/video/rtc/pc.py @@ -137,7 +137,7 @@ async def on_track(track: aiortc.mediastreams.MediaStreamTrack): handler = AudioTrackHandler( relay.subscribe(track), lambda pcm: self.emit("audio", pcm, user) ) - asyncio.ensure_future(handler.start()) + asyncio.create_task(handler.start()) self.emit("track_added", relay.subscribe(track), user) diff --git a/tests/rtc/test_location_discovery.py b/tests/rtc/test_location_discovery.py index 0f2e839a..bcfa69d0 100644 --- a/tests/rtc/test_location_discovery.py +++ b/tests/rtc/test_location_discovery.py @@ -20,7 +20,6 @@ def setUp(self): self.discovery = HTTPHintLocationDiscovery( url=STREAM_PROD_URL, max_retries=3, - client=self.client_mock, logger=self.logger_mock, ) @@ -67,7 +66,7 @@ def test_discover_success(self, mock_https_connection): self.assertEqual(location, "IAD") # Verify that the correct HTTP request was made - mock_conn.request.assert_called_once_with("HEAD", "/") + mock_conn.request.assert_called_once_with("HEAD", "/", None, {}) mock_response.getheader.assert_called_once_with(HEADER_CLOUDFRONT_POP, "") mock_response.read.assert_called_once() mock_conn.close.assert_called_once()