Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions backend/database/redis_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,11 @@ def has_silent_user_notification_been_sent(uid: str) -> bool:
return r.exists(f'users:{uid}:silent_notification_sent')


def try_acquire_byok_llm_error_notification_lock(uid: str, provider: str, reason: str, ttl: int = 60 * 60 * 24) -> bool:
"""Return True once per BYOK provider/error reason per TTL window."""
return bool(r.set(f'users:{uid}:byok_llm_error:{provider}:{reason}', '1', ex=ttl, nx=True))


# ******************************************************
# ******* IMPORTANT CONVERSATION NOTIFICATIONS *********
# ******************************************************
Expand Down Expand Up @@ -636,8 +641,7 @@ def remove_conversation_summary_app_id(app_id: str) -> bool:
# Lua script: atomic increment + TTL in a single round-trip.
# Returns [current_count, ttl_remaining]. Sets TTL on first hit
# and self-heals any key that lost its TTL (prevents permanent buckets).
_RATE_LIMIT_LUA = r.register_script(
"""
_RATE_LIMIT_LUA = r.register_script("""
local key = KEYS[1]
local window = tonumber(ARGV[1])
local current = redis.call('INCR', key)
Expand All @@ -650,8 +654,7 @@ def remove_conversation_summary_app_id(app_id: str) -> bool:
ttl = window
end
return {current, ttl}
"""
)
""")


def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> tuple[bool, int, int]:
Expand Down Expand Up @@ -680,8 +683,7 @@ def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> t
# Burst uses a sorted set keyed by timestamp-ms for sliding-window accuracy,
# trimmed on every call (O(log n)). Daily char counter auto-expires at midnight
# UTC (caller passes seconds_until_midnight_utc as the TTL).
_TTS_RATE_LIMIT_LUA = r.register_script(
"""
_TTS_RATE_LIMIT_LUA = r.register_script("""
local burst_key = KEYS[1]
local daily_key = KEYS[2]
local now_ms = tonumber(ARGV[1])
Expand Down Expand Up @@ -709,8 +711,7 @@ def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> t
redis.call('EXPIRE', daily_key, daily_ttl)
end
return {0, 0}
"""
)
""")


def _seconds_until_midnight_utc() -> int:
Expand Down
3 changes: 2 additions & 1 deletion backend/routers/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
trigger_external_integrations,
)
from utils.conversations.location import async_get_google_maps_location
from utils.byok import set_byok_keys
from utils.byok import set_byok_keys, set_byok_uid
from utils.conversations.process_conversation import process_conversation
from utils.executors import storage_executor
from utils.webhooks import (
Expand Down Expand Up @@ -79,6 +79,7 @@ async def _process_conversation_task(
"""
if byok_keys:
set_byok_keys(byok_keys)
set_byok_uid(uid)
try:
conversation_data = conversations_db.get_conversation(uid, conversation_id)
if not conversation_data:
Expand Down
4 changes: 3 additions & 1 deletion backend/routers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)

from utils import encryption
from utils.byok import get_byok_keys, set_byok_keys
from utils.byok import get_byok_keys, set_byok_keys, set_byok_uid
from utils.log_sanitizer import sanitize
from utils.stt.pre_recorded import deepgram_prerecorded, get_deepgram_model_for_language, postprocess_words
from utils.stt.vad import vad_is_empty
Expand Down Expand Up @@ -1359,6 +1359,7 @@ def _run_full_pipeline_background(
Moved ALL heavy processing here so the v2 endpoint returns 202 immediately.
"""
set_byok_keys(byok_keys or {})
set_byok_uid(uid if byok_keys else None)
segmented_paths = set()
wav_paths = []
stage_timings = {}
Expand Down Expand Up @@ -1580,6 +1581,7 @@ def _process_one_segment(path):
pass
finally:
set_byok_keys({})
set_byok_uid(None)
_cleanup_files(list(segmented_paths))
_cleanup_files(wav_paths)
try:
Expand Down
94 changes: 94 additions & 0 deletions backend/tests/unit/test_byok_llm_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import sys
import types
from unittest.mock import MagicMock, patch

os.environ.setdefault('OPENAI_API_KEY', 'sk-test-fake-for-unit-tests')
os.environ.setdefault('ANTHROPIC_API_KEY', 'ant-test-fake-for-unit-tests')
os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv')

sys.modules.setdefault('database._client', MagicMock())
llm_usage_stub = types.ModuleType('database.llm_usage')
llm_usage_stub.record_llm_usage = MagicMock()
sys.modules.setdefault('database.llm_usage', llm_usage_stub)


class _HTTPError(Exception):
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code


def test_classify_byok_llm_error_authentication():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("bad api key", 401)) == 'invalid'


def test_classify_byok_llm_error_permission():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("project denied", 403)) == 'permission'


