diff --git a/getstream/video/rtc/connection_manager.py b/getstream/video/rtc/connection_manager.py index 927d6603..693b3f12 100644 --- a/getstream/video/rtc/connection_manager.py +++ b/getstream/video/rtc/connection_manager.py @@ -176,7 +176,7 @@ async def _on_subscriber_offer(self, event: events_pb2.SubscriberOffer): # Set the local description. aiortc will manage the SDP content. with telemetry.start_as_current_span( "rtc.on_subscriber_offer.set_local_description" - ): + ) as span: await self.subscriber_pc.setLocalDescription(answer) logger.debug( diff --git a/getstream/video/rtc/pc.py b/getstream/video/rtc/pc.py index 674cf459..1dad2984 100644 --- a/getstream/video/rtc/pc.py +++ b/getstream/video/rtc/pc.py @@ -5,6 +5,7 @@ import aiortc from aiortc.contrib.media import MediaRelay +from getstream.common import telemetry from getstream.video.rtc.track_util import AudioTrackHandler from pyee.asyncio import AsyncIOEventEmitter from aiortc.rtcrtpsender import RTCRtpSender @@ -84,10 +85,12 @@ async def handle_answer(self, response): type="answer", sdp=response.sdp ) - await self.setRemoteDescription(remote_description) - logger.debug( - f"Publisher remote description set successfully. {self.localDescription}" - ) + with telemetry.start_as_current_span("publisher.pc.handle_answer") as span: + span.set_attribute("remoteDescription", remote_description.sdp) + await self.setRemoteDescription(remote_description) + logger.debug( + f"Publisher remote description set successfully. {self.localDescription}" + ) async def wait_for_connected(self, timeout: float = 15.0): # If already connected, return immediately diff --git a/getstream/video/rtc/peer_connection.py b/getstream/video/rtc/peer_connection.py index 193eb7e5..cb8be111 100644 --- a/getstream/video/rtc/peer_connection.py +++ b/getstream/video/rtc/peer_connection.py @@ -103,12 +103,16 @@ async def add_tracks( self.publisher_pc.addTrack(relayed_video) logger.info(f"Added relayed video track {relayed_video.id}") - with telemetry.start_as_current_span("rtc.publisher_pc.create_offer"): + with telemetry.start_as_current_span( + "rtc.publisher_pc.create_offer" + ) as span: offer = await self.publisher_pc.createOffer() + span.set_attribute("sdp", offer.sdp) with telemetry.start_as_current_span( "rtc.publisher_pc.set_local_description" - ): + ) as span: + span.set_attribute("sdp", offer.sdp) await self.publisher_pc.setLocalDescription(offer) try: diff --git a/getstream/video/rtc/track_util.py b/getstream/video/rtc/track_util.py index 47ce54d7..f90fe14e 100644 --- a/getstream/video/rtc/track_util.py +++ b/getstream/video/rtc/track_util.py @@ -1,6 +1,8 @@ import asyncio +import copy import io import wave +from enum import Enum import av import numpy as np @@ -8,12 +10,12 @@ from typing import ( Dict, Any, - NamedTuple, Callable, Optional, Union, Iterator, AsyncIterator, + Literal, ) import logging @@ -25,26 +27,146 @@ logger = logging.getLogger(__name__) -class PcmData(NamedTuple): +class AudioFormat(str, Enum): """ - A named tuple representing PCM audio data. + Audio format constants for PCM data. + + Inherits from str to maintain backward compatibility with string-based APIs. Attributes: - format: The format of the audio data. + S16: Signed 16-bit integer format (range: -32768 to 32767) + F32: 32-bit floating point format (range: -1.0 to 1.0) + + Example: + >>> from getstream.video.rtc.track_util import AudioFormat, PcmData + >>> import numpy as np + >>> pcm = PcmData(samples=np.array([1, 2], np.int16), sample_rate=16000, format=AudioFormat.S16) + >>> pcm.format == AudioFormat.S16 + True + >>> pcm.format == "s16" # Can compare with string directly + True + """ + + S16 = "s16" # Signed 16-bit integer + F32 = "f32" # 32-bit float + + @staticmethod + def validate(fmt: str) -> str: + """ + Validate that a format string is one of the supported audio formats. + + Args: + fmt: Format string to validate + + Returns: + The validated format string + + Raises: + ValueError: If format is not supported + + Example: + >>> AudioFormat.validate("s16") + 's16' + >>> AudioFormat.validate("invalid") # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Invalid audio format: 'invalid'. Must be one of: ... + """ + valid_formats = {f.value for f in AudioFormat} + if fmt not in valid_formats: + raise ValueError( + f"Invalid audio format: {fmt!r}. Must be one of: {', '.join(sorted(valid_formats))}" + ) + return fmt + + +# Type alias for audio format parameters +# Accepts both AudioFormat enum members and string literals for backwards compatibility +AudioFormatType = Union[AudioFormat, Literal["s16", "f32"]] + + +class PcmData: + """ + A class representing PCM audio data. + + Attributes: + format: The format of the audio data (use AudioFormat.S16 or AudioFormat.F32) sample_rate: The sample rate of the audio data. samples: The audio samples as a numpy array. pts: The presentation timestamp of the audio data. dts: The decode timestamp of the audio data. time_base: The time base for converting timestamps to seconds. + channels: Number of audio channels (1=mono, 2=stereo) """ - format: str - sample_rate: int - samples: NDArray = np.array([], dtype=np.int16) - pts: Optional[int] = None # Presentation timestamp - dts: Optional[int] = None # Decode timestamp - time_base: Optional[float] = None # Time base for converting timestamps to seconds - channels: int = 1 # Number of channels (1=mono, 2=stereo) + def __init__( + self, + sample_rate: int, + format: AudioFormatType, + samples: NDArray = None, + pts: Optional[int] = None, + dts: Optional[int] = None, + time_base: Optional[float] = None, + channels: int = 1, + ): + """ + Initialize PcmData. + + Args: + sample_rate: The sample rate of the audio data + format: The format of the audio data (use AudioFormat.S16 or AudioFormat.F32) + samples: The audio samples as a numpy array (default: empty array with dtype matching format) + pts: The presentation timestamp of the audio data + dts: The decode timestamp of the audio data + time_base: The time base for converting timestamps to seconds + channels: Number of audio channels (1=mono, 2=stereo) + + Raises: + TypeError: If samples dtype does not match the declared format + """ + self.format: AudioFormatType = format + self.sample_rate: int = sample_rate + + # Determine default dtype based on format when samples is None + if samples is None: + # Use same pattern as clear() method for consistency + fmt = (format or "").lower() + if fmt in ("f32", "float32"): + dtype = np.float32 + else: + dtype = np.int16 + samples = np.array([], dtype=dtype) + + # Strict validation: ensure samples dtype matches format + if isinstance(samples, np.ndarray): + fmt = (format or "").lower() + expected_dtype = np.float32 if fmt in ("f32", "float32") else np.int16 + + if samples.dtype != expected_dtype: + # Provide helpful error message with conversion suggestions + actual_dtype_name = ( + "float32" + if samples.dtype == np.float32 + else "int16" + if samples.dtype == np.int16 + else str(samples.dtype) + ) + expected_dtype_name = ( + "float32" if expected_dtype == np.float32 else "int16" + ) + + raise TypeError( + f"Dtype mismatch: format='{format}' requires samples with dtype={expected_dtype_name}, " + f"but got dtype={actual_dtype_name}. " + f"To fix: use .to_float32() for f32 format, or ensure samples match the declared format. " + f"For automatic conversion, use PcmData.from_data() instead." + ) + + self.samples: NDArray = samples + self.pts: Optional[int] = pts + self.dts: Optional[int] = dts + self.time_base: Optional[float] = time_base + self.channels: int = channels @property def stereo(self) -> bool: @@ -131,18 +253,26 @@ def from_bytes( cls, audio_bytes: bytes, sample_rate: int = 16000, - format: str = "s16", + format: AudioFormatType = AudioFormat.S16, channels: int = 1, ) -> "PcmData": """Build from raw PCM bytes (interleaved). + Args: + audio_bytes: Raw PCM audio bytes + sample_rate: Sample rate in Hz (default: 16000) + format: Audio format (default: AudioFormat.S16) + channels: Number of channels (default: 1 for mono) + Example: >>> import numpy as np >>> b = np.array([1, -1, 2, -2], dtype=np.int16).tobytes() - >>> pcm = PcmData.from_bytes(b, sample_rate=16000, format="s16", channels=2) + >>> pcm = PcmData.from_bytes(b, sample_rate=16000, format=AudioFormat.S16, channels=2) >>> pcm.samples.shape[0] # channels-first 2 """ + # Validate format + AudioFormat.validate(format) # Determine dtype and bytes per sample dtype: Any width: int @@ -200,16 +330,24 @@ def from_data( cls, data: Union[bytes, bytearray, memoryview, NDArray], sample_rate: int = 16000, - format: str = "s16", + format: AudioFormatType = AudioFormat.S16, channels: int = 1, ) -> "PcmData": """Build from bytes or numpy arrays. + Args: + data: Input audio data (bytes or numpy array) + sample_rate: Sample rate in Hz (default: 16000) + format: Audio format (default: AudioFormat.S16) + channels: Number of channels (default: 1 for mono) + Example: >>> import numpy as np - >>> PcmData.from_data(np.array([1, 2], np.int16), sample_rate=16000, format="s16", channels=1).channels + >>> PcmData.from_data(np.array([1, 2], np.int16), sample_rate=16000, format=AudioFormat.S16, channels=1).channels 1 """ + # Validate format + AudioFormat.validate(format) if isinstance(data, (bytes, bytearray, memoryview)): return cls.from_bytes( bytes(data), sample_rate=sample_rate, format=format, channels=channels @@ -280,10 +418,19 @@ def resample( return self # Prepare ndarray shape for AV input frame. - # Use planar input (s16p) with shape (channels, samples). + # Use planar format matching the input data type. in_layout = "mono" if self.channels == 1 else "stereo" cmaj = self.samples + + # Determine the format based on the input dtype if isinstance(cmaj, np.ndarray): + if cmaj.dtype == np.float32 or ( + self.format and self.format.lower() in ("f32", "float32") + ): + input_format = "fltp" # planar float32 + else: + input_format = "s16p" # planar int16 + if cmaj.ndim == 1: # (samples,) -> (channels, samples) if self.channels > 1: @@ -308,15 +455,23 @@ def resample( # Likely (samples, channels) cmaj = cmaj.T cmaj = np.ascontiguousarray(cmaj) - frame = av.AudioFrame.from_ndarray(cmaj, format="s16p", layout=in_layout) + else: + input_format = "s16p" # default to s16p for non-ndarray + + frame = av.AudioFrame.from_ndarray(cmaj, format=input_format, layout=in_layout) frame.sample_rate = self.sample_rate # Use provided resampler or create a new one if resampler is None: # Create new resampler for one-off use out_layout = "mono" if target_channels == 1 else "stereo" + # Keep the same format as input (convert planar to packed for output) + if input_format == "fltp": + output_format = "flt" # packed float32 + else: + output_format = "s16" # packed int16 resampler = av.AudioResampler( - format="s16", layout=out_layout, rate=target_sample_rate + format=output_format, layout=out_layout, rate=target_sample_rate ) # Resample the frame @@ -386,17 +541,35 @@ def resample( ): resampled_samples = resampled_samples.flatten() - # Ensure int16 dtype for s16 - if ( - isinstance(resampled_samples, np.ndarray) - and resampled_samples.dtype != np.int16 - ): - resampled_samples = resampled_samples.astype(np.int16) + # Determine output format based on input format + output_pcm_format = ( + AudioFormat.F32 if input_format == "fltp" else AudioFormat.S16 + ) + + # Ensure correct dtype matches the output format + if isinstance(resampled_samples, np.ndarray): + if output_pcm_format == "s16" and resampled_samples.dtype != np.int16: + # Convert to int16 for s16 format + if resampled_samples.dtype == np.float32: + # Float32 to int16: clip to [-1.0, 1.0], scale, round, convert + # This prevents overflow and ensures proper scaling + max_int16 = np.iinfo(np.int16).max # 32767 + resampled_samples = np.clip(resampled_samples, -1.0, 1.0) + resampled_samples = np.round( + resampled_samples * max_int16 + ).astype(np.int16) + else: + resampled_samples = resampled_samples.astype(np.int16) + elif ( + output_pcm_format == "f32" and resampled_samples.dtype != np.float32 + ): + # Ensure float32 for f32 format + resampled_samples = resampled_samples.astype(np.float32) return PcmData( - samples=resampled_samples, sample_rate=target_sample_rate, - format="s16", + format=output_pcm_format, + samples=resampled_samples, pts=self.pts, dts=self.dts, time_base=self.time_base, @@ -441,11 +614,7 @@ def to_bytes(self) -> bytes: # Fallback if isinstance(arr, (bytes, bytearray)): return bytes(arr) - try: - return bytes(arr) - except Exception: - logger.warning("Cannot convert samples to bytes; returning empty") - return b"" + return bytes(arr) def to_wav_bytes(self) -> bytes: """Return WAV bytes (header + frames). @@ -466,9 +635,9 @@ def to_wav_bytes(self) -> bytes: arr = arr.astype(np.float32) arr = (np.clip(arr, -1.0, 1.0) * 32767.0).astype(np.int16) frames = PcmData( - samples=arr, sample_rate=self.sample_rate, format="s16", + samples=arr, pts=self.pts, dts=self.dts, time_base=self.time_base, @@ -492,12 +661,27 @@ def to_wav_bytes(self) -> bytes: def to_float32(self) -> "PcmData": """Convert samples to float32 in [-1, 1]. + If the audio is already in f32 format, returns self without modification. + Example: >>> import numpy as np - >>> pcm = PcmData(samples=np.array([0, 1], np.int16), sample_rate=16000, format="s16", channels=1) + >>> pcm = PcmData(samples=np.array([0, 1], np.int16), sample_rate=16000, format=AudioFormat.S16, channels=1) >>> pcm.to_float32().samples.dtype == np.float32 True + >>> # Already f32 - returns self + >>> pcm_f32 = PcmData(samples=np.array([0.5], np.float32), sample_rate=16000, format=AudioFormat.F32) + >>> pcm_f32.to_float32() is pcm_f32 + True """ + # If already f32 format, return self without modification + if self.format in (AudioFormat.F32, "f32", "float32"): + # Additional check: verify the samples are actually float32 + if ( + isinstance(self.samples, np.ndarray) + and self.samples.dtype == np.float32 + ): + return self + arr = self.samples # Normalize to a numpy array for conversion @@ -530,9 +714,9 @@ def to_float32(self) -> "PcmData": arr_f32 = arr.astype(np.float32, copy=False) return PcmData( - samples=arr_f32, sample_rate=self.sample_rate, format="f32", + samples=arr_f32, pts=self.pts, dts=self.dts, time_base=self.time_base, @@ -540,13 +724,16 @@ def to_float32(self) -> "PcmData": ) def append(self, other: "PcmData") -> "PcmData": - """Append another chunk after adjusting it to match self. + """Append another chunk in-place after adjusting it to match self. + + Modifies this PcmData object and returns self for chaining. Example: >>> import numpy as np - >>> a = PcmData(samples=np.array([1, 2], np.int16), sample_rate=16000, format="s16", channels=1) - >>> b = PcmData(samples=np.array([3, 4], np.int16), sample_rate=16000, format="s16", channels=1) - >>> a.append(b).samples.tolist() + >>> a = PcmData(sample_rate=16000, format="s16", samples=np.array([1, 2], np.int16), channels=1) + >>> b = PcmData(sample_rate=16000, format="s16", samples=np.array([3, 4], np.int16), channels=1) + >>> _ = a.append(b) # modifies a in-place + >>> a.samples.tolist() [1, 2, 3, 4] """ @@ -593,9 +780,9 @@ def _ensure_ndarray(pcm: "PcmData") -> np.ndarray: else: arr = arr.astype(np.int16) other_adj = PcmData( - samples=arr, sample_rate=other_adj.sample_rate, format="s16", + samples=arr, pts=other_adj.pts, dts=other_adj.dts, time_base=other_adj.time_base, @@ -614,7 +801,7 @@ def _ensure_ndarray(pcm: "PcmData") -> np.ndarray: self_arr = _ensure_ndarray(self) other_arr = _ensure_ndarray(other_adj) - # If either is empty, return the other while preserving self's metadata + # If self is empty, replace with other while preserving self's metadata if _is_empty(self_arr): # Conform shape to target channels semantics and dtype if isinstance(other_arr, np.ndarray): @@ -626,15 +813,8 @@ def _ensure_ndarray(pcm: "PcmData") -> np.ndarray: else np.int16 ) other_arr = other_arr.astype(target_dtype, copy=False) - return PcmData( - samples=other_arr, - sample_rate=self.sample_rate, - format=self.format, - pts=self.pts, - dts=self.dts, - time_base=self.time_base, - channels=self.channels, - ) + self.samples = other_arr + return self if _is_empty(other_arr): return self @@ -659,15 +839,8 @@ def _ensure_ndarray(pcm: "PcmData") -> np.ndarray: "int16", ) and out.dtype != np.int16: out = out.astype(np.int16) - return PcmData( - samples=out, - sample_rate=self.sample_rate, - format=self.format, - pts=self.pts, - dts=self.dts, - time_base=self.time_base, - channels=self.channels, - ) + self.samples = out + return self else: # Multi-channel: normalize to (channels, samples) def _to_cmaj(arr: np.ndarray, channels: int) -> np.ndarray: @@ -696,15 +869,63 @@ def _to_cmaj(arr: np.ndarray, channels: int) -> np.ndarray: ) and out.dtype != np.int16: out = out.astype(np.int16) - return PcmData( - samples=out, - sample_rate=self.sample_rate, - format=self.format, - pts=self.pts, - dts=self.dts, - time_base=self.time_base, - channels=self.channels, - ) + self.samples = out + return self + + def copy(self) -> "PcmData": + """Create a deep copy of this PcmData object. + + Returns a new PcmData instance with copied samples and metadata, + allowing independent modifications without affecting the original. + + Example: + >>> import numpy as np + >>> a = PcmData(sample_rate=16000, format="s16", samples=np.array([1, 2], np.int16), channels=1) + >>> b = a.copy() + >>> _ = b.append(PcmData(sample_rate=16000, format="s16", samples=np.array([3, 4], np.int16), channels=1)) + >>> a.samples.tolist() # original unchanged + [1, 2] + >>> b.samples.tolist() # copy was modified + [1, 2, 3, 4] + """ + return PcmData( + sample_rate=self.sample_rate, + format=self.format, + samples=self.samples.copy() + if isinstance(self.samples, np.ndarray) + else copy.deepcopy(self.samples), + pts=self.pts, + dts=self.dts, + time_base=self.time_base, + channels=self.channels, + ) + + def clear(self) -> None: + """Clear all samples in this PcmData object in-place. + + Similar to list.clear() or dict.clear(), this method removes all samples + from the PcmData object while preserving all other metadata (sample_rate, + format, channels, timestamps, etc.). + + Returns None to match the behavior of standard Python collection methods. + + Example: + >>> import numpy as np + >>> a = PcmData(sample_rate=16000, format="s16", samples=np.array([1, 2, 3], np.int16), channels=1) + >>> a.clear() + >>> len(a.samples) + 0 + >>> a.sample_rate + 16000 + """ + # Determine the appropriate dtype based on format + fmt = (self.format or "").lower() + if fmt in ("f32", "float32"): + dtype = np.float32 + else: + dtype = np.int16 + + self.samples = np.array([], dtype=dtype) @classmethod def from_response( @@ -713,14 +934,22 @@ def from_response( *, sample_rate: int = 16000, channels: int = 1, - format: str = "s16", + format: AudioFormatType = AudioFormat.S16, ) -> Union["PcmData", Iterator["PcmData"], AsyncIterator["PcmData"]]: """Normalize provider response to PcmData or iterators of it. + Args: + response: Audio response (bytes, iterator, or async iterator) + sample_rate: Sample rate in Hz (default: 16000) + channels: Number of channels (default: 1) + format: Audio format (default: AudioFormat.S16) + Example: - >>> PcmData.from_response(bytes([0, 0]), sample_rate=16000, format="s16").sample_rate + >>> PcmData.from_response(bytes([0, 0]), sample_rate=16000, format=AudioFormat.S16).sample_rate 16000 """ + # Validate format + AudioFormat.validate(format) # bytes-like returns a single PcmData if isinstance(response, (bytes, bytearray, memoryview)): @@ -829,6 +1058,370 @@ def _gen(): f"Unsupported response type for PcmData.from_response: {type(response)}" ) + def chunks( + self, chunk_size: int, overlap: int = 0, pad_last: bool = False + ) -> Iterator["PcmData"]: + """ + Iterate over fixed-size chunks of audio data. + + Args: + chunk_size: Number of samples per chunk + overlap: Number of samples to overlap between chunks (for windowing) + pad_last: If True, pad the last chunk with zeros to match chunk_size + + Yields: + PcmData objects with chunk_size samples each + + Example: + >>> pcm = PcmData(samples=np.arange(10, dtype=np.int16), sample_rate=16000, format="s16") + >>> chunks = list(pcm.chunks(4, overlap=2)) + >>> len(chunks) # [0:4], [2:6], [4:8], [6:10], [8:10] + 5 + """ + # Ensure we have a 1D array for simpler chunking + if isinstance(self.samples, np.ndarray): + if self.samples.ndim == 2 and self.channels == 1: + samples = self.samples.flatten() + elif self.samples.ndim == 2: + # For multi-channel, work with channel-major format + samples = self.samples + else: + samples = self.samples + else: + # Convert bytes/other to ndarray first + temp = PcmData.from_bytes( + self.to_bytes(), + sample_rate=self.sample_rate, + format=self.format, + channels=self.channels, + ) + samples = temp.samples + + # Handle overlap + step = max(1, chunk_size - overlap) + + if self.channels > 1 and isinstance(samples, np.ndarray) and samples.ndim == 2: + # Multi-channel case: chunk along the samples axis + num_samples = samples.shape[1] + for i in range(0, num_samples, step): + end_idx = min(i + chunk_size, num_samples) + chunk_samples = samples[:, i:end_idx] + + # Check if we need to pad + if chunk_samples.shape[1] < chunk_size: + if pad_last and chunk_samples.shape[1] > 0: + pad_width = chunk_size - chunk_samples.shape[1] + chunk_samples = np.pad( + chunk_samples, + ((0, 0), (0, pad_width)), + mode="constant", + constant_values=0, + ) + elif chunk_samples.shape[1] == 0: + break + elif not pad_last: + # Yield incomplete chunk if it has samples + pass + + # Calculate timestamp for this chunk + chunk_pts = None + if self.pts is not None and self.time_base is not None: + chunk_pts = self.pts + int(i / self.sample_rate / self.time_base) + + yield PcmData( + sample_rate=self.sample_rate, + format=self.format, + samples=chunk_samples, + pts=chunk_pts, + dts=self.dts, + time_base=self.time_base, + channels=self.channels, + ) + else: + # Mono or 1D case + samples_1d = ( + samples.flatten() if isinstance(samples, np.ndarray) else samples + ) + total_samples = len(samples_1d) + + for i in range(0, total_samples, step): + end_idx = min(i + chunk_size, total_samples) + chunk_samples = samples_1d[i:end_idx] + + # Check if we need to pad + if len(chunk_samples) < chunk_size: + if pad_last and len(chunk_samples) > 0: + chunk_samples = np.pad( + chunk_samples, + (0, chunk_size - len(chunk_samples)), + mode="constant", + constant_values=0, + ) + elif len(chunk_samples) == 0: + break + elif not pad_last: + # Yield incomplete chunk if it has samples + pass + + # Calculate timestamp for this chunk + chunk_pts = None + if self.pts is not None and self.time_base is not None: + chunk_pts = self.pts + int(i / self.sample_rate / self.time_base) + + yield PcmData( + sample_rate=self.sample_rate, + format=self.format, + samples=chunk_samples, + pts=chunk_pts, + dts=self.dts, + time_base=self.time_base, + channels=1 + if isinstance(chunk_samples, np.ndarray) and chunk_samples.ndim == 1 + else self.channels, + ) + + def sliding_window( + self, window_size_ms: float, hop_ms: float, pad_last: bool = False + ) -> Iterator["PcmData"]: + """ + Generate sliding windows for analysis (useful for feature extraction). + + Args: + window_size_ms: Window size in milliseconds + hop_ms: Hop size in milliseconds + pad_last: If True, pad the last window with zeros + + Yields: + PcmData windows of the specified size + + Example: + >>> pcm = PcmData(samples=np.arange(800, dtype=np.int16), sample_rate=16000, format="s16") + >>> windows = list(pcm.sliding_window(25.0, 10.0)) # 25ms window, 10ms hop + >>> len(windows) # 400 samples per window, 160 sample hop + 5 + """ + window_samples = int(self.sample_rate * window_size_ms / 1000) + hop_samples = int(self.sample_rate * hop_ms / 1000) + overlap = max(0, window_samples - hop_samples) + + return self.chunks(window_samples, overlap=overlap, pad_last=pad_last) + + def tail( + self, + duration_s: float, + pad: bool = False, + pad_at: str = "end", + ) -> "PcmData": + """ + Keep only the last N seconds of audio. + + Args: + duration_s: Duration in seconds + pad: If True, pad with zeros when audio is shorter than requested + pad_at: Where to add padding ('start' or 'end'), default is 'end' + + Returns: + PcmData with the last N seconds + + Example: + >>> import numpy as np + >>> # 10 seconds of audio at 16kHz = 160000 samples + >>> pcm = PcmData(samples=np.arange(160000, dtype=np.int16), sample_rate=16000, format=AudioFormat.S16) + >>> # Get last 5 seconds + >>> tail = pcm.tail(duration_s=5.0) + >>> tail.duration + 5.0 + >>> # Get last 8 seconds, pad at start if needed + >>> short_pcm = PcmData(samples=np.arange(16000, dtype=np.int16), sample_rate=16000, format=AudioFormat.S16) + >>> padded = short_pcm.tail(duration_s=8.0, pad=True, pad_at='start') + >>> padded.duration + 8.0 + """ + target_samples = int(self.sample_rate * duration_s) + + # Validate pad_at parameter + if pad_at not in ("start", "end"): + raise ValueError(f"pad_at must be 'start' or 'end', got {pad_at!r}") + + # Get samples array + samples = self.samples + if not isinstance(samples, np.ndarray): + # Convert to ndarray first + samples = PcmData.from_bytes( + self.to_bytes(), + sample_rate=self.sample_rate, + format=self.format, + channels=self.channels, + ).samples + + # Handle multi-channel audio + if samples.ndim == 2 and self.channels > 1: + # Shape is (channels, samples) + num_samples = samples.shape[1] + if num_samples >= target_samples: + # Truncate: keep last target_samples + new_samples = samples[:, -target_samples:] + elif pad: + # Pad to reach target_samples + pad_width = target_samples - num_samples + if pad_at == "start": + new_samples = np.pad( + samples, + ((0, 0), (pad_width, 0)), + mode="constant", + constant_values=0, + ) + else: # pad_at == 'end' + new_samples = np.pad( + samples, + ((0, 0), (0, pad_width)), + mode="constant", + constant_values=0, + ) + else: + # Return as-is (shorter than requested, no padding) + new_samples = samples + else: + # Mono or 1D case + samples_1d = samples.flatten() if samples.ndim > 1 else samples + num_samples = len(samples_1d) + + if num_samples >= target_samples: + # Truncate: keep last target_samples + new_samples = samples_1d[-target_samples:] + elif pad: + # Pad to reach target_samples + pad_width = target_samples - num_samples + if pad_at == "start": + new_samples = np.pad( + samples_1d, (pad_width, 0), mode="constant", constant_values=0 + ) + else: # pad_at == 'end' + new_samples = np.pad( + samples_1d, (0, pad_width), mode="constant", constant_values=0 + ) + else: + # Return as-is (shorter than requested, no padding) + new_samples = samples_1d + + return PcmData( + sample_rate=self.sample_rate, + format=self.format, + samples=new_samples, + pts=self.pts, + dts=self.dts, + time_base=self.time_base, + channels=self.channels, + ) + + def head( + self, + duration_s: float, + pad: bool = False, + pad_at: str = "end", + ) -> "PcmData": + """ + Keep only the first N seconds of audio. + + Args: + duration_s: Duration in seconds + pad: If True, pad with zeros when audio is shorter than requested + pad_at: Where to add padding ('start' or 'end'), default is 'end' + + Returns: + PcmData with the first N seconds + + Example: + >>> import numpy as np + >>> # 10 seconds of audio at 16kHz = 160000 samples + >>> pcm = PcmData(samples=np.arange(160000, dtype=np.int16), sample_rate=16000, format=AudioFormat.S16) + >>> # Get first 3 seconds + >>> head = pcm.head(duration_s=3.0) + >>> head.duration + 3.0 + >>> # Get first 8 seconds, pad at end if needed + >>> short_pcm = PcmData(samples=np.arange(16000, dtype=np.int16), sample_rate=16000, format=AudioFormat.S16) + >>> padded = short_pcm.head(duration_s=8.0, pad=True, pad_at='end') + >>> padded.duration + 8.0 + """ + target_samples = int(self.sample_rate * duration_s) + + # Validate pad_at parameter + if pad_at not in ("start", "end"): + raise ValueError(f"pad_at must be 'start' or 'end', got {pad_at!r}") + + # Get samples array + samples = self.samples + if not isinstance(samples, np.ndarray): + # Convert to ndarray first + samples = PcmData.from_bytes( + self.to_bytes(), + sample_rate=self.sample_rate, + format=self.format, + channels=self.channels, + ).samples + + # Handle multi-channel audio + if samples.ndim == 2 and self.channels > 1: + # Shape is (channels, samples) + num_samples = samples.shape[1] + if num_samples >= target_samples: + # Truncate: keep first target_samples + new_samples = samples[:, :target_samples] + elif pad: + # Pad to reach target_samples + pad_width = target_samples - num_samples + if pad_at == "start": + new_samples = np.pad( + samples, + ((0, 0), (pad_width, 0)), + mode="constant", + constant_values=0, + ) + else: # pad_at == 'end' + new_samples = np.pad( + samples, + ((0, 0), (0, pad_width)), + mode="constant", + constant_values=0, + ) + else: + # Return as-is (shorter than requested, no padding) + new_samples = samples + else: + # Mono or 1D case + samples_1d = samples.flatten() if samples.ndim > 1 else samples + num_samples = len(samples_1d) + + if num_samples >= target_samples: + # Truncate: keep first target_samples + new_samples = samples_1d[:target_samples] + elif pad: + # Pad to reach target_samples + pad_width = target_samples - num_samples + if pad_at == "start": + new_samples = np.pad( + samples_1d, (pad_width, 0), mode="constant", constant_values=0 + ) + else: # pad_at == 'end' + new_samples = np.pad( + samples_1d, (0, pad_width), mode="constant", constant_values=0 + ) + else: + # Return as-is (shorter than requested, no padding) + new_samples = samples_1d + + return PcmData( + sample_rate=self.sample_rate, + format=self.format, + samples=new_samples, + pts=self.pts, + dts=self.dts, + time_base=self.time_base, + channels=self.channels, + ) + def patch_sdp_offer(sdp: str) -> str: """ @@ -1135,6 +1728,45 @@ async def detect_video_properties( logger.error(f"Error cleaning up buffered track: {e}") +def _normalize_audio_format( + pcm: PcmData, target_sample_rate: int, target_format: AudioFormatType +) -> PcmData: + """ + Helper function to normalize audio to target sample rate and format. + + Args: + pcm: Input audio data + target_sample_rate: Target sample rate + target_format: Target format (AudioFormat.S16 or AudioFormat.F32) + + Returns: + PcmData with target sample rate and format + """ + # Validate format + AudioFormat.validate(target_format) + # Resample if needed + if pcm.sample_rate != target_sample_rate: + pcm = pcm.resample(target_sample_rate) + + # Convert format if needed + if target_format == "f32" and pcm.format != "f32": + pcm = pcm.to_float32() + elif target_format == "s16" and pcm.format != "s16": + # Convert to s16 if needed + if pcm.format == "f32": + samples = pcm.samples + if isinstance(samples, np.ndarray): + samples = (np.clip(samples, -1.0, 1.0) * 32767.0).astype(np.int16) + pcm = PcmData( + sample_rate=pcm.sample_rate, + format="s16", + samples=samples, + channels=pcm.channels, + ) + + return pcm + + class AudioTrackHandler: """ A helper to receive raw PCM data from an aiortc AudioStreamTrack @@ -1240,8 +1872,8 @@ async def _run_track(self): self._on_audio_frame( PcmData( sample_rate=48_000, - samples=pcm_ndarray, format="s16", + samples=pcm_ndarray, pts=pts, dts=dts, time_base=time_base, diff --git a/tests/rtc/test_pcm_data.py b/tests/rtc/test_pcm_data.py index fa436f83..16be74ae 100644 --- a/tests/rtc/test_pcm_data.py +++ b/tests/rtc/test_pcm_data.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from getstream.video.rtc.track_util import PcmData @@ -376,3 +377,620 @@ def test_append_empty_buffer_float32_adjusts_other_and_keeps_meta(): else: expected_samples = expected_pcm.samples assert np.allclose(out.samples[-expected_samples.shape[0] :], expected_samples) + + +def test_copy_creates_independent_copy(): + """Test that copy() creates a deep copy that can be modified independently.""" + sr = 16000 + original_samples = np.array([1, 2, 3, 4], dtype=np.int16) + pcm_original = PcmData( + sample_rate=sr, format="s16", samples=original_samples.copy(), channels=1 + ) + + # Create a copy + pcm_copy = pcm_original.copy() + + # Verify initial state - both should have same values but different objects + assert pcm_copy.sample_rate == pcm_original.sample_rate + assert pcm_copy.format == pcm_original.format + assert pcm_copy.channels == pcm_original.channels + assert np.array_equal(pcm_copy.samples, pcm_original.samples) + assert pcm_copy.samples is not pcm_original.samples # Different array objects + + # Modify the copy + pcm_copy.samples[0] = 999 + + # Original should be unchanged + assert pcm_original.samples[0] == 1 + assert pcm_copy.samples[0] == 999 + + +def test_append_modifies_in_place(): + """Test that append() modifies the original PcmData in-place and returns self.""" + sr = 16000 + a = np.array([1, 2, 3, 4], dtype=np.int16) + b = np.array([5, 6], dtype=np.int16) + + pcm_a = PcmData(sample_rate=sr, format="s16", samples=a, channels=1) + pcm_b = PcmData(sample_rate=sr, format="s16", samples=b, channels=1) + + # Store the id to verify it's the same object + original_id = id(pcm_a) + + # Append and check that it returns self + result = pcm_a.append(pcm_b) + + assert id(result) == original_id # Same object + assert result is pcm_a # Identity check + + # Check that pcm_a was modified in-place + expected = np.array([1, 2, 3, 4, 5, 6], dtype=np.int16) + assert np.array_equal(pcm_a.samples, expected) + assert np.array_equal(result.samples, expected) + + +def test_copy_append_does_not_modify_original(): + """Test the critical use case: pcm.copy().append(other) doesn't modify pcm.""" + sr = 16000 + a = np.array([1, 2, 3, 4], dtype=np.int16) + b = np.array([5, 6], dtype=np.int16) + + pcm_a = PcmData(sample_rate=sr, format="s16", samples=a, channels=1) + pcm_b = PcmData(sample_rate=sr, format="s16", samples=b, channels=1) + + # Create a copy and append to it + pcm_copy = pcm_a.copy().append(pcm_b) + + # Original should be unchanged + assert np.array_equal(pcm_a.samples, np.array([1, 2, 3, 4], dtype=np.int16)) + + # Copy should have the appended data + assert np.array_equal( + pcm_copy.samples, np.array([1, 2, 3, 4, 5, 6], dtype=np.int16) + ) + + # They should be different objects + assert pcm_a is not pcm_copy + assert pcm_a.samples is not pcm_copy.samples + + +def test_copy_preserves_all_metadata(): + """Test that copy() preserves all metadata including timestamps.""" + sr = 16000 + samples = np.array([1, 2, 3], dtype=np.int16) + pcm_original = PcmData( + sample_rate=sr, + format="s16", + samples=samples, + channels=2, + pts=12345, + dts=67890, + time_base=0.00001, + ) + + pcm_copy = pcm_original.copy() + + # Verify all metadata is copied + assert pcm_copy.sample_rate == pcm_original.sample_rate + assert pcm_copy.format == pcm_original.format + assert pcm_copy.channels == pcm_original.channels + assert pcm_copy.pts == pcm_original.pts + assert pcm_copy.dts == pcm_original.dts + assert pcm_copy.time_base == pcm_original.time_base + assert np.array_equal(pcm_copy.samples, pcm_original.samples) + + +def test_append_chaining_with_copy(): + """Test that append() can be chained and works correctly with copy().""" + sr = 16000 + a = np.array([1, 2], dtype=np.int16) + b = np.array([3, 4], dtype=np.int16) + c = np.array([5, 6], dtype=np.int16) + + pcm_a = PcmData(sample_rate=sr, format="s16", samples=a, channels=1) + pcm_b = PcmData(sample_rate=sr, format="s16", samples=b, channels=1) + pcm_c = PcmData(sample_rate=sr, format="s16", samples=c, channels=1) + + # Chain append on a copy + result = pcm_a.copy().append(pcm_b).append(pcm_c) + + # Original should be unchanged + assert np.array_equal(pcm_a.samples, np.array([1, 2], dtype=np.int16)) + + # Result should have all samples + assert np.array_equal(result.samples, np.array([1, 2, 3, 4, 5, 6], dtype=np.int16)) + + +# ===== Tests for clear() method (like list.clear()) ===== + + +def test_clear_wipes_samples_like_list_clear(): + """Test that clear() works like list.clear() - removes all items, returns None.""" + sr = 16000 + samples = np.array([1, 2, 3, 4, 5], dtype=np.int16) + pcm = PcmData(sample_rate=sr, format="s16", samples=samples, channels=1) + + # clear() should return None, like list.clear() + result = pcm.clear() + assert result is None + + # Samples should be empty + assert isinstance(pcm.samples, np.ndarray) + assert len(pcm.samples) == 0 + assert pcm.samples.dtype == np.int16 + + +def test_clear_preserves_metadata(): + """Test that clear() preserves all metadata like list.clear() preserves list identity.""" + sr = 16000 + samples = np.array([1, 2, 3, 4, 5], dtype=np.int16) + pcm = PcmData( + sample_rate=sr, + format="s16", + samples=samples, + channels=2, + pts=1000, + dts=2000, + time_base=0.001, + ) + + # Store original metadata and id + original_id = id(pcm) + original_sr = pcm.sample_rate + original_format = pcm.format + original_channels = pcm.channels + original_pts = pcm.pts + original_dts = pcm.dts + original_time_base = pcm.time_base + + # Clear + pcm.clear() + + # Object identity preserved (like list.clear()) + assert id(pcm) == original_id + + # Samples should be empty + assert len(pcm.samples) == 0 + + # Metadata should be preserved + assert pcm.sample_rate == original_sr + assert pcm.format == original_format + assert pcm.channels == original_channels + assert pcm.pts == original_pts + assert pcm.dts == original_dts + assert pcm.time_base == original_time_base + + +def test_clear_f32_uses_correct_dtype(): + """Test that clear() uses float32 dtype for f32 format.""" + sr = 16000 + samples = np.array([0.1, 0.2, 0.3], dtype=np.float32) + pcm = PcmData(sample_rate=sr, format="f32", samples=samples, channels=1) + + pcm.clear() + + # Samples should be empty with float32 dtype + assert isinstance(pcm.samples, np.ndarray) + assert len(pcm.samples) == 0 + assert pcm.samples.dtype == np.float32 + + +def test_clear_after_append(): + """Test that clear() can be used to clear accumulated samples.""" + sr = 16000 + a = np.array([1, 2, 3], dtype=np.int16) + b = np.array([4, 5, 6], dtype=np.int16) + + pcm_a = PcmData(sample_rate=sr, format="s16", samples=a, channels=1) + pcm_b = PcmData(sample_rate=sr, format="s16", samples=b, channels=1) + + # Append then clear + pcm_a.append(pcm_b) + assert len(pcm_a.samples) == 6 + + pcm_a.clear() + assert len(pcm_a.samples) == 0 + + # Can append again after clear + pcm_a.append(pcm_b) + assert np.array_equal(pcm_a.samples, np.array([4, 5, 6], dtype=np.int16)) + + +def test_clear_returns_none(): + """Test that clear() returns None like list.clear().""" + sr = 16000 + samples = np.array([1, 2, 3], dtype=np.int16) + + pcm = PcmData(sample_rate=sr, format="s16", samples=samples, channels=1) + result = pcm.clear() + + # Should return None like list.clear() + assert result is None + + # Samples should be empty + assert len(pcm.samples) == 0 + + +def test_clear_multiple_times(): + """Test that clear() can be called multiple times safely.""" + sr = 16000 + samples = np.array([1, 2, 3], dtype=np.int16) + pcm = PcmData(sample_rate=sr, format="s16", samples=samples, channels=1) + + # Clear multiple times + pcm.clear() + assert len(pcm.samples) == 0 + + pcm.clear() # Should work on empty PcmData + assert len(pcm.samples) == 0 + + pcm.clear() # Should work again + assert len(pcm.samples) == 0 + + # Metadata should still be intact + assert pcm.sample_rate == sr + assert pcm.format == "s16" + + +def test_resample_float32_preserves_float32_dtype(): + """ + REGRESSION TEST: Float32 audio must stay as float32 after resampling. + + Previously, resample() would silently downgrade float32 to int16, then mark + it as f32 format, completely mangling high-dynamic-range audio. + """ + # Create float32 audio with high dynamic range + sample_rate_in = 16000 + sample_rate_out = 48000 + num_samples = 1000 + + # Use values that would be truncated if converted to int16 + # E.g., 0.5 becomes 0 when converted to int16 directly + samples_f32 = np.linspace(-1.0, 1.0, num_samples, dtype=np.float32) + + pcm_16k = PcmData( + sample_rate=sample_rate_in, + format="f32", + samples=samples_f32, + channels=1, + ) + + # Resample to 48kHz + pcm_48k = pcm_16k.resample(sample_rate_out) + + # CRITICAL: Format must still be f32 + assert pcm_48k.format == "f32", f"Format should be 'f32', got '{pcm_48k.format}'" + + # CRITICAL: Samples must be float32, not int16 + assert pcm_48k.samples.dtype == np.float32, ( + f"Samples should be float32, got {pcm_48k.samples.dtype}. " + f"This means float32 audio was silently downgraded to int16!" + ) + + # Verify values are still in float range, not truncated to int16 range + assert np.any(np.abs(pcm_48k.samples) < 1.0), ( + "No fractional values found - data may have been truncated to integers" + ) + + +def test_resample_float32_to_stereo_preserves_float32(): + """Test that float32 stays float32 when resampling AND converting to stereo.""" + sample_rate_in = 16000 + sample_rate_out = 48000 + + # Create float32 mono audio + samples_f32 = np.array([0.1, 0.2, 0.3, 0.4, 0.5] * 100, dtype=np.float32) + + pcm_16k_mono = PcmData( + sample_rate=sample_rate_in, + format="f32", + samples=samples_f32, + channels=1, + ) + + # Resample to 48kHz stereo + pcm_48k_stereo = pcm_16k_mono.resample(sample_rate_out, target_channels=2) + + # Format must be f32 + assert pcm_48k_stereo.format == "f32" + + # Dtype must be float32 + assert pcm_48k_stereo.samples.dtype == np.float32, ( + f"Expected float32, got {pcm_48k_stereo.samples.dtype}" + ) + + # Should be stereo + assert pcm_48k_stereo.channels == 2 + assert pcm_48k_stereo.samples.ndim == 2 + assert pcm_48k_stereo.samples.shape[0] == 2 + + +def test_resample_int16_stays_int16(): + """Verify that int16 audio correctly stays int16 (baseline test).""" + sample_rate_in = 16000 + sample_rate_out = 48000 + + # Create int16 audio + samples_i16 = np.array([100, 200, 300, 400, 500] * 100, dtype=np.int16) + + pcm_16k = PcmData( + sample_rate=sample_rate_in, + format="s16", + samples=samples_i16, + channels=1, + ) + + # Resample to 48kHz + pcm_48k = pcm_16k.resample(sample_rate_out) + + # Format must be s16 + assert pcm_48k.format == "s16" + + assert pcm_48k.samples.dtype == np.int16, ( + f"Expected int16, got {pcm_48k.samples.dtype}" + ) + + +# ===== Tests for constructor default samples dtype ===== + + +def test_constructor_default_samples_respects_f32_format(): + """ + REGRESSION TEST: Constructor must create float32 empty array for f32 format. + + Previously, constructor always created int16 empty array regardless of format, + causing dtype mismatch when format="f32" but samples.dtype=int16. + """ + # Create PcmData with f32 format but no samples - should default to float32 empty array + pcm = PcmData(sample_rate=16000, format="f32", channels=1) + + # Verify format is f32 + assert pcm.format == "f32" + + # CRITICAL: Samples must be float32, not int16 + assert pcm.samples.dtype == np.float32, ( + f"Expected float32 empty array for f32 format, got {pcm.samples.dtype}" + ) + + # Should be empty + assert len(pcm.samples) == 0 + + +def test_constructor_default_samples_respects_s16_format(): + """Verify constructor creates int16 empty array for s16 format (baseline).""" + # Create PcmData with s16 format but no samples - should default to int16 empty array + pcm = PcmData(sample_rate=16000, format="s16", channels=1) + + # Verify format is s16 + assert pcm.format == "s16" + + # Samples must be int16 + assert pcm.samples.dtype == np.int16, ( + f"Expected int16 empty array for s16 format, got {pcm.samples.dtype}" + ) + + # Should be empty + assert len(pcm.samples) == 0 + + +def test_constructor_default_samples_respects_enum_f32(): + """Verify constructor works with AudioFormat enum for f32.""" + from getstream.video.rtc.track_util import AudioFormat + + # Create PcmData with AudioFormat.F32 enum + pcm = PcmData(sample_rate=16000, format=AudioFormat.F32, channels=1) + + # Verify format + assert pcm.format == AudioFormat.F32 + + # CRITICAL: Samples must be float32 + assert pcm.samples.dtype == np.float32, ( + f"Expected float32 empty array for AudioFormat.F32, got {pcm.samples.dtype}" + ) + + # Should be empty + assert len(pcm.samples) == 0 + + +def test_constructor_default_samples_respects_enum_s16(): + """Verify constructor works with AudioFormat enum for s16.""" + from getstream.video.rtc.track_util import AudioFormat + + # Create PcmData with AudioFormat.S16 enum + pcm = PcmData(sample_rate=16000, format=AudioFormat.S16, channels=1) + + # Verify format + assert pcm.format == AudioFormat.S16 + + # Samples must be int16 + assert pcm.samples.dtype == np.int16, ( + f"Expected int16 empty array for AudioFormat.S16, got {pcm.samples.dtype}" + ) + + # Should be empty + assert len(pcm.samples) == 0 + + +def test_constructor_default_samples_handles_float32_string(): + """Verify constructor handles 'float32' format string.""" + # Create PcmData with 'float32' format string + pcm = PcmData(sample_rate=16000, format="float32", channels=1) + + # Verify format + assert pcm.format == "float32" + + # Samples must be float32 + assert pcm.samples.dtype == np.float32, ( + f"Expected float32 empty array for 'float32' format, got {pcm.samples.dtype}" + ) + + # Should be empty + assert len(pcm.samples) == 0 + + +# ===== Tests for strict dtype validation ===== + + +def test_constructor_raises_on_int16_with_f32_format(): + """Test that constructor raises TypeError when passing int16 samples with f32 format.""" + samples = np.array([1, 2, 3], dtype=np.int16) + + with pytest.raises(TypeError) as exc_info: + PcmData(sample_rate=16000, format="f32", samples=samples, channels=1) + + # Verify error message is helpful + error_msg = str(exc_info.value) + assert "Dtype mismatch" in error_msg + assert "format='f32'" in error_msg + assert "float32" in error_msg + assert "int16" in error_msg + assert "to_float32()" in error_msg or "from_data()" in error_msg + + +def test_constructor_raises_on_float32_with_s16_format(): + """Test that constructor raises TypeError when passing float32 samples with s16 format.""" + samples = np.array([0.1, 0.2, 0.3], dtype=np.float32) + + with pytest.raises(TypeError) as exc_info: + PcmData(sample_rate=16000, format="s16", samples=samples, channels=1) + + # Verify error message is helpful + error_msg = str(exc_info.value) + assert "Dtype mismatch" in error_msg + assert "format='s16'" in error_msg + assert "int16" in error_msg + assert "float32" in error_msg + + +def test_constructor_raises_on_int16_with_enum_f32_format(): + """Test that constructor raises TypeError with AudioFormat.F32 enum.""" + from getstream.video.rtc.track_util import AudioFormat + + samples = np.array([1, 2, 3], dtype=np.int16) + + with pytest.raises(TypeError) as exc_info: + PcmData(sample_rate=16000, format=AudioFormat.F32, samples=samples, channels=1) + + error_msg = str(exc_info.value) + assert "Dtype mismatch" in error_msg + assert "float32" in error_msg + + +def test_constructor_raises_on_float32_with_enum_s16_format(): + """Test that constructor raises TypeError with AudioFormat.S16 enum.""" + from getstream.video.rtc.track_util import AudioFormat + + samples = np.array([0.1, 0.2, 0.3], dtype=np.float32) + + with pytest.raises(TypeError) as exc_info: + PcmData(sample_rate=16000, format=AudioFormat.S16, samples=samples, channels=1) + + error_msg = str(exc_info.value) + assert "Dtype mismatch" in error_msg + assert "int16" in error_msg + + +def test_constructor_accepts_matching_int16_with_s16(): + """Verify constructor accepts int16 samples with s16 format (should work).""" + samples = np.array([1, 2, 3], dtype=np.int16) + pcm = PcmData(sample_rate=16000, format="s16", samples=samples, channels=1) + + assert pcm.format == "s16" + assert pcm.samples.dtype == np.int16 + assert np.array_equal(pcm.samples, samples) + + +def test_constructor_accepts_matching_float32_with_f32(): + """Verify constructor accepts float32 samples with f32 format (should work).""" + samples = np.array([0.1, 0.2, 0.3], dtype=np.float32) + pcm = PcmData(sample_rate=16000, format="f32", samples=samples, channels=1) + + assert pcm.format == "f32" + assert pcm.samples.dtype == np.float32 + assert np.array_equal(pcm.samples, samples) + + +def test_from_data_handles_conversion_automatically(): + """Verify from_data() still handles automatic conversion (doesn't go through strict validation the same way).""" + # from_data should handle conversion internally + samples_int16 = np.array([1, 2, 3], dtype=np.int16) + + # This should work because from_data converts internally before calling __init__ + pcm = PcmData.from_data(samples_int16, sample_rate=16000, format="s16", channels=1) + + assert pcm.format == "s16" + assert pcm.samples.dtype == np.int16 + + +def test_direct_float32_to_int16_conversion_needs_clipping(): + """ + Direct test showing that float32→int16 conversion needs clipping and scaling. + + This demonstrates the bug: directly casting float32 to int16 produces wrong values. + """ + # Test values with typical audio scenarios + test_values_f32 = np.array([0.5, 1.0, 2.0, -1.5, -1.0, 0.0], dtype=np.float32) + + # What happens with direct cast (THE BUG)? + buggy_conversion = test_values_f32.astype(np.int16) + + # This produces: [0, 1, 2, -1, -1, 0] - completely wrong! + # 0.5 becomes 0 (should be ~16383) + # 2.0 becomes 2 (should be 32767 after clipping) + assert buggy_conversion[0] == 0 # 0.5 wrongly becomes 0 + assert buggy_conversion[2] == 2 # 2.0 wrongly becomes 2 + + # Correct conversion with clipping and scaling + max_int16 = np.iinfo(np.int16).max # 32767 + correct_conversion = np.clip(test_values_f32, -1.0, 1.0) * max_int16 + correct_conversion = np.round(correct_conversion).astype(np.int16) + + # Should be: [16384, 32767, 32767, -32767, -32767, 0] + # Note: 0.5 * 32767 = 16383.5, rounds to 16384 + assert correct_conversion[0] == 16384 # 0.5 * 32767, rounded + assert correct_conversion[1] == 32767 # 1.0 * 32767 + assert correct_conversion[2] == 32767 # 2.0 clipped to 1.0, then * 32767 + assert correct_conversion[3] == -32767 # -1.5 clipped to -1.0, then * 32767 + assert correct_conversion[4] == -32767 # -1.0 * 32767 + assert correct_conversion[5] == 0 # 0.0 * 32767 + + +def test_resample_with_extreme_values_should_clip(): + """ + Test that resample() properly clips extreme float values when converting to int16. + + Scenario: PyAV might return float32 values outside [-1.0, 1.0] after resampling. + When converting back to int16, these must be clipped and scaled properly. + """ + # Create int16 audio with values that when converted to float32 and back + # might go through the problematic conversion path + sample_rate_in = 16000 + sample_rate_out = 48000 + + # Create int16 samples near the limits + samples_i16 = np.array([16383, 32767, -16383, -32767, 0], dtype=np.int16) + + pcm_16k = PcmData( + sample_rate=sample_rate_in, + format="s16", + samples=samples_i16, + channels=1, + ) + + # Resample - this goes through PyAV which may return float32 + pcm_48k = pcm_16k.resample(sample_rate_out) + + # Result should still be int16 with valid range + assert pcm_48k.format == "s16" + assert pcm_48k.samples.dtype == np.int16 + + # All values should be in valid int16 range + assert np.all(pcm_48k.samples >= -32768) + assert np.all(pcm_48k.samples <= 32767) + + # Values should not be close to zero (which would indicate incorrect scaling) + # After resampling, we should have more samples but similar magnitudes + max_val = np.max(np.abs(pcm_48k.samples)) + assert max_val > 10000, ( + f"Resampled values seem too small: max={max_val}, might indicate scaling bug" + ) diff --git a/tests/rtc/test_track_util.py b/tests/rtc/test_track_util.py index 081e2943..006085d4 100644 --- a/tests/rtc/test_track_util.py +++ b/tests/rtc/test_track_util.py @@ -4,7 +4,11 @@ from unittest.mock import Mock import logging -from getstream.video.rtc.track_util import AudioTrackHandler, PcmData +from getstream.video.rtc.track_util import ( + AudioTrackHandler, + PcmData, + AudioFormat, +) import getstream.video.rtc.track_util as track_util @@ -34,7 +38,7 @@ def __init__(self, sample_rate=48000, samples=None, channels=1): self._samples = samples - def to_ndarray(self, format="s16"): + def to_ndarray(self, format=AudioFormat.S16): """Mock the to_ndarray method that's causing issues in newer PyAV.""" if format == "s16": return self._samples @@ -68,13 +72,16 @@ async def recv(self): return frame -@pytest.fixture(autouse=True) +@pytest.fixture def _monkeypatch_audio_frame(monkeypatch): """Monkey-patch `track_util.av.AudioFrame` so `isinstance(..., av.AudioFrame)` succeeds for the real PyAV class *and* any duck-typed test double that exposes the minimal interface we rely on (``sample_rate``, ``layout``, and ``to_ndarray``). This covers both ``MockAudioFrame`` and ad-hoc classes like ``BrokenAudioFrame`` defined inside individual tests. + + Note: This fixture is NOT auto-used because some tests (like resampling tests) + need the real PyAV AudioFrame.from_ndarray method. """ real_audio_frame_cls = track_util.av.AudioFrame @@ -99,7 +106,7 @@ class _AudioFrame(metaclass=_AudioFrameMeta): @pytest.mark.asyncio -async def test_audio_track_handler_initialization(): +async def test_audio_track_handler_initialization(_monkeypatch_audio_frame): """Test that AudioTrackHandler can be initialized properly.""" mock_track = MockAudioTrack() callback = Mock() @@ -113,7 +120,7 @@ async def test_audio_track_handler_initialization(): @pytest.mark.asyncio -async def test_audio_track_handler_start_stop(): +async def test_audio_track_handler_start_stop(_monkeypatch_audio_frame): """Test that AudioTrackHandler can start and stop properly.""" mock_track = MockAudioTrack() callback = Mock() @@ -131,7 +138,7 @@ async def test_audio_track_handler_start_stop(): @pytest.mark.asyncio -async def test_run_track_mono_audio(): +async def test_run_track_mono_audio(_monkeypatch_audio_frame): """Test _run_track with mono audio frames.""" # Create mock frames with mono audio frames = [ @@ -183,7 +190,7 @@ def callback(pcm_data: PcmData): @pytest.mark.asyncio -async def test_run_track_stereo_audio(): +async def test_run_track_stereo_audio(_monkeypatch_audio_frame): """Test _run_track with stereo audio frames (should convert to mono).""" # Create mock frames with stereo audio stereo_samples = np.array( @@ -223,7 +230,7 @@ def callback(pcm_data: PcmData): @pytest.mark.asyncio -async def test_run_track_wrong_sample_rate(): +async def test_run_track_wrong_sample_rate(_monkeypatch_audio_frame): """Test _run_track raises error for unsupported sample rate.""" # Create mock frame with wrong sample rate frames = [ @@ -254,7 +261,7 @@ def callback(pcm_data: PcmData): @pytest.mark.asyncio -async def test_run_track_cancelled_error(): +async def test_run_track_cancelled_error(_monkeypatch_audio_frame): """Test _run_track handles CancelledError gracefully.""" mock_track = MockAudioTrack(should_raise="cancelled") received_data = [] @@ -278,7 +285,7 @@ def callback(pcm_data: PcmData): @pytest.mark.asyncio -async def test_run_track_general_exception(): +async def test_run_track_general_exception(_monkeypatch_audio_frame): """Test _run_track handles general exceptions gracefully.""" mock_track = MockAudioTrack(should_raise=Exception("Test error")) received_data = [] @@ -302,7 +309,7 @@ def callback(pcm_data: PcmData): @pytest.mark.asyncio -async def test_run_track_to_ndarray_error(): +async def test_run_track_to_ndarray_error(_monkeypatch_audio_frame): """Test _run_track handles to_ndarray method errors (PyAV compatibility issue).""" # Create a mock frame that raises an error when to_ndarray is called @@ -312,7 +319,7 @@ def __init__(self): self.layout = Mock() self.layout.channels = [Mock()] # Single channel - def to_ndarray(self, format="s16"): + def to_ndarray(self, format=AudioFormat.S16): raise AttributeError("'AudioFrame' object has no attribute 'to_ndarray'") frames = [BrokenAudioFrame()] @@ -335,3 +342,607 @@ def callback(pcm_data: PcmData): # Should not have received any data due to to_ndarray error assert len(received_data) == 0 + + +@pytest.mark.asyncio +class TestPcmDataChunking: + """Test PcmData chunking and sliding window methods.""" + + def test_chunks_basic(self): + """Test basic chunking functionality.""" + # Create test audio with 10 samples + samples = np.arange(10, dtype=np.int16) + pcm = PcmData( + sample_rate=16000, format=AudioFormat.S16, samples=samples, channels=1 + ) + + # Get chunks of size 4 + chunks = list(pcm.chunks(4)) + + # Should have 3 chunks: [0:4], [4:8], [8:10] + assert len(chunks) == 3 + + # Check first chunk + np.testing.assert_array_equal( + chunks[0].samples, np.array([0, 1, 2, 3], dtype=np.int16) + ) + assert chunks[0].sample_rate == 16000 + assert chunks[0].format == "s16" + + # Check second chunk + np.testing.assert_array_equal( + chunks[1].samples, np.array([4, 5, 6, 7], dtype=np.int16) + ) + + # Check third chunk (partial) + np.testing.assert_array_equal( + chunks[2].samples, np.array([8, 9], dtype=np.int16) + ) + + def test_chunks_with_overlap(self): + """Test chunking with overlap.""" + samples = np.arange(10, dtype=np.int16) + pcm = PcmData( + samples=samples, sample_rate=16000, format=AudioFormat.S16, channels=1 + ) + + # Get chunks of size 4 with overlap of 2 + chunks = list(pcm.chunks(4, overlap=2)) + + # Should have chunks: [0:4], [2:6], [4:8], [6:10], [8:10] + assert len(chunks) == 5 + + np.testing.assert_array_equal( + chunks[0].samples, np.array([0, 1, 2, 3], dtype=np.int16) + ) + np.testing.assert_array_equal( + chunks[1].samples, np.array([2, 3, 4, 5], dtype=np.int16) + ) + np.testing.assert_array_equal( + chunks[2].samples, np.array([4, 5, 6, 7], dtype=np.int16) + ) + np.testing.assert_array_equal( + chunks[3].samples, np.array([6, 7, 8, 9], dtype=np.int16) + ) + np.testing.assert_array_equal( + chunks[4].samples, np.array([8, 9], dtype=np.int16) + ) # Final partial chunk + + def test_chunks_with_padding(self): + """Test chunking with padding for the last chunk.""" + samples = np.arange(10, dtype=np.int16) + pcm = PcmData( + samples=samples, sample_rate=16000, format=AudioFormat.S16, channels=1 + ) + + # Get chunks of size 4 with padding + chunks = list(pcm.chunks(4, pad_last=True)) + + # Should have 3 chunks, last one padded + assert len(chunks) == 3 + + # Last chunk should be padded with zeros + np.testing.assert_array_equal( + chunks[2].samples, np.array([8, 9, 0, 0], dtype=np.int16) + ) + + def test_chunks_stereo(self): + """Test chunking with stereo audio.""" + # Create stereo samples (2 channels, 8 samples each) + left_channel = np.arange(8, dtype=np.int16) + right_channel = np.arange(8, 16, dtype=np.int16) + samples = np.array([left_channel, right_channel]) # Shape: (2, 8) + + pcm = PcmData( + samples=samples, sample_rate=16000, format=AudioFormat.S16, channels=2 + ) + + # Get chunks of size 3 + chunks = list(pcm.chunks(3)) + + # Should have 3 chunks: samples 0-2, 3-5, 6-7 + assert len(chunks) == 3 + + # Check first chunk maintains stereo + assert chunks[0].channels == 2 + assert chunks[0].samples.shape == (2, 3) + np.testing.assert_array_equal( + chunks[0].samples[0], np.array([0, 1, 2], dtype=np.int16) + ) + np.testing.assert_array_equal( + chunks[0].samples[1], np.array([8, 9, 10], dtype=np.int16) + ) + + def test_sliding_window(self): + """Test sliding window functionality.""" + # Create 800 samples (50ms at 16kHz) + samples = np.arange(800, dtype=np.int16) + pcm = PcmData( + samples=samples, sample_rate=16000, format=AudioFormat.S16, channels=1 + ) + + # 25ms window = 400 samples, 10ms hop = 160 samples + windows = list(pcm.sliding_window(25.0, 10.0)) + + # Calculate expected number of windows + # With 800 samples, 400-sample windows, 160-sample hop: + # Windows start at: 0, 160, 320, 480, 640 + # Last full window starts at 400 (ends at 800) + assert len(windows) >= 3 + + # Check first window + assert len(windows[0].samples) == 400 + assert windows[0].samples[0] == 0 + + # Check second window (should start at 160) + assert len(windows[1].samples) == 400 + assert windows[1].samples[0] == 160 + + def test_sliding_window_with_padding(self): + """Test sliding window with padding.""" + samples = np.arange(500, dtype=np.int16) + pcm = PcmData( + samples=samples, sample_rate=16000, format=AudioFormat.S16, channels=1 + ) + + # 25ms window = 400 samples, 15ms hop = 240 samples + windows = list(pcm.sliding_window(25.0, 15.0, pad_last=True)) + + # Windows start at: 0, 240, 480 + # Last window needs padding + assert len(windows) == 3 + + # Last window should be padded + assert len(windows[-1].samples) == 400 + # Check that padding with zeros occurred + assert windows[-1].samples[-1] == 0 # Last value should be 0 (padded) + + +class TestPcmDataNoOpOptimizations: + """Test that PcmData methods avoid unnecessary work when data is already in target format.""" + + def test_resample_no_op_same_rate(self): + """Test that resample() returns self when already at target sample rate.""" + pcm = PcmData( + samples=np.array([1, 2, 3], dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Resample to same rate should return the same object (not a copy) + resampled = pcm.resample(16000) + assert resampled is pcm, ( + "resample() should return self when already at target rate" + ) + + def test_resample_no_op_same_rate_and_channels(self): + """Test that resample() returns self when rate and channels match.""" + pcm = PcmData( + samples=np.array([[1, 2], [3, 4]], dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=2, + ) + + # Resample to same rate and channels should return the same object + resampled = pcm.resample(16000, target_channels=2) + assert resampled is pcm, ( + "resample() should return self when rate and channels match" + ) + + def test_resample_does_work_different_rate(self): + """Test that resample() creates new object when rate differs.""" + # Use more realistic audio length (10ms at 16kHz = 160 samples) + pcm = PcmData( + samples=np.ones(160, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Resample to different rate should create new object + resampled = pcm.resample(48000) + assert resampled is not pcm, ( + "resample() should create new object for different rate" + ) + assert resampled.sample_rate == 48000 + # 160 samples at 16kHz -> 480 samples at 48kHz + assert len(resampled.samples) > len(pcm.samples) + + def test_resample_does_work_different_channels(self): + """Test that resample() creates new object when channels differ.""" + pcm = PcmData( + samples=np.array([1, 2, 3], dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Resample to different channels should create new object + resampled = pcm.resample(16000, target_channels=2) + assert resampled is not pcm, ( + "resample() should create new object for different channels" + ) + assert resampled.channels == 2 + + def test_to_float32_no_op_already_f32(self): + """Test that to_float32() returns self when already in f32 format.""" + pcm = PcmData( + samples=np.array([0.5, -0.5], dtype=np.float32), + sample_rate=16000, + format=AudioFormat.F32, + channels=1, + ) + + # Should return the same object + converted = pcm.to_float32() + assert converted is pcm, "to_float32() should return self when already f32" + + def test_to_float32_does_work_s16_format(self): + """Test that to_float32() converts when format is s16.""" + pcm = PcmData( + samples=np.array([16384, -16384], dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Should create new object with f32 format + converted = pcm.to_float32() + assert converted is not pcm, "to_float32() should create new object for s16" + assert converted.format == AudioFormat.F32 + assert converted.samples.dtype == np.float32 + # Check conversion: 16384 / 32768 = 0.5 + np.testing.assert_allclose(converted.samples, [0.5, -0.5], rtol=1e-4) + + def test_chained_operations_avoid_unnecessary_work(self): + """Test that chaining resample().to_float32() avoids work when already correct.""" + # Create f32 audio at 16kHz + pcm = PcmData( + samples=np.array([0.5, -0.5, 0.25], dtype=np.float32), + sample_rate=16000, + format=AudioFormat.F32, + channels=1, + ) + + # Chain operations that are already satisfied + result = pcm.resample(16000).to_float32() + + # Both operations should be no-ops, so we get back the original object + assert result is pcm, "Chained no-op operations should return original object" + + def test_chained_operations_do_necessary_work(self): + """Test that chaining resample().to_float32() does work when needed.""" + # Create s16 audio at 48kHz + pcm = PcmData( + samples=np.ones(480, dtype=np.int16) * 16384, + sample_rate=48000, + format=AudioFormat.S16, + channels=1, + ) + + # Chain operations that need work + result = pcm.resample(16000).to_float32() + + # Should have done both conversions + assert result is not pcm, "Should create new object when work is needed" + assert result.sample_rate == 16000, "Should have resampled to 16kHz" + assert result.format == AudioFormat.F32, "Should have converted to f32" + assert result.samples.dtype == np.float32, "Samples should be float32" + # 480 samples at 48kHz -> ~160 samples at 16kHz + assert 140 <= len(result.samples) <= 170 + + +class TestPcmDataHeadTail: + """Test PcmData head() and tail() methods for audio truncation and padding.""" + + def test_tail_truncate_longer_audio(self): + """Test tail() truncates audio longer than requested duration.""" + # 10 seconds of audio at 16kHz = 160000 samples + # Use modulo to keep values within int16 range + pcm = PcmData( + samples=(np.arange(160000) % 32000).astype(np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Get last 5 seconds + tail = pcm.tail(duration_s=5.0) + + assert tail.duration == 5.0 + assert len(tail.samples) == 80000 # 5s * 16000 samples/s + # Should contain the last 80000 samples + np.testing.assert_array_equal( + tail.samples, (np.arange(80000, 160000) % 32000).astype(np.int16) + ) + + def test_tail_return_whole_audio_when_shorter_no_pad(self): + """Test tail() returns whole audio when shorter than requested and pad=False.""" + # 1 second of audio + pcm = PcmData( + samples=np.arange(16000, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Request 8 seconds but don't pad + tail = pcm.tail(duration_s=8.0, pad=False) + + # Should get the whole 1 second + assert tail.duration == 1.0 + assert len(tail.samples) == 16000 + np.testing.assert_array_equal(tail.samples, np.arange(16000, dtype=np.int16)) + + def test_tail_pad_at_start(self): + """Test tail() pads with zeros at start when audio is shorter.""" + # 1 second of audio + pcm = PcmData( + samples=np.arange(16000, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Request 8 seconds, pad at start + tail = pcm.tail(duration_s=8.0, pad=True, pad_at="start") + + assert tail.duration == 8.0 + assert len(tail.samples) == 128000 # 8s * 16000 + # First 7 seconds should be zeros + np.testing.assert_array_equal( + tail.samples[:112000], np.zeros(112000, dtype=np.int16) + ) + # Last 1 second should be original data + np.testing.assert_array_equal( + tail.samples[112000:], np.arange(16000, dtype=np.int16) + ) + + def test_tail_pad_at_end(self): + """Test tail() pads with zeros at end when audio is shorter.""" + # 1 second of audio + pcm = PcmData( + samples=np.arange(16000, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Request 3 seconds, pad at end + tail = pcm.tail(duration_s=3.0, pad=True, pad_at="end") + + assert tail.duration == 3.0 + assert len(tail.samples) == 48000 # 3s * 16000 + # First 1 second should be original data + np.testing.assert_array_equal( + tail.samples[:16000], np.arange(16000, dtype=np.int16) + ) + # Last 2 seconds should be zeros + np.testing.assert_array_equal( + tail.samples[16000:], np.zeros(32000, dtype=np.int16) + ) + + def test_tail_stereo(self): + """Test tail() works with stereo audio.""" + # 2 seconds of stereo audio + # Use values within int16 range + left = np.arange(32000, dtype=np.int16) + right = (np.arange(32000) + 100).astype( + np.int16 + ) # Offset to differentiate channels + samples = np.array([left, right]) # Shape: (2, 32000) + + pcm = PcmData( + samples=samples, + sample_rate=16000, + format=AudioFormat.S16, + channels=2, + ) + + # Get last 1 second + tail = pcm.tail(duration_s=1.0) + + assert tail.duration == 1.0 + assert tail.channels == 2 + assert tail.samples.shape == (2, 16000) + # Check last second of each channel + np.testing.assert_array_equal( + tail.samples[0], np.arange(16000, 32000, dtype=np.int16) + ) + np.testing.assert_array_equal( + tail.samples[1], (np.arange(16000, 32000) + 100).astype(np.int16) + ) + + def test_head_truncate_longer_audio(self): + """Test head() truncates audio longer than requested duration.""" + # 10 seconds of audio at 16kHz = 160000 samples + pcm = PcmData( + samples=np.arange(160000, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Get first 3 seconds + head = pcm.head(duration_s=3.0) + + assert head.duration == 3.0 + assert len(head.samples) == 48000 # 3s * 16000 samples/s + # Should contain the first 48000 samples + np.testing.assert_array_equal(head.samples, np.arange(48000, dtype=np.int16)) + + def test_head_return_whole_audio_when_shorter_no_pad(self): + """Test head() returns whole audio when shorter than requested and pad=False.""" + # 1 second of audio + pcm = PcmData( + samples=np.arange(16000, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Request 8 seconds but don't pad + head = pcm.head(duration_s=8.0, pad=False) + + # Should get the whole 1 second + assert head.duration == 1.0 + assert len(head.samples) == 16000 + np.testing.assert_array_equal(head.samples, np.arange(16000, dtype=np.int16)) + + def test_head_pad_at_end(self): + """Test head() pads with zeros at end when audio is shorter (default behavior).""" + # 1 second of audio + pcm = PcmData( + samples=np.arange(16000, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Request 3 seconds, pad at end (default) + head = pcm.head(duration_s=3.0, pad=True) + + assert head.duration == 3.0 + assert len(head.samples) == 48000 # 3s * 16000 + # First 1 second should be original data + np.testing.assert_array_equal( + head.samples[:16000], np.arange(16000, dtype=np.int16) + ) + # Last 2 seconds should be zeros + np.testing.assert_array_equal( + head.samples[16000:], np.zeros(32000, dtype=np.int16) + ) + + def test_head_pad_at_start(self): + """Test head() pads with zeros at start when audio is shorter.""" + # 1 second of audio + pcm = PcmData( + samples=np.arange(16000, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Request 3 seconds, pad at start + head = pcm.head(duration_s=3.0, pad=True, pad_at="start") + + assert head.duration == 3.0 + assert len(head.samples) == 48000 # 3s * 16000 + # First 2 seconds should be zeros + np.testing.assert_array_equal( + head.samples[:32000], np.zeros(32000, dtype=np.int16) + ) + # Last 1 second should be original data + np.testing.assert_array_equal( + head.samples[32000:], np.arange(16000, dtype=np.int16) + ) + + def test_head_stereo(self): + """Test head() works with stereo audio.""" + # 2 seconds of stereo audio + left = np.arange(32000, dtype=np.int16) + right = np.arange(32000, 64000, dtype=np.int16) + samples = np.array([left, right]) # Shape: (2, 32000) + + pcm = PcmData( + samples=samples, + sample_rate=16000, + format=AudioFormat.S16, + channels=2, + ) + + # Get first 1 second + head = pcm.head(duration_s=1.0) + + assert head.duration == 1.0 + assert head.channels == 2 + assert head.samples.shape == (2, 16000) + # Check first second of each channel + np.testing.assert_array_equal(head.samples[0], np.arange(16000, dtype=np.int16)) + np.testing.assert_array_equal( + head.samples[1], np.arange(32000, 48000, dtype=np.int16) + ) + + def test_tail_f32_format(self): + """Test tail() works with f32 format.""" + pcm = PcmData( + samples=np.arange(16000, dtype=np.float32) / 32768.0, + sample_rate=16000, + format=AudioFormat.F32, + channels=1, + ) + + tail = pcm.tail(duration_s=0.5, pad=True, pad_at="start") + + assert tail.duration == 0.5 + assert tail.format == AudioFormat.F32 + assert tail.samples.dtype == np.float32 + # Last 0.5s should be original data + expected = np.arange(8000, 16000, dtype=np.float32) / 32768.0 + np.testing.assert_array_almost_equal(tail.samples[-8000:], expected) + + def test_parameter_validation(self): + """Test that invalid parameters raise appropriate errors.""" + pcm = PcmData( + samples=np.arange(16000, dtype=np.int16), + sample_rate=16000, + format=AudioFormat.S16, + channels=1, + ) + + # Invalid pad_at value + with np.testing.assert_raises(ValueError): + pcm.tail(duration_s=1.0, pad_at="middle") + + def test_real_world_use_case(self): + """Test the original use case: truncate to last 8s, pad at start if needed.""" + + # Simulate the user's truncate_audio_to_last_n_seconds function + def truncate_audio_to_last_n_seconds( + audio_array, n_seconds=8, sample_rate=16000 + ): + """Original function from user's code.""" + max_samples = n_seconds * sample_rate + if len(audio_array) > max_samples: + return audio_array[-max_samples:] + elif len(audio_array) < max_samples: + padding = max_samples - len(audio_array) + return np.pad( + audio_array, (padding, 0), mode="constant", constant_values=0 + ) + return audio_array + + # Test case 1: Long audio (should truncate) + long_audio = np.arange(200000, dtype=np.int16) + pcm_long = PcmData( + samples=long_audio, sample_rate=16000, format=AudioFormat.S16 + ) + + result_original = truncate_audio_to_last_n_seconds(long_audio) + result_new = pcm_long.tail(duration_s=8.0, pad=True, pad_at="start") + + np.testing.assert_array_equal(result_new.samples, result_original) + + # Test case 2: Short audio (should pad at start) + short_audio = np.arange(16000, dtype=np.int16) + pcm_short = PcmData( + samples=short_audio, sample_rate=16000, format=AudioFormat.S16 + ) + + result_original = truncate_audio_to_last_n_seconds(short_audio) + result_new = pcm_short.tail(duration_s=8.0, pad=True, pad_at="start") + + np.testing.assert_array_equal(result_new.samples, result_original) + + # Test case 3: Exact length (should return as-is) + exact_audio = np.arange(128000, dtype=np.int16) + pcm_exact = PcmData( + samples=exact_audio, sample_rate=16000, format=AudioFormat.S16 + ) + + result_original = truncate_audio_to_last_n_seconds(exact_audio) + result_new = pcm_exact.tail(duration_s=8.0, pad=True, pad_at="start") + + np.testing.assert_array_equal(result_new.samples, result_original)