Skip to content

Commit 1dca53e

Browse files
committed
feat: Implement token limit for message history
1 parent 1db73db commit 1dca53e

File tree

4 files changed

+130
-2
lines changed

4 files changed

+130
-2
lines changed

src/ansari/agents/ansari.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import traceback
1313

1414
import litellm
15+
import tiktoken
1516

1617
from ansari.ansari_db import MessageLogger
1718
from ansari.ansari_logger import get_logger
@@ -45,6 +46,11 @@ def __init__(self, settings: Settings, message_logger: MessageLogger = None, jso
4546
self.message_history = [{"role": "system", "content": self.sys_msg}]
4647

4748
self.model = settings.MODEL
49+
try:
50+
self.encoder = tiktoken.encoding_for_model(self.model)
51+
except Exception:
52+
logger.warning(f"Could not get encoding for model {self.model}, falling back to cl100k_base")
53+
self.encoder = tiktoken.get_encoding("cl100k_base")
4854

4955
def _initialize_tools(self):
5056
"""Initialize tool instances. Can be overridden by subclasses."""
@@ -115,6 +121,16 @@ def get_completion(self, **kwargs):
115121
return litellm.completion(**kwargs)
116122

117123
def process_message_history(self, use_tool=True):
124+
if self._check_token_limit():
125+
msg = "The conversation has become too long. Please start a new conversation."
126+
# We don't append to history here to avoid growing it further,
127+
# but we should probably log it if we want to show it in the UI history?
128+
# Actually, the UI displays what we yield.
129+
# But if we don't append it, the next request will still have the long history.
130+
# The user is instructed to start a NEW conversation, so the history will be reset.
131+
yield msg
132+
return
133+
118134
common_params = {
119135
"model": self.model,
120136
"messages": self.message_history,
@@ -452,3 +468,22 @@ def _log_truncated_message_history(self, message_history, count: int, failures:
452468
+ f"\n{trunc_msg_hist}\n"
453469
+ "-" * 60,
454470
)
471+
472+
def _check_token_limit(self):
473+
try:
474+
num_tokens = 0
475+
for message in self.message_history:
476+
# Approximate token count
477+
content = message.get("content", "")
478+
if isinstance(content, str):
479+
num_tokens += len(self.encoder.encode(content))
480+
elif isinstance(content, list):
481+
for item in content:
482+
if isinstance(item, dict) and "text" in item:
483+
num_tokens += len(self.encoder.encode(item["text"]))
484+
485+
logger.debug(f"Current token count: {num_tokens}")
486+
return num_tokens > self.settings.MAX_TOKEN_LIMIT
487+
except Exception as e:
488+
logger.warning(f"Error checking token limit: {e}")
489+
return False

src/ansari/app/main_api.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from ansari.agents.ansari_workflow import AnsariWorkflow
3939
from ansari.ansari_db import AnsariDB, MessageLogger, SourceType
4040
from ansari.ansari_logger import get_logger
41-
from ansari.routers.whatsapp_router import router as whatsapp_router
41+
4242
from ansari.config import Settings, get_settings
4343
from ansari.presenters.api_presenter import ApiPresenter
4444
from ansari.util.general_helpers import CORSMiddlewareWithLogging, get_extended_origins, register_to_mailing_list
@@ -89,7 +89,8 @@ async def lifespan(app: FastAPI):
8989
app = FastAPI(lifespan=lifespan)
9090

9191
# Include the WhatsApp API router
92-
app.include_router(whatsapp_router)
92+
93+
9394

9495

9596
# Custom exception handler, which aims to log FastAPI-related exceptions before raising them
@@ -147,8 +148,14 @@ def add_app_middleware():
147148

148149

149150
presenter = ApiPresenter(app, ansari)
151+
150152
presenter.present()
151153

154+
# Include the WhatsApp API router
155+
# NOTE: We import it here to avoid circular imports, as whatsapp_router imports db and presenter from this file
156+
from ansari.routers.whatsapp_router import router as whatsapp_router
157+
app.include_router(whatsapp_router)
158+
152159
cache = FanoutCache(get_settings().diskcache_dir, shards=4, timeout=1)
153160

154161

src/ansari/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def get_resource_path(filename):
7575
VECTARA_API_KEY: SecretStr
7676

7777
MAX_FAILURES: int = Field(default=3)
78+
MAX_TOKEN_LIMIT: int = Field(default=100000)
7879

7980
MAWSUAH_VECTARA_CORPUS_KEY: str = Field(
8081
alias="MAWSUAH_VECTARA_CORPUS_KEY",

tests/unit/test_token_limit.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytest
2+
from unittest.mock import MagicMock, patch
3+
from src.ansari.agents.ansari import Ansari
4+
5+
@pytest.fixture
6+
def mock_settings():
7+
settings = MagicMock()
8+
settings.MODEL = "gpt-4o"
9+
settings.MAX_TOKEN_LIMIT = 10
10+
settings.PROMPT_PATH = "/test/prompts"
11+
settings.SYSTEM_PROMPT_FILE_NAME = "system_msg_default"
12+
settings.KALEMAT_API_KEY.get_secret_value.return_value = "key"
13+
settings.VECTARA_API_KEY.get_secret_value.return_value = "key"
14+
settings.USUL_API_TOKEN.get_secret_value.return_value = "key"
15+
settings.MAWSUAH_VECTARA_CORPUS_KEY = "corpus"
16+
return settings
17+
18+
def test_token_limit_exceeded(mock_settings):
19+
with patch("src.ansari.agents.ansari.PromptMgr") as mock_prompt_mgr, \
20+
patch("src.ansari.agents.ansari.SearchQuran"), \
21+
patch("src.ansari.agents.ansari.SearchHadith"), \
22+
patch("src.ansari.agents.ansari.SearchMawsuah"), \
23+
patch("src.ansari.agents.ansari.SearchTafsirEncyc"), \
24+
patch("tiktoken.encoding_for_model") as mock_encoding:
25+
26+
# Mock encoding to return a length > 10
27+
mock_encoder = MagicMock()
28+
mock_encoder.encode.return_value = [1] * 20 # 20 tokens
29+
mock_encoding.return_value = mock_encoder
30+
31+
mock_prompt = MagicMock()
32+
mock_prompt.render.return_value = "sys"
33+
mock_prompt_mgr.return_value.bind.return_value = mock_prompt
34+
35+
ansari = Ansari(mock_settings)
36+
37+
# Add a message
38+
ansari.message_history = [
39+
{"role": "system", "content": "sys"},
40+
{"role": "user", "content": "This message is long enough."}
41+
]
42+
43+
# Process
44+
response_gen = ansari.process_message_history()
45+
response = next(response_gen)
46+
47+
assert "The conversation has become too long" in response
48+
49+
def test_token_limit_not_exceeded(mock_settings):
50+
with patch("src.ansari.agents.ansari.PromptMgr") as mock_prompt_mgr, \
51+
patch("src.ansari.agents.ansari.SearchQuran"), \
52+
patch("src.ansari.agents.ansari.SearchHadith"), \
53+
patch("src.ansari.agents.ansari.SearchMawsuah"), \
54+
patch("src.ansari.agents.ansari.SearchTafsirEncyc"), \
55+
patch("tiktoken.encoding_for_model") as mock_encoding, \
56+
patch("src.ansari.agents.ansari.litellm.completion") as mock_completion:
57+
58+
# Mock encoding to return a length < 10
59+
mock_encoder = MagicMock()
60+
mock_encoder.encode.return_value = [1] * 5 # 5 tokens
61+
mock_encoding.return_value = mock_encoder
62+
63+
mock_prompt = MagicMock()
64+
mock_prompt.render.return_value = "sys"
65+
mock_prompt_mgr.return_value.bind.return_value = mock_prompt
66+
67+
# Mock completion response
68+
mock_chunk = MagicMock()
69+
mock_chunk.choices[0].delta.content = "Response"
70+
mock_chunk.choices[0].delta.tool_calls = None
71+
mock_completion.return_value = [mock_chunk]
72+
73+
ansari = Ansari(mock_settings)
74+
75+
# Add a short message
76+
ansari.message_history = [
77+
{"role": "system", "content": "sys"},
78+
{"role": "user", "content": "Short."}
79+
]
80+
81+
# Process
82+
response_gen = ansari.process_message_history()
83+
response = next(response_gen)
84+
85+
assert response == "Response"

0 commit comments

Comments
 (0)