From ac34622584038cf0b69f4c9f33e971c0dc234353 Mon Sep 17 00:00:00 2001 From: "Yuichiro Tachibana (Tsuchiya)" Date: Mon, 21 Aug 2023 14:30:02 +0900 Subject: [PATCH 1/3] Cancel MediaRelay's consumer task when the last proxy is stopped --- src/aiortc/contrib/media.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/aiortc/contrib/media.py b/src/aiortc/contrib/media.py index e67e8ca94..30861ae43 100644 --- a/src/aiortc/contrib/media.py +++ b/src/aiortc/contrib/media.py @@ -578,6 +578,10 @@ def _stop(self, proxy: RelayStreamTrack) -> None: # unregister proxy self.__log_debug("Stop proxy %s", id(proxy)) self.__proxies[track].discard(proxy) + if len(self.__proxies[track]) == 0 and track in self.__tasks: + self.__tasks[track].cancel() + del self.__tasks[track] + del self.__proxies[track] def __log_debug(self, msg: str, *args) -> None: logger.debug(f"MediaRelay(%s) {msg}", id(self), *args) From 5f0e87c23796f6bf0bc8cf008a285aff4eab8ffa Mon Sep 17 00:00:00 2001 From: "Yuichiro Tachibana (Tsuchiya)" Date: Mon, 20 Nov 2023 13:29:31 +0900 Subject: [PATCH 2/3] Add a unit test checking the task is canceleed when all the consumers are stopped --- tests/test_contrib_media.py | 45 +++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_contrib_media.py b/tests/test_contrib_media.py index dfe4331c1..63b9836cf 100644 --- a/tests/test_contrib_media.py +++ b/tests/test_contrib_media.py @@ -196,6 +196,51 @@ async def test_audio_stop_consumer(self): # stop source track source.stop() + @asynctest + async def test_audio_stop_all_consumers_and_restart_new_consumers(self): + source = AudioStreamTrack() + relay = MediaRelay() + proxy1 = relay.subscribe(source) + proxy2 = relay.subscribe(source) + + # read some frames + frame1, frame2 = await asyncio.gather(proxy1.recv(), proxy2.recv()) + self.assertIsInstance(frame1, av.AudioFrame) + self.assertIsInstance(frame2, av.AudioFrame) + + task = relay._MediaRelay__tasks[source] + + # stop all consumers + proxy1.stop() + proxy2.stop() + exc1, exc2 = await asyncio.gather( + proxy1.recv(), proxy2.recv(), return_exceptions=True + ) + self.assertIsInstance(exc1, MediaStreamError) + self.assertIsInstance(exc2, MediaStreamError) + self.assertTrue(task.cancelled()) + self.assertEqual(relay._MediaRelay__proxies, {}) + self.assertEqual(relay._MediaRelay__tasks, {}) + + # Start new consumers + proxy3 = relay.subscribe(source) + proxy4 = relay.subscribe(source) + + # read some frames + for i in range(2): + frame3, frame4 = await asyncio.gather( + proxy3.recv(), proxy4.recv(), return_exceptions=True + ) + self.assertIsInstance(frame3, av.AudioFrame) + self.assertIsInstance(frame4, av.AudioFrame) + + # A new task should have been created and is running + task = relay._MediaRelay__tasks[source] + self.assertFalse(task.done()) + + # stop source track + source.stop() + @asynctest async def test_audio_stop_consumer_unbuffered(self): source = AudioStreamTrack() From 67b5b013e97bcd75e68c8d7604332b63aade615d Mon Sep 17 00:00:00 2001 From: "Yuichiro Tachibana (Tsuchiya)" Date: Mon, 20 Nov 2023 13:46:28 +0900 Subject: [PATCH 3/3] Fix MediaRelay.__run_track to handle asyncio.CancelledError --- src/aiortc/contrib/media.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/aiortc/contrib/media.py b/src/aiortc/contrib/media.py index 30861ae43..a530fdaf5 100644 --- a/src/aiortc/contrib/media.py +++ b/src/aiortc/contrib/media.py @@ -580,8 +580,6 @@ def _stop(self, proxy: RelayStreamTrack) -> None: self.__proxies[track].discard(proxy) if len(self.__proxies[track]) == 0 and track in self.__tasks: self.__tasks[track].cancel() - del self.__tasks[track] - del self.__proxies[track] def __log_debug(self, msg: str, *args) -> None: logger.debug(f"MediaRelay(%s) {msg}", id(self), *args) @@ -589,11 +587,15 @@ def __log_debug(self, msg: str, *args) -> None: async def __run_track(self, track: MediaStreamTrack) -> None: self.__log_debug("Start reading source %s" % id(track)) + task_cancelled_error = None while True: try: frame = await track.recv() except MediaStreamError: frame = None + except asyncio.CancelledError as e: + frame = None + task_cancelled_error = e for proxy in self.__proxies[track]: if proxy._buffered: proxy._queue.put_nowait(frame) @@ -606,3 +608,6 @@ async def __run_track(self, track: MediaStreamTrack) -> None: self.__log_debug("Stop reading source %s", id(track)) del self.__proxies[track] del self.__tasks[track] + + if task_cancelled_error: + raise task_cancelled_error