diff --git a/backend/database/redis_db.py b/backend/database/redis_db.py index 00f76372ada..e33447a3b45 100644 --- a/backend/database/redis_db.py +++ b/backend/database/redis_db.py @@ -586,6 +586,11 @@ def has_silent_user_notification_been_sent(uid: str) -> bool: return r.exists(f'users:{uid}:silent_notification_sent') +def try_acquire_byok_llm_error_notification_lock(uid: str, provider: str, reason: str, ttl: int = 60 * 60 * 24) -> bool: + """Return True once per BYOK provider/error reason per TTL window.""" + return bool(r.set(f'users:{uid}:byok_llm_error:{provider}:{reason}', '1', ex=ttl, nx=True)) + + # ****************************************************** # ******* IMPORTANT CONVERSATION NOTIFICATIONS ********* # ****************************************************** @@ -636,8 +641,7 @@ def remove_conversation_summary_app_id(app_id: str) -> bool: # Lua script: atomic increment + TTL in a single round-trip. # Returns [current_count, ttl_remaining]. Sets TTL on first hit # and self-heals any key that lost its TTL (prevents permanent buckets). -_RATE_LIMIT_LUA = r.register_script( - """ +_RATE_LIMIT_LUA = r.register_script(""" local key = KEYS[1] local window = tonumber(ARGV[1]) local current = redis.call('INCR', key) @@ -650,8 +654,7 @@ def remove_conversation_summary_app_id(app_id: str) -> bool: ttl = window end return {current, ttl} -""" -) +""") def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> tuple[bool, int, int]: @@ -680,8 +683,7 @@ def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> t # Burst uses a sorted set keyed by timestamp-ms for sliding-window accuracy, # trimmed on every call (O(log n)). Daily char counter auto-expires at midnight # UTC (caller passes seconds_until_midnight_utc as the TTL). -_TTS_RATE_LIMIT_LUA = r.register_script( - """ +_TTS_RATE_LIMIT_LUA = r.register_script(""" local burst_key = KEYS[1] local daily_key = KEYS[2] local now_ms = tonumber(ARGV[1]) @@ -709,8 +711,7 @@ def check_rate_limit(key: str, policy: str, max_requests: int, window: int) -> t redis.call('EXPIRE', daily_key, daily_ttl) end return {0, 0} -""" -) +""") def _seconds_until_midnight_utc() -> int: diff --git a/backend/routers/pusher.py b/backend/routers/pusher.py index 2e4625076df..f1967ea3fc2 100644 --- a/backend/routers/pusher.py +++ b/backend/routers/pusher.py @@ -25,7 +25,7 @@ trigger_external_integrations, ) from utils.conversations.location import async_get_google_maps_location -from utils.byok import set_byok_keys +from utils.byok import set_byok_keys, set_byok_uid from utils.conversations.process_conversation import process_conversation from utils.executors import storage_executor from utils.webhooks import ( @@ -79,6 +79,7 @@ async def _process_conversation_task( """ if byok_keys: set_byok_keys(byok_keys) + set_byok_uid(uid) try: conversation_data = conversations_db.get_conversation(uid, conversation_id) if not conversation_data: diff --git a/backend/routers/sync.py b/backend/routers/sync.py index 671deba0881..ace4806c7f0 100644 --- a/backend/routers/sync.py +++ b/backend/routers/sync.py @@ -49,7 +49,7 @@ ) from utils import encryption -from utils.byok import get_byok_keys, set_byok_keys +from utils.byok import get_byok_keys, set_byok_keys, set_byok_uid from utils.log_sanitizer import sanitize from utils.stt.pre_recorded import deepgram_prerecorded, get_deepgram_model_for_language, postprocess_words from utils.stt.vad import vad_is_empty @@ -1359,6 +1359,7 @@ def _run_full_pipeline_background( Moved ALL heavy processing here so the v2 endpoint returns 202 immediately. """ set_byok_keys(byok_keys or {}) + set_byok_uid(uid if byok_keys else None) segmented_paths = set() wav_paths = [] stage_timings = {} @@ -1580,6 +1581,7 @@ def _process_one_segment(path): pass finally: set_byok_keys({}) + set_byok_uid(None) _cleanup_files(list(segmented_paths)) _cleanup_files(wav_paths) try: diff --git a/backend/tests/unit/test_byok_llm_logging.py b/backend/tests/unit/test_byok_llm_logging.py new file mode 100644 index 00000000000..b0961af6cee --- /dev/null +++ b/backend/tests/unit/test_byok_llm_logging.py @@ -0,0 +1,94 @@ +import os +import sys +import types +from unittest.mock import MagicMock, patch + +os.environ.setdefault('OPENAI_API_KEY', 'sk-test-fake-for-unit-tests') +os.environ.setdefault('ANTHROPIC_API_KEY', 'ant-test-fake-for-unit-tests') +os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv') + +sys.modules.setdefault('database._client', MagicMock()) +llm_usage_stub = types.ModuleType('database.llm_usage') +llm_usage_stub.record_llm_usage = MagicMock() +sys.modules.setdefault('database.llm_usage', llm_usage_stub) + + +class _HTTPError(Exception): + def __init__(self, message: str, status_code: int): + super().__init__(message) + self.status_code = status_code + + +def test_classify_byok_llm_error_authentication(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("bad api key", 401)) == 'invalid' + + +def test_classify_byok_llm_error_permission(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("project denied", 403)) == 'permission' + + +def test_classify_byok_llm_error_insufficient_quota(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("insufficient_quota", 429)) == 'quota' + + +def test_classify_byok_llm_error_ignores_transient_rate_limit(): + from utils.llm.byok_errors import classify_byok_llm_error + + assert classify_byok_llm_error(_HTTPError("rate limit reached, retry later", 429)) is None + + +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user') +@patch('utils.llm.byok_errors._send_byok_llm_error_notification') +def test_handle_llm_error_logs_byok_source(mock_send_notification, mock_get_key, mock_get_uid): + from utils.llm.byok_errors import handle_llm_error + + with patch('utils.llm.byok_errors.logger.error') as mock_log: + handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test') + + log_args = mock_log.call_args.args + assert 'LLM error source=%s' in log_args[0] + assert log_args[1] == 'byok' + assert log_args[2] == 'openai' + assert log_args[8] == 'quota' + mock_send_notification.assert_called_once_with('user-1', 'openai', 'quota') + + +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value=None) +def test_handle_llm_error_logs_platform_source(mock_get_key, mock_get_uid): + from utils.llm.byok_errors import handle_llm_error + + with patch('utils.llm.byok_errors.logger.error') as mock_log: + handle_llm_error(_HTTPError("server error", 500), 'openai', feature='memories', model='gpt-test') + + assert mock_log.call_args.args[1] == 'platform' + assert mock_log.call_args.args[8] == 'unknown' + + +def test_validate_byok_request_records_current_uid(): + from utils.byok import get_byok_uid, validate_byok_request + + with patch('utils.byok._check_byok_validity', return_value=None): + validate_byok_request('user-1') + + assert get_byok_uid() == 'user-1' + + +def test_llm_error_callback_uses_provider_context(): + from utils.llm.clients import _LLMErrorCallback + + callback = _LLMErrorCallback('openai', model='gpt-test', feature='memories') + error = _HTTPError('bad key', 401) + + with patch('utils.llm.clients.handle_llm_error') as mock_handle: + callback.on_llm_error(error) + + mock_handle.assert_called_once() + assert mock_handle.call_args.args[:2] == (error, 'openai') diff --git a/backend/tests/unit/test_byok_llm_notifications.py b/backend/tests/unit/test_byok_llm_notifications.py new file mode 100644 index 00000000000..1bf02c0b3a9 --- /dev/null +++ b/backend/tests/unit/test_byok_llm_notifications.py @@ -0,0 +1,100 @@ +import os +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +os.environ.setdefault('OPENAI_API_KEY', 'sk-test-fake-for-unit-tests') +os.environ.setdefault('ANTHROPIC_API_KEY', 'ant-test-fake-for-unit-tests') +os.environ.setdefault('ENCRYPTION_SECRET', 'omi_ZwB2ZNqB2HHpMK6wStk7sTpavJiPTFg7gXUHnc4tFABPU6pZ2c2DKgehtfgi4RZv') + + +class _HTTPError(Exception): + def __init__(self, message: str, status_code: int): + super().__init__(message) + self.status_code = status_code + + +@patch('utils.llm.byok_errors.messaging.send_each') +@patch('utils.llm.byok_errors.notification_db.get_all_tokens', return_value=['token-1']) +@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock', return_value=True) +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user') +def test_handle_llm_error_notifies_actionable_byok_error( + mock_get_key, + mock_get_uid, + mock_lock, + mock_get_tokens, + mock_send_each, +): + from utils.llm.byok_errors import handle_llm_error + + mock_send_each.return_value = SimpleNamespace(responses=[SimpleNamespace(success=True, exception=None)]) + + handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test') + + mock_lock.assert_called_once_with('user-1', 'openai', 'quota') + mock_get_tokens.assert_called_once_with('user-1') + mock_send_each.assert_called_once() + message = mock_send_each.call_args.args[0][0] + assert message.data == {'type': 'byok_llm_error', 'provider': 'openai', 'reason': 'quota'} + + +@patch('utils.llm.byok_errors.messaging.send_each') +@patch('utils.llm.byok_errors.notification_db.get_all_tokens', return_value=['token-1']) +@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock', return_value=False) +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user') +def test_handle_llm_error_deduplicates_recent_notification( + mock_get_key, + mock_get_uid, + mock_lock, + mock_get_tokens, + mock_send_each, +): + from utils.llm.byok_errors import handle_llm_error + + handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test') + + mock_lock.assert_called_once_with('user-1', 'openai', 'quota') + mock_send_each.assert_not_called() + + +@patch('utils.llm.byok_errors.messaging.send_each') +@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock') +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value=None) +def test_handle_llm_error_does_not_notify_platform_error( + mock_get_key, + mock_get_uid, + mock_lock, + mock_send_each, +): + from utils.llm.byok_errors import handle_llm_error + + handle_llm_error(_HTTPError("insufficient_quota", 429), 'openai', feature='memories', model='gpt-test') + + mock_lock.assert_not_called() + mock_send_each.assert_not_called() + + +@patch('utils.llm.byok_errors.notification_db.remove_bulk_tokens') +@patch('utils.llm.byok_errors.messaging.send_each') +@patch('utils.llm.byok_errors.notification_db.get_all_tokens', return_value=['bad-token']) +@patch('utils.llm.byok_errors.try_acquire_byok_llm_error_notification_lock', return_value=True) +@patch('utils.llm.byok_errors.get_byok_uid', return_value='user-1') +@patch('utils.llm.byok_errors.get_byok_key', return_value='sk-user') +def test_handle_llm_error_removes_permanent_bad_tokens( + mock_get_key, + mock_get_uid, + mock_lock, + mock_get_tokens, + mock_send_each, + mock_remove_tokens, +): + from utils.llm.byok_errors import handle_llm_error + + fcm_error = SimpleNamespace(code='UNREGISTERED') + mock_send_each.return_value = SimpleNamespace(responses=[SimpleNamespace(success=False, exception=fcm_error)]) + + handle_llm_error(_HTTPError("bad key", 401), 'openai', feature='memories', model='gpt-test') + + mock_remove_tokens.assert_called_once_with(['bad-token']) diff --git a/backend/utils/byok.py b/backend/utils/byok.py index 39d355273dd..6c3c9fdfa1d 100644 --- a/backend/utils/byok.py +++ b/backend/utils/byok.py @@ -73,6 +73,7 @@ def invalidate_byok_state_cache(uid: str) -> None: # Keys for the current request, if the client supplied them. # Default is None (not {}) to avoid sharing a mutable object across contexts. _byok_ctx: ContextVar[Optional[Dict[str, str]]] = ContextVar('byok_keys', default=None) +_byok_uid_ctx: ContextVar[Optional[str]] = ContextVar('byok_uid', default=None) def get_byok_keys() -> Dict[str, str]: @@ -87,6 +88,16 @@ def get_byok_key(provider: str) -> Optional[str]: return keys.get(provider) +def get_byok_uid() -> Optional[str]: + """Return the authenticated uid for the current request, when known.""" + return _byok_uid_ctx.get() + + +def set_byok_uid(uid: Optional[str]) -> None: + """Attach the authenticated uid to the current request context.""" + _byok_uid_ctx.set(uid) + + def has_byok_keys() -> bool: """True if the current request carries at least one BYOK header.""" keys = _byok_ctx.get() @@ -127,10 +138,12 @@ async def dispatch(self, request: Request, call_next): if value: keys[provider] = value token = _byok_ctx.set(keys) + uid_token = _byok_uid_ctx.set(None) try: return await call_next(request) finally: _byok_ctx.reset(token) + _byok_uid_ctx.reset(uid_token) # --------------------------------------------------------------------------- @@ -203,6 +216,7 @@ def validate_byok_request(uid: str) -> None: if error: logger.warning('BYOK validation failed uid=%s: %s', uid, error) raise HTTPException(status_code=403, detail=error) + set_byok_uid(uid) def validate_byok_websocket(uid: str) -> Optional[str]: @@ -215,4 +229,6 @@ def validate_byok_websocket(uid: str) -> Optional[str]: error = _check_byok_validity(uid) if error: logger.warning('BYOK WS validation failed uid=%s: %s', uid, error) + else: + set_byok_uid(uid) return error diff --git a/backend/utils/llm/byok_errors.py b/backend/utils/llm/byok_errors.py new file mode 100644 index 00000000000..7d6ddfea22b --- /dev/null +++ b/backend/utils/llm/byok_errors.py @@ -0,0 +1,185 @@ +import asyncio +import logging +from typing import Optional + +from firebase_admin import messaging + +try: + import database.notifications as notification_db +except ImportError: + notification_db = None + +try: + from database.redis_db import try_acquire_byok_llm_error_notification_lock +except ImportError: + + def try_acquire_byok_llm_error_notification_lock( + uid: str, provider: str, reason: str, ttl: int = 60 * 60 * 24 + ) -> bool: + logger.error('BYOK LLM notification lock unavailable uid=%s provider=%s reason=%s', uid, provider, reason) + return False + + +from utils.byok import get_byok_key, get_byok_uid +from utils.executors import storage_executor, submit_with_context +from utils.log_sanitizer import sanitize + +logger = logging.getLogger(__name__) + +_PERMANENT_FAILURE_CODES = frozenset({'UNREGISTERED', 'INVALID_REGISTRATION_TOKEN', 'NOT_FOUND'}) +_QUOTA_ERROR_NAMES = frozenset({'RateLimitError'}) + + +def get_llm_error_source(provider: Optional[str]) -> str: + """Return platform/byok for the current request and provider.""" + if provider and get_byok_key(provider): + return 'byok' + return 'platform' + + +def classify_byok_llm_error(error: Exception) -> Optional[str]: + """Classify user-actionable BYOK failures for structured logging.""" + status_code = _get_status_code(error) + error_name = type(error).__name__ + error_text = sanitize(str(error)).lower() + + if status_code == 401 or error_name == 'AuthenticationError': + return 'invalid' + if status_code == 403 or error_name == 'PermissionDeniedError': + return 'permission' + if status_code == 429 or error_name in _QUOTA_ERROR_NAMES: + if 'insufficient_quota' in error_text or 'quota' in error_text: + return 'quota' + return None + + +def handle_llm_error( + error: Exception, + provider: Optional[str], + feature: Optional[str] = None, + model: Optional[str] = None, + operation: str = 'chat', +) -> None: + """Log LLM failures with source context.""" + source = get_llm_error_source(provider) + reason = classify_byok_llm_error(error) if source == 'byok' else None + uid = get_byok_uid() + status_code = _get_status_code(error) + + logger.error( + 'LLM error source=%s provider=%s feature=%s model=%s operation=%s uid=%s status_code=%s reason=%s ' + 'error_type=%s error=%s', + source, + provider or 'unknown', + feature or 'unknown', + model or 'unknown', + operation, + uid or 'unknown', + status_code or 'unknown', + reason or 'unknown', + type(error).__name__, + sanitize(str(error)), + ) + + if source == 'byok' and uid and provider and reason: + _send_byok_llm_error_notification(uid, provider, reason) + + +async def handle_llm_error_async( + error: Exception, + provider: Optional[str], + feature: Optional[str] = None, + model: Optional[str] = None, + operation: str = 'chat', +) -> None: + """Run LLM error handling off the event loop while preserving BYOK context.""" + future = submit_with_context(storage_executor, handle_llm_error, error, provider, feature, model, operation) + try: + await asyncio.wrap_future(future) + except Exception as e: + logger.error('Async LLM error handler failed provider=%s feature=%s: %s', provider, feature, e) + + +def _get_status_code(error: Exception) -> Optional[int]: + status_code = getattr(error, 'status_code', None) + if isinstance(status_code, int): + return status_code + + response = getattr(error, 'response', None) + response_status = getattr(response, 'status_code', None) + if isinstance(response_status, int): + return response_status + return None + + +def _send_byok_llm_error_notification(uid: str, provider: str, reason: str) -> None: + if notification_db is None: + logger.error('BYOK LLM notification database unavailable uid=%s provider=%s reason=%s', uid, provider, reason) + return + + provider_name = provider.capitalize() + if reason == 'quota': + body = f'Your {provider_name} BYOK key appears to be out of quota. Update it to restore AI features.' + elif reason == 'permission': + body = f'Your {provider_name} BYOK key was denied access. Check its project and permissions in Omi settings.' + else: + body = f'Your {provider_name} BYOK key was rejected. Update it in Omi settings to restore AI features.' + + try: + tokens = notification_db.get_all_tokens(uid) + except Exception as e: + logger.error( + 'BYOK LLM notification token lookup failed uid=%s provider=%s reason=%s: %s', uid, provider, reason, e + ) + return + + if not tokens: + logger.info('No tokens found for BYOK LLM notification uid=%s provider=%s reason=%s', uid, provider, reason) + return + + try: + acquired = try_acquire_byok_llm_error_notification_lock(uid, provider, reason) + except Exception as e: + logger.error('BYOK LLM notification lock failed uid=%s provider=%s reason=%s: %s', uid, provider, reason, e) + return + + if not acquired: + logger.info('BYOK LLM notification already sent recently uid=%s provider=%s reason=%s', uid, provider, reason) + return + + notification = messaging.Notification(title='omi', body=body) + data = {'type': 'byok_llm_error', 'provider': provider, 'reason': reason} + messages = [messaging.Message(token=token, notification=notification, data=data) for token in tokens] + + try: + response = messaging.send_each(messages) + except Exception as e: + logger.error('BYOK LLM notification send failed uid=%s provider=%s reason=%s: %s', uid, provider, reason, e) + return + + invalid_tokens = [] + success_count = 0 + for idx, result in enumerate(response.responses): + if result.success: + success_count += 1 + elif result.exception: + error_code = getattr(result.exception, 'code', None) + if error_code in _PERMANENT_FAILURE_CODES: + invalid_tokens.append(tokens[idx]) + else: + logger.error('BYOK LLM notification FCM send failed uid=%s error=%s', uid, result.exception) + + if invalid_tokens: + try: + notification_db.remove_bulk_tokens(invalid_tokens) + except Exception as e: + logger.error('BYOK LLM notification invalid token cleanup failed uid=%s: %s', uid, e) + + logger.info( + 'BYOK LLM notification sent uid=%s provider=%s reason=%s success=%s total=%s', + uid, + provider, + reason, + success_count, + len(tokens), + ) diff --git a/backend/utils/llm/clients.py b/backend/utils/llm/clients.py index 2d73a028f59..769bbee4c22 100644 --- a/backend/utils/llm/clients.py +++ b/backend/utils/llm/clients.py @@ -6,6 +6,7 @@ import anthropic import httpx from cachetools import TTLCache +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.language_models import BaseChatModel from langchain_core.output_parsers import PydanticOutputParser from langchain_google_genai import ChatGoogleGenerativeAI @@ -14,12 +15,49 @@ from models.structured import Structured from utils.byok import get_byok_key +from utils.llm.byok_errors import handle_llm_error from utils.llm.usage_tracker import get_usage_callback logger = logging.getLogger(__name__) _usage_callback = get_usage_callback() + +class _LLMErrorCallback(BaseCallbackHandler): + """LangChain callback that tags provider errors with platform/BYOK source.""" + + def __init__(self, provider: str, model: str = '', feature: str = ''): + self.provider = provider + self.model = model + self.feature = feature + + def on_llm_error(self, error: BaseException, **kwargs) -> None: + if isinstance(error, Exception): + handle_llm_error(error, self.provider, feature=self.feature, model=self.model) + + +_llm_error_callbacks: Dict[Tuple[str, str, str], _LLMErrorCallback] = {} + + +def _get_llm_error_callback(provider: str, model: str = '', feature: str = '') -> _LLMErrorCallback: + key = (provider, model, feature) + if key not in _llm_error_callbacks: + _llm_error_callbacks[key] = _LLMErrorCallback(provider, model=model, feature=feature) + return _llm_error_callbacks[key] + + +def _with_llm_callbacks(kwargs: Dict[str, Any], provider: str, model: str = '', feature: str = '') -> Dict[str, Any]: + result = dict(kwargs) + callbacks = list(result.get('callbacks') or []) + if _usage_callback not in callbacks: + callbacks.append(_usage_callback) + error_callback = _get_llm_error_callback(provider, model=model, feature=feature) + if error_callback not in callbacks: + callbacks.append(error_callback) + result['callbacks'] = callbacks + return result + + # --------------------------------------------------------------------------- # BYOK (Bring Your Own Key) # @@ -56,6 +94,7 @@ class _OpenAIEmbeddingsProxy: """Transparent proxy for OpenAIEmbeddings that uses BYOK OpenAI when set.""" __slots__ = ('_model', '_default', '_ctor_kwargs') + _METHODS_TO_WRAP = {'embed_documents', 'aembed_documents', 'embed_query', 'aembed_query'} def __init__(self, model: str, default: OpenAIEmbeddings, ctor_kwargs: Dict[str, Any]): object.__setattr__(self, '_model', model) @@ -74,7 +113,28 @@ def _resolve(self) -> OpenAIEmbeddings: return self._default def __getattr__(self, name: str): - return getattr(self._resolve(), name) + attr = getattr(self._resolve(), name) + if name not in self._METHODS_TO_WRAP or not callable(attr): + return attr + if name.startswith('a'): + + async def _wrapped_async(*args, **kwargs): + try: + return await attr(*args, **kwargs) + except Exception as e: + handle_llm_error(e, 'openai', feature='embeddings', model=self._model, operation=name) + raise + + return _wrapped_async + + def _wrapped(*args, **kwargs): + try: + return attr(*args, **kwargs) + except Exception as e: + handle_llm_error(e, 'openai', feature='embeddings', model=self._model, operation=name) + raise + + return _wrapped _BYOK_CACHE_MAX_SIZE = 256 @@ -111,7 +171,10 @@ def _create_byok_client( model: str, provider: str, byok_key: str, streaming: bool = False, feature: str = '' ) -> Optional[ChatOpenAI]: """Create a ChatOpenAI using the user's BYOK key. Returns None if BYOK not supported for this provider.""" - kwargs: Dict[str, Any] = {'callbacks': [_usage_callback], 'request_timeout': 120, 'max_retries': 1} + callback_provider = _effective_byok_provider(model, provider) + kwargs: Dict[str, Any] = _with_llm_callbacks( + {'request_timeout': 120, 'max_retries': 1}, callback_provider, model=model, feature=feature + ) if model == 'gpt-5.1': kwargs['extra_body'] = {"prompt_cache_retention": "24h"} if streaming: @@ -148,6 +211,7 @@ def get_anthropic_client() -> anthropic.AsyncAnthropic: def get_openai_chat(model: str, **kwargs) -> ChatOpenAI: """Explicit factory; equivalent to using the module-level proxies.""" + kwargs = _with_llm_callbacks(kwargs, 'openai', model=model) byok = get_byok_key('openai') if byok: return _cached_openai_chat(model, byok, kwargs) @@ -417,11 +481,9 @@ def _get_or_create_openai_llm(model_name: str, streaming: bool = False) -> ChatO """Get or create a cached ChatOpenAI for an OpenAI model.""" key = (model_name, streaming, 'openai') if key not in _llm_cache: - kwargs: Dict[str, Any] = { - 'callbacks': [_usage_callback], - 'request_timeout': 120, - 'max_retries': 1, - } + kwargs: Dict[str, Any] = _with_llm_callbacks( + {'request_timeout': 120, 'max_retries': 1}, 'openai', model=model_name + ) if model_name == 'gpt-5.1': kwargs['extra_body'] = {"prompt_cache_retention": "24h"} if streaming: @@ -447,10 +509,10 @@ def _get_or_create_openrouter_llm( 'api_key': os.environ.get('OPENROUTER_API_KEY'), 'base_url': "https://openrouter.ai/api/v1", 'default_headers': {"X-Title": "Omi Chat"}, - 'callbacks': [_usage_callback], 'request_timeout': 120, 'max_retries': 1, } + kwargs = _with_llm_callbacks(kwargs, 'openrouter', model=api_model) if temperature is not None: kwargs['temperature'] = temperature if streaming: @@ -478,7 +540,7 @@ def _get_or_create_gemini_llm(model_name: str, streaming: bool = False) -> BaseC use_vertex = os.environ.get('USE_VERTEX_AI', '').lower() == 'true' gcp_project = os.environ.get('GOOGLE_CLOUD_PROJECT', '') if use_vertex else '' gemini_key = os.environ.get('GEMINI_API_KEY', '') - kwargs: Dict[str, Any] = {'callbacks': [_usage_callback], 'timeout': 120, 'max_retries': 1} + kwargs: Dict[str, Any] = _with_llm_callbacks({'timeout': 120, 'max_retries': 1}, 'gemini', model=model_name) if streaming: kwargs['streaming'] = True @@ -624,7 +686,10 @@ def get_qos_info() -> Dict[str, Dict[str, str]]: # Legacy module-level alias (kept for test compatibility). # Production code should use get_llm(feature) exclusively. # --------------------------------------------------------------------------- -llm_mini = ChatOpenAI(model='gpt-4.1-mini', callbacks=[_usage_callback], request_timeout=120, max_retries=1) +llm_mini = ChatOpenAI( + model='gpt-4.1-mini', + **_with_llm_callbacks({'request_timeout': 120, 'max_retries': 1}, 'openai', model='gpt-4.1-mini'), +) # --------------------------------------------------------------------------- # Embeddings, parser, utilities @@ -667,6 +732,10 @@ def gemini_embed_query(text: str) -> List[float]: 'taskType': 'RETRIEVAL_QUERY', } headers = {'x-goog-api-key': api_key, 'Content-Type': 'application/json'} - resp = httpx.post(url, json=payload, headers=headers, timeout=10) - resp.raise_for_status() - return resp.json()['embedding']['values'] + try: + resp = httpx.post(url, json=payload, headers=headers, timeout=10) + resp.raise_for_status() + return resp.json()['embedding']['values'] + except Exception as e: + handle_llm_error(e, 'gemini', feature='embeddings', model='embedding-001', operation='embed_query') + raise diff --git a/backend/utils/retrieval/agentic.py b/backend/utils/retrieval/agentic.py index 84c3697ffb4..e49a7eecce4 100644 --- a/backend/utils/retrieval/agentic.py +++ b/backend/utils/retrieval/agentic.py @@ -47,6 +47,7 @@ ) from utils.retrieval.tools.app_tools import load_app_tools, get_tool_status_message from utils.retrieval.safety import AgentSafetyGuard, SafetyGuardError +from utils.llm.byok_errors import handle_llm_error_async from utils.llm.clients import anthropic_client, ANTHROPIC_AGENT_MODEL from utils.llm.chat import _get_agentic_qa_prompt from utils.other.endpoints import timeit @@ -420,7 +421,7 @@ async def _run_anthropic_agent_stream( response = await stream.get_final_message() except Exception as e: - logger.error(f"Anthropic API error: {e}") + await handle_llm_error_async(e, 'anthropic', feature='chat_agent', model=ANTHROPIC_AGENT_MODEL) await callback.put_data(f"\n\nSorry, I encountered an error. Please try again.") await callback.end() return