Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/ansari/agents/ansari.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import traceback

import litellm
import tiktoken

from ansari.ansari_db import MessageLogger
from ansari.ansari_logger import get_logger
Expand Down Expand Up @@ -45,6 +46,11 @@ def __init__(self, settings: Settings, message_logger: MessageLogger = None, jso
self.message_history = [{"role": "system", "content": self.sys_msg}]

self.model = settings.MODEL
try:
self.encoder = tiktoken.encoding_for_model(self.model)
except Exception:
logger.warning(f"Could not get encoding for model {self.model}, falling back to cl100k_base")
self.encoder = tiktoken.get_encoding("cl100k_base")

def _initialize_tools(self):
"""Initialize tool instances. Can be overridden by subclasses."""
Expand Down Expand Up @@ -115,6 +121,16 @@ def get_completion(self, **kwargs):
return litellm.completion(**kwargs)

def process_message_history(self, use_tool=True):
if self._check_token_limit():
msg = "The conversation has become too long. Please start a new conversation."
# We don't append to history here to avoid growing it further,
# but we should probably log it if we want to show it in the UI history?
# Actually, the UI displays what we yield.
# But if we don't append it, the next request will still have the long history.
# The user is instructed to start a NEW conversation, so the history will be reset.
yield msg
return

common_params = {
"model": self.model,
"messages": self.message_history,
Expand Down Expand Up @@ -452,3 +468,22 @@ def _log_truncated_message_history(self, message_history, count: int, failures:
+ f"\n{trunc_msg_hist}\n"
+ "-" * 60,
)

def _check_token_limit(self):
try:
num_tokens = 0
for message in self.message_history:
# Approximate token count
content = message.get("content", "")
if isinstance(content, str):
num_tokens += len(self.encoder.encode(content))
elif isinstance(content, list):
for item in content:
if isinstance(item, dict) and "text" in item:
num_tokens += len(self.encoder.encode(item["text"]))

logger.debug(f"Current token count: {num_tokens}")
return num_tokens > self.settings.MAX_TOKEN_LIMIT
except Exception as e:
logger.warning(f"Error checking token limit: {e}")
return False
11 changes: 9 additions & 2 deletions src/ansari/app/main_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ansari.agents.ansari_workflow import AnsariWorkflow
from ansari.ansari_db import AnsariDB, MessageLogger, SourceType
from ansari.ansari_logger import get_logger
from ansari.routers.whatsapp_router import router as whatsapp_router

from ansari.config import Settings, get_settings
from ansari.presenters.api_presenter import ApiPresenter
from ansari.util.general_helpers import CORSMiddlewareWithLogging, get_extended_origins, register_to_mailing_list
Expand Down Expand Up @@ -89,7 +89,8 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)

# Include the WhatsApp API router
app.include_router(whatsapp_router)




# Custom exception handler, which aims to log FastAPI-related exceptions before raising them
Expand Down Expand Up @@ -147,8 +148,14 @@ def add_app_middleware():


presenter = ApiPresenter(app, ansari)

presenter.present()

# Include the WhatsApp API router
# NOTE: We import it here to avoid circular imports, as whatsapp_router imports db and presenter from this file
from ansari.routers.whatsapp_router import router as whatsapp_router
app.include_router(whatsapp_router)

cache = FanoutCache(get_settings().diskcache_dir, shards=4, timeout=1)


Expand Down
1 change: 1 addition & 0 deletions src/ansari/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def get_resource_path(filename):
VECTARA_API_KEY: SecretStr

MAX_FAILURES: int = Field(default=3)
MAX_TOKEN_LIMIT: int = Field(default=100000)

MAWSUAH_VECTARA_CORPUS_KEY: str = Field(
alias="MAWSUAH_VECTARA_CORPUS_KEY",
Expand Down
85 changes: 85 additions & 0 deletions tests/unit/test_token_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import pytest
from unittest.mock import MagicMock, patch
from src.ansari.agents.ansari import Ansari

@pytest.fixture
def mock_settings():
settings = MagicMock()
settings.MODEL = "gpt-4o"
settings.MAX_TOKEN_LIMIT = 10
settings.PROMPT_PATH = "/test/prompts"
settings.SYSTEM_PROMPT_FILE_NAME = "system_msg_default"
settings.KALEMAT_API_KEY.get_secret_value.return_value = "key"
settings.VECTARA_API_KEY.get_secret_value.return_value = "key"
settings.USUL_API_TOKEN.get_secret_value.return_value = "key"
settings.MAWSUAH_VECTARA_CORPUS_KEY = "corpus"
return settings

def test_token_limit_exceeded(mock_settings):
with patch("src.ansari.agents.ansari.PromptMgr") as mock_prompt_mgr, \
patch("src.ansari.agents.ansari.SearchQuran"), \
patch("src.ansari.agents.ansari.SearchHadith"), \
patch("src.ansari.agents.ansari.SearchMawsuah"), \
patch("src.ansari.agents.ansari.SearchTafsirEncyc"), \
patch("tiktoken.encoding_for_model") as mock_encoding:

# Mock encoding to return a length > 10
mock_encoder = MagicMock()
mock_encoder.encode.return_value = [1] * 20 # 20 tokens
mock_encoding.return_value = mock_encoder

mock_prompt = MagicMock()
mock_prompt.render.return_value = "sys"
mock_prompt_mgr.return_value.bind.return_value = mock_prompt

ansari = Ansari(mock_settings)

# Add a message
ansari.message_history = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "This message is long enough."}
]

# Process
response_gen = ansari.process_message_history()
response = next(response_gen)

assert "The conversation has become too long" in response

def test_token_limit_not_exceeded(mock_settings):
with patch("src.ansari.agents.ansari.PromptMgr") as mock_prompt_mgr, \
patch("src.ansari.agents.ansari.SearchQuran"), \
patch("src.ansari.agents.ansari.SearchHadith"), \
patch("src.ansari.agents.ansari.SearchMawsuah"), \
patch("src.ansari.agents.ansari.SearchTafsirEncyc"), \
patch("tiktoken.encoding_for_model") as mock_encoding, \
patch("src.ansari.agents.ansari.litellm.completion") as mock_completion:

# Mock encoding to return a length < 10
mock_encoder = MagicMock()
mock_encoder.encode.return_value = [1] * 5 # 5 tokens
mock_encoding.return_value = mock_encoder

mock_prompt = MagicMock()
mock_prompt.render.return_value = "sys"
mock_prompt_mgr.return_value.bind.return_value = mock_prompt

# Mock completion response
mock_chunk = MagicMock()
mock_chunk.choices[0].delta.content = "Response"
mock_chunk.choices[0].delta.tool_calls = None
mock_completion.return_value = [mock_chunk]

ansari = Ansari(mock_settings)

# Add a short message
ansari.message_history = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "Short."}
]

# Process
response_gen = ansari.process_message_history()
response = next(response_gen)

assert response == "Response"