def test_classify_byok_llm_error_insufficient_quota():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("insufficient_quota", 429)) == 'quota'


def test_classify_byok_llm_error_ignores_transient_rate_limit():
from utils.llm.byok_errors import classify_byok_llm_error

assert classify_byok_llm_error(_HTTPError("rate limit reached, retry later", 429)) is None


@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user')
@patch('utils.llm.byok_errors._send_byok_llm_error_notification')
def test_handle_llm_error_logs_byok_source(mock_send_notification, mock_get_key, mock_get_uid):
from utils.llm.byok_errors import handle_llm_error

with patch('utils.llm.byok_errors.logger.error') as mock_log:
handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test')

log_args = mock_log.call_args.args
assert 'LLM error source=%s' in log_args[0]
assert log_args[1] == 'byok'
assert log_args[2] == 'openai'
assert log_args[8] == 'quota'
mock_send_notification.assert_called_once_with('user-1', 'openai', 'quota')


@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value=None)
def test_handle_llm_error_logs_platform_source(mock_get_key, mock_get_uid):
from utils.llm.byok_errors import handle_llm_error

with patch('utils.llm.byok_errors.logger.error') as mock_log:
handle_llm_error(_HTTPError("server error", 500), 'openai', feature='memories', model='gpt-test')

assert mock_log.call_args.args[1] == 'platform'
assert mock_log.call_args.args[8] == 'unknown'


def test_validate_byok_request_records_current_uid():
from utils.byok import get_byok_uid, validate_byok_request

with patch('utils.byok._check_byok_validity', return_value=None):
validate_byok_request('user-1')

assert get_byok_uid() == 'user-1'


def test_llm_error_callback_uses_provider_context():
from utils.llm.clients import _LLMErrorCallback

callback = _LLMErrorCallback('openai', model='gpt-test', feature='memories')
error = _HTTPError('bad key', 401)

with patch('utils.llm.clients.handle_llm_error') as mock_handle:
callback.on_llm_error(error)

mock_handle.assert_called_once()
assert mock_handle.call_args.args[:2] == (error, 'openai')
100 changes: 100 additions & 0 deletions backend/tests/unit/test_byok_llm_notifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

os.environ.setdefault('OPENAI_API_KEY', 'sk-test-fake-for-unit-tests')
os.environ.setdefault('ANTHROPIC_API_KEY', 'ant-test-fake-for-unit-tests')
os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv')


class _HTTPError(Exception):
def __init__(self, message: str, status_code: int):
super().__init__(message)
self.status_code = status_code


@patch('utils.llm.byok_errors.messaging.send_each')
@patch('utils.llm.byok_errors.notification_db.get_all_tokens', return_value=['token-1'])
@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock', return_value=True)
@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user')
def test_handle_llm_error_notifies_actionable_byok_error(
mock_get_key,
mock_get_uid,
mock_lock,
mock_get_tokens,
mock_send_each,
):
from utils.llm.byok_errors import handle_llm_error

mock_send_each.return_value = SimpleNamespace(responses=[SimpleNamespace(success=True, exception=None)])

handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test')

mock_lock.assert_called_once_with('user-1', 'openai', 'quota')
mock_get_tokens.assert_called_once_with('user-1')
mock_send_each.assert_called_once()
message = mock_send_each.call_args.args[0][0]
assert message.data == {'type': 'byok_llm_error', 'provider': 'openai', 'reason': 'quota'}


