Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions backend/routers/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import asyncio
import json
import time
from collections import deque
from datetime import datetime, timezone
from typing import List
from typing import List, Set

from fastapi import APIRouter
from fastapi.websockets import WebSocketDisconnect, WebSocket
Expand All @@ -27,7 +28,6 @@
get_audio_bytes_webhook_seconds,
)
from utils.other.storage import upload_audio_chunk
from utils.other.task import safe_create_task
from utils.speaker_identification import extract_speaker_samples

router = APIRouter()
Expand Down Expand Up @@ -130,22 +130,38 @@ async def _websocket_util_trigger(
has_audio_apps_enabled = is_audio_bytes_app_enabled(uid)
private_cloud_sync_enabled = users_db.get_user_private_cloud_sync_enabled(uid)

# Queue for pending speaker sample extraction requests
speaker_sample_queue: List[dict] = []
# Track background tasks to cancel on cleanup (prevents memory leaks from fire-and-forget tasks)
bg_tasks: Set[asyncio.Task] = set()

# Queue for pending private cloud sync chunks
private_cloud_queue: List[dict] = []
def spawn(coro) -> asyncio.Task:
"""Create a tracked background task that will be cancelled on cleanup."""
task = asyncio.create_task(coro)
bg_tasks.add(task)

def on_done(t):
bg_tasks.discard(t)
if t.cancelled():
return
exc = t.exception()
if exc:
print(f"Unhandled exception in background task: {exc}", uid)

task.add_done_callback(on_done)
return task

# Queue for pending transcript events (batched for realtime integrations + webhooks)
transcript_queue: List[dict] = []
# Bounded queues — prevent unbounded memory growth during backpressure
speaker_sample_queue: deque = deque(maxlen=SPEAKER_SAMPLE_QUEUE_WARN_SIZE)
transcript_queue: deque = deque(maxlen=TRANSCRIPT_QUEUE_WARN_SIZE)
audio_bytes_queue: deque = deque(maxlen=AUDIO_BYTES_QUEUE_WARN_SIZE)

# Queue for pending audio bytes triggers (batched for app integrations + webhooks)
audio_bytes_queue: List[dict] = []
# private_cloud_queue stays unbounded — it carries irreplaceable user audio.
# Silent drops (via deque maxlen) would cause permanent data loss.
private_cloud_queue: List[dict] = []
audio_bytes_event = asyncio.Event() # Signals when items are added for instant wake

async def process_private_cloud_queue():
"""Background task that processes private cloud sync uploads with retry logic."""
nonlocal websocket_active, private_cloud_queue
nonlocal websocket_active

while websocket_active or len(private_cloud_queue) > 0:
await asyncio.sleep(PRIVATE_CLOUD_SYNC_PROCESS_INTERVAL)
Expand All @@ -155,7 +171,7 @@ async def process_private_cloud_queue():

# Process all pending chunks
chunks_to_process = private_cloud_queue.copy()
private_cloud_queue = []
private_cloud_queue.clear()

successful_conversation_ids = set() # Track conversations with successful uploads

Expand Down Expand Up @@ -197,7 +213,7 @@ async def process_private_cloud_queue():

async def process_speaker_sample_queue():
"""Background task that processes speaker sample extraction requests."""
nonlocal websocket_active, speaker_sample_queue
nonlocal websocket_active

while websocket_active or len(speaker_sample_queue) > 0:
await asyncio.sleep(SPEAKER_SAMPLE_PROCESS_INTERVAL)
Expand All @@ -211,14 +227,15 @@ async def process_speaker_sample_queue():
ready_requests = []
pending_requests = []

for request in speaker_sample_queue:
for request in list(speaker_sample_queue):
if current_time - request['queued_at'] >= SPEAKER_SAMPLE_MIN_AGE:
ready_requests.append(request)
else:
pending_requests.append(request)

# Keep pending requests in queue
speaker_sample_queue = pending_requests
# Keep pending requests in queue (rebuild deque with pending only)
speaker_sample_queue.clear()
speaker_sample_queue.extend(pending_requests)

# Process ready requests (fire and forget)
for request in ready_requests:
Expand All @@ -239,7 +256,7 @@ async def process_speaker_sample_queue():

async def process_transcript_queue():
"""Batched consumer for transcript events (realtime integrations + webhooks)."""
nonlocal websocket_active, transcript_queue
nonlocal websocket_active

while websocket_active or len(transcript_queue) > 0:
await asyncio.sleep(TRANSCRIPT_QUEUE_FLUSH_INTERVAL)
Expand All @@ -248,8 +265,8 @@ async def process_transcript_queue():
continue

# Process batch
batch = transcript_queue.copy()
transcript_queue = []
batch = list(transcript_queue)
transcript_queue.clear()

for item in batch:
segments = item['segments']
Expand All @@ -262,7 +279,7 @@ async def process_transcript_queue():

async def process_audio_bytes_queue():
"""Event-driven consumer for audio bytes triggers (app integrations + webhooks)."""
nonlocal websocket_active, audio_bytes_queue
nonlocal websocket_active

while websocket_active or len(audio_bytes_queue) > 0:
# Wait for signal or check periodically for shutdown
Expand All @@ -277,8 +294,8 @@ async def process_audio_bytes_queue():
continue

# Process all queued items
batch = audio_bytes_queue.copy()
audio_bytes_queue = []
batch = list(audio_bytes_queue)
audio_bytes_queue.clear()

for item in batch:
try:
Expand Down Expand Up @@ -335,7 +352,7 @@ async def receive_tasks():
language = res.get('language', 'en')
if conversation_id:
print(f"Pusher received process_conversation request: {conversation_id}", uid)
safe_create_task(_process_conversation_task(uid, conversation_id, language, websocket))
spawn(_process_conversation_task(uid, conversation_id, language, websocket))
continue

# Speaker sample extraction request - queue for background processing
Expand Down Expand Up @@ -459,6 +476,15 @@ async def receive_tasks():
print(f"Error during WebSocket operation: {e}")
finally:
websocket_active = False

# Cancel all tracked background tasks to prevent memory leaks
tasks_to_cancel = list(bg_tasks)
for task in tasks_to_cancel:
task.cancel()
if tasks_to_cancel:
await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
bg_tasks.clear()

if websocket.client_state == WebSocketState.CONNECTED:
try:
await websocket.close(code=websocket_close_code)
Expand Down
15 changes: 15 additions & 0 deletions backend/testing/chaos-oom/Dockerfile.harness
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
FROM python:3.11-slim

WORKDIR /app

RUN pip install --no-cache-dir fastapi uvicorn[standard] websockets starlette

COPY mock_deps/ /app/mock_deps/
COPY harness_main.py /app/
# PUSHER_MODULE env var selects which pusher.py to use (set at docker run time)
COPY pusher_vuln.py /app/
COPY pusher_fixed.py /app/

EXPOSE 8080

CMD ["uvicorn", "harness_main:app", "--host", "0.0.0.0", "--port", "8080", "--log-level", "warning"]
129 changes: 129 additions & 0 deletions backend/testing/chaos-oom/harness_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
Chaos engineering harness — FastAPI app wrapping pusher.py with memory introspection.

Usage:
PUSHER_MODULE=pusher_vuln uvicorn harness_main:app --host 0.0.0.0 --port 8080
PUSHER_MODULE=pusher_fixed uvicorn harness_main:app --host 0.0.0.0 --port 8080
"""

import asyncio
import importlib
import os
import sys
import tracemalloc

# Add mock_deps to Python path so pusher.py's imports resolve to our mocks
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'mock_deps'))

# Start tracemalloc for memory attribution
tracemalloc.start(10)

# Improvement #8: Monkeypatch asyncio.to_thread to track thread pool backlog
_orig_to_thread = asyncio.to_thread

to_thread_metrics = {
'submitted': 0,
'completed': 0,
'in_flight': 0,
'max_in_flight': 0,
}

# Limit thread pool to make backlog obvious
_max_workers = int(os.environ.get('TO_THREAD_WORKERS', '2'))
from concurrent.futures import ThreadPoolExecutor

_executor = ThreadPoolExecutor(max_workers=_max_workers)
asyncio.get_event_loop_policy() # ensure loop policy exists


async def tracked_to_thread(func, /, *args, **kwargs):
"""Wrapper around asyncio.to_thread that tracks in-flight thread tasks."""
to_thread_metrics['submitted'] += 1
to_thread_metrics['in_flight'] += 1
if to_thread_metrics['in_flight'] > to_thread_metrics['max_in_flight']:
to_thread_metrics['max_in_flight'] = to_thread_metrics['in_flight']
try:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(_executor, lambda: func(*args, **kwargs))
return result
finally:
to_thread_metrics['in_flight'] -= 1
to_thread_metrics['completed'] += 1


asyncio.to_thread = tracked_to_thread

from fastapi import FastAPI

# Import the pusher module specified by environment variable
pusher_module_name = os.environ.get('PUSHER_MODULE', 'pusher_vuln')
pusher = importlib.import_module(pusher_module_name)

app = FastAPI()
app.include_router(pusher.router)

# Create temp dirs the original main.py creates
for path in ['_temp', '_samples', '_segments', '_speech_profiles']:
os.makedirs(path, exist_ok=True)


@app.get('/health')
def health_check():
return {"status": "healthy", "module": pusher_module_name}


@app.get('/debug/memory')
async def debug_memory():
"""Return current memory usage and top allocators for leak attribution."""
import resource
import gc

snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')

# RSS from OS
rss_bytes = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1024 # Linux returns KB

# Tracemalloc totals
current, peak = tracemalloc.get_traced_memory()

# Count asyncio tasks — async endpoint runs inside the event loop
try:
all_tasks = asyncio.all_tasks()
task_count = len(all_tasks)
except RuntimeError:
task_count = -1

# GC stats for allocator retention analysis
gc_stats = gc.get_stats()

# Improvement #2: safe_create_task metrics (vuln only — fixed uses spawn)
task_metrics = {}
try:
from utils.other.task import get_task_metrics

task_metrics = get_task_metrics()
except (ImportError, AttributeError):
pass

# Improvement #4: Per-leak debug metrics from pusher module
pusher_debug = getattr(pusher, 'debug_metrics', {})

# Improvement #8: Thread pool backlog metrics
thread_metrics = dict(to_thread_metrics)

return {
"rss_mb": round(rss_bytes / 1024 / 1024, 2),
"traced_current_mb": round(current / 1024 / 1024, 2),
"traced_peak_mb": round(peak / 1024 / 1024, 2),
"asyncio_tasks": task_count,
"gc_collections": [s.get('collections', 0) for s in gc_stats],
"top_allocations": [
{"file": str(stat.traceback), "size_kb": round(stat.size / 1024, 1), "count": stat.count}
for stat in top_stats[:15]
],
"module": pusher_module_name,
"safe_create_task_metrics": task_metrics,
"pusher_debug": pusher_debug,
"to_thread_metrics": thread_metrics,
}
Loading