Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion services/chatbot/src/chatbot/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_or_create_session_id,
store_api_key,
store_model_name,
get_user_jwt
)

chat_bp = Blueprint("chat", __name__, url_prefix="/genai")
Expand Down Expand Up @@ -53,14 +54,15 @@ async def chat():
session_id = await get_or_create_session_id()
openai_api_key = await get_api_key(session_id)
model_name = await get_model_name(session_id)
user_jwt = await get_user_jwt()
if not openai_api_key:
return jsonify({"message": "Missing OpenAI API key. Please authenticate."}), 400
data = await request.get_json()
message = data.get("message", "").strip()
id = data.get("id", uuid4().int & (1 << 63) - 1)
if not message:
return jsonify({"message": "Message is required", "id": id}), 400
reply, response_id = await process_user_message(session_id, message, openai_api_key, model_name)
reply, response_id = await process_user_message(session_id, message, openai_api_key, model_name, user_jwt)
return jsonify({"id": response_id, "message": reply}), 200


Expand Down
4 changes: 2 additions & 2 deletions services/chatbot/src/chatbot/chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ async def delete_chat_history(session_id):
await db.chat_sessions.delete_one({"session_id": session_id})


async def process_user_message(session_id, user_message, api_key, model_name):
async def process_user_message(session_id, user_message, api_key, model_name, user_jwt):
history = await get_chat_history(session_id)
# generate a unique numeric id for the message that is random but unique
source_message_id = uuid4().int & (1 << 63) - 1
history.append({"id": source_message_id, "role": "user", "content": user_message})
# Run LangGraph agent
response = await execute_langgraph_agent(api_key, model_name, history, session_id)
response = await execute_langgraph_agent(api_key, model_name, history, user_jwt, session_id)
print("Response", response)
reply: Messages = response.get("messages", [{}])[-1]
print("Reply", reply.content)
Expand Down
9 changes: 5 additions & 4 deletions services/chatbot/src/chatbot/langgraph_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from langgraph.prebuilt import create_react_agent

from .extensions import postgresdb
from .mcp_client import mcp_client
from .mcp_client import get_mcp_client


async def get_retriever_tool(api_key):
Expand Down Expand Up @@ -46,7 +46,7 @@ async def get_retriever_tool(api_key):
return retriever_tool


async def build_langgraph_agent(api_key, model_name):
async def build_langgraph_agent(api_key, model_name, user_jwt):
system_prompt = textwrap.dedent(
"""
You are crAPI Assistant — an expert agent that helps users explore and test the Completely Ridiculous API (crAPI), a vulnerable-by-design application for learning and evaluating modern API security issues.
Expand Down Expand Up @@ -86,6 +86,7 @@ async def build_langgraph_agent(api_key, model_name):
)
llm = ChatOpenAI(api_key=api_key, model=model_name)
toolkit = SQLDatabaseToolkit(db=postgresdb, llm=llm)
mcp_client = get_mcp_client(user_jwt)
mcp_tools = await mcp_client.get_tools()
db_tools = toolkit.get_tools()
tools = mcp_tools + db_tools
Expand All @@ -95,8 +96,8 @@ async def build_langgraph_agent(api_key, model_name):
return agent_node


async def execute_langgraph_agent(api_key, model_name, messages, session_id=None):
agent = await build_langgraph_agent(api_key, model_name)
async def execute_langgraph_agent(api_key, model_name, messages, user_jwt, session_id=None):
agent = await build_langgraph_agent(api_key, model_name, user_jwt)
print("messages", messages)
print("Session ID", session_id)
response = await agent.ainvoke({"messages": messages})
Expand Down
26 changes: 14 additions & 12 deletions services/chatbot/src/chatbot/mcp_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import asyncio
import os

from langchain_mcp_adapters.client import MultiServerMCPClient

mcp_client = MultiServerMCPClient(
{
"crapi": {
"transport": "streamable_http",
"url": "http://localhost:5500/mcp/",
"headers": {},
},
}
)
def get_mcp_client(user_jwt: str | None) -> MultiServerMCPClient:
headers = {}
if user_jwt:
headers["Authorization"] = f"Bearer {user_jwt}"

return MultiServerMCPClient(
{
"crapi": {
"transport": "streamable_http",
"url": "http://localhost:5500/mcp/",
"headers": headers,
}
}
)
6 changes: 6 additions & 0 deletions services/chatbot/src/chatbot/session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@ async def get_model_name(session_id):
if "model_name" not in doc:
return Config.DEFAULT_MODEL_NAME
return doc["model_name"]

async def get_user_jwt() -> str | None:
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
return auth.replace("Bearer ", "")
return None
Loading