Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
lightning_sdk >= 2025.09.16
nest-asyncio
241 changes: 199 additions & 42 deletions src/litai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@
# limitations under the License.
"""LLM client class."""

import asyncio
import datetime
import itertools
import json
import logging
import os
import threading
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from asyncio import Task
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Sequence, Union

import nest_asyncio
import requests
from lightning_sdk.lightning_cloud.openapi import V1ConversationResponseChunk
from lightning_sdk.llm import LLM as SDKLLM

from litai.tools import LitTool
from litai.utils.supported_public_models import ModelLiteral
from litai.utils.utils import handle_model_error
from litai.utils.utils import handle_empty_response, handle_model_error

if TYPE_CHECKING:
from langchain_core.tools import StructuredTool
Expand Down Expand Up @@ -206,7 +210,7 @@ def _format_tool_response(
return LLM.call_tool(result, lit_tools) or ""
return json.dumps(result)

def _model_call(
def _model_call( # noqa: D417
self,
model: SDKLLM,
prompt: str,
Expand Down Expand Up @@ -258,6 +262,147 @@ def context_length(self, model: Optional[str] = None) -> int:
return self._llm.get_context_length(self._model)
return self._llm.get_context_length(model)

async def _peek_and_rebuild_async(
self,
agen: AsyncIterator[str],
) -> Optional[AsyncIterator[str]]:
"""Peek into an async iterator to check for non-empty content and rebuild it if necessary."""
peeked_items: List[str] = []
has_content_found = False

async for item in agen:
peeked_items.append(item)
if item != "":
has_content_found = True
break

if has_content_found:

async def rebuilt() -> AsyncIterator[str]:
for peeked_item in peeked_items:
yield peeked_item

async for remaining_item in agen:
yield remaining_item

return rebuilt()

return None

async def async_chat(
self,
models_to_try: List[SDKLLM],
prompt: str,
system_prompt: Optional[str],
max_tokens: Optional[int],
images: Optional[Union[List[str], str]],
conversation: Optional[str],
metadata: Optional[Dict[str, str]],
stream: bool,
full_response: Optional[bool] = None,
model: Optional[SDKLLM] = None,
tools: Optional[Sequence[Union[str, Dict[str, Any]]]] = None,
lit_tools: Optional[List[LitTool]] = None,
auto_call_tools: bool = False,
reasoning_effort: Optional[str] = None,
**kwargs: Any,
) -> Union[str, AsyncIterator[str], None]:
"""Sends a message to the LLM asynchronously with full retry/fallback logic."""
for sdk_model in models_to_try:
for attempt in range(self.max_retries):
try:
response = await self._model_call( # type: ignore[misc]
model=sdk_model,
prompt=prompt,
system_prompt=system_prompt,
max_completion_tokens=max_tokens,
images=images,
conversation=conversation,
metadata=metadata,
stream=stream,
tools=tools,
lit_tools=lit_tools,
full_response=full_response,
auto_call_tools=auto_call_tools,
reasoning_effort=reasoning_effort,
**kwargs,
)

if not stream and response:
return response
if stream and response:
non_empty_stream = await self._peek_and_rebuild_async(response)
if non_empty_stream:
return non_empty_stream
handle_empty_response(sdk_model, attempt, self.max_retries)
if sdk_model == model:
print(f"💥 Failed to override with model '{model}'")
except Exception as e:
handle_model_error(e, sdk_model, attempt, self.max_retries, self._verbose)
raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]")

def sync_chat(
self,
models_to_try: List[SDKLLM],
prompt: str,
system_prompt: Optional[str],
max_tokens: Optional[int],
images: Optional[Union[List[str], str]],
conversation: Optional[str],
metadata: Optional[Dict[str, str]],
stream: bool,
model: Optional[SDKLLM] = None,
full_response: Optional[bool] = None,
tools: Optional[Sequence[Union[str, Dict[str, Any]]]] = None,
lit_tools: Optional[List[LitTool]] = None,
auto_call_tools: bool = False,
reasoning_effort: Optional[str] = None,
**kwargs: Any,
) -> Union[str, Iterator[str], None]:
"""Sends a message to the LLM synchronously with full retry/fallback logic."""
for sdk_model in models_to_try:
for attempt in range(self.max_retries):
try:
response = self._model_call(
model=sdk_model,
prompt=prompt,
system_prompt=system_prompt,
max_completion_tokens=max_tokens,
images=images,
conversation=conversation,
metadata=metadata,
stream=stream,
tools=tools,
lit_tools=lit_tools,
full_response=full_response,
auto_call_tools=auto_call_tools,
reasoning_effort=reasoning_effort,
**kwargs,
)

if not stream and response:
return response
if stream:
try:
peek_iter, return_iter = itertools.tee(response)
has_content = False
for chunk in peek_iter:
if chunk != "":
has_content = True
break
if has_content:
return return_iter
except StopIteration:
pass
handle_empty_response(sdk_model, attempt, self.max_retries)

except Exception as e:
if sdk_model == model:
print(f"💥 Failed to override with model '{model}'")
handle_model_error(e, sdk_model, attempt, self.max_retries, self._verbose)

raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]")

def chat( # noqa: D417
self,
prompt: str,
Expand All @@ -272,7 +417,7 @@ def chat( # noqa: D417
auto_call_tools: bool = False,
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs: Any,
) -> str:
) -> Union[str, Task[Union[str, AsyncIterator[str], None]], Iterator[str], None]:
"""Sends a message to the LLM and retrieves a response.

Args:
Expand Down Expand Up @@ -303,57 +448,61 @@ def chat( # noqa: D417
self._wait_for_model()
lit_tools = LitTool.convert_tools(tools)
processed_tools = [tool.as_tool() for tool in lit_tools] if lit_tools else None

models_to_try = []
sdk_model = None
if model:
try:
model_key = f"{model}::{self._teamspace}::{self._enable_async}"
if model_key not in self._sdkllm_cache:
self._sdkllm_cache[model_key] = SDKLLM(
name=model, teamspace=self._teamspace, enable_async=self._enable_async
)
sdk_model = self._sdkllm_cache[model_key]
return self._model_call(
model_key = f"{model}::{self._teamspace}::{self._enable_async}"
if model_key not in self._sdkllm_cache:
self._sdkllm_cache[model_key] = SDKLLM(
name=model, teamspace=self._teamspace, enable_async=self._enable_async
)
sdk_model = self._sdkllm_cache[model_key]
models_to_try.append(sdk_model)
models_to_try.extend(self.models)

if self._enable_async:
nest_asyncio.apply()
nest_asyncio.apply()

loop = asyncio.get_event_loop()
return loop.create_task(
self.async_chat(
models_to_try=models_to_try,
model=sdk_model,
prompt=prompt,
system_prompt=system_prompt,
max_completion_tokens=max_tokens,
max_tokens=max_tokens,
images=images,
conversation=conversation,
metadata=metadata,
stream=stream,
full_response=self._full_response,
tools=processed_tools,
lit_tools=lit_tools,
auto_call_tools=auto_call_tools,
reasoning_effort=reasoning_effort,
**kwargs,
)
except Exception as e:
print(f"💥 Failed to override with model '{model}'")
handle_model_error(e, sdk_model, 0, self.max_retries, self._verbose)
)

# Retry with fallback models
for model in self.models:
for attempt in range(self.max_retries):
try:
return self._model_call(
model=model,
prompt=prompt,
system_prompt=system_prompt,
max_completion_tokens=max_tokens,
images=images,
conversation=conversation,
metadata=metadata,
stream=stream,
tools=processed_tools,
lit_tools=lit_tools,
auto_call_tools=auto_call_tools,
reasoning_effort=reasoning_effort,
**kwargs,
)

except Exception as e:
handle_model_error(e, model, attempt, self.max_retries, self._verbose)

raise RuntimeError(f"💥 [LLM call failed after {self.max_retries} attempts]")
return self.sync_chat(
models_to_try=models_to_try,
model=sdk_model,
prompt=prompt,
system_prompt=system_prompt,
max_tokens=max_tokens,
images=images,
conversation=conversation,
metadata=metadata,
stream=stream,
full_response=self._full_response,
tools=processed_tools,
lit_tools=lit_tools,
auto_call_tools=auto_call_tools,
reasoning_effort=reasoning_effort,
**kwargs,
)

@staticmethod
def call_tool(
Expand Down Expand Up @@ -491,7 +640,11 @@ def if_(self, input: str, question: str) -> bool:
Answer with only 'yes' or 'no'.
"""

response = self.chat(prompt).strip().lower()
response = self.chat(prompt)
if isinstance(response, str):
response = response.strip().lower()
else:
return False
return "yes" in response

def classify(self, input: str, choices: List[str]) -> str:
Expand All @@ -517,7 +670,11 @@ def classify(self, input: str, choices: List[str]) -> str:
Answer with only one of the choices.
""".strip()

response = self.chat(prompt).strip().lower()
response = self.chat(prompt)
if isinstance(response, str):
response = response.strip().lower()
else:
return normalized_choices[0]

if response in normalized_choices:
return response
Expand Down
10 changes: 10 additions & 0 deletions src/litai/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,13 @@ def handle_model_error(e: Exception, model: SDKLLM, attempt: int, max_retries: i
print("-" * 50)
print(f"❌ All {max_retries} attempts failed for model {model.name}")
print("-" * 50)


def handle_empty_response(model: SDKLLM, attempt: int, max_retries: int) -> None:
"""Handles empty responses from model calls."""
if attempt < max_retries - 1:
print(f"🔁 Received empty response. Attempt {attempt + 1}/{max_retries} failed. Retrying...")
else:
print("-" * 50)
print(f"❌ All {max_retries} attempts received empty responses for model {model.name}.")
print("-" * 50)
Loading
Loading