From 7f7ceef4fa82c3ad8f39f70cafecb196a51edb2e Mon Sep 17 00:00:00 2001 From: Ethan Chen Date: Thu, 20 Nov 2025 13:36:52 -0800 Subject: [PATCH] feat: message_history storage in redis --- controllers/repo_controller.py | 6 +++--- db/session.py | 14 ++++++++++++++ main.py | 2 -- services/repo_service.py | 10 +++++++--- 4 files changed, 24 insertions(+), 8 deletions(-) create mode 100644 db/session.py diff --git a/controllers/repo_controller.py b/controllers/repo_controller.py index edcb580..1151aca 100644 --- a/controllers/repo_controller.py +++ b/controllers/repo_controller.py @@ -7,9 +7,9 @@ router = APIRouter() @router.get("/repo/query", status_code=status.HTTP_200_OK) -async def query_repo(question: str, repo_path: str, history: list[Any]): - response, history = await query(question, repo_path, history) - return JSONResponse(content={"response": response, "history": history}, status_code=status.HTTP_200_OK) +async def query_repo(question: str, repo_path: str, session_id: str): + response = await query(question, repo_path, session_id) + return JSONResponse(content={"response": response}, status_code=status.HTTP_200_OK) @router.post("/repo/optimize", status_code=status.HTTP_200_OK) async def optimize_repo(repo_path: str, language: Optional[str] = None, ref: Optional[str] = None): diff --git a/db/session.py b/db/session.py new file mode 100644 index 0000000..111b401 --- /dev/null +++ b/db/session.py @@ -0,0 +1,14 @@ +import redis +from typing import Any +import json, pickle + +r = redis.Redis(host='localhost', port=6379, decode_responses=False) + +def get_session(session_id: str): + if r.exists(session_id) == 0: + return [] + else: + return pickle.loads(r.get(session_id)) + +def set_session(session_id: str, history: list[Any]): + r.set(session_id, pickle.dumps(history)) diff --git a/main.py b/main.py index 2bdb088..5b5164f 100644 --- a/main.py +++ b/main.py @@ -2,9 +2,7 @@ from controllers.graph_controller import router as graph_router from controllers.repo_controller import router as repo_router from fastapi.middleware.cors import CORSMiddleware -import redis -r = redis.Redis(host='localhost', port=6379, decode_responses=True) app = FastAPI() @app.get("/") diff --git a/services/repo_service.py b/services/repo_service.py index 80ebdcd..1b082ab 100644 --- a/services/repo_service.py +++ b/services/repo_service.py @@ -5,10 +5,11 @@ from rich.panel import Panel from rich.markdown import Markdown from typing import Any +from db.session import get_session, set_session console = Console(width=None, force_terminal=True) -async def query(question: str, repo_path: str, history: list[Any]): +async def query(question: str, repo_path: str, session_id: str): init_session_log(_setup_common_initialization(repo_path)) log_session_event(f"USER: {question}") with MemgraphIngestor( @@ -16,10 +17,13 @@ async def query(question: str, repo_path: str, history: list[Any]): port=settings.MEMGRAPH_PORT, ) as ingestor: console.print("[bold green]Successfully connected to Memgraph.[/bold green]") + history = get_session(session_id) rag_agent = _initialize_services_and_agent(repo_path, ingestor) - response = await rag_agent.run(question + get_session_context(), history) + question_with_context = question + get_session_context() + response = await rag_agent.run(question_with_context, message_history=history) history.extend(response.new_messages()) - return response.output, history + set_session(session_id, history) + return response.output async def optimize(repo_path: str, language: str, ref: str):