@patch('utils.llm.byok_errors.messaging.send_each')
@patch('utils.llm.byok_errors.notification_db.get_all_tokens', return_value=['token-1'])
@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock', return_value=False)
@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user')
def test_handle_llm_error_deduplicates_recent_notification(
mock_get_key,
mock_get_uid,
mock_lock,
mock_get_tokens,
mock_send_each,
):
from utils.llm.byok_errors import handle_llm_error

handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test')

mock_lock.assert_called_once_with('user-1', 'openai', 'quota')
mock_send_each.assert_not_called()


@patch('utils.llm.byok_errors.messaging.send_each')
@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock')
@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value=None)
def test_handle_llm_error_does_not_notify_platform_error(
mock_get_key,
mock_get_uid,
mock_lock,
mock_send_each,
):
from utils.llm.byok_errors import handle_llm_error

handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test')

mock_lock.assert_not_called()
mock_send_each.assert_not_called()


@patch('utils.llm.byok_errors.notification_db.remove_bulk_tokens')
@patch('utils.llm.byok_errors.messaging.send_each')
@patch('utils.llm.byok_errors.notification_db.get_all_tokens', return_value=['bad-token'])
@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock', return_value=True)
@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1')
@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user')
def test_handle_llm_error_removes_permanent_bad_tokens(
mock_get_key,
mock_get_uid,
mock_lock,
mock_get_tokens,
mock_send_each,
mock_remove_tokens,
):
from utils.llm.byok_errors import handle_llm_error

fcm_error = SimpleNamespace(code='UNREGISTERED')
mock_send_each.return_value = SimpleNamespace(responses=[SimpleNamespace(success=False, exception=fcm_error)])

handle_llm_error(_HTTPError("bad key", 401), 'openai', feature='memories', model='gpt-test')

mock_remove_tokens.assert_called_once_with(['bad-token'])
16 changes: 16 additions & 0 deletions backend/utils/byok.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def invalidate_byok_state_cache(uid: str) -> None:
# Keys for the current request, if the client supplied them.
# Default is None (not {}) to avoid sharing a mutable object across contexts.
_byok_ctx: ContextVar[Optional[Dict[str, str]]] = ContextVar('byok_keys', default=None)
_byok_uid_ctx: ContextVar[Optional[str]] = ContextVar('byok_uid', default=None)


def get_byok_keys() -> Dict[str, str]:
Expand All @@ -87,6 +88,16 @@ def get_byok_key(provider: str) -> Optional[str]:
return keys.get(provider)


def get_byok_uid() -> Optional[str]:
"""Return the authenticated uid for the current request, when known."""
return _byok_uid_ctx.get()


def set_byok_uid(uid: Optional[str]) -> None:
"""Attach the authenticated uid to the current request context."""
_byok_uid_ctx.set(uid)


def has_byok_keys() -> bool:
"""True if the current request carries at least one BYOK header."""
keys = _byok_ctx.get()
Expand Down Expand Up @@ -127,10 +138,12 @@ async def dispatch(self, request: Request, call_next):
if value:
keys[provider] = value
token = _byok_ctx.set(keys)
uid_token = _byok_uid_ctx.set(None)
try:
return await call_next(request)
finally:
_byok_ctx.reset(token)
_byok_uid_ctx.reset(uid_token)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -203,6 +216,7 @@ def validate_byok_request(uid: str) -> None:
if error:
logger.warning('BYOK validation failed uid=%s: %s', uid, error)
raise HTTPException(status_code=403, detail=error)
set_byok_uid(uid)


def validate_byok_websocket(uid: str) -> Optional[str]:
Expand All @@ -215,4 +229,6 @@ def validate_byok_websocket(uid: str) -> Optional[str]:
error = _check_byok_validity(uid)
if error:
logger.warning('BYOK WS validation failed uid=%s: %s', uid, error)
else:
set_byok_uid(uid)
return error
Loading