Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions _requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ mypy==1.16.1
psutil>=7.0.0
pytest-asyncio>=1.1.0
openai>=1.97.1
langchain>=0.3.27
langchain-core>=0.3.72
2 changes: 2 additions & 0 deletions src/litai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def chat( # noqa: D417
str: The response from the LLM.
"""
self._wait_for_model()
tools = LitTool.convert_tools(tools)
tool_schema = [tool.as_tool() for tool in tools] if tools else None
if tool_schema:
tool_context = (
Expand Down Expand Up @@ -359,6 +360,7 @@ def call_tool(response: str, tools: Optional[List[LitTool]] = None) -> Optional[
parsed = json.loads(response)
tool_name = parsed["tool"]
tool_args = parsed["parameters"]
tools = LitTool.convert_tools(tools)
for tool in tools:
if tool.name == tool_name:
return tool.run(**tool_args)
Expand Down
52 changes: 51 additions & 1 deletion src/litai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@

import json
from inspect import Signature, signature
from typing import Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, Field

if TYPE_CHECKING:
from langchain_core.tools.structured import StructuredTool


class LitTool(BaseModel):
"""A tool is a function that can be used to interact with the world."""
Expand Down Expand Up @@ -79,6 +82,53 @@ def as_tool(self, json_mode: bool = False) -> Union[str, Dict[str, Any]]:
"parameters": self._extract_parameters(),
}

@classmethod
def from_langchain(cls, tool: "StructuredTool") -> "LitTool":
"""Convert a LangChain StructuredTool to a LitTool."""

class LangchainTool(LitTool):
def setup(self) -> None:
super().setup()
self.name: str = tool.name
self.description: str = tool.description
self._tool = tool

def run(self, *args: Any, **kwargs: Any) -> Any:
return self._tool.func(*args, **kwargs) # type: ignore

def _extract_parameters(self) -> Dict[str, Any]:
return self._tool.args_schema.model_json_schema() # type: ignore

return LangchainTool()

@classmethod
def convert_tools(cls, tools: Optional[List[Any]]) -> List["LitTool"]:
"""Convert a list of tools into LitTool instances.

- Passes through LitTool instances.
- Wraps LangChain StructuredTool objects.
- Raises TypeError for unsupported types.
"""
if tools is None:
return []
if len(tools) == 0:
return []

lit_tools = []

for tool in tools:
if isinstance(tool, LitTool):
lit_tools.append(tool)

# LangChain StructuredTool - check by type name and module
elif type(tool).__name__ == "StructuredTool" and type(tool).__module__ == "langchain_core.tools.structured":
lit_tools.append(cls.from_langchain(tool))

else:
raise TypeError(f"Unsupported tool type: {type(tool)}")

return lit_tools


def tool(func: Optional[Callable] = None) -> Union[LitTool, Callable]:
"""Decorator to convert a function into a LitTool instance.
Expand Down
18 changes: 18 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from langchain_core.tools import tool as langchain_tool

from litai import LLM, tool

Expand Down Expand Up @@ -410,3 +411,20 @@ def test_dump_debug(mock_makedirs, mock_open):
assert "Test response" in written_content
assert "📛 Exception:" in written_content
assert "Test exception" in written_content


@patch("litai.llm.SDKLLM")
def test_call_langchain_tools(mock_sdkllm):
@langchain_tool
def get_weather(city: str) -> str:
"""Get the weather of a given city."""
return f"Weather in {city} is sunny."

llm = LLM()
with patch.object(
llm,
"chat",
return_value=json.dumps({"type": "function_call", "tool": "get_weather", "parameters": {"city": "London"}}),
):
result = llm.chat("how is the weather in London?", tools=[get_weather])
assert llm.call_tool(result, tools=[get_weather]) == "Weather in London is sunny."
32 changes: 32 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Unit tests for tools module."""

import pytest
from langchain_core.tools import tool as langchain_tool

from litai import LitTool, tool

Expand Down Expand Up @@ -182,3 +183,34 @@ def setup(self) -> None:
assert tool_instance.state == 1, "State initialized with 1"
tool_instance.state += 1
assert tool_instance.state == 2, "State not incremented. Should be 2"


def test_from_langchain():
@langchain_tool
def get_weather(city: str) -> str:
"""Get the weather of a given city."""
return f"Weather in {city} is sunny."

lit_tool = LitTool.from_langchain(get_weather)
assert isinstance(lit_tool, LitTool)
assert lit_tool.name == "get_weather"
assert lit_tool.description == "Get the weather of a given city."
assert lit_tool.as_tool() == {
"name": "get_weather",
"description": "Get the weather of a given city.",
"parameters": get_weather.args_schema.model_json_schema(),
}


def test_convert_tools_empty():
lit_tools = LitTool.convert_tools([])
assert len(lit_tools) == 0


def test_convert_tools_unsupported_type():
def get_weather(city: str) -> str:
"""Get the weather of a given city."""
return f"Weather in {city} is sunny."

with pytest.raises(TypeError, match="Unsupported tool type: <class 'function'>"):
LitTool.convert_tools([get_weather])
Loading