Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel MediaRelay's consumer task when the last proxy is stopped #919

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/aiortc/contrib/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,18 +578,24 @@ def _stop(self, proxy: RelayStreamTrack) -> None:
# unregister proxy
self.__log_debug("Stop proxy %s", id(proxy))
self.__proxies[track].discard(proxy)
if len(self.__proxies[track]) == 0 and track in self.__tasks:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(self.__proxies[track]) == 0 and track in self.__tasks:
# stop worker if this was the last proxy
if not self.__proxies[track] and track in self.__tasks:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can track not be in self.__tasks ?

self.__tasks[track].cancel()

def __log_debug(self, msg: str, *args) -> None:
logger.debug(f"MediaRelay(%s) {msg}", id(self), *args)

async def __run_track(self, track: MediaStreamTrack) -> None:
self.__log_debug("Start reading source %s" % id(track))

task_cancelled_error = None
while True:
try:
frame = await track.recv()
except MediaStreamError:
frame = None
except asyncio.CancelledError as e:
frame = None
task_cancelled_error = e
for proxy in self.__proxies[track]:
if proxy._buffered:
proxy._queue.put_nowait(frame)
Expand All @@ -602,3 +608,6 @@ async def __run_track(self, track: MediaStreamTrack) -> None:
self.__log_debug("Stop reading source %s", id(track))
del self.__proxies[track]
del self.__tasks[track]

if task_cancelled_error:
raise task_cancelled_error
45 changes: 45 additions & 0 deletions tests/test_contrib_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,51 @@ async def test_audio_stop_consumer(self):
# stop source track
source.stop()

@asynctest
async def test_audio_stop_all_consumers_and_restart_new_consumers(self):
source = AudioStreamTrack()
relay = MediaRelay()
proxy1 = relay.subscribe(source)
proxy2 = relay.subscribe(source)

# read some frames
frame1, frame2 = await asyncio.gather(proxy1.recv(), proxy2.recv())
self.assertIsInstance(frame1, av.AudioFrame)
self.assertIsInstance(frame2, av.AudioFrame)

task = relay._MediaRelay__tasks[source]

# stop all consumers
proxy1.stop()
proxy2.stop()
exc1, exc2 = await asyncio.gather(
proxy1.recv(), proxy2.recv(), return_exceptions=True
)
self.assertIsInstance(exc1, MediaStreamError)
self.assertIsInstance(exc2, MediaStreamError)
self.assertTrue(task.cancelled())
self.assertEqual(relay._MediaRelay__proxies, {})
self.assertEqual(relay._MediaRelay__tasks, {})

# Start new consumers
proxy3 = relay.subscribe(source)
proxy4 = relay.subscribe(source)

# read some frames
for i in range(2):
frame3, frame4 = await asyncio.gather(
proxy3.recv(), proxy4.recv(), return_exceptions=True
)
self.assertIsInstance(frame3, av.AudioFrame)
self.assertIsInstance(frame4, av.AudioFrame)

# A new task should have been created and is running
task = relay._MediaRelay__tasks[source]
self.assertFalse(task.done())

# stop source track
source.stop()

@asynctest
async def test_audio_stop_consumer_unbuffered(self):
source = AudioStreamTrack()
Expand Down