Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 118 additions & 173 deletions getstream/video/rtc/audio_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

class AudioStreamTrack(aiortc.mediastreams.MediaStreamTrack):
"""
Audio stream track that accepts PcmData objects directly from a queue.
Audio stream track that accepts PcmData objects and buffers them as bytes.

Works with PcmData objects instead of raw bytes, avoiding format conversion issues.
Internally buffers as bytes for efficient memory usage.

Usage:
track = AudioStreamTrack(sample_rate=48000, channels=2)
Expand All @@ -34,7 +35,7 @@ def __init__(
sample_rate: int = 48000,
channels: int = 1,
format: str = "s16",
max_queue_size: int = 100,
audio_buffer_size_ms: int = 30000, # 30 seconds default
):
"""
Initialize an AudioStreamTrack that accepts PcmData objects.
Expand All @@ -43,92 +44,119 @@ def __init__(
sample_rate: Target sample rate in Hz (default: 48000)
channels: Number of channels - 1=mono, 2=stereo (default: 1)
format: Audio format - "s16" or "f32" (default: "s16")
max_queue_size: Maximum number of PcmData objects in queue (default: 100)
audio_buffer_size_ms: Maximum buffer size in milliseconds (default: 30000ms = 30s)
"""
super().__init__()
self.sample_rate = sample_rate
self.channels = channels
self.format = format
self.max_queue_size = max_queue_size
self.audio_buffer_size_ms = audio_buffer_size_ms

logger.debug(
"Initialized AudioStreamTrack",
extra={
"sample_rate": sample_rate,
"channels": channels,
"format": format,
"max_queue_size": max_queue_size,
"audio_buffer_size_ms": audio_buffer_size_ms,
},
)

# Create async queue for PcmData objects
self._queue = asyncio.Queue()
# Internal bytearray buffer for audio data
self._buffer = bytearray()
self._buffer_lock = asyncio.Lock()

# Timing for frame pacing
self._start = None
self._timestamp = None

# Buffer for chunks smaller than 20ms
self._buffer = None
self._last_frame_time = None

# Calculate bytes per sample based on format
self._bytes_per_sample = 2 if format == "s16" else 4 # s16=2 bytes, f32=4 bytes
self._bytes_per_frame = int(
aiortc.mediastreams.AUDIO_PTIME
* self.sample_rate
* self.channels
* self._bytes_per_sample
)

async def write(self, pcm: PcmData):
"""
Add PcmData to the queue.
Add PcmData to the buffer.

The PcmData will be automatically resampled/converted to match
the track's configured sample_rate, channels, and format.
the track's configured sample_rate, channels, and format,
then converted to bytes and stored in the buffer.

Args:
pcm: PcmData object with audio data
"""
# Check if queue is getting too large and trim if necessary
if self._queue.qsize() >= self.max_queue_size:
dropped_items = 0
while self._queue.qsize() >= self.max_queue_size:
try:
self._queue.get_nowait()
self._queue.task_done()
dropped_items += 1
except asyncio.QueueEmpty:
break

logger.warning(
"Audio queue overflow, dropped items max is %d. pcm duration %s ms",
self.max_queue_size,
pcm.duration_ms,
extra={
"dropped_items": dropped_items,
"queue_size": self._queue.qsize(),
},
# Normalize the PCM data to target format immediately
pcm_normalized = self._normalize_pcm(pcm)

# Convert to bytes
audio_bytes = pcm_normalized.to_bytes()

async with self._buffer_lock:
# Check buffer size before adding
max_buffer_bytes = int(
(self.audio_buffer_size_ms / 1000)
* self.sample_rate
* self.channels
* self._bytes_per_sample
)

await self._queue.put(pcm)
# Add new data to buffer first
self._buffer.extend(audio_bytes)

# Check if we exceeded the limit
if len(self._buffer) > max_buffer_bytes:
# Calculate how many bytes to drop from the beginning
bytes_to_drop = len(self._buffer) - max_buffer_bytes
dropped_ms = (
bytes_to_drop
/ (self.sample_rate * self.channels * self._bytes_per_sample)
) * 1000

# TODO: do not perform logging inside critical section, move this code outside the lock
logger.warning(
"Audio buffer overflow, dropping %.1fms of audio. Buffer max is %dms",
dropped_ms,
self.audio_buffer_size_ms,
extra={
"buffer_size_bytes": len(self._buffer),
"incoming_bytes": len(audio_bytes),
"dropped_bytes": bytes_to_drop,
},
)

# Drop from the beginning of the buffer to keep latest data
self._buffer = self._buffer[bytes_to_drop:]

buffer_duration_ms = (
len(self._buffer)
/ (self.sample_rate * self.channels * self._bytes_per_sample)
) * 1000

logger.debug(
"Added PcmData to queue",
"Added audio to buffer",
extra={
"pcm_samples": len(pcm.samples)
if pcm.samples.ndim == 1
else pcm.samples.shape,
"pcm_sample_rate": pcm.sample_rate,
"pcm_channels": pcm.channels,
"queue_size": self._queue.qsize(),
"pcm_duration_ms": pcm.duration_ms,
"buffer_duration_ms": buffer_duration_ms,
"buffer_size_bytes": len(self._buffer),
},
)

async def flush(self) -> None:
"""
Clear any pending audio from the queue and buffer.
Clear any pending audio from the buffer.
Playback stops immediately.
"""
cleared = 0
while not self._queue.empty():
try:
self._queue.get_nowait()
self._queue.task_done()
cleared += 1
except asyncio.QueueEmpty:
break

self._buffer = None
logger.debug("Flushed audio queue", extra={"cleared_items": cleared})
async with self._buffer_lock:
bytes_cleared = len(self._buffer)
self._buffer.clear()

logger.debug("Flushed audio buffer", extra={"cleared_bytes": bytes_cleared})

async def recv(self) -> Frame:
"""
Expand All @@ -142,23 +170,56 @@ async def recv(self) -> Frame:

# Calculate samples needed for 20ms frame
samples_per_frame = int(aiortc.mediastreams.AUDIO_PTIME * self.sample_rate)
wakeup_time = 0

# Initialize timestamp if not already done
if self._timestamp is None:
self._start = time.time()
self._timestamp = 0
self._last_frame_time = time.time()
else:
# Use timestamp-based pacing to avoid drift over time
# This ensures we stay synchronized with the expected audio rate
# even if individual frames have slight timing variations
self._timestamp += samples_per_frame
start_ts = self._start or time.time()
wait = start_ts + (self._timestamp / self.sample_rate) - time.time()
if wait > 0:
wakeup_time += time.time() + wait
await asyncio.sleep(wait)
if time.time() - wakeup_time > 0.1:
logger.warning(
f"scheduled sleep took {time.time() - wakeup_time} instead of {wait}, this can happen if there is something blocking the main event-loop, you can enable --debug mode to see blocking traces"
)

# Get or accumulate PcmData to fill a 20ms frame
pcm_for_frame = await self._get_pcm_for_frame(samples_per_frame)
self._last_frame_time = time.time()

# Get 20ms of audio data from buffer
async with self._buffer_lock:
if len(self._buffer) >= self._bytes_per_frame:
# We have enough data
audio_bytes = bytes(self._buffer[: self._bytes_per_frame])
self._buffer = self._buffer[self._bytes_per_frame :]
elif len(self._buffer) > 0:
# We have some data but not enough - pad with silence
audio_bytes = bytes(self._buffer)
padding_needed = self._bytes_per_frame - len(audio_bytes)
audio_bytes += bytes(padding_needed) # Pad with zeros (silence)
self._buffer.clear()

logger.debug(
"Padded audio frame with silence",
extra={
"available_bytes": len(audio_bytes) - padding_needed,
"required_bytes": self._bytes_per_frame,
"padding_bytes": padding_needed,
},
)
else:
# No data at all - emit silence
audio_bytes = bytes(self._bytes_per_frame)

# Create AudioFrame
# Determine layout and format
layout = "stereo" if self.channels == 2 else "mono"

# Convert format name: "s16" -> "s16", "f32" -> "flt"
Expand All @@ -172,20 +233,7 @@ async def recv(self) -> Frame:
frame = AudioFrame(format=av_format, layout=layout, samples=samples_per_frame)

# Fill frame with data
if pcm_for_frame is not None:
audio_bytes = pcm_for_frame.to_bytes()

# Write to the single plane (packed format has 1 plane)
if len(audio_bytes) >= frame.planes[0].buffer_size:
frame.planes[0].update(audio_bytes[: frame.planes[0].buffer_size])
else:
# Pad with silence if not enough data
padding = bytes(frame.planes[0].buffer_size - len(audio_bytes))
frame.planes[0].update(audio_bytes + padding)
else:
# No data available, return silence
for plane in frame.planes:
plane.update(bytes(plane.buffer_size))
frame.planes[0].update(audio_bytes)

# Set frame properties
frame.pts = self._timestamp
Expand All @@ -194,108 +242,6 @@ async def recv(self) -> Frame:

return frame

async def _get_pcm_for_frame(self, samples_needed: int) -> PcmData | None:
"""
Get or accumulate PcmData to fill exactly samples_needed samples.

This method handles:
- Buffering partial chunks
- Resampling to target sample rate
- Converting to target channels
- Converting to target format
- Chunking to exact frame size

Args:
samples_needed: Number of samples needed for the frame

Returns:
PcmData with exactly samples_needed samples, or None if no data available
"""
# Start with buffered data if any
if self._buffer is not None:
pcm_accumulated = self._buffer
self._buffer = None
else:
pcm_accumulated = None

# Try to get data from queue
try:
# Don't wait too long - if no data, return silence
while True:
# Check if we have enough samples
if pcm_accumulated is not None:
current_samples = (
len(pcm_accumulated.samples)
if pcm_accumulated.samples.ndim == 1
else pcm_accumulated.samples.shape[1]
)
if current_samples >= samples_needed:
break

# Try to get more data
if self._queue.empty():
# No more data available
break

pcm_chunk = await asyncio.wait_for(self._queue.get(), timeout=0.01)
self._queue.task_done()

# Resample/convert to target format
pcm_chunk = self._normalize_pcm(pcm_chunk)

# Accumulate
if pcm_accumulated is None:
pcm_accumulated = pcm_chunk
else:
pcm_accumulated = pcm_accumulated.append(pcm_chunk)

except asyncio.TimeoutError:
pass

# If no data at all, return None (will produce silence)
if pcm_accumulated is None:
return None

# Get the number of samples we have
current_samples = (
len(pcm_accumulated.samples)
if pcm_accumulated.samples.ndim == 1
else pcm_accumulated.samples.shape[1]
)

# If we have exactly the right amount, return it
if current_samples == samples_needed:
return pcm_accumulated

# If we have more than needed, split it
if current_samples > samples_needed:
# Calculate duration needed in seconds
duration_needed_s = samples_needed / self.sample_rate

# Use head() to get exactly what we need
pcm_for_frame = pcm_accumulated.head(
duration_s=duration_needed_s, pad=False, pad_at="end"
)

# Calculate what's left in seconds
duration_used_s = (
len(pcm_for_frame.samples)
if pcm_for_frame.samples.ndim == 1
else pcm_for_frame.samples.shape[1]
) / self.sample_rate

# Buffer the rest
self._buffer = pcm_accumulated.tail(
duration_s=pcm_accumulated.duration - duration_used_s,
pad=False,
pad_at="start",
)

return pcm_for_frame

# If we have less than needed, return what we have (will be padded with silence)
return pcm_accumulated

def _normalize_pcm(self, pcm: PcmData) -> PcmData:
"""
Normalize PcmData to match the track's target format.
Expand All @@ -306,9 +252,8 @@ def _normalize_pcm(self, pcm: PcmData) -> PcmData:
Returns:
PcmData resampled/converted to target sample_rate, channels, and format
"""
# Resample to target sample rate and channels if needed
if pcm.sample_rate != self.sample_rate or pcm.channels != self.channels:
pcm = pcm.resample(self.sample_rate, target_channels=self.channels)

pcm = pcm.resample(self.sample_rate, target_channels=self.channels)

# Convert format if needed
if self.format == "s16" and pcm.format != "s16":
Expand Down
Loading