diff --git a/backend/database/conversations.py b/backend/database/conversations.py index 1d9896ff443..fda6d2a62ed 100644 --- a/backend/database/conversations.py +++ b/backend/database/conversations.py @@ -319,18 +319,19 @@ def create_audio_files_from_chunks( if not chunks: return [] - # Group chunks based on 30-second gap rule + # Group chunks based on gap rule (90s threshold accommodates both 5s and 60s chunk durations) audio_files = [] current_group = [] + gap_threshold = 90 # seconds — must exceed max chunk duration (60s) to avoid false splits for i, chunk in enumerate(chunks): if not current_group: current_group.append(chunk) else: - # Check if there's a gap > 30 seconds between chunks + # Check if there's a gap between chunks exceeding the threshold prev_chunk = current_group[-1] time_gap = chunk['timestamp'] - prev_chunk['timestamp'] - if time_gap > 30: + if time_gap > gap_threshold: # Gap detected, finalize current group audio_file = _finalize_audio_file_group(uid, conversation_id, current_group, audio_files) if audio_file: @@ -372,11 +373,13 @@ def _finalize_audio_file_group( # Extract timestamps timestamps = [chunk['timestamp'] for chunk in chunk_group] - # Calculate started_at and duration from timestamps + # Calculate started_at and duration from timestamps and blob sizes started_at = datetime.fromtimestamp(chunk_group[0]['timestamp'], tz=timezone.utc) last_chunk_start = datetime.fromtimestamp(chunk_group[-1]['timestamp'], tz=timezone.utc) - # Add 5 seconds for the last chunk's duration - duration = (last_chunk_start - started_at).total_seconds() + 5.0 + # Estimate last chunk duration from blob size (PCM16 mono at 8kHz = 16000 bytes/sec) + last_chunk_size = chunk_group[-1].get('size', 0) + last_chunk_duration = last_chunk_size / 16000.0 if last_chunk_size > 0 else 5.0 + duration = (last_chunk_start - started_at).total_seconds() + last_chunk_duration return AudioFile( id=file_id, diff --git a/backend/routers/pusher.py b/backend/routers/pusher.py index 100833f916f..de60cd184d0 100644 --- a/backend/routers/pusher.py +++ b/backend/routers/pusher.py @@ -4,7 +4,7 @@ import time from collections import deque from datetime import datetime, timezone -from typing import List, Set +from typing import Dict, List, Set from fastapi import APIRouter from fastapi.websockets import WebSocketDisconnect, WebSocket @@ -27,7 +27,7 @@ realtime_transcript_webhook, get_audio_bytes_webhook_seconds, ) -from utils.other.storage import upload_audio_chunk +from utils.other.storage import upload_audio_chunk, upload_audio_chunks_batch from utils.speaker_identification import extract_speaker_samples import logging @@ -41,7 +41,8 @@ # Constants for private cloud sync PRIVATE_CLOUD_SYNC_PROCESS_INTERVAL = 1.0 -PRIVATE_CLOUD_CHUNK_DURATION = 5.0 +PRIVATE_CLOUD_CHUNK_DURATION = 60.0 +PRIVATE_CLOUD_BATCH_MAX_AGE = 60.0 # seconds — flush batch if oldest chunk exceeds this age PRIVATE_CLOUD_SYNC_MAX_RETRIES = 3 # Queue warning thresholds @@ -164,45 +165,44 @@ def on_done(t): audio_bytes_event = asyncio.Event() # Signals when items are added for instant wake async def process_private_cloud_queue(): - """Background task that processes private cloud sync uploads with retry logic.""" - nonlocal websocket_active - - while websocket_active or len(private_cloud_queue) > 0: - await asyncio.sleep(PRIVATE_CLOUD_SYNC_PROCESS_INTERVAL) - - if not private_cloud_queue: - continue + """Background task that batches private cloud sync uploads by conversation_id. - # Process all pending chunks - chunks_to_process = private_cloud_queue.copy() - private_cloud_queue.clear() - - successful_conversation_ids = set() # Track conversations with successful uploads - - for chunk_info in chunks_to_process: - chunk_data = chunk_info['data'] - conv_id = chunk_info['conversation_id'] - timestamp = chunk_info['timestamp'] - retries = chunk_info.get('retries', 0) - - try: - await asyncio.to_thread( - upload_audio_chunk, chunk_data, uid, conv_id, timestamp, cached_protection_level - ) - successful_conversation_ids.add(conv_id) - except Exception as e: - if retries < PRIVATE_CLOUD_SYNC_MAX_RETRIES: - # Re-queue with incremented retry count - chunk_info['retries'] = retries + 1 - private_cloud_queue.append(chunk_info) - logger.error(f"Private cloud upload failed (retry {retries + 1}): {e} {uid} {conv_id}") - else: - logger.info( - f"Private cloud upload failed after {PRIVATE_CLOUD_SYNC_MAX_RETRIES} retries, dropping chunk: {e} {uid} {conv_id}" - ) + Chunks are accumulated per conversation and flushed when: + - The batch reaches 60s of audio data, or + - The oldest chunk in the batch exceeds PRIVATE_CLOUD_BATCH_MAX_AGE, or + - The websocket disconnects (shutdown flush). + """ + nonlocal websocket_active - # Update audio_files for conversations with successful uploads - for conv_id in successful_conversation_ids: + # Pending batches keyed by conversation_id + pending: Dict[str, dict] = {} + + def _add_to_batch(chunk_info: dict): + conv_id = chunk_info['conversation_id'] + if conv_id not in pending: + pending[conv_id] = { + 'data': bytearray(), + 'conversation_id': conv_id, + 'timestamp': chunk_info['timestamp'], # oldest chunk timestamp + 'queued_at': time.monotonic(), + 'retries': 0, + } + batch = pending[conv_id] + batch['data'].extend(chunk_info['data']) + + async def _flush_batch(conv_id: str): + """Upload a batched chunk and update audio files.""" + batch = pending.pop(conv_id, None) + if not batch or len(batch['data']) == 0: + return + chunk_data = bytes(batch['data']) + timestamp = batch['timestamp'] + retries = batch.get('retries', 0) + try: + chunks_to_upload = [{'data': chunk_data, 'timestamp': timestamp}] + await asyncio.to_thread( + upload_audio_chunks_batch, chunks_to_upload, uid, conv_id, cached_protection_level + ) try: audio_files = await asyncio.to_thread(conversations_db.create_audio_files_from_chunks, uid, conv_id) if audio_files: @@ -214,6 +214,47 @@ async def process_private_cloud_queue(): ) except Exception as e: logger.error(f"Error updating audio files: {e} {uid} {conv_id}") + except Exception as e: + if retries < PRIVATE_CLOUD_SYNC_MAX_RETRIES: + batch['retries'] = retries + 1 + batch['data'] = bytearray(chunk_data) + batch['queued_at'] = time.monotonic() # reset age so next retry waits ~60s + pending[conv_id] = batch + logger.error(f"Private cloud batch upload failed (retry {retries + 1}): {e} {uid} {conv_id}") + else: + logger.info( + f"Private cloud batch upload failed after {PRIVATE_CLOUD_SYNC_MAX_RETRIES} retries, dropping: {e} {uid} {conv_id}" + ) + del chunk_data + + while websocket_active or len(private_cloud_queue) > 0 or len(pending) > 0: + await asyncio.sleep(PRIVATE_CLOUD_SYNC_PROCESS_INTERVAL) + + # Drain queue into pending batches + if private_cloud_queue: + chunks_to_process = private_cloud_queue.copy() + private_cloud_queue.clear() + for chunk_info in chunks_to_process: + _add_to_batch(chunk_info) + + if not pending: + continue + + now = time.monotonic() + batch_size_threshold = sample_rate * 2 * PRIVATE_CLOUD_CHUNK_DURATION + + # Determine which conversations to flush + conv_ids_to_flush = [] + for conv_id, batch in pending.items(): + batch_age = now - batch['queued_at'] + is_shutdown = not websocket_active + is_size_ready = len(batch['data']) >= batch_size_threshold + is_age_ready = batch_age >= PRIVATE_CLOUD_BATCH_MAX_AGE + if is_shutdown or is_size_ready or is_age_ready: + conv_ids_to_flush.append(conv_id) + + for conv_id in conv_ids_to_flush: + await _flush_batch(conv_id) async def process_speaker_sample_queue(): """Background task that processes speaker sample extraction requests.""" @@ -334,7 +375,28 @@ async def receive_tasks(): # Conversation ID if header_type == 103: - current_conversation_id = bytes(data[4:]).decode("utf-8") + new_conversation_id = bytes(data[4:]).decode("utf-8") + # Flush private cloud buffer for the old conversation before switching + if ( + private_cloud_sync_enabled + and current_conversation_id + and current_conversation_id != new_conversation_id + and len(private_cloud_sync_buffer) > 0 + ): + private_cloud_queue.append( + { + 'data': bytes(private_cloud_sync_buffer), + 'conversation_id': current_conversation_id, + 'timestamp': private_cloud_chunk_start_time or time.time(), + 'retries': 0, + } + ) + logger.info( + f"Flushed private cloud buffer on conversation switch: {len(private_cloud_sync_buffer)} bytes {uid}" + ) + private_cloud_sync_buffer = bytearray() + private_cloud_chunk_start_time = None + current_conversation_id = new_conversation_id logger.info(f"Pusher received conversation_id: {current_conversation_id} {uid}") continue @@ -405,7 +467,7 @@ async def receive_tasks(): private_cloud_chunk_start_time = buffer_start_timestamp private_cloud_sync_buffer.extend(audio_data) - # Queue chunk every 5 seconds (sample_rate * 2 bytes per sample * 5 seconds) + # Queue chunk every PRIVATE_CLOUD_CHUNK_DURATION seconds if len(private_cloud_sync_buffer) >= sample_rate * 2 * PRIVATE_CLOUD_CHUNK_DURATION: if len(private_cloud_queue) >= PRIVATE_CLOUD_QUEUE_WARN_SIZE: logger.warning(f"Warning: private_cloud_queue size {len(private_cloud_queue)} {uid}") diff --git a/backend/test.sh b/backend/test.sh index b2dafeb195b..7fef1458351 100755 --- a/backend/test.sh +++ b/backend/test.sh @@ -39,8 +39,11 @@ pytest tests/unit/test_translation_optimization.py -v pytest tests/unit/test_conversation_source_unknown.py -v pytest tests/unit/test_transcribe_conversation_cache.py -v pytest tests/unit/test_pusher_private_cloud_data_protection.py -v +pytest tests/unit/test_pusher_batch_upload.py -v pytest tests/unit/test_storage_upload_audio_chunk_data_protection.py -v +pytest tests/unit/test_storage_opus_encoding.py -v pytest tests/unit/test_people_conversations_500s.py -v pytest tests/unit/test_firestore_read_ops_cache.py -v pytest tests/unit/test_ws_auth_handshake.py -v pytest tests/unit/test_streaming_deepgram_backoff.py -v +pytest tests/unit/test_batch_upload_storage.py -v diff --git a/backend/tests/unit/test_batch_upload_storage.py b/backend/tests/unit/test_batch_upload_storage.py new file mode 100644 index 00000000000..08045593a20 --- /dev/null +++ b/backend/tests/unit/test_batch_upload_storage.py @@ -0,0 +1,676 @@ +"""Unit tests for upload_audio_chunks_batch (#5418 Phase 2). + +Verifies: +1. Batch upload with multiple chunks — streams to GCS +2. Single chunk batch uses single timestamp filename +3. Encrypted batch upload (enhanced protection) +4. Empty batch returns empty list +5. DB lookup count — only one fetch per batch when level is None +6. Unsorted input produces correctly ordered upload +""" + +import os +import sys +from unittest.mock import MagicMock, patch + +os.environ.setdefault("ENCRYPTION_SECRET", "omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv") + +# Mock heavy dependencies at sys.modules level before importing storage +sys.modules.setdefault("database._client", MagicMock()) + +_mock_gcs_storage = MagicMock() +_mock_gcs_client_instance = MagicMock() +_mock_gcs_storage.Client.return_value = _mock_gcs_client_instance +sys.modules.setdefault("google.cloud.storage", _mock_gcs_storage) +sys.modules.setdefault("google.cloud.storage.transfer_manager", MagicMock()) +sys.modules.setdefault("google.cloud.exceptions", MagicMock()) +sys.modules.setdefault("google.oauth2", MagicMock()) +sys.modules.setdefault("google.oauth2.service_account", MagicMock()) + +from utils.other import storage as storage_mod + + +class _FakeNotFound(Exception): + """Fake NotFound exception for testing (storage_mod.NotFound is mocked).""" + + pass + + +def _collect_written_bytes(mock_blob): + """Collect all bytes written via blob.open().__enter__().write() calls.""" + mock_file = mock_blob.open.return_value.__enter__.return_value + written = b'' + for c in mock_file.write.call_args_list: + written += c[0][0] + return written + + +class TestBatchUpload: + """Tests for upload_audio_chunks_batch streaming to GCS.""" + + def _setup_mock_bucket(self): + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_bucket.blob.return_value = mock_blob + storage_mod.storage_client.bucket.return_value = mock_bucket + return mock_bucket, mock_blob + + @patch.object(storage_mod, 'users_db') + def test_batch_multiple_chunks_standard(self, mock_users_db): + """Multiple chunks streamed as single .batch.bin object.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + + chunks = [ + {'data': b'\x01' * 100, 'timestamp': 1000.000}, + {'data': b'\x02' * 100, 'timestamp': 1005.000}, + {'data': b'\x03' * 100, 'timestamp': 1010.000}, + ] + + paths = storage_mod.upload_audio_chunks_batch( + chunks=chunks, + uid='test-uid', + conversation_id='conv-1', + data_protection_level='standard', + ) + + assert len(paths) == 1 + assert paths[0].endswith('.batch.bin') + assert '1000.000-1010.000' in paths[0] + # Verify streaming write was used (blob.open, not upload_from_string) + mock_blob.open.assert_called_once() + written = _collect_written_bytes(mock_blob) + assert len(written) == 300 + assert written[:100] == b'\x01' * 100 + assert written[100:200] == b'\x02' * 100 + assert written[200:] == b'\x03' * 100 + + @patch.object(storage_mod, 'users_db') + def test_single_chunk_batch(self, mock_users_db): + """Single chunk batch uses single timestamp in filename.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + + chunks = [{'data': b'\x01' * 50, 'timestamp': 1000.000}] + + paths = storage_mod.upload_audio_chunks_batch( + chunks=chunks, + uid='test-uid', + conversation_id='conv-1', + data_protection_level='standard', + ) + + assert len(paths) == 1 + assert '1000.000.batch.bin' in paths[0] + mock_blob.open.assert_called_once() + + @patch.object(storage_mod, 'encryption') + @patch.object(storage_mod, 'users_db') + def test_batch_encrypted(self, mock_users_db, mock_encryption): + """Enhanced protection encrypts each chunk and streams to GCS.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + mock_encryption.encrypt_audio_chunk.return_value = b'\xee' * 120 + + chunks = [ + {'data': b'\x01' * 100, 'timestamp': 1000.000}, + {'data': b'\x02' * 100, 'timestamp': 1005.000}, + ] + + paths = storage_mod.upload_audio_chunks_batch( + chunks=chunks, + uid='test-uid', + conversation_id='conv-1', + data_protection_level='enhanced', + ) + + assert len(paths) == 1 + assert paths[0].endswith('.batch.enc') + assert mock_encryption.encrypt_audio_chunk.call_count == 2 + written = _collect_written_bytes(mock_blob) + assert len(written) == 240 + + @patch.object(storage_mod, 'users_db') + def test_empty_batch_returns_empty(self, mock_users_db): + """Empty chunk list returns empty list without any GCS ops.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + + paths = storage_mod.upload_audio_chunks_batch( + chunks=[], + uid='test-uid', + conversation_id='conv-1', + data_protection_level='standard', + ) + + assert paths == [] + mock_blob.open.assert_not_called() + + @patch.object(storage_mod, 'users_db') + def test_db_lookup_once_per_batch(self, mock_users_db): + """When data_protection_level is None, DB is queried exactly once per batch.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + mock_users_db.get_data_protection_level.return_value = 'standard' + + chunks = [ + {'data': b'\x01' * 50, 'timestamp': 1000.000}, + {'data': b'\x02' * 50, 'timestamp': 1005.000}, + {'data': b'\x03' * 50, 'timestamp': 1010.000}, + ] + + storage_mod.upload_audio_chunks_batch( + chunks=chunks, + uid='test-uid', + conversation_id='conv-1', + ) + + mock_users_db.get_data_protection_level.assert_called_once_with('test-uid') + + @patch.object(storage_mod, 'users_db') + def test_unsorted_input_produces_ordered_upload(self, mock_users_db): + """Chunks provided out of order are sorted by timestamp before upload.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + + chunks = [ + {'data': b'\x03' * 50, 'timestamp': 1010.000}, + {'data': b'\x01' * 50, 'timestamp': 1000.000}, + {'data': b'\x02' * 50, 'timestamp': 1005.000}, + ] + + paths = storage_mod.upload_audio_chunks_batch( + chunks=chunks, + uid='test-uid', + conversation_id='conv-1', + data_protection_level='standard', + ) + + # Filename should reflect sorted order: first_ts-last_ts + assert '1000.000-1010.000' in paths[0] + # Streamed data should be in timestamp order + written = _collect_written_bytes(mock_blob) + assert written[:50] == b'\x01' * 50 + assert written[50:100] == b'\x02' * 50 + assert written[100:] == b'\x03' * 50 + + @patch.object(storage_mod, 'users_db') + def test_skips_db_when_level_provided(self, mock_users_db): + """When data_protection_level is explicitly provided, no DB read.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + + storage_mod.upload_audio_chunks_batch( + chunks=[{'data': b'\x01' * 50, 'timestamp': 1000.000}], + uid='test-uid', + conversation_id='conv-1', + data_protection_level='standard', + ) + + mock_users_db.get_data_protection_level.assert_not_called() + + @patch.object(storage_mod, 'users_db') + def test_large_batch_streams_correctly(self, mock_users_db): + """Large batch (50 chunks) streams without regression.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + + chunks = [{'data': b'\xaa' * 80_000, 'timestamp': 1000.000 + i * 5.0} for i in range(50)] + + paths = storage_mod.upload_audio_chunks_batch( + chunks=chunks, + uid='test-uid', + conversation_id='conv-1', + data_protection_level='standard', + ) + + assert len(paths) == 1 + assert paths[0].endswith('.batch.bin') + mock_file = mock_blob.open.return_value.__enter__.return_value + assert mock_file.write.call_count == 50 + + @patch.object(storage_mod, 'users_db') + def test_identical_timestamps_filename_and_order(self, mock_users_db): + """Chunks with identical timestamps produce valid filename and stable order.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + + chunks = [ + {'data': b'\x01' * 50, 'timestamp': 1000.000}, + {'data': b'\x02' * 50, 'timestamp': 1000.000}, + ] + + paths = storage_mod.upload_audio_chunks_batch( + chunks=chunks, + uid='test-uid', + conversation_id='conv-1', + data_protection_level='standard', + ) + + # Same first_ts and last_ts → filename uses single timestamp + assert len(paths) == 1 + assert '1000.000.batch.bin' in paths[0] + + @patch.object(storage_mod, 'users_db') + def test_streaming_api_call_args(self, mock_users_db): + """Batch mode uses blob.open('wb') with correct args and does NOT call upload_from_string.""" + mock_bucket, mock_blob = self._setup_mock_bucket() + + chunks = [{'data': b'\x01' * 50, 'timestamp': 1000.000}] + + storage_mod.upload_audio_chunks_batch( + chunks=chunks, + uid='test-uid', + conversation_id='conv-1', + data_protection_level='standard', + ) + + mock_blob.open.assert_called_once_with('wb', content_type='application/octet-stream') + mock_blob.upload_from_string.assert_not_called() + + +class TestListAudioChunksBatchAware: + """Tests for list_audio_chunks handling .batch.bin/.batch.enc files.""" + + def _setup_mock_bucket_with_blobs(self, blob_names): + """Set up mock bucket that returns specified blob names from list_blobs.""" + mock_bucket = MagicMock() + mock_blobs = [] + for name in blob_names: + mock_blob = MagicMock() + mock_blob.name = name + mock_blob.size = 1000 + mock_blobs.append(mock_blob) + mock_bucket.list_blobs.return_value = mock_blobs + storage_mod.storage_client.bucket.return_value = mock_bucket + return mock_bucket + + def test_list_per_chunk_files(self): + """Standard per-chunk .bin files are listed correctly.""" + self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000.bin', + 'chunks/uid/conv/1005.000.bin', + ] + ) + + result = storage_mod.list_audio_chunks('uid', 'conv') + assert len(result) == 2 + assert result[0]['timestamp'] == 1000.000 + assert result[1]['timestamp'] == 1005.000 + + def test_list_batch_bin_file(self): + """Batch .batch.bin file is listed with first timestamp.""" + self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000-1010.000.batch.bin', + ] + ) + + result = storage_mod.list_audio_chunks('uid', 'conv') + assert len(result) == 1 + assert result[0]['timestamp'] == 1000.000 + assert result[0]['is_batch'] is True + + def test_list_batch_enc_file(self): + """Batch .batch.enc file is listed with first timestamp.""" + self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000-1010.000.batch.enc', + ] + ) + + result = storage_mod.list_audio_chunks('uid', 'conv') + assert len(result) == 1 + assert result[0]['timestamp'] == 1000.000 + assert result[0]['is_batch'] is True + + def test_list_single_chunk_batch(self): + """Single-chunk batch file (no range) is listed correctly.""" + self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000.batch.bin', + ] + ) + + result = storage_mod.list_audio_chunks('uid', 'conv') + assert len(result) == 1 + assert result[0]['timestamp'] == 1000.000 + assert result[0]['is_batch'] is True + + def test_list_mixed_per_chunk_and_batch(self): + """Mixed per-chunk and batch files are all listed and sorted.""" + self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1010.000-1020.000.batch.bin', + 'chunks/uid/conv/1000.000.bin', + 'chunks/uid/conv/1005.000.bin', + ] + ) + + result = storage_mod.list_audio_chunks('uid', 'conv') + assert len(result) == 3 + assert result[0]['timestamp'] == 1000.000 + assert result[0]['is_batch'] is False + assert result[1]['timestamp'] == 1005.000 + assert result[2]['timestamp'] == 1010.000 + assert result[2]['is_batch'] is True + + def test_list_skips_meta_json(self): + """Meta JSON files are not listed as chunks.""" + self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000.bin', + 'chunks/uid/conv/1000.000.meta.json', + ] + ) + + result = storage_mod.list_audio_chunks('uid', 'conv') + assert len(result) == 1 + + def test_per_chunk_has_is_batch_false(self): + """Per-chunk files have is_batch=False.""" + self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000.enc', + ] + ) + + result = storage_mod.list_audio_chunks('uid', 'conv') + assert result[0]['is_batch'] is False + + +class TestDeleteAudioChunksBatchAware: + """Tests for delete_audio_chunks handling batch files.""" + + def _setup_mock_bucket_with_blobs(self, blob_names): + """Set up mock bucket with blobs for exists/delete and list_blobs.""" + mock_bucket = MagicMock() + # Track which blobs exist + blob_map = {} + for name in blob_names: + mock_blob = MagicMock() + mock_blob.name = name + mock_blob.exists.return_value = True + blob_map[name] = mock_blob + + def make_blob(path): + if path in blob_map: + return blob_map[path] + mb = MagicMock() + mb.name = path + mb.exists.return_value = False + return mb + + mock_bucket.blob.side_effect = make_blob + + # list_blobs returns all blobs + list_blobs = [] + for name in blob_names: + lb = MagicMock() + lb.name = name + list_blobs.append(lb) + mock_bucket.list_blobs.return_value = list_blobs + storage_mod.storage_client.bucket.return_value = mock_bucket + return mock_bucket, blob_map + + def test_delete_per_chunk_by_timestamp(self): + """Per-chunk files are deleted by exact timestamp match.""" + mock_bucket, blob_map = self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000.bin', + ] + ) + + storage_mod.delete_audio_chunks('uid', 'conv', [1000.000]) + + blob_map['chunks/uid/conv/1000.000.bin'].delete.assert_called_once() + + def test_delete_batch_by_start_timestamp(self): + """Batch files are deleted when start timestamp matches.""" + mock_bucket, blob_map = self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000-1010.000.batch.bin', + ] + ) + + storage_mod.delete_audio_chunks('uid', 'conv', [1000.000]) + + # The batch blob found via list_blobs scan should be deleted + list_blobs = mock_bucket.list_blobs.return_value + list_blobs[0].delete.assert_called_once() + + def test_delete_tries_batch_extensions(self): + """Direct blob lookup tries .batch.enc and .batch.bin extensions.""" + mock_bucket, blob_map = self._setup_mock_bucket_with_blobs( + [ + 'chunks/uid/conv/1000.000.batch.enc', + ] + ) + + storage_mod.delete_audio_chunks('uid', 'conv', [1000.000]) + + blob_map['chunks/uid/conv/1000.000.batch.enc'].delete.assert_called_once() + + +class TestDownloadAudioChunksMergeBatchAware: + """Tests for download_audio_chunks_and_merge handling batch blobs.""" + + def _setup_mock_bucket(self, list_blobs_data, download_data): + """ + Set up mock bucket for download tests. + list_blobs_data: list of (name, size) tuples for list_audio_chunks + download_data: dict of path -> bytes for download_as_bytes + """ + mock_bucket = MagicMock() + + # list_blobs for list_audio_chunks + list_blobs = [] + for name, size in list_blobs_data: + mb = MagicMock() + mb.name = name + mb.size = size + list_blobs.append(mb) + mock_bucket.list_blobs.return_value = list_blobs + + # blob download + def make_blob(path): + mb = MagicMock() + mb.name = path + if path in download_data: + mb.download_as_bytes.return_value = download_data[path] + else: + mb.download_as_bytes.side_effect = _FakeNotFound(f"Not found: {path}") + return mb + + mock_bucket.blob.side_effect = make_blob + storage_mod.storage_client.bucket.return_value = mock_bucket + return mock_bucket + + @patch.object(storage_mod, 'NotFound', _FakeNotFound) + def test_download_batch_blob_found(self): + """Batch blob is resolved via list_audio_chunks and downloaded once.""" + batch_path = 'chunks/uid/conv/1000.000-1010.000.batch.bin' + batch_data = b'\x01' * 100 + b'\x02' * 100 + + self._setup_mock_bucket( + list_blobs_data=[(batch_path, 200)], + download_data={batch_path: batch_data}, + ) + + result = storage_mod.download_audio_chunks_and_merge( + uid='uid', + conversation_id='conv', + timestamps=[1000.000], + fill_gaps=False, + ) + + assert result == batch_data + + @patch.object(storage_mod, 'NotFound', _FakeNotFound) + def test_download_per_chunk_still_works(self): + """Per-chunk .bin files are still downloaded correctly.""" + self._setup_mock_bucket( + list_blobs_data=[ + ('chunks/uid/conv/1000.000.bin', 100), + ('chunks/uid/conv/1005.000.bin', 100), + ], + download_data={ + 'chunks/uid/conv/1000.000.bin': b'\x01' * 100, + 'chunks/uid/conv/1005.000.bin': b'\x02' * 100, + }, + ) + + result = storage_mod.download_audio_chunks_and_merge( + uid='uid', + conversation_id='conv', + timestamps=[1000.000, 1005.000], + fill_gaps=False, + ) + + assert result == b'\x01' * 100 + b'\x02' * 100 + + @patch.object(storage_mod, 'NotFound', _FakeNotFound) + def test_download_batch_deduplicates(self): + """Multiple timestamps pointing to same batch blob download it once.""" + batch_path = 'chunks/uid/conv/1000.000-1010.000.batch.bin' + batch_data = b'\xaa' * 300 + + mock_bucket = self._setup_mock_bucket( + # list_audio_chunks only returns batch with first timestamp + list_blobs_data=[(batch_path, 300)], + download_data={batch_path: batch_data}, + ) + + result = storage_mod.download_audio_chunks_and_merge( + uid='uid', + conversation_id='conv', + timestamps=[1000.000], + fill_gaps=False, + ) + + assert result == batch_data + + @patch.object(storage_mod, 'NotFound', _FakeNotFound) + @patch.object(storage_mod, 'encryption') + def test_download_batch_encrypted_decrypts(self, mock_encryption): + """Encrypted batch blob is decrypted via decrypt_audio_file.""" + batch_path = 'chunks/uid/conv/1000.000-1010.000.batch.enc' + encrypted_data = b'\xee' * 200 + decrypted_data = b'\xdd' * 180 + + mock_encryption.decrypt_audio_file.return_value = decrypted_data + + self._setup_mock_bucket( + list_blobs_data=[(batch_path, 200)], + download_data={batch_path: encrypted_data}, + ) + + result = storage_mod.download_audio_chunks_and_merge( + uid='uid', + conversation_id='conv', + timestamps=[1000.000], + fill_gaps=False, + ) + + mock_encryption.decrypt_audio_file.assert_called_once_with(encrypted_data, 'uid') + assert result == decrypted_data + + +class TestCopyAudioChunksForMergeBatchAware: + """Tests for _copy_audio_chunks_for_merge preserving batch blob filenames.""" + + @classmethod + def setup_class(cls): + """Mock heavy transitive imports before loading merge_conversations.""" + for mod_name in [ + 'openai', + 'openai.resources', + 'openai._client', + 'utils.llm', + 'utils.llm.clients', + 'utils.apps', + 'database.apps', + 'database.memories', + 'database.tasks', + 'database.plugins', + 'database.notifications', + ]: + sys.modules.setdefault(mod_name, MagicMock()) + + @patch('utils.conversations.merge_conversations.conversations_db') + @patch('utils.conversations.merge_conversations.list_audio_chunks') + @patch('utils.conversations.merge_conversations.storage_client') + def test_copy_preserves_batch_filename(self, mock_storage_client, mock_list, mock_conv_db): + """Batch blob filenames are preserved during copy (not renamed to single-timestamp).""" + from utils.conversations.merge_conversations import _copy_audio_chunks_for_merge + + mock_bucket = MagicMock() + mock_storage_client.bucket.return_value = mock_bucket + + mock_list.return_value = [ + { + 'timestamp': 1000.000, + 'path': 'chunks/uid/conv-old/1000.000-1060.000.batch.bin', + 'size': 960000, + 'is_batch': True, + } + ] + mock_conv_db.create_audio_files_from_chunks.return_value = [] + + _copy_audio_chunks_for_merge('uid', [{'id': 'conv-old'}], 'conv-new') + + # Verify the copy preserved the batch filename + copy_call = mock_bucket.copy_blob.call_args + new_path = copy_call[0][2] # third positional arg is new_name + assert new_path == 'chunks/uid/conv-new/1000.000-1060.000.batch.bin' + + @patch('utils.conversations.merge_conversations.conversations_db') + @patch('utils.conversations.merge_conversations.list_audio_chunks') + @patch('utils.conversations.merge_conversations.storage_client') + def test_copy_preserves_single_chunk_filename(self, mock_storage_client, mock_list, mock_conv_db): + """Single-chunk filenames are also preserved during copy.""" + from utils.conversations.merge_conversations import _copy_audio_chunks_for_merge + + mock_bucket = MagicMock() + mock_storage_client.bucket.return_value = mock_bucket + + mock_list.return_value = [ + { + 'timestamp': 1000.000, + 'path': 'chunks/uid/conv-old/1000.000.bin', + 'size': 80000, + 'is_batch': False, + } + ] + mock_conv_db.create_audio_files_from_chunks.return_value = [] + + _copy_audio_chunks_for_merge('uid', [{'id': 'conv-old'}], 'conv-new') + + copy_call = mock_bucket.copy_blob.call_args + new_path = copy_call[0][2] + assert new_path == 'chunks/uid/conv-new/1000.000.bin' + + @patch('utils.conversations.merge_conversations.conversations_db') + @patch('utils.conversations.merge_conversations.list_audio_chunks') + @patch('utils.conversations.merge_conversations.storage_client') + def test_copy_mixed_single_and_batch(self, mock_storage_client, mock_list, mock_conv_db): + """Mixed single + batch blobs are all copied with original filenames.""" + from utils.conversations.merge_conversations import _copy_audio_chunks_for_merge + + mock_bucket = MagicMock() + mock_storage_client.bucket.return_value = mock_bucket + + mock_list.return_value = [ + { + 'timestamp': 1000.000, + 'path': 'chunks/uid/conv-old/1000.000.enc', + 'size': 80000, + 'is_batch': False, + }, + { + 'timestamp': 1010.000, + 'path': 'chunks/uid/conv-old/1010.000-1070.000.batch.enc', + 'size': 960000, + 'is_batch': True, + }, + ] + mock_conv_db.create_audio_files_from_chunks.return_value = [] + + _copy_audio_chunks_for_merge('uid', [{'id': 'conv-old'}], 'conv-new') + + assert mock_bucket.copy_blob.call_count == 2 + paths = [call[0][2] for call in mock_bucket.copy_blob.call_args_list] + assert 'chunks/uid/conv-new/1000.000.enc' in paths + assert 'chunks/uid/conv-new/1010.000-1070.000.batch.enc' in paths diff --git a/backend/tests/unit/test_pusher_batch_upload.py b/backend/tests/unit/test_pusher_batch_upload.py new file mode 100644 index 00000000000..4ecc5bbd2ba --- /dev/null +++ b/backend/tests/unit/test_pusher_batch_upload.py @@ -0,0 +1,516 @@ +"""Unit tests for pusher private cloud batch upload logic (Phase 2 of #5418). + +Tests the batching behavior in process_private_cloud_queue() without importing +the full pusher module. Mirrors the pattern used in test_pusher_private_cloud_data_protection.py. +""" + +import time + +import pytest + +# --- Reimplemented constants (mirrors pusher.py) --- + +PRIVATE_CLOUD_CHUNK_DURATION = 60.0 +PRIVATE_CLOUD_BATCH_MAX_AGE = 60.0 +PRIVATE_CLOUD_SYNC_MAX_RETRIES = 3 + + +def _add_to_batch(pending, chunk_info): + """Mirrors the _add_to_batch inner function in process_private_cloud_queue.""" + conv_id = chunk_info['conversation_id'] + if conv_id not in pending: + pending[conv_id] = { + 'data': bytearray(), + 'conversation_id': conv_id, + 'timestamp': chunk_info['timestamp'], + 'queued_at': chunk_info.get('_queued_at', time.monotonic()), + 'retries': 0, + } + batch = pending[conv_id] + batch['data'].extend(chunk_info['data']) + + +def _get_flush_candidates(pending, sample_rate, now, websocket_active=True): + """Mirrors the flush decision logic in process_private_cloud_queue.""" + batch_size_threshold = sample_rate * 2 * PRIVATE_CLOUD_CHUNK_DURATION + conv_ids_to_flush = [] + for conv_id, batch in pending.items(): + batch_age = now - batch['queued_at'] + is_shutdown = not websocket_active + is_size_ready = len(batch['data']) >= batch_size_threshold + is_age_ready = batch_age >= PRIVATE_CLOUD_BATCH_MAX_AGE + if is_shutdown or is_size_ready or is_age_ready: + conv_ids_to_flush.append(conv_id) + return conv_ids_to_flush + + +class TestBatchAccumulation: + """Tests that chunks for the same conversation accumulate into one batch.""" + + def test_multiple_chunks_same_conversation_batched(self): + """Multiple chunks for the same conversation produce one batch entry.""" + pending = {} + now = time.monotonic() + for i in range(12): + _add_to_batch( + pending, + { + 'data': b'\x00' * 80_000, + 'conversation_id': 'conv-1', + 'timestamp': 1000.0 + i * 5.0, + '_queued_at': now, + }, + ) + assert len(pending) == 1 + assert 'conv-1' in pending + assert len(pending['conv-1']['data']) == 80_000 * 12 + # Oldest timestamp preserved + assert pending['conv-1']['timestamp'] == 1000.0 + + def test_different_conversations_separate_batches(self): + """Chunks for different conversations go to separate batches.""" + pending = {} + now = time.monotonic() + _add_to_batch( + pending, {'data': b'\x01' * 100, 'conversation_id': 'conv-A', 'timestamp': 1.0, '_queued_at': now} + ) + _add_to_batch( + pending, {'data': b'\x02' * 200, 'conversation_id': 'conv-B', 'timestamp': 2.0, '_queued_at': now} + ) + _add_to_batch( + pending, {'data': b'\x03' * 150, 'conversation_id': 'conv-A', 'timestamp': 3.0, '_queued_at': now} + ) + assert len(pending) == 2 + assert len(pending['conv-A']['data']) == 250 + assert len(pending['conv-B']['data']) == 200 + # Oldest timestamps preserved + assert pending['conv-A']['timestamp'] == 1.0 + assert pending['conv-B']['timestamp'] == 2.0 + + +class TestSizeFlush: + """Tests that batches flush when they reach 60s of audio data.""" + + def test_flush_at_60s_threshold(self): + """Batch flushes when accumulated data reaches 60s at sample_rate=8000.""" + pending = {} + sample_rate = 8000 + now = time.monotonic() + # 60s of PCM16 at 8kHz = 8000 * 2 * 60 = 960,000 bytes + _add_to_batch( + pending, + { + 'data': b'\x00' * 960_000, + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + '_queued_at': now, + }, + ) + flush = _get_flush_candidates(pending, sample_rate, now, websocket_active=True) + assert 'conv-1' in flush + + def test_no_flush_just_below_threshold(self): + """Batch does NOT flush when 1 byte below the 60s threshold.""" + pending = {} + sample_rate = 8000 + now = time.monotonic() + threshold = sample_rate * 2 * 60 # 960,000 + _add_to_batch( + pending, + { + 'data': b'\x00' * (threshold - 1), + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + '_queued_at': now, + }, + ) + flush = _get_flush_candidates(pending, sample_rate, now, websocket_active=True) + assert 'conv-1' not in flush + + def test_no_flush_below_threshold(self): + """Batch does NOT flush when below 60s of data and within max age.""" + pending = {} + sample_rate = 8000 + now = time.monotonic() + # 30s of audio = 480,000 bytes, well under 960,000 + _add_to_batch( + pending, + { + 'data': b'\x00' * 480_000, + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + '_queued_at': now, + }, + ) + flush = _get_flush_candidates(pending, sample_rate, now, websocket_active=True) + assert 'conv-1' not in flush + + +class TestMaxAgeFlush: + """Tests that the 60s max-age timer forces flush of idle conversations.""" + + def test_flush_after_max_age(self): + """Sub-threshold batch flushes when oldest chunk exceeds 60s age.""" + pending = {} + sample_rate = 8000 + old_time = time.monotonic() - 61.0 # 61 seconds ago + _add_to_batch( + pending, + { + 'data': b'\x00' * 1000, + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + '_queued_at': old_time, + }, + ) + now = time.monotonic() + flush = _get_flush_candidates(pending, sample_rate, now, websocket_active=True) + assert 'conv-1' in flush + + def test_flush_at_exact_max_age(self): + """Batch flushes when age equals exactly 60s (>=).""" + pending = {} + sample_rate = 8000 + now = 1000.0 + queued_at = now - 60.0 # exactly 60s ago + pending['conv-1'] = { + 'data': bytearray(b'\x00' * 1000), + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + 'queued_at': queued_at, + 'retries': 0, + } + flush = _get_flush_candidates(pending, sample_rate, now, websocket_active=True) + assert 'conv-1' in flush + + def test_no_flush_just_before_max_age(self): + """Batch does NOT flush when 0.1s before max age.""" + pending = {} + sample_rate = 8000 + now = 1000.0 + queued_at = now - 59.9 # 59.9s ago + pending['conv-1'] = { + 'data': bytearray(b'\x00' * 1000), + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + 'queued_at': queued_at, + 'retries': 0, + } + flush = _get_flush_candidates(pending, sample_rate, now, websocket_active=True) + assert 'conv-1' not in flush + + def test_no_flush_before_max_age(self): + """Batch within max-age window does not flush.""" + pending = {} + sample_rate = 8000 + recent_time = time.monotonic() - 30.0 # 30 seconds ago + _add_to_batch( + pending, + { + 'data': b'\x00' * 1000, + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + '_queued_at': recent_time, + }, + ) + now = time.monotonic() + flush = _get_flush_candidates(pending, sample_rate, now, websocket_active=True) + assert 'conv-1' not in flush + + +class TestRetryBackoff: + """Tests that failed uploads reset queued_at for natural backoff.""" + + def test_retry_resets_queued_at(self): + """After upload failure, queued_at is reset so batch won't re-flush for ~60s.""" + pending = {} + sample_rate = 8000 + now = 1000.0 + + # Simulate a batch that failed upload — retry logic resets queued_at + failed_batch = { + 'data': bytearray(b'\x00' * 1000), + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + 'queued_at': now, # reset to "now" by retry logic + 'retries': 1, + } + pending['conv-1'] = failed_batch + + # 59.9s later — should NOT flush yet (backoff) + flush = _get_flush_candidates(pending, sample_rate, now + 59.9, websocket_active=True) + assert 'conv-1' not in flush + + # 60s later — should flush (backoff expired) + flush = _get_flush_candidates(pending, sample_rate, now + 60.0, websocket_active=True) + assert 'conv-1' in flush + + def test_retry_preserves_data_and_increments_count(self): + """Retry preserves chunk data and increments retry count.""" + # Mirrors the retry logic in _flush_batch + batch = { + 'data': bytearray(b'\xab' * 500), + 'conversation_id': 'conv-1', + 'timestamp': 100.0, + 'queued_at': 900.0, + 'retries': 0, + } + # Simulate failed upload — retry path + retries = batch['retries'] + chunk_data = bytes(batch['data']) + batch['retries'] = retries + 1 + batch['data'] = bytearray(chunk_data) + batch['queued_at'] = 1000.0 # reset + + assert batch['retries'] == 1 + assert len(batch['data']) == 500 + assert batch['queued_at'] == 1000.0 + + +class TestShutdownFlush: + """Tests that shutdown forces flush of all pending batches regardless of size/age.""" + + def test_shutdown_flushes_all_pending(self): + """All conversations flush on shutdown even if below thresholds.""" + pending = {} + sample_rate = 8000 + now = time.monotonic() + _add_to_batch( + pending, {'data': b'\x00' * 100, 'conversation_id': 'conv-A', 'timestamp': 1.0, '_queued_at': now} + ) + _add_to_batch( + pending, {'data': b'\x00' * 200, 'conversation_id': 'conv-B', 'timestamp': 2.0, '_queued_at': now} + ) + flush = _get_flush_candidates(pending, sample_rate, now, websocket_active=False) + assert set(flush) == {'conv-A', 'conv-B'} + + +class TestConversationSwitch: + """Tests that conversation switch flushes old conversation buffer.""" + + def test_conversation_switch_flushes_old_buffer(self): + """Mirrors the conversation switch flush in receive_tasks (header_type 103).""" + private_cloud_sync_buffer = bytearray(b'\x00' * 500) + current_conversation_id = 'conv-old' + new_conversation_id = 'conv-new' + private_cloud_chunk_start_time = 100.0 + private_cloud_queue = [] + + # Reproduce the flush logic from pusher.py header_type == 103 + if ( + current_conversation_id + and current_conversation_id != new_conversation_id + and len(private_cloud_sync_buffer) > 0 + ): + private_cloud_queue.append( + { + 'data': bytes(private_cloud_sync_buffer), + 'conversation_id': current_conversation_id, + 'timestamp': private_cloud_chunk_start_time or time.time(), + 'retries': 0, + } + ) + private_cloud_sync_buffer = bytearray() + private_cloud_chunk_start_time = None + + assert len(private_cloud_queue) == 1 + assert private_cloud_queue[0]['conversation_id'] == 'conv-old' + assert len(private_cloud_queue[0]['data']) == 500 + assert len(private_cloud_sync_buffer) == 0 + + def test_no_flush_on_same_conversation_id(self): + """No flush if conversation_id doesn't change.""" + private_cloud_sync_buffer = bytearray(b'\x00' * 500) + current_conversation_id = 'conv-1' + new_conversation_id = 'conv-1' + private_cloud_queue = [] + + if ( + current_conversation_id + and current_conversation_id != new_conversation_id + and len(private_cloud_sync_buffer) > 0 + ): + private_cloud_queue.append( + { + 'data': bytes(private_cloud_sync_buffer), + 'conversation_id': current_conversation_id, + 'timestamp': time.time(), + 'retries': 0, + } + ) + private_cloud_sync_buffer = bytearray() + + assert len(private_cloud_queue) == 0 + assert len(private_cloud_sync_buffer) == 500 + + def test_no_flush_on_empty_buffer(self): + """No flush when buffer is empty even if conversation_id changes.""" + private_cloud_sync_buffer = bytearray() + current_conversation_id = 'conv-old' + new_conversation_id = 'conv-new' + private_cloud_queue = [] + private_cloud_sync_enabled = True + + if ( + private_cloud_sync_enabled + and current_conversation_id + and current_conversation_id != new_conversation_id + and len(private_cloud_sync_buffer) > 0 + ): + private_cloud_queue.append( + { + 'data': bytes(private_cloud_sync_buffer), + 'conversation_id': current_conversation_id, + 'timestamp': time.time(), + 'retries': 0, + } + ) + + assert len(private_cloud_queue) == 0 + + def test_no_flush_when_no_current_conversation(self): + """No flush when current_conversation_id is None.""" + private_cloud_sync_buffer = bytearray(b'\x00' * 500) + current_conversation_id = None + new_conversation_id = 'conv-new' + private_cloud_queue = [] + private_cloud_sync_enabled = True + + if ( + private_cloud_sync_enabled + and current_conversation_id + and current_conversation_id != new_conversation_id + and len(private_cloud_sync_buffer) > 0 + ): + private_cloud_queue.append( + { + 'data': bytes(private_cloud_sync_buffer), + 'conversation_id': current_conversation_id, + 'timestamp': time.time(), + 'retries': 0, + } + ) + + assert len(private_cloud_queue) == 0 + + +# --- Tests for conversations.py gap threshold and duration logic --- + + +def _finalize_audio_file_group_duration(chunk_group): + """Mirrors _finalize_audio_file_group duration calculation from conversations.py.""" + from datetime import datetime, timezone + + started_at = datetime.fromtimestamp(chunk_group[0]['timestamp'], tz=timezone.utc) + last_chunk_start = datetime.fromtimestamp(chunk_group[-1]['timestamp'], tz=timezone.utc) + last_chunk_size = chunk_group[-1].get('size', 0) + last_chunk_duration = last_chunk_size / 16000.0 if last_chunk_size > 0 else 5.0 + duration = (last_chunk_start - started_at).total_seconds() + last_chunk_duration + return duration + + +def _group_chunks_by_gap(chunks, gap_threshold=90): + """Mirrors create_audio_files_from_chunks gap grouping from conversations.py.""" + groups = [] + current_group = [] + for chunk in chunks: + if not current_group: + current_group.append(chunk) + else: + time_gap = chunk['timestamp'] - current_group[-1]['timestamp'] + if time_gap > gap_threshold: + groups.append(current_group) + current_group = [chunk] + else: + current_group.append(chunk) + if current_group: + groups.append(current_group) + return groups + + +class TestAudioFileGapThreshold: + """Tests for the 90s gap threshold in create_audio_files_from_chunks.""" + + def test_gap_at_90s_no_split(self): + """Chunks 90s apart should NOT split (gap <= threshold).""" + chunks = [ + {'timestamp': 1000.0, 'size': 960_000}, + {'timestamp': 1090.0, 'size': 960_000}, + ] + groups = _group_chunks_by_gap(chunks, gap_threshold=90) + assert len(groups) == 1 + assert len(groups[0]) == 2 + + def test_gap_at_91s_splits(self): + """Chunks 91s apart should split (gap > threshold).""" + chunks = [ + {'timestamp': 1000.0, 'size': 960_000}, + {'timestamp': 1091.0, 'size': 960_000}, + ] + groups = _group_chunks_by_gap(chunks, gap_threshold=90) + assert len(groups) == 2 + assert len(groups[0]) == 1 + assert len(groups[1]) == 1 + + def test_60s_chunks_stay_grouped(self): + """Consecutive 60s chunks (normal batching pattern) stay in one group.""" + chunks = [{'timestamp': 1000.0 + i * 60.0, 'size': 960_000} for i in range(5)] + groups = _group_chunks_by_gap(chunks, gap_threshold=90) + assert len(groups) == 1 + assert len(groups[0]) == 5 + + def test_5s_chunks_stay_grouped(self): + """Legacy 5s chunks still group correctly.""" + chunks = [{'timestamp': 1000.0 + i * 5.0, 'size': 80_000} for i in range(12)] + groups = _group_chunks_by_gap(chunks, gap_threshold=90) + assert len(groups) == 1 + assert len(groups[0]) == 12 + + +class TestAudioFileDurationFromSize: + """Tests for blob-size-based duration calculation in _finalize_audio_file_group.""" + + def test_duration_from_60s_blob(self): + """60s of PCM16 at 8kHz = 960,000 bytes → duration should be ~60.0s.""" + chunks = [{'timestamp': 1000.0, 'size': 960_000}] + duration = _finalize_audio_file_group_duration(chunks) + assert abs(duration - 60.0) < 0.01 + + def test_duration_from_5s_blob(self): + """5s of PCM16 at 8kHz = 80,000 bytes → duration should be ~5.0s.""" + chunks = [{'timestamp': 1000.0, 'size': 80_000}] + duration = _finalize_audio_file_group_duration(chunks) + assert abs(duration - 5.0) < 0.01 + + def test_duration_fallback_no_size(self): + """When size is 0 or missing, falls back to 5.0s.""" + chunks = [{'timestamp': 1000.0, 'size': 0}] + duration = _finalize_audio_file_group_duration(chunks) + assert abs(duration - 5.0) < 0.01 + + chunks_no_key = [{'timestamp': 1000.0}] + duration2 = _finalize_audio_file_group_duration(chunks_no_key) + assert abs(duration2 - 5.0) < 0.01 + + def test_multi_chunk_duration(self): + """Duration across multiple 60s chunks: first→last gap + last chunk duration.""" + chunks = [ + {'timestamp': 1000.0, 'size': 960_000}, + {'timestamp': 1060.0, 'size': 960_000}, + {'timestamp': 1120.0, 'size': 960_000}, + ] + # Expected: (1120 - 1000) + 60.0 = 180.0 + duration = _finalize_audio_file_group_duration(chunks) + assert abs(duration - 180.0) < 0.01 + + +class TestConstants: + """Tests that batch constants are set correctly.""" + + def test_chunk_duration_is_60s(self): + """Chunk duration is 60s for batch upload.""" + assert PRIVATE_CLOUD_CHUNK_DURATION == 60.0 + + def test_batch_max_age_is_60s(self): + """Batch max age is 60s.""" + assert PRIVATE_CLOUD_BATCH_MAX_AGE == 60.0 diff --git a/backend/tests/unit/test_storage_opus_encoding.py b/backend/tests/unit/test_storage_opus_encoding.py new file mode 100644 index 00000000000..5d5b6f43de5 --- /dev/null +++ b/backend/tests/unit/test_storage_opus_encoding.py @@ -0,0 +1,617 @@ +"""Unit tests for Opus encoding/decoding in private cloud sync. + +Verifies: +- PCM→Opus→PCM roundtrip produces same-length output +- Compression ratio is significant (>5x for 5s chunks) +- Feature flag controls whether Opus encoding is used +- Extension handling for .opus, .opus.enc, .bin, .enc +- Timestamp parsing works for double-extension filenames +- Upload produces correct extensions when Opus is enabled +- Download decodes Opus back to PCM +""" + +import os +import struct +import sys +from unittest.mock import MagicMock, patch + +import pytest + +os.environ.setdefault("ENCRYPTION_SECRET", "omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv") + +# Mock heavy dependencies at sys.modules level before importing storage +sys.modules.setdefault("database._client", MagicMock()) + +_mock_gcs_storage = MagicMock() +_mock_gcs_client_instance = MagicMock() +_mock_gcs_storage.Client.return_value = _mock_gcs_client_instance +sys.modules.setdefault("google.cloud.storage", _mock_gcs_storage) +sys.modules.setdefault("google.cloud.storage.transfer_manager", MagicMock()) +sys.modules.setdefault("google.cloud.exceptions", MagicMock()) +sys.modules.setdefault("google.oauth2", MagicMock()) +sys.modules.setdefault("google.oauth2.service_account", MagicMock()) + +from utils.other import storage as storage_mod + + +class TestOpusEncodeDecode: + """Tests for encode_pcm_to_opus and decode_opus_to_pcm.""" + + def test_roundtrip_preserves_length(self): + """Encode→decode produces same number of bytes as input.""" + # 5 seconds of PCM16 at 16kHz mono = 160000 bytes + pcm_data = b'\x00' * 160000 + opus_data = storage_mod.encode_pcm_to_opus(pcm_data) + decoded = storage_mod.decode_opus_to_pcm(opus_data) + assert len(decoded) == len(pcm_data) + + def test_compression_ratio(self): + """Opus should achieve at least 5x compression on 5s PCM chunks.""" + # Silence compresses very well; real audio ~10-12x + pcm_data = b'\x00' * 160000 + opus_data = storage_mod.encode_pcm_to_opus(pcm_data) + ratio = len(pcm_data) / len(opus_data) + assert ratio > 5.0, f"Compression ratio {ratio:.1f}x is below 5x minimum" + + def test_small_input_padded(self): + """Input smaller than one frame is padded but trimmed to original length on decode.""" + # 100 bytes = less than one 20ms frame (640 bytes) + pcm_data = b'\x80' * 100 + opus_data = storage_mod.encode_pcm_to_opus(pcm_data) + decoded = storage_mod.decode_opus_to_pcm(opus_data) + # Decoded length equals original input (trimmed from padded frame) + assert len(decoded) == len(pcm_data) + + def test_exact_frame_boundary(self): + """Input exactly on frame boundary has no padding.""" + frame_bytes = storage_mod.OPUS_FRAME_SIZE * storage_mod.OPUS_CHANNELS * 2 # 640 + pcm_data = b'\x00' * (frame_bytes * 10) # exactly 10 frames + opus_data = storage_mod.encode_pcm_to_opus(pcm_data) + decoded = storage_mod.decode_opus_to_pcm(opus_data) + assert len(decoded) == len(pcm_data) + + def test_packet_count_header(self): + """Opus output starts with correct packet count and original PCM length.""" + frame_bytes = storage_mod.OPUS_FRAME_SIZE * storage_mod.OPUS_CHANNELS * 2 + pcm_data = b'\x00' * (frame_bytes * 5) # 5 frames + opus_data = storage_mod.encode_pcm_to_opus(pcm_data) + packet_count = struct.unpack_from(' bytes. + Missing extensions raise _FakeNotFound.""" + + def factory(path): + mock_blob = MagicMock() + for ext, data in ext_data_map.items(): + if path.endswith(f'.{ext}'): + mock_blob.download_as_bytes.return_value = data + return mock_blob + mock_blob.download_as_bytes.side_effect = _FakeNotFound('not found') + return mock_blob + + return factory + + @patch.object(storage_mod, 'NotFound', _FakeNotFound) + @patch.object(storage_mod, 'encryption') + def test_fallback_opus_corrupt_to_legacy_bin(self, mock_encryption): + """When .opus.enc exists but decrypt fails, falls back to .bin.""" + mock_bucket = MagicMock() + pcm_data = b'\x00' * 640 + + mock_bucket.blob.side_effect = self._blob_factory( + { + 'opus.enc': b'corrupt-opus-data', + 'bin': pcm_data, + } + ) + storage_mod.storage_client.bucket.return_value = mock_bucket + mock_encryption.decrypt_audio_file.side_effect = Exception("decrypt failed") + + result = storage_mod.download_audio_chunks_and_merge('uid', 'conv', [1000.0], fill_gaps=False) + assert result == pcm_data + + @patch.object(storage_mod, 'NotFound', _FakeNotFound) + def test_fallback_all_not_found_raises(self): + """When no extension exists for a timestamp, raises FileNotFoundError.""" + mock_bucket = MagicMock() + mock_bucket.blob.side_effect = self._blob_factory({}) # nothing available + storage_mod.storage_client.bucket.return_value = mock_bucket + + with pytest.raises(FileNotFoundError): + storage_mod.download_audio_chunks_and_merge('uid', 'conv', [1000.0], fill_gaps=False) + + @patch.object(storage_mod, 'NotFound', _FakeNotFound) + def test_opus_decode_success_no_fallback(self): + """When .opus chunk is valid, uses it without trying .bin.""" + mock_bucket = MagicMock() + pcm_data = b'\x00' * 640 + opus_data = storage_mod.encode_pcm_to_opus(pcm_data) + + call_log = [] + original_factory = self._blob_factory({'opus': opus_data}) + + def tracking_factory(path): + call_log.append(path) + return original_factory(path) + + mock_bucket.blob.side_effect = tracking_factory + storage_mod.storage_client.bucket.return_value = mock_bucket + + result = storage_mod.download_audio_chunks_and_merge('uid', 'conv', [1000.0], fill_gaps=False) + assert len(result) == len(pcm_data) + # Should NOT have tried .bin after .opus succeeded + assert not any(p.endswith('.bin') for p in call_log) + + @patch.object(storage_mod, 'NotFound', _FakeNotFound) + def test_fallback_opus_decode_error_to_bin(self): + """When .opus data is malformed (decode raises), falls back to .bin.""" + mock_bucket = MagicMock() + pcm_data = b'\x00' * 640 + bad_opus = b'\x01\x00\x00\x00\x80\x02\x00\x00\xff\xff' # 1 pkt, pcm_len=640, bad pkt_len + + mock_bucket.blob.side_effect = self._blob_factory( + { + 'opus': bad_opus, + 'bin': pcm_data, + } + ) + storage_mod.storage_client.bucket.return_value = mock_bucket + + result = storage_mod.download_audio_chunks_and_merge('uid', 'conv', [1000.0], fill_gaps=False) + assert result == pcm_data + + +class TestBatchExtensionHelpers: + """Tests for batch extension support in helpers and PRIVATE_CLOUD_EXTENSIONS.""" + + def test_private_cloud_extensions_includes_batch(self): + """PRIVATE_CLOUD_EXTENSIONS includes .batch.bin and .batch.enc.""" + assert '.batch.bin' in storage_mod.PRIVATE_CLOUD_EXTENSIONS + assert '.batch.enc' in storage_mod.PRIVATE_CLOUD_EXTENSIONS + + @pytest.mark.parametrize( + "path,expected", + [ + ("chunks/uid/conv/1000.000-1010.000.batch.bin", "batch.bin"), + ("chunks/uid/conv/1000.000-1010.000.batch.enc", "batch.enc"), + ("chunks/uid/conv/1000.000.batch.bin", "batch.bin"), + ], + ) + def test_get_extension_for_batch_path(self, path, expected): + assert storage_mod._get_extension_for_path(path) == expected + + @pytest.mark.parametrize( + "filename,expected", + [ + ("1000.000-1010.000.batch.bin", "1000.000-1010.000"), + ("1000.000-1010.000.batch.enc", "1000.000-1010.000"), + ("1000.000.batch.bin", "1000.000"), + ("1000.000.batch.enc", "1000.000"), + ], + ) + def test_strip_batch_extension(self, filename, expected): + assert storage_mod._strip_extension(filename) == expected + + +class TestListAudioChunksBatch: + """Tests for list_audio_chunks with batch blobs.""" + + def _make_mock_blob(self, name, size=1000): + blob = MagicMock() + blob.name = name + blob.size = size + return blob + + def test_lists_batch_bin_blobs(self): + """list_audio_chunks recognizes .batch.bin with range timestamp.""" + mock_bucket = MagicMock() + mock_bucket.list_blobs.return_value = [ + self._make_mock_blob('chunks/uid/conv/1000.000-1010.000.batch.bin', 480000), + ] + storage_mod.storage_client.bucket.return_value = mock_bucket + + chunks = storage_mod.list_audio_chunks('uid', 'conv') + + assert len(chunks) == 1 + assert chunks[0]['timestamp'] == 1000.0 + assert chunks[0]['is_batch'] is True + assert chunks[0]['path'] == 'chunks/uid/conv/1000.000-1010.000.batch.bin' + + def test_lists_batch_enc_blobs(self): + """list_audio_chunks recognizes .batch.enc with range timestamp.""" + mock_bucket = MagicMock() + mock_bucket.list_blobs.return_value = [ + self._make_mock_blob('chunks/uid/conv/1000.000-1010.000.batch.enc', 500000), + ] + storage_mod.storage_client.bucket.return_value = mock_bucket + + chunks = storage_mod.list_audio_chunks('uid', 'conv') + + assert len(chunks) == 1 + assert chunks[0]['timestamp'] == 1000.0 + assert chunks[0]['is_batch'] is True + + def test_single_timestamp_batch(self): + """Batch blob with single timestamp (short conversation).""" + mock_bucket = MagicMock() + mock_bucket.list_blobs.return_value = [ + self._make_mock_blob('chunks/uid/conv/1000.000.batch.bin', 160000), + ] + storage_mod.storage_client.bucket.return_value = mock_bucket + + chunks = storage_mod.list_audio_chunks('uid', 'conv') + + assert len(chunks) == 1 + assert chunks[0]['timestamp'] == 1000.0 + assert chunks[0]['is_batch'] is True + + def test_mixed_single_and_batch_blobs(self): + """Conversation with both single-chunk and batch blobs (migration period).""" + mock_bucket = MagicMock() + mock_bucket.list_blobs.return_value = [ + self._make_mock_blob('chunks/uid/conv/1000.000.opus', 8000), + self._make_mock_blob('chunks/uid/conv/1005.000.opus', 8000), + self._make_mock_blob('chunks/uid/conv/1010.000-1025.000.batch.bin', 480000), + ] + storage_mod.storage_client.bucket.return_value = mock_bucket + + chunks = storage_mod.list_audio_chunks('uid', 'conv') + + assert len(chunks) == 3 + assert chunks[0]['is_batch'] is False + assert chunks[1]['is_batch'] is False + assert chunks[2]['is_batch'] is True + assert chunks[2]['timestamp'] == 1010.0 + + def test_is_batch_false_for_single_blobs(self): + """Single-chunk blobs have is_batch=False.""" + mock_bucket = MagicMock() + mock_bucket.list_blobs.return_value = [ + self._make_mock_blob('chunks/uid/conv/1000.000.opus.enc', 8000), + ] + storage_mod.storage_client.bucket.return_value = mock_bucket + + chunks = storage_mod.list_audio_chunks('uid', 'conv') + + assert len(chunks) == 1 + assert chunks[0]['is_batch'] is False + + +class TestDeleteAudioChunksBatch: + """Tests for delete_audio_chunks with batch blobs.""" + + def test_deletes_single_timestamp_batch(self): + """Finds and deletes batch blob with single timestamp.""" + mock_bucket = MagicMock() + blob_map = {} + + def blob_factory(path): + if path not in blob_map: + b = MagicMock() + b.name = path + b.exists.return_value = path.endswith('.batch.bin') + blob_map[path] = b + return blob_map[path] + + mock_bucket.blob.side_effect = blob_factory + mock_bucket.list_blobs.return_value = [] + storage_mod.storage_client.bucket.return_value = mock_bucket + + storage_mod.delete_audio_chunks('uid', 'conv', [1000.0]) + + batch_path = 'chunks/uid/conv/1000.000.batch.bin' + assert batch_path in blob_map + blob_map[batch_path].delete.assert_called_once() + + def test_deletes_range_named_batch_via_scan(self): + """Finds and deletes range-named batch blob by scanning.""" + mock_bucket = MagicMock() + + single_blob = MagicMock() + single_blob.exists.return_value = False + mock_bucket.blob.return_value = single_blob + + batch_blob = MagicMock() + batch_blob.name = 'chunks/uid/conv/1000.000-1010.000.batch.bin' + mock_bucket.list_blobs.return_value = [batch_blob] + storage_mod.storage_client.bucket.return_value = mock_bucket + + storage_mod.delete_audio_chunks('uid', 'conv', [1000.0, 1005.0, 1010.0]) + + batch_blob.delete.assert_called_once() + + +class TestDownloadBatchBlobs: + """Tests for download_audio_chunks_and_merge with batch blobs.""" + + def _make_mock_blob(self, name, size=1000): + blob = MagicMock() + blob.name = name + blob.size = size + return blob + + @patch.object(storage_mod, 'NotFound', type('FakeNotFound', (Exception,), {})) + def test_downloads_batch_blob_once(self): + """Batch blob covering multiple timestamps is downloaded once.""" + mock_bucket = MagicMock() + pcm_data = b'\x00' * 480000 + + batch_blob_listed = self._make_mock_blob('chunks/uid/conv/1000.000-1010.000.batch.bin', 480000) + mock_bucket.list_blobs.return_value = [batch_blob_listed] + + download_calls = [] + + def blob_factory(path): + b = MagicMock() + download_calls.append(path) + if path == 'chunks/uid/conv/1000.000-1010.000.batch.bin': + b.download_as_bytes.return_value = pcm_data + else: + b.download_as_bytes.side_effect = storage_mod.NotFound('not found') + return b + + mock_bucket.blob.side_effect = blob_factory + storage_mod.storage_client.bucket.return_value = mock_bucket + + result = storage_mod.download_audio_chunks_and_merge('uid', 'conv', [1000.0, 1005.0, 1010.0], fill_gaps=False) + + assert result == pcm_data + batch_downloads = [p for p in download_calls if 'batch' in p] + assert len(batch_downloads) == 1 + + @patch.object(storage_mod, 'NotFound', type('FakeNotFound', (Exception,), {})) + def test_mixed_single_and_batch_download(self): + """Mix of single-chunk and batch blobs downloads correctly.""" + mock_bucket = MagicMock() + single_pcm = b'\x01' * 160000 + batch_pcm = b'\x02' * 320000 + + mock_bucket.list_blobs.return_value = [ + self._make_mock_blob('chunks/uid/conv/1000.000.opus', 8000), + self._make_mock_blob('chunks/uid/conv/1005.000-1015.000.batch.bin', 320000), + ] + + opus_encoded_single = storage_mod.encode_pcm_to_opus(single_pcm) + + def blob_factory(path): + b = MagicMock() + if path == 'chunks/uid/conv/1005.000-1015.000.batch.bin': + b.download_as_bytes.return_value = batch_pcm + elif path.endswith('.opus') and '1000.000' in path: + b.download_as_bytes.return_value = opus_encoded_single + else: + b.download_as_bytes.side_effect = storage_mod.NotFound('not found') + return b + + mock_bucket.blob.side_effect = blob_factory + storage_mod.storage_client.bucket.return_value = mock_bucket + + result = storage_mod.download_audio_chunks_and_merge( + 'uid', 'conv', [1000.0, 1005.0, 1010.0, 1015.0], fill_gaps=False + ) + + assert len(result) == len(single_pcm) + len(batch_pcm) diff --git a/backend/utils/conversations/merge_conversations.py b/backend/utils/conversations/merge_conversations.py index db0b9494c4e..a64a4cc0cc9 100644 --- a/backend/utils/conversations/merge_conversations.py +++ b/backend/utils/conversations/merge_conversations.py @@ -26,6 +26,7 @@ list_audio_chunks, storage_client, private_cloud_sync_bucket, + _get_extension_for_path, ) import logging @@ -332,13 +333,14 @@ def _copy_audio_chunks_for_merge( Audio chunks are stored in GCS at: chunks/{uid}/{conversation_id}/{timestamp}.bin (or .enc for encrypted) + chunks/{uid}/{conversation_id}/{first_ts}-{last_ts}.batch.bin (batch blobs) - The timestamps in chunk filenames are absolute Unix timestamps (when chunk was recorded). - We keep the original timestamps since they represent the actual recording time. + The filenames contain absolute Unix timestamps (when chunk was recorded). + We preserve original filenames to maintain both single-chunk and batch blob naming. Strategy: - Copy all chunks from all conversations to new conversation path - - Keep original timestamps (they're absolute, not relative) + - Preserve original filenames (handles single and batch blobs) - Create AudioFile records from the copied chunks Args: @@ -360,13 +362,10 @@ def _copy_audio_chunks_for_merge( chunks = list_audio_chunks(uid, conv_id) for chunk in chunks: has_chunks = True - original_ts = chunk['timestamp'] - # Determine extension from original path - ext = 'enc' if chunk['path'].endswith('.enc') else 'bin' - - # Copy to new path with same timestamp (it's absolute Unix time) - new_path = f'chunks/{uid}/{new_conversation_id}/{original_ts:.3f}.{ext}' + # Preserve original filename (handles both single and batch blob naming) + original_filename = chunk['path'].split('/')[-1] + new_path = f'chunks/{uid}/{new_conversation_id}/{original_filename}' source_blob = bucket.blob(chunk['path']) bucket.copy_blob(source_blob, bucket, new_path) diff --git a/backend/utils/other/storage.py b/backend/utils/other/storage.py index 1b74ff63a89..ca0aa673c87 100644 --- a/backend/utils/other/storage.py +++ b/backend/utils/other/storage.py @@ -2,11 +2,13 @@ import io import json import os +import struct import wave from typing import List from concurrent.futures import ThreadPoolExecutor, as_completed import threading +import opuslib from google.cloud import storage from google.oauth2 import service_account from google.cloud.storage import transfer_manager @@ -20,6 +22,15 @@ logger = logging.getLogger(__name__) +# Opus encoding constants +OPUS_SAMPLE_RATE = 16000 +OPUS_CHANNELS = 1 +OPUS_FRAME_DURATION_MS = 20 # 20ms frames (standard for voice) +OPUS_FRAME_SIZE = OPUS_SAMPLE_RATE * OPUS_FRAME_DURATION_MS // 1000 # 320 samples per frame + +# Valid private cloud sync extensions (longest first for correct matching) +PRIVATE_CLOUD_EXTENSIONS = ['.batch.enc', '.batch.bin', '.opus.enc', '.opus', '.enc', '.bin'] + if os.environ.get('SERVICE_ACCOUNT_JSON'): service_account_info = json.loads(os.environ["SERVICE_ACCOUNT_JSON"]) credentials = service_account.Credentials.from_service_account_info(service_account_info) @@ -309,6 +320,128 @@ def delete_syncing_temporal_file(file_path: str): # ************************************************ +def encode_pcm_to_opus(pcm_data: bytes, sample_rate: int = OPUS_SAMPLE_RATE, channels: int = OPUS_CHANNELS) -> bytes: + """ + Encode PCM16 audio to Opus. + + Format: 4-byte little-endian packet count, then for each packet: + 2-byte little-endian length prefix followed by the Opus packet bytes. + This allows exact reconstruction on decode. + + Args: + pcm_data: Raw PCM16 audio bytes + sample_rate: Sample rate in Hz (default 16000) + channels: Number of audio channels (default 1) + + Returns: + Length-prefixed Opus packets as bytes + """ + encoder = opuslib.Encoder(sample_rate, channels, opuslib.APPLICATION_VOIP) + frame_size = sample_rate * OPUS_FRAME_DURATION_MS // 1000 + bytes_per_frame = frame_size * channels * 2 # 16-bit = 2 bytes per sample + + packets = [] + offset = 0 + while offset + bytes_per_frame <= len(pcm_data): + frame = pcm_data[offset : offset + bytes_per_frame] + encoded = encoder.encode(frame, frame_size) + packets.append(encoded) + offset += bytes_per_frame + + # Encode remaining samples (pad with silence) + if offset < len(pcm_data): + remaining = pcm_data[offset:] + padded = remaining + b'\x00' * (bytes_per_frame - len(remaining)) + encoded = encoder.encode(padded, frame_size) + packets.append(encoded) + + # Pack: [packet_count (4 bytes)] + [original_pcm_len (4 bytes)] + [len (2 bytes) + data] per packet + output = struct.pack(' bytes: + """ + Decode length-prefixed Opus packets back to PCM16. + + Args: + opus_data: Length-prefixed Opus packets (from encode_pcm_to_opus) + sample_rate: Sample rate in Hz (default 16000) + channels: Number of audio channels (default 1) + + Returns: + Raw PCM16 audio bytes + + Raises: + ValueError: If opus_data is too short or has invalid header/packet structure + """ + if len(opus_data) < 8: + raise ValueError(f"Opus data too short: {len(opus_data)} bytes (need at least 8 for header)") + + decoder = opuslib.Decoder(sample_rate, channels) + frame_size = sample_rate * OPUS_FRAME_DURATION_MS // 1000 + + offset = 0 + packet_count = struct.unpack_from(' len(opus_data): + raise ValueError(f"Truncated Opus data: expected packet {i}/{packet_count} length at offset {offset}") + pkt_len = struct.unpack_from(' len(opus_data): + raise ValueError( + f"Truncated Opus data: packet {i} needs {pkt_len} bytes at offset {offset}, only {len(opus_data) - offset} available" + ) + pkt_data = opus_data[offset : offset + pkt_len] + offset += pkt_len + decoded = decoder.decode(pkt_data, frame_size) + pcm_parts.append(decoded) + + result = b''.join(pcm_parts) + # Trim to original PCM length to remove padding from partial final frame + if original_pcm_len > 0 and original_pcm_len < len(result): + result = result[:original_pcm_len] + return result + + +def _get_extension_for_path(path: str) -> str: + """Extract the private cloud sync extension from a GCS path.""" + if path.endswith('.batch.enc'): + return 'batch.enc' + elif path.endswith('.batch.bin'): + return 'batch.bin' + elif path.endswith('.opus.enc'): + return 'opus.enc' + elif path.endswith('.opus'): + return 'opus' + elif path.endswith('.enc'): + return 'enc' + elif path.endswith('.bin'): + return 'bin' + return 'bin' + + +def _strip_extension(filename: str) -> str: + """Strip private cloud sync extension to get the timestamp string. + + Handles both single-chunk filenames (e.g. '1000.000.opus') and + batch filenames (e.g. '1000.000-1010.000.batch.bin'). + """ + for ext in ('.batch.enc', '.batch.bin', '.opus.enc', '.opus', '.enc', '.bin'): + if filename.endswith(ext): + return filename[: -len(ext)] + return filename.rsplit('.', 1)[0] + + def upload_audio_chunk( chunk_data: bytes, uid: str, conversation_id: str, timestamp: float, data_protection_level: str = None ) -> str: @@ -334,34 +467,128 @@ def upload_audio_chunk( # Format timestamp to 3 decimal places for cleaner filenames formatted_timestamp = f'{timestamp:.3f}' + upload_data = encode_pcm_to_opus(chunk_data) + if protection_level == 'enhanced': - # Encrypt as length-prefixed binary - encrypted_chunk = encryption.encrypt_audio_chunk(chunk_data, uid) - path = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}.enc' + encrypted_chunk = encryption.encrypt_audio_chunk(upload_data, uid) + path = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}.opus.enc' blob = bucket.blob(path) blob.upload_from_string(encrypted_chunk, content_type='application/octet-stream') else: - # Standard - no encryption - path = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}.bin' + path = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}.opus' blob = bucket.blob(path) - blob.upload_from_string(chunk_data, content_type='application/octet-stream') + blob.upload_from_string(upload_data, content_type='application/octet-stream') + del upload_data return path +def upload_audio_chunks_batch( + chunks: List[dict], + uid: str, + conversation_id: str, + data_protection_level: str = None, +) -> List[str]: + """ + Upload multiple audio chunks to GCS in a single streaming write. + + Concatenates all chunk data into one GCS object (1 write op instead of N). + + Args: + chunks: List of dicts with 'data' (bytes) and 'timestamp' (float). + uid: User ID. + conversation_id: Conversation ID. + data_protection_level: Optional cached protection level. When provided, + skips the Firestore read. Falls back to DB read when None. + + Returns: + List of GCS paths for the uploaded batch. + """ + if not chunks: + return [] + + # Sort by timestamp for consistent ordering + sorted_chunks = sorted(chunks, key=lambda c: c['timestamp']) + + # Resolve protection level once for the entire batch + protection_level = ( + data_protection_level if data_protection_level is not None else users_db.get_data_protection_level(uid) + ) + + bucket = storage_client.bucket(private_cloud_sync_bucket) + + # Build batch filename from first and last timestamps + first_ts = f'{sorted_chunks[0]["timestamp"]:.3f}' + last_ts = f'{sorted_chunks[-1]["timestamp"]:.3f}' + batch_name = f'{first_ts}-{last_ts}' if len(sorted_chunks) > 1 else first_ts + + if protection_level == 'enhanced': + # Encrypt each chunk individually (length-prefixed), stream to GCS + path = f'chunks/{uid}/{conversation_id}/{batch_name}.batch.enc' + blob = bucket.blob(path) + with blob.open('wb', content_type='application/octet-stream') as f: + for chunk in sorted_chunks: + encrypted_chunk = encryption.encrypt_audio_chunk(chunk['data'], uid) + f.write(encrypted_chunk) + del encrypted_chunk + else: + # Standard — stream raw PCM data to GCS + path = f'chunks/{uid}/{conversation_id}/{batch_name}.batch.bin' + blob = bucket.blob(path) + with blob.open('wb', content_type='application/octet-stream') as f: + for chunk in sorted_chunks: + f.write(chunk['data']) + + return [path] + + def delete_audio_chunks(uid: str, conversation_id: str, timestamps: List[float]) -> None: - """Delete audio chunks after they've been merged.""" + """Delete audio chunks after they've been merged. + + Handles both single-chunk blobs (per-timestamp lookup) and batch blobs + (listed and matched by start timestamp). + """ bucket = storage_client.bucket(private_cloud_sync_bucket) + deleted_batch_paths = set() + for timestamp in timestamps: # Format timestamp to match upload format (3 decimal places) formatted_timestamp = f'{timestamp:.3f}' - # Try both encrypted and unencrypted paths - for extension in ['.enc', '.bin']: + + # Try single-chunk extensions first + for extension in PRIVATE_CLOUD_EXTENSIONS: + if extension in ('.batch.enc', '.batch.bin'): + continue # batch blobs handled separately below chunk_path = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}{extension}' blob = bucket.blob(chunk_path) if blob.exists(): blob.delete() + # Try batch blobs: exact single-timestamp batch (e.g. "1000.000.batch.bin") + for batch_ext in ('.batch.enc', '.batch.bin'): + batch_path = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}{batch_ext}' + if batch_path not in deleted_batch_paths: + blob = bucket.blob(batch_path) + if blob.exists(): + blob.delete() + deleted_batch_paths.add(batch_path) + + # Scan for range-named batch blobs whose start timestamp matches any requested timestamp + ts_set = {f'{ts:.3f}' for ts in timestamps} + prefix = f'chunks/{uid}/{conversation_id}/' + for blob in bucket.list_blobs(prefix=prefix): + if blob.name in deleted_batch_paths: + continue + filename = blob.name.split('/')[-1] + if '.batch.' not in filename: + continue + timestamp_str = _strip_extension(filename) + if '-' in timestamp_str: + start_ts = timestamp_str.split('-', 1)[0] + if start_ts in ts_set: + blob.delete() + deleted_batch_paths.add(blob.name) + def list_audio_chunks(uid: str, conversation_id: str) -> List[dict]: """ @@ -376,18 +603,29 @@ def list_audio_chunks(uid: str, conversation_id: str) -> List[dict]: chunks = [] for blob in blobs: - # Extract timestamp from filename (e.g., '1234567890.123.bin' or '1234567890.123.enc') + # Extract timestamp from filename + # Supports single-chunk: '1234567890.123.opus', '1234567890.123.opus.enc', etc. + # Supports batch: '1234567890.123-1234567900.123.batch.bin', '1234567890.123.batch.enc' filename = blob.name.split('/')[-1] - if filename.endswith('.bin') or filename.endswith('.enc'): + has_valid_ext = any(filename.endswith(ext) for ext in PRIVATE_CLOUD_EXTENSIONS) + if has_valid_ext: try: - # Remove extension (.bin or .enc) - timestamp_str = filename.rsplit('.', 1)[0] - timestamp = float(timestamp_str) + timestamp_str = _strip_extension(filename) + is_batch = '.batch.' in filename + + if is_batch and '-' in timestamp_str: + # Batch blob with timestamp range: "first_ts-last_ts" + first_ts_str, last_ts_str = timestamp_str.split('-', 1) + timestamp = float(first_ts_str) + else: + timestamp = float(timestamp_str) + chunks.append( { 'timestamp': timestamp, 'path': blob.name, 'size': blob.size, + 'is_batch': is_batch, } ) except ValueError: @@ -422,6 +660,7 @@ def download_audio_chunks_and_merge( Download and merge audio chunks on-demand, handling mixed encryption states. Downloads chunks in parallel. Normalizes all chunks to unencrypted PCM format for consistent merging. + Supports both single-chunk blobs and batch blobs (from upload_audio_chunks_batch). Args: uid: User ID @@ -437,47 +676,138 @@ def download_audio_chunks_and_merge( bucket = storage_client.bucket(private_cloud_sync_bucket) + # Resolve actual GCS paths — needed to find batch blobs whose filenames + # contain timestamp ranges instead of single timestamps + actual_chunks = list_audio_chunks(uid, conversation_id) + ts_set = {round(ts, 3) for ts in timestamps} + + # Build batch blob map: for batch blobs, track which timestamps they cover + batch_paths = {} # path -> chunk_info (deduplicate downloads) + ts_to_batch_path = {} # timestamp -> batch_path (for timestamps inside batch range) + single_chunk_timestamps = [] # timestamps that have individual blobs + + for chunk in actual_chunks: + if chunk.get('is_batch'): + path = chunk['path'] + batch_paths[path] = chunk + + # Parse batch range to determine covered timestamps + filename = path.split('/')[-1] + ts_str = _strip_extension(filename) + if '-' in ts_str: + start_str, end_str = ts_str.split('-', 1) + batch_start = float(start_str) + batch_end = float(end_str) + else: + batch_start = batch_end = float(ts_str) + + # Map requested timestamps that fall within this batch's range + for ts in timestamps: + if batch_start <= round(ts, 3) <= batch_end: + ts_to_batch_path[round(ts, 3)] = path + elif round(chunk['timestamp'], 3) in ts_set: + single_chunk_timestamps.append(chunk['timestamp']) + + def _download_and_decode_blob(path: str) -> bytes | None: + """Download a blob and decode/decrypt based on extension.""" + ext = _get_extension_for_path(path) + encrypted = ext in ('opus.enc', 'enc', 'batch.enc') + is_opus = ext in ('opus.enc', 'opus') + + try: + chunk_data = bucket.blob(path).download_as_bytes() + except NotFound: + return None + + try: + if encrypted: + raw_data = encryption.decrypt_audio_file(chunk_data, uid) + else: + raw_data = chunk_data + + if is_opus: + pcm_data = decode_opus_to_pcm(raw_data, sample_rate=sample_rate) + del raw_data + else: + pcm_data = raw_data + + return pcm_data + except Exception as e: + logger.warning(f"Failed to decode/decrypt {path}: {e}") + return None + def download_single_chunk(timestamp: float) -> tuple[float, bytes | None]: - """Download a single chunk and return (timestamp, pcm_data).""" + """Download a single-chunk blob by trying extensions in priority order.""" formatted_timestamp = f'{timestamp:.3f}' - chunk_path_enc = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}.enc' - chunk_path_bin = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}.bin' - chunk_data = None - is_encrypted = False + extensions_to_try = [ + ('opus.enc', True, True), # (ext, encrypted, opus) + ('enc', True, False), + ('opus', False, True), + ('bin', False, False), + ] - # Try encrypted first, then unencrypted - try: - chunk_data = bucket.blob(chunk_path_enc).download_as_bytes() - is_encrypted = True - except NotFound: + for ext, encrypted, opus in extensions_to_try: + chunk_path = f'chunks/{uid}/{conversation_id}/{formatted_timestamp}.{ext}' try: - chunk_data = bucket.blob(chunk_path_bin).download_as_bytes() - is_encrypted = False + chunk_data = bucket.blob(chunk_path).download_as_bytes() except NotFound: - logger.warning(f"Warning: Chunk not found for timestamp {formatted_timestamp}") - return (timestamp, None) + continue + + try: + if encrypted: + raw_data = encryption.decrypt_audio_file(chunk_data, uid) + else: + raw_data = chunk_data - # Normalize to PCM (decrypt if needed - if is_encrypted: - pcm_data = encryption.decrypt_audio_file(chunk_data, uid) - else: - pcm_data = chunk_data + if opus: + pcm_data = decode_opus_to_pcm(raw_data, sample_rate=sample_rate) + del raw_data + else: + pcm_data = raw_data + + return (timestamp, pcm_data) + except Exception as e: + logger.warning( + f"Failed to decode/decrypt {ext} chunk at {formatted_timestamp}: {e}, trying next format" + ) + continue - return (timestamp, pcm_data) + logger.warning(f"Warning: Chunk not found for timestamp {formatted_timestamp}") + return (timestamp, None) - # Download chunks in parallel + # Download all data in parallel chunk_results = {} - max_workers = min(10, len(timestamps)) + + # Determine which timestamps need individual downloads vs batch downloads + individual_timestamps = [ts for ts in timestamps if round(ts, 3) not in ts_to_batch_path] + unique_batch_paths = set(ts_to_batch_path.values()) + + max_workers = min(10, len(individual_timestamps) + len(unique_batch_paths)) + if max_workers == 0: + max_workers = 1 with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_timestamp = {executor.submit(download_single_chunk, ts): ts for ts in timestamps} + # Submit individual chunk downloads + individual_futures = {executor.submit(download_single_chunk, ts): ts for ts in individual_timestamps} + + # Submit batch blob downloads (once per unique path) + batch_futures = {executor.submit(_download_and_decode_blob, path): path for path in unique_batch_paths} - for future in as_completed(future_to_timestamp): + # Collect individual results + for future in as_completed(individual_futures): timestamp, pcm_data = future.result() if pcm_data is not None: chunk_results[timestamp] = pcm_data + # Collect batch results — assign full batch data at the batch's start timestamp + for future in as_completed(batch_futures): + path = batch_futures[future] + pcm_data = future.result() + if pcm_data is not None: + batch_info = batch_paths[path] + chunk_results[batch_info['timestamp']] = pcm_data + # Merge chunks merged_data = bytearray()