diff --git a/requirements.txt b/requirements.txt index 7f9490a..21044e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ lightning_sdk >= 2025.09.16 +nest-asyncio diff --git a/src/litai/llm.py b/src/litai/llm.py index 458673e..0784dcd 100644 --- a/src/litai/llm.py +++ b/src/litai/llm.py @@ -13,21 +13,25 @@ # limitations under the License. """LLM client class.""" +import asyncio import datetime +import itertools import json import logging import os import threading import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union +from asyncio import Task +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Sequence, Union +import nest_asyncio import requests from lightning_sdk.lightning_cloud.openapi import V1ConversationResponseChunk from lightning_sdk.llm import LLM as SDKLLM from litai.tools import LitTool from litai.utils.supported_public_models import ModelLiteral -from litai.utils.utils import handle_model_error +from litai.utils.utils import handle_empty_response, handle_model_error if TYPE_CHECKING: from langchain_core.tools import StructuredTool @@ -206,7 +210,7 @@ def _format_tool_response( return LLM.call_tool(result, lit_tools) or "" return json.dumps(result) - def _model_call( + def _model_call( # noqa: D417 self, model: SDKLLM, prompt: str, @@ -258,6 +262,147 @@ def context_length(self, model: Optional[str] = None) -> int: return self._llm.get_context_length(self._model) return self._llm.get_context_length(model) + async def _peek_and_rebuild_async( + self, + agen: AsyncIterator[str], + ) -> Optional[AsyncIterator[str]]: + """Peek into an async iterator to check for non-empty content and rebuild it if necessary.""" + peeked_items: List[str] = [] + has_content_found = False + + async for item in agen: + peeked_items.append(item) + if item != "": + has_content_found = True + break + + if has_content_found: + + async def rebuilt() -> AsyncIterator[str]: + for peeked_item in peeked_items: + yield peeked_item + + async for remaining_item in agen: + yield remaining_item + + return rebuilt() + + return None + + async def async_chat( + self, + models_to_try: List[SDKLLM], + prompt: str, + system_prompt: Optional[str], + max_tokens: Optional[int], + images: Optional[Union[List[str], str]], + conversation: Optional[str], + metadata: Optional[Dict[str, str]], + stream: bool, + full_response: Optional[bool] = None, + model: Optional[SDKLLM] = None, + tools: Optional[Sequence[Union[str, Dict[str, Any]]]] = None, + lit_tools: Optional[List[LitTool]] = None, + auto_call_tools: bool = False, + reasoning_effort: Optional[str] = None, + **kwargs: Any, + ) -> Union[str, AsyncIterator[str], None]: + """Sends a message to the LLM asynchronously with full retry/fallback logic.""" + for sdk_model in models_to_try: + for attempt in range(self.max_retries): + try: + response = await self._model_call( # type: ignore[misc] + model=sdk_model, + prompt=prompt, + system_prompt=system_prompt, + max_completion_tokens=max_tokens, + images=images, + conversation=conversation, + metadata=metadata, + stream=stream, + tools=tools, + lit_tools=lit_tools, + full_response=full_response, + auto_call_tools=auto_call_tools, + reasoning_effort=reasoning_effort, + **kwargs, + ) + + if not stream and response: + return response + if stream and response: + non_empty_stream = await self._peek_and_rebuild_async(response) + if non_empty_stream: + return non_empty_stream + handle_empty_response(sdk_model, attempt, self.max_retries) + if sdk_model == model: + print(f"💥 Failed to override with model '{model}'") + except Exception as e: + handle_model_error(e, sdk_model, attempt, self.max_retries, self._verbose) + raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]") + + def sync_chat( + self, + models_to_try: List[SDKLLM], + prompt: str, + system_prompt: Optional[str], + max_tokens: Optional[int], + images: Optional[Union[List[str], str]], + conversation: Optional[str], + metadata: Optional[Dict[str, str]], + stream: bool, + model: Optional[SDKLLM] = None, + full_response: Optional[bool] = None, + tools: Optional[Sequence[Union[str, Dict[str, Any]]]] = None, + lit_tools: Optional[List[LitTool]] = None, + auto_call_tools: bool = False, + reasoning_effort: Optional[str] = None, + **kwargs: Any, + ) -> Union[str, Iterator[str], None]: + """Sends a message to the LLM synchronously with full retry/fallback logic.""" + for sdk_model in models_to_try: + for attempt in range(self.max_retries): + try: + response = self._model_call( + model=sdk_model, + prompt=prompt, + system_prompt=system_prompt, + max_completion_tokens=max_tokens, + images=images, + conversation=conversation, + metadata=metadata, + stream=stream, + tools=tools, + lit_tools=lit_tools, + full_response=full_response, + auto_call_tools=auto_call_tools, + reasoning_effort=reasoning_effort, + **kwargs, + ) + + if not stream and response: + return response + if stream: + try: + peek_iter, return_iter = itertools.tee(response) + has_content = False + for chunk in peek_iter: + if chunk != "": + has_content = True + break + if has_content: + return return_iter + except StopIteration: + pass + handle_empty_response(sdk_model, attempt, self.max_retries) + + except Exception as e: + if sdk_model == model: + print(f"💥 Failed to override with model '{model}'") + handle_model_error(e, sdk_model, attempt, self.max_retries, self._verbose) + + raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]") + def chat( # noqa: D417 self, prompt: str, @@ -272,7 +417,7 @@ def chat( # noqa: D417 auto_call_tools: bool = False, reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, **kwargs: Any, - ) -> str: + ) -> Union[str, Task[Union[str, AsyncIterator[str], None]], Iterator[str], None]: """Sends a message to the LLM and retrieves a response. Args: @@ -303,57 +448,61 @@ def chat( # noqa: D417 self._wait_for_model() lit_tools = LitTool.convert_tools(tools) processed_tools = [tool.as_tool() for tool in lit_tools] if lit_tools else None + + models_to_try = [] + sdk_model = None if model: - try: - model_key = f"{model}::{self._teamspace}::{self._enable_async}" - if model_key not in self._sdkllm_cache: - self._sdkllm_cache[model_key] = SDKLLM( - name=model, teamspace=self._teamspace, enable_async=self._enable_async - ) - sdk_model = self._sdkllm_cache[model_key] - return self._model_call( + model_key = f"{model}::{self._teamspace}::{self._enable_async}" + if model_key not in self._sdkllm_cache: + self._sdkllm_cache[model_key] = SDKLLM( + name=model, teamspace=self._teamspace, enable_async=self._enable_async + ) + sdk_model = self._sdkllm_cache[model_key] + models_to_try.append(sdk_model) + models_to_try.extend(self.models) + + if self._enable_async: + nest_asyncio.apply() + nest_asyncio.apply() + + loop = asyncio.get_event_loop() + return loop.create_task( + self.async_chat( + models_to_try=models_to_try, model=sdk_model, prompt=prompt, system_prompt=system_prompt, - max_completion_tokens=max_tokens, + max_tokens=max_tokens, images=images, conversation=conversation, metadata=metadata, stream=stream, + full_response=self._full_response, tools=processed_tools, lit_tools=lit_tools, auto_call_tools=auto_call_tools, reasoning_effort=reasoning_effort, **kwargs, ) - except Exception as e: - print(f"💥 Failed to override with model '{model}'") - handle_model_error(e, sdk_model, 0, self.max_retries, self._verbose) + ) - # Retry with fallback models - for model in self.models: - for attempt in range(self.max_retries): - try: - return self._model_call( - model=model, - prompt=prompt, - system_prompt=system_prompt, - max_completion_tokens=max_tokens, - images=images, - conversation=conversation, - metadata=metadata, - stream=stream, - tools=processed_tools, - lit_tools=lit_tools, - auto_call_tools=auto_call_tools, - reasoning_effort=reasoning_effort, - **kwargs, - ) - - except Exception as e: - handle_model_error(e, model, attempt, self.max_retries, self._verbose) - - raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]") + return self.sync_chat( + models_to_try=models_to_try, + model=sdk_model, + prompt=prompt, + system_prompt=system_prompt, + max_tokens=max_tokens, + images=images, + conversation=conversation, + metadata=metadata, + stream=stream, + full_response=self._full_response, + tools=processed_tools, + lit_tools=lit_tools, + auto_call_tools=auto_call_tools, + reasoning_effort=reasoning_effort, + **kwargs, + ) @staticmethod def call_tool( @@ -491,7 +640,11 @@ def if_(self, input: str, question: str) -> bool: Answer with only 'yes' or 'no'. """ - response = self.chat(prompt).strip().lower() + response = self.chat(prompt) + if isinstance(response, str): + response = response.strip().lower() + else: + return False return "yes" in response def classify(self, input: str, choices: List[str]) -> str: @@ -517,7 +670,11 @@ def classify(self, input: str, choices: List[str]) -> str: Answer with only one of the choices. """.strip() - response = self.chat(prompt).strip().lower() + response = self.chat(prompt) + if isinstance(response, str): + response = response.strip().lower() + else: + return normalized_choices[0] if response in normalized_choices: return response diff --git a/src/litai/utils/utils.py b/src/litai/utils/utils.py index 06761e4..252c836 100644 --- a/src/litai/utils/utils.py +++ b/src/litai/utils/utils.py @@ -182,3 +182,13 @@ def handle_model_error(e: Exception, model: SDKLLM, attempt: int, max_retries: i print("-" * 50) print(f"❌ All {max_retries} attempts failed for model {model.name}") print("-" * 50) + + +def handle_empty_response(model: SDKLLM, attempt: int, max_retries: int) -> None: + """Handles empty responses from model calls.""" + if attempt < max_retries - 1: + print(f"🔁 Received empty response. Attempt {attempt + 1}/{max_retries} failed. Retrying...") + else: + print("-" * 50) + print(f"❌ All {max_retries} attempts received empty responses for model {model.name}.") + print("-" * 50) diff --git a/tests/test_llm.py b/tests/test_llm.py index b93ac50..233aa89 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -247,6 +247,175 @@ def mock_llm_constructor(name, teamspace="default-teamspace", **kwargs): ) +def test_empty_response_retries(monkeypatch): + """Test fallback model logic when main model fails.""" + from litai.llm import LLM as LLMCLIENT + + LLMCLIENT._sdkllm_cache.clear() + mock_main_model = MagicMock() + mock_main_model.name = "main-model" + mock_fallback_model = MagicMock() + mock_fallback_model.name = "fallback-model" + + mock_main_model.chat.side_effect = "" + mock_fallback_model.chat.side_effect = [ + "", + "", + "Fallback response", + ] + + def mock_llm_constructor(name, teamspace="default-teamspace", **kwargs): + if name == "main-model": + return mock_main_model + if name == "fallback-model": + return mock_fallback_model + raise ValueError(f"Unknown model: {name}") + + monkeypatch.setattr("litai.llm.SDKLLM", mock_llm_constructor) + + llm = LLM( + model="main-model", + fallback_models=["fallback-model"], + ) + + response = llm.chat(prompt="Hello") + assert response == "Fallback response" + + assert mock_main_model.chat.call_count == 3 + assert mock_fallback_model.chat.call_count == 3 + + mock_fallback_model.chat.assert_called_with( + prompt="Hello", + system_prompt=None, + max_completion_tokens=None, + images=None, + conversation=None, + metadata=None, + stream=False, + full_response=False, + tools=None, + reasoning_effort=None, + ) + + +def test_empty_response_retries_sync_stream(monkeypatch): + """Test that retries work correctly for sync streaming when empty responses are returned.""" + from litai.llm import LLM as LLMCLIENT + + LLMCLIENT._sdkllm_cache.clear() + + class MockSyncIterator: + def __init__(self, items): + self.items = items + self.index = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.index < len(self.items): + item = self.items[self.index] + self.index += 1 + return item + raise StopIteration + + mock_responses = [ + MockSyncIterator([]), + MockSyncIterator([]), + MockSyncIterator(["hello", " world"]), + ] + + mock_main_model = MagicMock() + + def mock_llm_constructor(name, teamspace="default-teamspace", **kwargs): + if name == "main-model": + mock_main_model.chat.side_effect = mock_responses + mock_main_model.name = "main-model" + return mock_main_model + raise ValueError(f"Unknown model: {name}") + + monkeypatch.setattr("litai.llm.SDKLLM", mock_llm_constructor) + + llm = LLM( + model="main-model", + ) + + response = llm.chat("test prompt", stream=True) + + assert mock_main_model.chat.call_count == 3 + + result = "" + for chunk in response: + result += chunk + assert result == "hello world" + + +@pytest.mark.asyncio +async def test_empty_response_retries_async(monkeypatch): + """Test that retries work correctly for async and non streaming when empty responses are returned.""" + from litai.llm import LLM as LLMCLIENT + + LLMCLIENT._sdkllm_cache.clear() + mock_sdkllm = MagicMock() + mock_sdkllm.name = "mock-model" + + mock_sdkllm.chat = AsyncMock(side_effect=["", "", "Main response"]) + + monkeypatch.setattr("litai.llm.SDKLLM", lambda *args, **kwargs: mock_sdkllm) + + llm = LLM( + model="main-model", + enable_async=True, + ) + response = await llm.chat(prompt="Hello", stream=False) + + assert response == "Main response" + assert mock_sdkllm.chat.call_count == 3 + + +@pytest.mark.asyncio +async def test_empty_response_retries_async_stream(monkeypatch): + """Test that retries work correctly for async streaming when empty responses are returned.""" + from litai.llm import LLM as LLMCLIENT + + LLMCLIENT._sdkllm_cache.clear() + mock_sdkllm = MagicMock() + mock_sdkllm.name = "mock-model" + + class MockAsyncIterator: + def __init__(self, items): + self.items = items + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index < len(self.items): + item = self.items[self.index] + self.index += 1 + return item + raise StopAsyncIteration + + mock_sdkllm.chat = AsyncMock( + side_effect=[MockAsyncIterator([]), MockAsyncIterator([]), MockAsyncIterator(["Main", " response"])] + ) + + monkeypatch.setattr("litai.llm.SDKLLM", lambda *args, **kwargs: mock_sdkllm) + + llm = LLM( + model="main-model", + enable_async=True, + ) + + response = await llm.chat(prompt="Hello", stream=True) + result = "" + async for chunk in response: + result += chunk + assert result == "Main response" + assert mock_sdkllm.chat.call_count == 3 + + @pytest.mark.asyncio async def test_llm_async_chat(monkeypatch): """Test async requests."""