diff --git a/src/ansari/agents/ansari.py b/src/ansari/agents/ansari.py index c4be15a..d56b0d7 100644 --- a/src/ansari/agents/ansari.py +++ b/src/ansari/agents/ansari.py @@ -12,6 +12,7 @@ import traceback import litellm +import tiktoken from ansari.ansari_db import MessageLogger from ansari.ansari_logger import get_logger @@ -45,6 +46,11 @@ def __init__(self, settings: Settings, message_logger: MessageLogger = None, jso self.message_history = [{"role": "system", "content": self.sys_msg}] self.model = settings.MODEL + try: + self.encoder = tiktoken.encoding_for_model(self.model) + except Exception: + logger.warning(f"Could not get encoding for model {self.model}, falling back to cl100k_base") + self.encoder = tiktoken.get_encoding("cl100k_base") def _initialize_tools(self): """Initialize tool instances. Can be overridden by subclasses.""" @@ -115,6 +121,16 @@ def get_completion(self, **kwargs): return litellm.completion(**kwargs) def process_message_history(self, use_tool=True): + if self._check_token_limit(): + msg = "The conversation has become too long. Please start a new conversation." + # We don't append to history here to avoid growing it further, + # but we should probably log it if we want to show it in the UI history? + # Actually, the UI displays what we yield. + # But if we don't append it, the next request will still have the long history. + # The user is instructed to start a NEW conversation, so the history will be reset. + yield msg + return + common_params = { "model": self.model, "messages": self.message_history, @@ -452,3 +468,22 @@ def _log_truncated_message_history(self, message_history, count: int, failures: + f"\n{trunc_msg_hist}\n" + "-" * 60, ) + + def _check_token_limit(self): + try: + num_tokens = 0 + for message in self.message_history: + # Approximate token count + content = message.get("content", "") + if isinstance(content, str): + num_tokens += len(self.encoder.encode(content)) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and "text" in item: + num_tokens += len(self.encoder.encode(item["text"])) + + logger.debug(f"Current token count: {num_tokens}") + return num_tokens > self.settings.MAX_TOKEN_LIMIT + except Exception as e: + logger.warning(f"Error checking token limit: {e}") + return False diff --git a/src/ansari/app/main_api.py b/src/ansari/app/main_api.py index c9629b3..7ca4381 100644 --- a/src/ansari/app/main_api.py +++ b/src/ansari/app/main_api.py @@ -38,7 +38,7 @@ from ansari.agents.ansari_workflow import AnsariWorkflow from ansari.ansari_db import AnsariDB, MessageLogger, SourceType from ansari.ansari_logger import get_logger -from ansari.routers.whatsapp_router import router as whatsapp_router + from ansari.config import Settings, get_settings from ansari.presenters.api_presenter import ApiPresenter from ansari.util.general_helpers import CORSMiddlewareWithLogging, get_extended_origins, register_to_mailing_list @@ -89,7 +89,8 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) # Include the WhatsApp API router -app.include_router(whatsapp_router) + + # Custom exception handler, which aims to log FastAPI-related exceptions before raising them @@ -147,8 +148,14 @@ def add_app_middleware(): presenter = ApiPresenter(app, ansari) + presenter.present() +# Include the WhatsApp API router +# NOTE: We import it here to avoid circular imports, as whatsapp_router imports db and presenter from this file +from ansari.routers.whatsapp_router import router as whatsapp_router +app.include_router(whatsapp_router) + cache = FanoutCache(get_settings().diskcache_dir, shards=4, timeout=1) diff --git a/src/ansari/config.py b/src/ansari/config.py index d91ce58..e350771 100644 --- a/src/ansari/config.py +++ b/src/ansari/config.py @@ -75,6 +75,7 @@ def get_resource_path(filename): VECTARA_API_KEY: SecretStr MAX_FAILURES: int = Field(default=3) + MAX_TOKEN_LIMIT: int = Field(default=100000) MAWSUAH_VECTARA_CORPUS_KEY: str = Field( alias="MAWSUAH_VECTARA_CORPUS_KEY", diff --git a/tests/unit/test_token_limit.py b/tests/unit/test_token_limit.py new file mode 100644 index 0000000..e31d4df --- /dev/null +++ b/tests/unit/test_token_limit.py @@ -0,0 +1,85 @@ +import pytest +from unittest.mock import MagicMock, patch +from src.ansari.agents.ansari import Ansari + +@pytest.fixture +def mock_settings(): + settings = MagicMock() + settings.MODEL = "gpt-4o" + settings.MAX_TOKEN_LIMIT = 10 + settings.PROMPT_PATH = "/test/prompts" + settings.SYSTEM_PROMPT_FILE_NAME = "system_msg_default" + settings.KALEMAT_API_KEY.get_secret_value.return_value = "key" + settings.VECTARA_API_KEY.get_secret_value.return_value = "key" + settings.USUL_API_TOKEN.get_secret_value.return_value = "key" + settings.MAWSUAH_VECTARA_CORPUS_KEY = "corpus" + return settings + +def test_token_limit_exceeded(mock_settings): + with patch("src.ansari.agents.ansari.PromptMgr") as mock_prompt_mgr, \ + patch("src.ansari.agents.ansari.SearchQuran"), \ + patch("src.ansari.agents.ansari.SearchHadith"), \ + patch("src.ansari.agents.ansari.SearchMawsuah"), \ + patch("src.ansari.agents.ansari.SearchTafsirEncyc"), \ + patch("tiktoken.encoding_for_model") as mock_encoding: + + # Mock encoding to return a length > 10 + mock_encoder = MagicMock() + mock_encoder.encode.return_value = [1] * 20 # 20 tokens + mock_encoding.return_value = mock_encoder + + mock_prompt = MagicMock() + mock_prompt.render.return_value = "sys" + mock_prompt_mgr.return_value.bind.return_value = mock_prompt + + ansari = Ansari(mock_settings) + + # Add a message + ansari.message_history = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "This message is long enough."} + ] + + # Process + response_gen = ansari.process_message_history() + response = next(response_gen) + + assert "The conversation has become too long" in response + +def test_token_limit_not_exceeded(mock_settings): + with patch("src.ansari.agents.ansari.PromptMgr") as mock_prompt_mgr, \ + patch("src.ansari.agents.ansari.SearchQuran"), \ + patch("src.ansari.agents.ansari.SearchHadith"), \ + patch("src.ansari.agents.ansari.SearchMawsuah"), \ + patch("src.ansari.agents.ansari.SearchTafsirEncyc"), \ + patch("tiktoken.encoding_for_model") as mock_encoding, \ + patch("src.ansari.agents.ansari.litellm.completion") as mock_completion: + + # Mock encoding to return a length < 10 + mock_encoder = MagicMock() + mock_encoder.encode.return_value = [1] * 5 # 5 tokens + mock_encoding.return_value = mock_encoder + + mock_prompt = MagicMock() + mock_prompt.render.return_value = "sys" + mock_prompt_mgr.return_value.bind.return_value = mock_prompt + + # Mock completion response + mock_chunk = MagicMock() + mock_chunk.choices[0].delta.content = "Response" + mock_chunk.choices[0].delta.tool_calls = None + mock_completion.return_value = [mock_chunk] + + ansari = Ansari(mock_settings) + + # Add a short message + ansari.message_history = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "Short."} + ] + + # Process + response_gen = ansari.process_message_history() + response = next(response_gen) + + assert response == "Response"