diff --git a/getstream/video/rtc/audio_track.py b/getstream/video/rtc/audio_track.py index 92bd453b..a49797d3 100644 --- a/getstream/video/rtc/audio_track.py +++ b/getstream/video/rtc/audio_track.py @@ -14,9 +14,10 @@ class AudioStreamTrack(aiortc.mediastreams.MediaStreamTrack): """ - Audio stream track that accepts PcmData objects directly from a queue. + Audio stream track that accepts PcmData objects and buffers them as bytes. Works with PcmData objects instead of raw bytes, avoiding format conversion issues. + Internally buffers as bytes for efficient memory usage. Usage: track = AudioStreamTrack(sample_rate=48000, channels=2) @@ -34,7 +35,7 @@ def __init__( sample_rate: int = 48000, channels: int = 1, format: str = "s16", - max_queue_size: int = 100, + audio_buffer_size_ms: int = 30000, # 30 seconds default ): """ Initialize an AudioStreamTrack that accepts PcmData objects. @@ -43,13 +44,13 @@ def __init__( sample_rate: Target sample rate in Hz (default: 48000) channels: Number of channels - 1=mono, 2=stereo (default: 1) format: Audio format - "s16" or "f32" (default: "s16") - max_queue_size: Maximum number of PcmData objects in queue (default: 100) + audio_buffer_size_ms: Maximum buffer size in milliseconds (default: 30000ms = 30s) """ super().__init__() self.sample_rate = sample_rate self.channels = channels self.format = format - self.max_queue_size = max_queue_size + self.audio_buffer_size_ms = audio_buffer_size_ms logger.debug( "Initialized AudioStreamTrack", @@ -57,78 +58,105 @@ def __init__( "sample_rate": sample_rate, "channels": channels, "format": format, - "max_queue_size": max_queue_size, + "audio_buffer_size_ms": audio_buffer_size_ms, }, ) - # Create async queue for PcmData objects - self._queue = asyncio.Queue() + # Internal bytearray buffer for audio data + self._buffer = bytearray() + self._buffer_lock = asyncio.Lock() + + # Timing for frame pacing self._start = None self._timestamp = None - - # Buffer for chunks smaller than 20ms - self._buffer = None + self._last_frame_time = None + + # Calculate bytes per sample based on format + self._bytes_per_sample = 2 if format == "s16" else 4 # s16=2 bytes, f32=4 bytes + self._bytes_per_frame = int( + aiortc.mediastreams.AUDIO_PTIME + * self.sample_rate + * self.channels + * self._bytes_per_sample + ) async def write(self, pcm: PcmData): """ - Add PcmData to the queue. + Add PcmData to the buffer. The PcmData will be automatically resampled/converted to match - the track's configured sample_rate, channels, and format. + the track's configured sample_rate, channels, and format, + then converted to bytes and stored in the buffer. Args: pcm: PcmData object with audio data """ - # Check if queue is getting too large and trim if necessary - if self._queue.qsize() >= self.max_queue_size: - dropped_items = 0 - while self._queue.qsize() >= self.max_queue_size: - try: - self._queue.get_nowait() - self._queue.task_done() - dropped_items += 1 - except asyncio.QueueEmpty: - break - - logger.warning( - "Audio queue overflow, dropped items max is %d. pcm duration %s ms", - self.max_queue_size, - pcm.duration_ms, - extra={ - "dropped_items": dropped_items, - "queue_size": self._queue.qsize(), - }, + # Normalize the PCM data to target format immediately + pcm_normalized = self._normalize_pcm(pcm) + + # Convert to bytes + audio_bytes = pcm_normalized.to_bytes() + + async with self._buffer_lock: + # Check buffer size before adding + max_buffer_bytes = int( + (self.audio_buffer_size_ms / 1000) + * self.sample_rate + * self.channels + * self._bytes_per_sample ) - await self._queue.put(pcm) + # Add new data to buffer first + self._buffer.extend(audio_bytes) + + # Check if we exceeded the limit + if len(self._buffer) > max_buffer_bytes: + # Calculate how many bytes to drop from the beginning + bytes_to_drop = len(self._buffer) - max_buffer_bytes + dropped_ms = ( + bytes_to_drop + / (self.sample_rate * self.channels * self._bytes_per_sample) + ) * 1000 + + # TODO: do not perform logging inside critical section, move this code outside the lock + logger.warning( + "Audio buffer overflow, dropping %.1fms of audio. Buffer max is %dms", + dropped_ms, + self.audio_buffer_size_ms, + extra={ + "buffer_size_bytes": len(self._buffer), + "incoming_bytes": len(audio_bytes), + "dropped_bytes": bytes_to_drop, + }, + ) + + # Drop from the beginning of the buffer to keep latest data + self._buffer = self._buffer[bytes_to_drop:] + + buffer_duration_ms = ( + len(self._buffer) + / (self.sample_rate * self.channels * self._bytes_per_sample) + ) * 1000 + logger.debug( - "Added PcmData to queue", + "Added audio to buffer", extra={ - "pcm_samples": len(pcm.samples) - if pcm.samples.ndim == 1 - else pcm.samples.shape, - "pcm_sample_rate": pcm.sample_rate, - "pcm_channels": pcm.channels, - "queue_size": self._queue.qsize(), + "pcm_duration_ms": pcm.duration_ms, + "buffer_duration_ms": buffer_duration_ms, + "buffer_size_bytes": len(self._buffer), }, ) async def flush(self) -> None: """ - Clear any pending audio from the queue and buffer. + Clear any pending audio from the buffer. Playback stops immediately. """ - cleared = 0 - while not self._queue.empty(): - try: - self._queue.get_nowait() - self._queue.task_done() - cleared += 1 - except asyncio.QueueEmpty: - break - - self._buffer = None - logger.debug("Flushed audio queue", extra={"cleared_items": cleared}) + async with self._buffer_lock: + bytes_cleared = len(self._buffer) + self._buffer.clear() + + logger.debug("Flushed audio buffer", extra={"cleared_bytes": bytes_cleared}) async def recv(self) -> Frame: """ @@ -142,23 +170,56 @@ async def recv(self) -> Frame: # Calculate samples needed for 20ms frame samples_per_frame = int(aiortc.mediastreams.AUDIO_PTIME * self.sample_rate) + wakeup_time = 0 # Initialize timestamp if not already done if self._timestamp is None: self._start = time.time() self._timestamp = 0 + self._last_frame_time = time.time() else: + # Use timestamp-based pacing to avoid drift over time + # This ensures we stay synchronized with the expected audio rate + # even if individual frames have slight timing variations self._timestamp += samples_per_frame start_ts = self._start or time.time() wait = start_ts + (self._timestamp / self.sample_rate) - time.time() if wait > 0: + wakeup_time += time.time() + wait await asyncio.sleep(wait) + if time.time() - wakeup_time > 0.1: + logger.warning( + f"scheduled sleep took {time.time() - wakeup_time} instead of {wait}, this can happen if there is something blocking the main event-loop, you can enable --debug mode to see blocking traces" + ) - # Get or accumulate PcmData to fill a 20ms frame - pcm_for_frame = await self._get_pcm_for_frame(samples_per_frame) + self._last_frame_time = time.time() + + # Get 20ms of audio data from buffer + async with self._buffer_lock: + if len(self._buffer) >= self._bytes_per_frame: + # We have enough data + audio_bytes = bytes(self._buffer[: self._bytes_per_frame]) + self._buffer = self._buffer[self._bytes_per_frame :] + elif len(self._buffer) > 0: + # We have some data but not enough - pad with silence + audio_bytes = bytes(self._buffer) + padding_needed = self._bytes_per_frame - len(audio_bytes) + audio_bytes += bytes(padding_needed) # Pad with zeros (silence) + self._buffer.clear() + + logger.debug( + "Padded audio frame with silence", + extra={ + "available_bytes": len(audio_bytes) - padding_needed, + "required_bytes": self._bytes_per_frame, + "padding_bytes": padding_needed, + }, + ) + else: + # No data at all - emit silence + audio_bytes = bytes(self._bytes_per_frame) # Create AudioFrame - # Determine layout and format layout = "stereo" if self.channels == 2 else "mono" # Convert format name: "s16" -> "s16", "f32" -> "flt" @@ -172,20 +233,7 @@ async def recv(self) -> Frame: frame = AudioFrame(format=av_format, layout=layout, samples=samples_per_frame) # Fill frame with data - if pcm_for_frame is not None: - audio_bytes = pcm_for_frame.to_bytes() - - # Write to the single plane (packed format has 1 plane) - if len(audio_bytes) >= frame.planes[0].buffer_size: - frame.planes[0].update(audio_bytes[: frame.planes[0].buffer_size]) - else: - # Pad with silence if not enough data - padding = bytes(frame.planes[0].buffer_size - len(audio_bytes)) - frame.planes[0].update(audio_bytes + padding) - else: - # No data available, return silence - for plane in frame.planes: - plane.update(bytes(plane.buffer_size)) + frame.planes[0].update(audio_bytes) # Set frame properties frame.pts = self._timestamp @@ -194,108 +242,6 @@ async def recv(self) -> Frame: return frame - async def _get_pcm_for_frame(self, samples_needed: int) -> PcmData | None: - """ - Get or accumulate PcmData to fill exactly samples_needed samples. - - This method handles: - - Buffering partial chunks - - Resampling to target sample rate - - Converting to target channels - - Converting to target format - - Chunking to exact frame size - - Args: - samples_needed: Number of samples needed for the frame - - Returns: - PcmData with exactly samples_needed samples, or None if no data available - """ - # Start with buffered data if any - if self._buffer is not None: - pcm_accumulated = self._buffer - self._buffer = None - else: - pcm_accumulated = None - - # Try to get data from queue - try: - # Don't wait too long - if no data, return silence - while True: - # Check if we have enough samples - if pcm_accumulated is not None: - current_samples = ( - len(pcm_accumulated.samples) - if pcm_accumulated.samples.ndim == 1 - else pcm_accumulated.samples.shape[1] - ) - if current_samples >= samples_needed: - break - - # Try to get more data - if self._queue.empty(): - # No more data available - break - - pcm_chunk = await asyncio.wait_for(self._queue.get(), timeout=0.01) - self._queue.task_done() - - # Resample/convert to target format - pcm_chunk = self._normalize_pcm(pcm_chunk) - - # Accumulate - if pcm_accumulated is None: - pcm_accumulated = pcm_chunk - else: - pcm_accumulated = pcm_accumulated.append(pcm_chunk) - - except asyncio.TimeoutError: - pass - - # If no data at all, return None (will produce silence) - if pcm_accumulated is None: - return None - - # Get the number of samples we have - current_samples = ( - len(pcm_accumulated.samples) - if pcm_accumulated.samples.ndim == 1 - else pcm_accumulated.samples.shape[1] - ) - - # If we have exactly the right amount, return it - if current_samples == samples_needed: - return pcm_accumulated - - # If we have more than needed, split it - if current_samples > samples_needed: - # Calculate duration needed in seconds - duration_needed_s = samples_needed / self.sample_rate - - # Use head() to get exactly what we need - pcm_for_frame = pcm_accumulated.head( - duration_s=duration_needed_s, pad=False, pad_at="end" - ) - - # Calculate what's left in seconds - duration_used_s = ( - len(pcm_for_frame.samples) - if pcm_for_frame.samples.ndim == 1 - else pcm_for_frame.samples.shape[1] - ) / self.sample_rate - - # Buffer the rest - self._buffer = pcm_accumulated.tail( - duration_s=pcm_accumulated.duration - duration_used_s, - pad=False, - pad_at="start", - ) - - return pcm_for_frame - - # If we have less than needed, return what we have (will be padded with silence) - return pcm_accumulated - def _normalize_pcm(self, pcm: PcmData) -> PcmData: """ Normalize PcmData to match the track's target format. @@ -306,9 +252,8 @@ def _normalize_pcm(self, pcm: PcmData) -> PcmData: Returns: PcmData resampled/converted to target sample_rate, channels, and format """ - # Resample to target sample rate and channels if needed - if pcm.sample_rate != self.sample_rate or pcm.channels != self.channels: - pcm = pcm.resample(self.sample_rate, target_channels=self.channels) + + pcm = pcm.resample(self.sample_rate, target_channels=self.channels) # Convert format if needed if self.format == "s16" and pcm.format != "s16": diff --git a/tests/test_audio_stream_track.py b/tests/test_audio_stream_track.py new file mode 100644 index 00000000..ec4384bc --- /dev/null +++ b/tests/test_audio_stream_track.py @@ -0,0 +1,315 @@ +import asyncio +import time +import pytest +import numpy as np + +from getstream.video.rtc.audio_track import AudioStreamTrack +from getstream.video.rtc.track_util import PcmData, AudioFormat +import aiortc + + +class TestAudioStreamTrack: + """Test suite for AudioStreamTrack class.""" + + @pytest.mark.asyncio + async def test_initialization(self): + """Test that AudioStreamTrack initializes with correct parameters.""" + # Default initialization + track = AudioStreamTrack() + assert track.sample_rate == 48000 + assert track.channels == 1 + assert track.format == "s16" + assert track.audio_buffer_size_ms == 30000 + assert track.kind == "audio" + assert track.readyState == "live" + + # Custom initialization + track = AudioStreamTrack( + sample_rate=16000, + channels=2, + format="f32", + audio_buffer_size_ms=10000, + ) + assert track.sample_rate == 16000 + assert track.channels == 2 + assert track.format == "f32" + assert track.audio_buffer_size_ms == 10000 + + @pytest.mark.asyncio + async def test_write_and_recv_basic(self): + """Test basic write and receive functionality.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + # Create 40ms of audio data (should be enough for 2 frames) + samples_40ms = int(0.04 * 48000) # 1920 samples + audio_data = np.zeros(samples_40ms, dtype=np.int16) + pcm = PcmData( + samples=audio_data, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + + # Write data to track + await track.write(pcm) + + # Receive first 20ms frame + frame1 = await track.recv() + assert frame1.sample_rate == 48000 + assert frame1.samples == 960 # 20ms at 48kHz + assert frame1.format.name == "s16" + assert len(frame1.layout.channels) == 1 + + # Receive second 20ms frame + frame2 = await track.recv() + assert frame2.sample_rate == 48000 + assert frame2.samples == 960 + + @pytest.mark.asyncio + async def test_format_conversion(self): + """Test that write converts formats correctly.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="f32") + + # Write s16 data to f32 track + samples = np.array([100, 200, 300], dtype=np.int16) + pcm = PcmData( + samples=samples, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + + await track.write(pcm) + + # Check that buffer contains f32 data + assert track._bytes_per_sample == 4 # f32 = 4 bytes + + @pytest.mark.asyncio + async def test_sample_rate_conversion(self): + """Test that write resamples audio correctly.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + # Write 16kHz data to 48kHz track + samples_16k = np.zeros(320, dtype=np.int16) # 20ms at 16kHz + pcm = PcmData( + samples=samples_16k, + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + await track.write(pcm) + + # After resampling, we should have 3x more samples (960 samples) + expected_bytes = 960 * 2 # 960 samples * 2 bytes per s16 sample + assert len(track._buffer) == expected_bytes + + @pytest.mark.asyncio + async def test_channel_conversion(self): + """Test mono to stereo and stereo to mono conversion.""" + # Test mono to stereo + track_stereo = AudioStreamTrack(sample_rate=48000, channels=2, format="s16") + + samples_mono = np.zeros(960, dtype=np.int16) # 20ms mono + pcm_mono = PcmData( + samples=samples_mono, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + + await track_stereo.write(pcm_mono) + + # Should have doubled the data for stereo + expected_bytes = 960 * 2 * 2 # samples * bytes_per_sample * channels + assert len(track_stereo._buffer) == expected_bytes + + # Test stereo to mono + track_mono = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + samples_stereo = np.zeros((2, 960), dtype=np.int16) # 20ms stereo + pcm_stereo = PcmData( + samples=samples_stereo, + sample_rate=48000, + format=AudioFormat.S16, + channels=2, + ) + + await track_mono.write(pcm_stereo) + + # Should have halved the data for mono + expected_bytes = 960 * 2 # samples * bytes_per_sample + assert len(track_mono._buffer) == expected_bytes + + @pytest.mark.asyncio + async def test_buffer_overflow(self): + """Test that buffer drops old data when it exceeds max size.""" + # Set small buffer size for testing (100ms) + track = AudioStreamTrack( + sample_rate=48000, channels=1, format="s16", audio_buffer_size_ms=100 + ) + + # Try to write 200ms of data + samples_200ms = int(0.2 * 48000) # 9600 samples + audio_data = np.zeros(samples_200ms, dtype=np.int16) + pcm = PcmData( + samples=audio_data, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + + await track.write(pcm) + + # Buffer should only contain 100ms worth of data + max_samples = int(0.1 * 48000) # 4800 samples + max_bytes = max_samples * 2 # * 2 bytes per s16 sample + assert len(track._buffer) == max_bytes + + @pytest.mark.asyncio + async def test_silence_emission(self): + """Test that recv emits silence when buffer is empty.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + # Receive frame without writing any data + frame = await track.recv() + assert frame.samples == 960 # 20ms at 48kHz + + # Extract audio data and verify it's silence (all zeros) + audio_data = frame.to_ndarray() + assert np.all(audio_data == 0) + + @pytest.mark.asyncio + async def test_partial_frame_padding(self): + """Test that partial frames are padded with silence.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + # Write only 10ms of data (480 samples) + samples_10ms = int(0.01 * 48000) + audio_data = np.ones(samples_10ms, dtype=np.int16) * 100 # Non-zero data + pcm = PcmData( + samples=audio_data, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + + await track.write(pcm) + + # Receive 20ms frame + frame = await track.recv() + assert frame.samples == 960 # Full 20ms frame + + # Check that first half has data and second half is silence + audio_array = frame.to_ndarray() + assert np.any(audio_array[:480] != 0) # First 10ms has data + assert np.all(audio_array[480:] == 0) # Last 10ms is silence + + @pytest.mark.asyncio + async def test_flush(self): + """Test that flush clears the buffer.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + # Write some data + samples = np.zeros(960, dtype=np.int16) + pcm = PcmData( + samples=samples, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + await track.write(pcm) + + # Verify data is in buffer + assert len(track._buffer) > 0 + + # Flush + await track.flush() + + # Verify buffer is empty + assert len(track._buffer) == 0 + + @pytest.mark.asyncio + async def test_frame_timing(self, monkeypatch): + """Test that frames are emitted at 20ms intervals.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + # Write enough data for multiple frames + samples = np.zeros(4800, dtype=np.int16) # 100ms of audio + pcm = PcmData( + samples=samples, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + await track.write(pcm) + + # Record timing + frame_times = [] + + # Receive 3 frames and measure timing + for _ in range(3): + await track.recv() + frame_times.append(time.time()) + + # Check intervals between frames (should be ~20ms) + for i in range(1, len(frame_times)): + interval_ms = (frame_times[i] - frame_times[i - 1]) * 1000 + # Allow reasonable tolerance for asyncio.sleep() accuracy + assert 15.0 <= interval_ms <= 25.0 + + @pytest.mark.asyncio + async def test_continuous_streaming(self): + """Test continuous audio streaming scenario.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + # Simulate continuous writing and reading + write_task = asyncio.create_task(self._continuous_writer(track)) + read_task = asyncio.create_task(self._continuous_reader(track)) + + # Run for 100ms + await asyncio.sleep(0.1) + + # Cancel tasks + write_task.cancel() + read_task.cancel() + + try: + await write_task + await read_task + except asyncio.CancelledError: + pass + + async def _continuous_writer(self, track): + """Helper to continuously write audio data.""" + while True: + # Write 20ms chunks + samples = np.zeros(960, dtype=np.int16) + pcm = PcmData( + samples=samples, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + await track.write(pcm) + await asyncio.sleep(0.02) # 20ms + + async def _continuous_reader(self, track): + """Helper to continuously read audio frames.""" + frames_received = 0 + while True: + frame = await track.recv() + frames_received += 1 + assert frame.samples == 960 + + @pytest.mark.asyncio + async def test_media_stream_error(self): + """Test that MediaStreamError is raised when track is not live.""" + track = AudioStreamTrack(sample_rate=48000, channels=1, format="s16") + + # Stop the track + track.stop() + + # Try to receive - should raise error + with pytest.raises(aiortc.mediastreams.MediaStreamError): + await track.recv()