diff --git a/backend/models/chat.py b/backend/models/chat.py index 32c77533c59..740cb2739af 100644 --- a/backend/models/chat.py +++ b/backend/models/chat.py @@ -77,7 +77,10 @@ def _sync_app_and_plugin_ids(cls, data: Any) -> Any: @staticmethod def get_messages_as_string( - messages: List['Message'], use_user_name_if_available: bool = False, use_plugin_name_if_available: bool = False + messages: List['Message'], + use_user_name_if_available: bool = False, + use_plugin_name_if_available: bool = False, + include_file_info: bool = False, ) -> str: sorted_messages = sorted(messages, key=lambda m: m.created_at) @@ -90,16 +93,27 @@ def get_sender_name(message: Message) -> str: # return plugin.name RESTORE ME return message.sender.upper() # TODO: use app id - formatted_messages = [ - f"({message.created_at.strftime('%d %b %Y at %H:%M UTC')}) {get_sender_name(message)}: {message.text}" - for message in sorted_messages - ] + formatted_messages = [] + for message in sorted_messages: + msg_text = ( + f"({message.created_at.strftime('%d %b %Y at %H:%M UTC')}) {get_sender_name(message)}: {message.text}" + ) + + # Add file info if requested and files exist + if include_file_info and message.files_id and len(message.files_id) > 0: + file_info = f" [Files attached: {len(message.files_id)} file(s), IDs: {', '.join(message.files_id)}]" + msg_text += file_info + + formatted_messages.append(msg_text) return '\n'.join(formatted_messages) @staticmethod def get_messages_as_xml( - messages: List['Message'], use_user_name_if_available: bool = False, use_plugin_name_if_available: bool = False + messages: List['Message'], + use_user_name_if_available: bool = False, + use_plugin_name_if_available: bool = False, + include_file_info: bool = False, ) -> str: sorted_messages = sorted(messages, key=lambda m: m.created_at) @@ -112,27 +126,35 @@ def get_sender_name(message: Message) -> str: # return plugin.name RESTORE ME return message.sender.upper() # TODO: use app id - formatted_messages = [ - f""" - - - {message.created_at.strftime('%d %b %Y at %H:%M UTC')} - - - {get_sender_name(message)} - - - {message.text} - - {('' + ''.join(f"{file.name}" for file in message.files) + '') if message.files and len(message.files) > 0 else ''} - - """.replace( - ' ', '' - ) - .replace('\n\n\n', '\n\n') - .strip() - for message in sorted_messages - ] + formatted_messages = [] + for message in sorted_messages: + # Build file section if requested + file_section = "" + if include_file_info and message.files and len(message.files) > 0: + file_section = '\n' + for file in message.files: + file_section += f' \n' + file_section += '' + elif include_file_info and message.files_id and len(message.files_id) > 0: + # Fallback if files not loaded but IDs exist + file_section = '\n' + for file_id in message.files_id: + file_section += f' \n' + file_section += '' + elif message.files and len(message.files) > 0: + # Original behavior when include_file_info is False + file_section = ( + '' + ''.join(f"{file.name}" for file in message.files) + '' + ) + + msg = f""" +{message.created_at.strftime('%d %b %Y at %H:%M UTC')} +{get_sender_name(message)} +{message.text} +{file_section} +""" + + formatted_messages.append(msg.replace(' ', '').strip()) return '\n'.join(formatted_messages) diff --git a/backend/utils/llm/chat.py b/backend/utils/llm/chat.py index 6fcde741d9c..ff9830dac5a 100644 --- a/backend/utils/llm/chat.py +++ b/backend/utils/llm/chat.py @@ -392,7 +392,7 @@ def _get_qa_rag_prompt( ) -def _get_agentic_qa_prompt(uid: str, app: Optional[App] = None) -> str: +def _get_agentic_qa_prompt(uid: str, app: Optional[App] = None, messages: List[Message] = None) -> str: """ Build the system prompt for the agentic agent, preserving the structure and instructions from _get_qa_rag_prompt while adding tool-calling capabilities. @@ -400,6 +400,7 @@ def _get_agentic_qa_prompt(uid: str, app: Optional[App] = None) -> str: Args: uid: User ID app: Optional app/plugin for personalized behavior + messages: Optional message history for file context Returns: System prompt string @@ -437,13 +438,30 @@ def _get_agentic_qa_prompt(uid: str, app: Optional[App] = None) -> str: {plugin_info} +""" + + # Add file context if messages contain files + file_context_section = "" + if messages: + message_history_with_files = Message.get_messages_as_string(messages, include_file_info=True) + + # Check if any files are present + if '[Files attached:' in message_history_with_files: + file_context_section = f""" + +Recent conversation (includes file attachment IDs): +{message_history_with_files} + +When you see [Files attached: X file(s), IDs: ...], you can reference those file IDs in search_files_tool. + + """ base_prompt = f""" You are Omi, a helpful AI assistant for {user_name}. You are designed to provide accurate, detailed, and comprehensive responses in the most personalized way possible. - +{file_context_section} Current date time in {user_name}'s timezone ({tz}): {current_datetime_str} Current date time ISO format: {current_datetime_iso} diff --git a/backend/utils/other/chat_file.py b/backend/utils/other/chat_file.py index caa3598cfd3..535225e08ac 100644 --- a/backend/utils/other/chat_file.py +++ b/backend/utils/other/chat_file.py @@ -163,7 +163,6 @@ def _ensure_thread_and_assistant(self): # Continue anyway - IDs will be recreated next time def _fill_question(self, uid, question, file_ids: List[str], thread_id: str): - # OpenAI has a limit of 10 items in content array (1 text + max 9 images) files = chat_db.get_chat_files_desc(uid, files_id=file_ids, limit=9) diff --git a/backend/utils/retrieval/agentic.py b/backend/utils/retrieval/agentic.py index 2d400fad451..95285b84bbf 100644 --- a/backend/utils/retrieval/agentic.py +++ b/backend/utils/retrieval/agentic.py @@ -32,6 +32,7 @@ update_action_item_tool, get_omi_product_info_tool, perplexity_search_tool, + search_files_tool, ) from utils.retrieval.safety import AgentSafetyGuard, SafetyGuardError from utils.llm.clients import llm_agent, llm_agent_stream @@ -126,6 +127,7 @@ def execute_agentic_chat( update_action_item_tool, get_omi_product_info_tool, perplexity_search_tool, + search_files_tool, ] # Convert messages to LangChain format and prepend system message @@ -185,8 +187,8 @@ async def execute_agentic_chat_stream( Yields: Formatted chunks with "data: " or "think: " prefixes """ - # Build system prompt - system_prompt = _get_agentic_qa_prompt(uid, app) + # Build system prompt with file context + system_prompt = _get_agentic_qa_prompt(uid, app, messages) # Get all tools tools = [ @@ -198,6 +200,7 @@ async def execute_agentic_chat_stream( update_action_item_tool, get_omi_product_info_tool, perplexity_search_tool, + search_files_tool, ] # Convert messages to LangChain format and prepend system message @@ -225,6 +228,7 @@ async def execute_agentic_chat_stream( "thread_id": str(uuid.uuid4()), "conversations_collected": conversations_collected, "safety_guard": safety_guard, + "chat_session_id": chat_session.id if chat_session else None, } } diff --git a/backend/utils/retrieval/graph.py b/backend/utils/retrieval/graph.py index 4007e630b12..aef677f5d4e 100644 --- a/backend/utils/retrieval/graph.py +++ b/backend/utils/retrieval/graph.py @@ -130,25 +130,13 @@ def determine_conversation_type( ) -> Literal[ "no_context_conversation", "agentic_context_dependent_conversation", - # "omi_question", - "file_chat_question", "persona_question", ]: - # chat with files by attachments on the last message print("determine_conversation_type") - messages = state.get("messages", []) - if len(messages) > 0 and len(messages[-1].files_id) > 0: - return "file_chat_question" # persona app: App = state.get("plugin_selected") if app and app.is_a_persona(): - # file - question = state.get("parsed_question", "") - is_file_question = retrieve_is_file_question(question) - if is_file_question: - return "file_chat_question" - return "persona_question" # chat @@ -157,15 +145,6 @@ def determine_conversation_type( if not question or len(question) == 0: return "no_context_conversation" - # determine the follow-up question is chatting with files or not - is_file_question = retrieve_is_file_question(question) - if is_file_question: - return "file_chat_question" - - # is_omi_question = retrieve_is_an_omi_question(question) - # if is_omi_question: - # return "omi_question" - requires = requires_context(question) if requires: return "agentic_context_dependent_conversation" @@ -466,13 +445,13 @@ def file_chat_question(state: GraphState): # workflow.add_node("omi_question", omi_question) # workflow.add_node("context_dependent_conversation", context_dependent_conversation) workflow.add_node("agentic_context_dependent_conversation", agentic_context_dependent_conversation) -workflow.add_node("file_chat_question", file_chat_question) +# workflow.add_node("file_chat_question", file_chat_question) workflow.add_node("persona_question", persona_question) workflow.add_edge("no_context_conversation", END) # workflow.add_edge("omi_question", END) workflow.add_edge("persona_question", END) -workflow.add_edge("file_chat_question", END) +# workflow.add_edge("file_chat_question", END) workflow.add_edge("agentic_context_dependent_conversation", END) # workflow.add_edge("context_dependent_conversation", "retrieve_topics_filters") # workflow.add_edge("context_dependent_conversation", "retrieve_date_filters") diff --git a/backend/utils/retrieval/tools/__init__.py b/backend/utils/retrieval/tools/__init__.py index ebf55f9313f..c2c666356bc 100644 --- a/backend/utils/retrieval/tools/__init__.py +++ b/backend/utils/retrieval/tools/__init__.py @@ -24,6 +24,9 @@ from .perplexity_tools import ( perplexity_search_tool, ) +from .file_tools import ( + search_files_tool, +) __all__ = [ 'get_conversations_tool', @@ -35,4 +38,5 @@ 'update_action_item_tool', 'get_omi_product_info_tool', 'perplexity_search_tool', + 'search_files_tool', ] diff --git a/backend/utils/retrieval/tools/file_tools.py b/backend/utils/retrieval/tools/file_tools.py new file mode 100644 index 00000000000..17df6455578 --- /dev/null +++ b/backend/utils/retrieval/tools/file_tools.py @@ -0,0 +1,81 @@ +""" +File search tools for the agentic chat system. + +These tools allow the LLM to search and query files uploaded to chat sessions. +""" + +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import tool +from typing import List, Optional +import database.chat as chat_db +from models.chat import ChatSession, FileChat +from utils.other.chat_file import FileChatTool + + +@tool +def search_files_tool(question: str, file_ids: Optional[List[str]] = None, config: RunnableConfig = None) -> str: + """ + Search and ask questions about files attached to the current chat session. + Use this when the user asks about documents, images, PDFs, or any files they've uploaded. + + The conversation history shows which files are attached to which messages in the format: + [Files attached: X file(s), IDs: file_id_1, file_id_2, ...] + + You can specify which files to search by providing their IDs, or omit file_ids to search all files. + + Examples: + - User asks "what does the document say?" → Use file_ids from the most recent message with files + - User asks "compare the two PDFs I uploaded" → Use file_ids from messages with PDFs + - User asks "summarize all my files" → Don't specify file_ids (searches all) + + Args: + question: The specific question to ask about the files + file_ids: Optional list of specific file IDs to search. If not provided, searches all session files. + + Returns: + Answer based on the file contents + """ + if config is None: + return "Configuration error: missing config" + + uid = config['configurable']['user_id'] + chat_session_id = config['configurable'].get('chat_session_id') + + if not chat_session_id: + return "No active chat session. Files are not available." + + try: + # Get session data + session_data = chat_db.get_chat_session_by_id(uid, chat_session_id) + + if not session_data: + return "Chat session not found." + + chat_session = ChatSession(**session_data) + + # Determine which files to search + if file_ids and len(file_ids) > 0: + # Use specified files + # Validate that these files belong to the session + session_file_ids = set(chat_session.file_ids or []) + file_ids_to_search = [fid for fid in file_ids if fid in session_file_ids] + + if not file_ids_to_search: + return "The specified files are not available in this chat session." + else: + # Use all session files + file_ids_to_search = chat_session.file_ids if chat_session.file_ids else [] + + if not file_ids_to_search: + return "No files have been uploaded to this chat session yet. Ask the user to upload files first." + + # Use FileChatTool to query files + fc_tool = FileChatTool(uid, chat_session_id) + answer = fc_tool.process_chat_with_file(question, file_ids_to_search) + + return answer + + except ValueError as e: + return f"Session error: {str(e)}" + except Exception as e: + return f"I encountered an error while searching the files. Please try again or rephrase your question."