From 50e0b044c03951285926230d289f48a5f0471e3f Mon Sep 17 00:00:00 2001 From: Ryan Kraus Date: Tue, 23 Jul 2024 11:17:58 -0400 Subject: [PATCH] Added jupyter notebook for agentic retrieval. --- ...agentic_rag_with_nemo_retriever_nims.ipynb | 879 ++++++++++++++++++ 1 file changed, 879 insertions(+) create mode 100644 notebooks/agentic_rag_with_nemo_retriever_nims.ipynb diff --git a/notebooks/agentic_rag_with_nemo_retriever_nims.ipynb b/notebooks/agentic_rag_with_nemo_retriever_nims.ipynb new file mode 100644 index 000000000..7bd753aad --- /dev/null +++ b/notebooks/agentic_rag_with_nemo_retriever_nims.ipynb @@ -0,0 +1,879 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "919fe33c-0149-4f7d-b200-544a18986c9a", + "metadata": {}, + "source": [ + "# Agentic RAG pipeline with Nemo Retriever and LLM NIMs \n", + "\n", + "## Overview\n", + "\n", + "Retrieval-augmented generation (RAG) has proven to be an effective strategy for ensuring large language model (LLM) responses are up-to-date and not hallucinated. \n", + "\n", + "Various retrieval strategies have been proposed that can improve the recall of documents for generation. There is no one-size-fits-all all. The strategy (for example: chunk size, number of documents returned, semantic search vs graph retrieval, etc.) depends on your data. Although the retrieval strategies might differ, an agentic framework designed on top of your retrieval system that does reasoning, decision-making, and reflection on your retrieved data is becoming more common in modern RAG systems. An agent can be described as a system that can use an LLM to reason through a problem, create a plan to solve the problem, and execute the plan with the help of a set of tools. For example, LLMs are notoriously bad at solving math problems, giving an LLM a calculator “tool” that it can use to perform mathematical tasks while it reasons through a larger problem of calculating YoY increase of a company’s revenue can be described as an agentic workflow. \n", + "\n", + "As generative AI systems start transitioning towards entities capable of performing \"agentic\" tasks, we need robust models that have been trained on the ability to break down tasks, act as central planners, and have multi-step reasoning capabilities with model and system-level safety checks. With the Llama 3.1 family, Meta is launching a suite of LLMs spanning 8B, 70B, and 405B parameters with these tool-calling capabilities for agentic workloads. NVIDIA has partnered with Meta to make sure the latest Llama models can be deployed optimally through NVIDIA NIMs.\n", + "\n", + "Further, with the general availability of the NVIDIA NeMo Retriever collection of NIM microservices, enterprises have access to scalable software to customize their data-dependent RAG pipelines. The NeMo Retriever NIMs can be easily plugged into existing RAG pipelines and interfaces with open source LLM frameworks like LangChain or LlamaIndex, so you can easily integrate retriever models into generative AI applications.\n" + ] + }, + { + "cell_type": "markdown", + "id": "72f3ee57-68ab-4040-bd36-4014e2a23d96", + "metadata": {}, + "source": [ + "### Setup the Environment \n", + "\n", + "First, let's install a few packages for interfacing with NVIDIA embedding, raranking, LLM models and vector databases.\n", + "\n", + "Install the following system dependencies if they are not already available on your system with e.g. ```brew install``` for Mac. Depending on what document types you're parsing, you may not need all of these.\n", + "* poppler-utils (images and PDFs)\n", + "* tesseract-ocr(images and PDFs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a384cc48-0425-4e8f-aafc-cfb8e56025c9", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U langchain_community unstructured[all-docs] langchain-nvidia-ai-endpoints langchainhub faiss-gpu langchain langgraph pandas rank_bm25" + ] + }, + { + "cell_type": "markdown", + "id": "385d9099-2737-4e51-88ad-87f701fd89d4", + "metadata": {}, + "source": [ + "### NeMo Retriever NIMs\n", + "\n", + "NeMo Retriever microservices can be used for embedding and reranking. These microservices can be deployed within the enterprise locally, and are packaged together with NVIDIA Triton Inference Server and NVIDIA TensorRT for optimized inference of text for embedding and reranking. Additional enterprise benefits include:\n", + "\n", + "**Scalable deployment**: Whether you're catering to a few users or millions, NeMo Retriever embedding and reranking NIMs can be scaled seamlessly to meet your demands.\n", + "\n", + "**Flexible integration**: Easily incorporate NeMo Retriever embedding and reranking NIMs into existing workflows and applications, thanks to the OpenAI-compliant API endpoints–and deploy anywhere your data resides.\n", + "\n", + "**Secure processing**: Your data privacy is paramount. NeMo Retriever embedding and reranking NIMs ensure that all inferences are processed securely, with rigorous data.\n", + "\n", + "NeMo Retriever embedding and reranking NIM microservices are available today. Developers can download and deploy docker containers locally.\n", + "\n", + "#### Access the Llama 3.1 405B model\n", + "\n", + "The new Llama 3.1 set of models can be seen as the first big push of open-source models towards serious agentic capabilities. These models can now become part of a larger automation system, with LLMs doing the planning and picking the right tools to solve a larger problem. Since NVIDIA Llama 3.1 NIMs have the necessary support for OpenAI style tool calling, libraries like LangChain can now be used with NIMs to bind LLMs to Pydantic classes and fill in objects/dictionaries. This combination makes it easier for developers to get structured outputs from NIM LLMs without having to resort to regex parsing. You can access Llama 3.1 405B at ai.nvidia.com. Follow these instructions to generate the API key\n" + ] + }, + { + "cell_type": "markdown", + "id": "5a393ca3-91be-4791-8494-5bf44a60e8d7", + "metadata": {}, + "source": [ + "### Architecture\n", + "\n", + "Retrieving passages or documents within a RAG pipeline without further validation and self-reflection can usually result in unhelpful responses and factual inaccuracies. Additionally, since the models aren't explicitly trained to follow facts from passages, post-generation verification is necessary. \n", + "\n", + "Multi-agent frameworks, like LangGraph, enable developers to group LLM application-level logic into nodes and edges, for finer levels of control over agentic decision-making. LangGraph with NVIDIA LangChain OSS connectors can be used for embedding, reranking, and implementing the necessary agentic RAG techniques with LLMs (as discussed previously). \n", + "\n", + "To implement this, an application developer must include the finer-level decision-making on top of their RAG pipeline. Figure below shows one of the many renditions on a router node depending on the use case. Here, the router takes a decision to rewrite the query with help on an LLM, perchance of better recall from the retrieve.\n", + "\n", + "![alt text](agentic_rag.png \"Title\")\n", + "\n", + "**Query decomposer**: Breaks down the question into multiple smaller logical questions, and is helpful when a question needs to be answered using chunks from multiple documents.\n", + "\n", + "**Router**: Decides if chunks need to be retrieved from the local retriever to answer the given question based on the relevancy of documents stored locally. Alternatively, ‌the agent can be programmed to do a web search or simply answer with an ‘I don't know.’\n", + "\n", + "**Retriever**: This is the internal implementation of the RAG pipeline. For example, a hybrid retriever of a semantic and keyword search retriever.\n", + "\n", + "**Grader**: Checks if the retrieved passages/chunks are relevant to the question at hand.\n", + "\n", + "**Hallucination checker**: Checks if the LLM generation from each chunk is relevant to the chunk. Post-generation verification is necessary since the models are not explicitly trained to follow facts from passages.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "d443d9df-b730-4d58-9d78-a5f92b57ece8", + "metadata": {}, + "source": [ + "### Download the dataset\n", + "Let's download the NIH clinical studies datasets from docugami repository. It cont" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8057e78a-e8d0-4489-85e5-3afc11e0ef6a", + "metadata": {}, + "outputs": [], + "source": [ + "!wget https://raw.githubusercontent.com/docugami/KG-RAG-datasets/main/nih-clinical-trial-protocols/download.csv\n", + "!wget https://raw.githubusercontent.com/docugami/KG-RAG-datasets/main/nih-clinical-trial-protocols/download.py\n", + "!python download.py" + ] + }, + { + "cell_type": "markdown", + "id": "32d23af0-49a7-424f-b0c5-29884392c3ea", + "metadata": {}, + "source": [ + "#### Step-1: Load and chunk the dataset\n", + "\n", + "Use Langchain dataloaders to load all the PDF files in the created directory and split them into chunks of 500 characters each" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "565a6d44-2c9f-4fff-b1ec-eea05df9350d", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "from langchain_community.document_loaders import DirectoryLoader\n", + "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n", + "\n", + "loader = DirectoryLoader('./docs', glob=\"**/*.pdf\")\n", + "docs = loader.load()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12d4a37a-35b8-471c-b6b4-dc82f66c9de3", + "metadata": {}, + "outputs": [], + "source": [ + "docs[0].page_content" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba8410f2-e716-4190-86ba-521b2c4e5103", + "metadata": {}, + "outputs": [], + "source": [ + "text_splitter = RecursiveCharacterTextSplitter(\n", + " chunk_size=500, chunk_overlap=100\n", + ")\n", + "doc_splits = text_splitter.split_documents(docs)" + ] + }, + { + "cell_type": "markdown", + "id": "5c6e0429-ef93-4b18-a030-6a5e64295b0e", + "metadata": {}, + "source": [ + "### Step-2: Initialize the Embedding, Reranking and LLM connectors\n", + "\n", + "#### Embedding and Reranking NIM\n", + "Use the NVIDIA OSS connectors to langchain to initialize the embedding, reranking and LLM models, after setting up the embedding and reranking NIMs locally using instructions here and here. point the ```base_url``` below to the ip address for your local machine. \n", + "\n", + "#### Llama 3.1 405B LLM\n", + "The latest Llama 3.1 405B model is hosted on ai.nvidia.com. Use the instruction here to obtain the API Key for access " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1d02d99-d519-47c8-a2ae-a599a3bcb915", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, NVIDIARerank\n", + "from langchain_openai import ChatOpenAI\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "# connect to an embedding NIM running at localhost:8080\n", + "embeddings = NVIDIAEmbeddings(\n", + " base_url=\"http://:8000/v1\", \n", + " model=\"nvidia/nv-embedqa-e5-v5\",\n", + " truncate=\"END\"\n", + ")\n", + "\n", + "reranker = NVIDIARerank(\n", + " base_url=\"http://:8000/v1\", \n", + " model=\"nvidia/nv-rerankqa-mistral-4b-v3\",\n", + " truncate=\"END\"\n", + ")\n", + "\n", + "llm = ChatOpenAI(\n", + " base_url=\"https://integrate.api.nvidia.com/v1\",\n", + " api_key=\"\",\n", + " model=\"meta/llama-3.1-405b-instruct\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "601db4a1-64d3-4564-9f54-bd0faf40731f", + "metadata": {}, + "source": [ + "#### Step-3: Create a hybrid search retriever\n", + "\n", + "Load the documents into a keyword search store and semantic search FAISS vector database. We create a weighted hybrid of a keyword and semantic search for better retrieval recall, and a higher score is given to the keyword search retriever because of domain specific medical jargon. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4778a3b3-08ab-4ce0-ad52-3a830c0aea00", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import EnsembleRetriever\n", + "from langchain_community.retrievers import BM25Retriever\n", + "from langchain_community.vectorstores import FAISS\n", + "\n", + "bm25_retriever = BM25Retriever.from_documents(doc_splits)\n", + "faiss_vectorstore = FAISS.from_documents(doc_splits, embeddings)\n", + "\n", + "faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={\"k\": 2})\n", + "\n", + "hybrid_retriever = EnsembleRetriever(\n", + " retrievers=[bm25_retriever, faiss_retriever], weights=[0.7, 0.3]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df8f0d28-7871-4467-bca3-24dcca226ab6", + "metadata": {}, + "outputs": [], + "source": [ + "question = \"How does Get It Right First Time (GIRFT) Urology programme relate to TURBT and URS?\"" + ] + }, + { + "cell_type": "markdown", + "id": "40303a31-1c12-41f2-b4c5-3a35731078d5", + "metadata": {}, + "source": [ + "#### Step-4: Query decompostion with structured generation\n", + "\n", + "The new Llama 3.1 set of models can be seen as the first big push of open-source models towards serious agentic capabilities. These models can now become part of a larger automation system, with LLMs doing the planning and picking the right tools to solve a larger problem. Since NVIDIA Llama 3.1 NIMs have the necessary support for OpenAI style tool calling, libraries like LangChain can now be used with NIMs to bind LLMs to Pydantic classes and fill in objects/dictionaries. This combination makes it easier for developers to get structured outputs from NIM LLMs without having to resort to regex parsing. \n", + "\n", + "Here we user Llama 3.1 NIMs tool calling capability to split the initial query intp sub-queries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4aa25ae7-5624-40d9-95d8-a11a2a59a637", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Literal, Optional, Tuple, List\n", + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "class SubQuery(BaseModel):\n", + " \"\"\"Given a user question, break it down into distinct sub questions that \\\n", + " you need to answer in order to answer the original question.\"\"\"\n", + "\n", + " questions: List[str] = Field(description=\"The list of sub questions\")\n", + "\n", + "sub_question_generator = llm.with_structured_output(SubQuery)\n", + "sub_question_generator.invoke(question)" + ] + }, + { + "cell_type": "markdown", + "id": "67594309-8649-4f3f-bcf8-b1bc8ab33699", + "metadata": {}, + "source": [ + "#### Step-5: Create a simple RAG chain with hybrid retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2030b107-618e-4dda-8e08-0d79e461aef8", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain import hub\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "\n", + "# Prompt\n", + "prompt = hub.pull(\"rlm/rag-prompt\")\n", + "\n", + "# Post-processing\n", + "def format_docs(docs):\n", + " return \"\\n\\n\".join(doc.page_content for doc in docs)\n", + "\n", + "# Chain\n", + "rag_chain = prompt | llm | StrOutputParser()\n", + "\n", + "# Run\n", + "docs = hybrid_retriever.get_relevant_documents(question)\n", + "generation = rag_chain.invoke({\"context\": format_docs(docs), \"question\": question})\n", + "print(generation)" + ] + }, + { + "cell_type": "markdown", + "id": "139ad15c-50a9-4edf-82dd-642536317413", + "metadata": {}, + "source": [ + "#### Step-6: Create a Retrieval grader with structured generation\n", + "\n", + "Checks if the retrieved passages/chunks are relevant to the question at hand." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fafad21-60cc-483e-92a3-6a7edb1838e3", + "metadata": {}, + "outputs": [], + "source": [ + "### Retrieval Grader\n", + "\n", + "# Data model\n", + "class GradeDocuments(BaseModel):\n", + " \"\"\"Binary score for relevance check on retrieved documents.\"\"\"\n", + "\n", + " binary_score: str = Field(\n", + " description=\"Documents are relevant to the question, 'yes' or 'no'\"\n", + " )\n", + "\n", + "\n", + "# LLM with function call\n", + "\n", + "retrieval_grader = llm.with_structured_output(GradeDocuments)\n", + "\n", + "# Prompt\n", + "system = \"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n \n", + " It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \\n\n", + " If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n", + " Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\"\n", + "\n", + "grade_prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " \n", + " (\"system\", system),\n", + " (\"human\", \"Retrieved document: \\n\\n {document} \\n\\n User question: {question}\"),\n", + " ]\n", + ")\n", + "\n", + "retrieval_grader = grade_prompt | retrieval_grader\n", + "docs = hybrid_retriever.get_relevant_documents(question)\n", + "doc_txt = docs[1].page_content\n", + "print(retrieval_grader.invoke({\"question\": question, \"document\": doc_txt}))" + ] + }, + { + "cell_type": "markdown", + "id": "3539a5e0-df6c-4d87-bb68-260ecd34db24", + "metadata": {}, + "source": [ + "#### Step-7: Create a hallucination checker with structured generation\n", + "Checks if the LLM generation from each chunk is relevant to the chunk. Post-generation verification is necessary since the models are not explicitly trained to follow facts from passages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e78931ec-940c-46ad-a0b2-f43f953f1fd7", + "metadata": {}, + "outputs": [], + "source": [ + "### Hallucination Grader\n", + "\n", + "# Data model\n", + "class GradeHallucinations(BaseModel):\n", + " \"\"\"Binary score for hallucination present in generation answer.\"\"\"\n", + "\n", + " binary_score: str = Field(\n", + " description=\"Answer is grounded in the facts, 'yes' or 'no'\"\n", + " )\n", + "\n", + "\n", + "hallucination_grader = llm.with_structured_output(GradeHallucinations)\n", + "\n", + "# Prompt\n", + "system = \"\"\"You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \\n \n", + " Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.\"\"\"\n", + "hallucination_prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\"system\", system),\n", + " (\"human\", \"Set of facts: \\n\\n {documents} \\n\\n LLM generation: {generation}\"),\n", + " ]\n", + ")\n", + "\n", + "hallucination_grader = hallucination_prompt | hallucination_grader\n", + "hallucination_grader.invoke({\"documents\": format_docs(docs), \"generation\": generation})" + ] + }, + { + "cell_type": "markdown", + "id": "2931a1e2-a4d8-4708-90d1-34c48fc11dd2", + "metadata": {}, + "source": [ + "#### Step-7: Create a answer grader with structured generation\n", + "Checks if the final answer resolves the supplied question " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd62276f-bf26-40d0-8cff-e07b10e00321", + "metadata": {}, + "outputs": [], + "source": [ + "### Answer Grader\n", + "\n", + "# Data model\n", + "class GradeAnswer(BaseModel):\n", + " \"\"\"Binary score to assess answer addresses question.\"\"\"\n", + "\n", + " binary_score: str = Field(\n", + " description=\"Answer addresses the question, 'yes' or 'no'\"\n", + " )\n", + "\n", + "\n", + "generation_grader = llm.with_structured_output(GradeAnswer)\n", + "\n", + "# Prompt\n", + "system = \"\"\"You are a grader assessing whether an answer addresses / resolves a question \\n \n", + " Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question.\"\"\"\n", + "answer_prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\"system\", system),\n", + " (\"human\", \"User question: \\n\\n {question} \\n\\n LLM generation: {generation}\"),\n", + " ]\n", + ")\n", + "\n", + "answer_grader = answer_prompt | generation_grader\n", + "answer_grader.invoke({\"question\": question, \"generation\": generation})" + ] + }, + { + "cell_type": "markdown", + "id": "f5c2b6b7-4077-4257-8c4d-f88ea6b25d7f", + "metadata": {}, + "source": [ + "#### Step-8: Question rewriting\n", + "If none of retrieved documents are unrelated to the given question, then we ask the LLM to rewrite the question again for easier retrieval. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6f4c70e-1660-4149-82c0-837f19fc9fb5", + "metadata": {}, + "outputs": [], + "source": [ + "### Question Re-writer\n", + "\n", + "# Prompt\n", + "system = \"\"\"You a question re-writer that converts an input question to a better version that is optimized \\n \n", + " for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.\"\"\"\n", + "re_write_prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\"system\", system),\n", + " (\n", + " \"human\",\n", + " \"Here is the initial question: \\n\\n {question} \\n Formulate an improved question.\",\n", + " ),\n", + " ]\n", + ")\n", + "\n", + "question_rewriter = re_write_prompt | llm | StrOutputParser()\n", + "question_rewriter.invoke({\"question\": question})" + ] + }, + { + "cell_type": "markdown", + "id": "276001c5-c079-4e5b-9f42-81a06704d200", + "metadata": {}, + "source": [ + "#### Step-9: Langgraph setup \n", + "\n", + "Capture the flow in as a graph. Define the graph state, which is a data structure that is shared among the nodes of the graph, each node modifies the graph state depending on its function. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1617e9e-66a8-4c1a-a1fe-cc936284c085", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List\n", + "\n", + "from typing_extensions import TypedDict\n", + "\n", + "\n", + "class GraphState(TypedDict):\n", + " \"\"\"\n", + " Represents the state of our graph.\n", + "\n", + " Attributes:\n", + " question: question\n", + " generation: LLM generation\n", + " documents: list of documents\n", + " \"\"\"\n", + "\n", + " question: str\n", + " sub_questions: List[str]\n", + " generation: str\n", + " documents: List[str]" + ] + }, + { + "cell_type": "markdown", + "id": "65428bef-1159-4725-8c1a-33c5ea1dd2bf", + "metadata": {}, + "source": [ + "#### Step-10: Define the nodes as functions\n", + "Using the langchain constructs we have defined above for query decompostion, grading, retrieval, hallucination checking etc, we can write functions that act as nodes for the multi-agent graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "add509d8-6682-4127-8d95-13dd37d79702", + "metadata": {}, + "outputs": [], + "source": [ + "### Nodes\n", + "\n", + "def decompose(state):\n", + " \"\"\"\n", + " Retrieve documents\n", + "\n", + " Args:\n", + " state (dict): The current graph state\n", + "\n", + " Returns:\n", + " state (dict): New key added to state, documents, that contains retrieved documents\n", + " \"\"\"\n", + " print(\"---QUERY DECOMPOSITION ---\")\n", + " question = state[\"question\"]\n", + "\n", + " # Reranking\n", + " sub_queries = sub_question_generator.invoke(question)\n", + " return {\"sub_questions\": sub_queries.questions, \"question\": question}\n", + "\n", + "def retrieve(state):\n", + " \"\"\"\n", + " Retrieve documents\n", + "\n", + " Args:\n", + " state (dict): The current graph state\n", + "\n", + " Returns:\n", + " state (dict): New key added to state, documents, that contains retrieved documents\n", + " \"\"\"\n", + " print(\"---RETRIEVE---\")\n", + " sub_questions = state[\"sub_questions\"]\n", + "\n", + " # Retrieval\n", + " documents = []\n", + " for question in sub_questions:\n", + " docs = hybrid_retriever.get_relevant_documents(question)\n", + " documents.extend(docs)\n", + " return {\"documents\": documents, \"question\": question}\n", + "\n", + "\n", + "def rerank(state):\n", + " \"\"\"\n", + " Retrieve documents\n", + "\n", + " Args:\n", + " state (dict): The current graph state\n", + "\n", + " Returns:\n", + " state (dict): New key added to state, documents, that contains retrieved documents\n", + " \"\"\"\n", + " print(\"---RERANK---\")\n", + " question = state[\"question\"]\n", + " documents = state[\"documents\"]\n", + "\n", + " # Reranking\n", + " documents = reranker.compress_documents(query=question, documents=documents)\n", + " return {\"documents\": documents, \"question\": question}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c2e6458-ce81-4a18-8e4b-1d395b0971fc", + "metadata": {}, + "outputs": [], + "source": [ + "def generate(state):\n", + " \"\"\"\n", + " Generate answer\n", + "\n", + " Args:\n", + " state (dict): The current graph state\n", + "\n", + " Returns:\n", + " state (dict): New key added to state, generation, that contains LLM generation\n", + " \"\"\"\n", + " print(\"---GENERATE---\")\n", + " question = state[\"question\"]\n", + " documents = state[\"documents\"]\n", + "\n", + " # RAG generation\n", + " generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n", + " return {\"documents\": documents, \"question\": question, \"generation\": generation}\n", + "\n", + "\n", + "def grade_documents(state):\n", + " \"\"\"\n", + " Determines whether the retrieved documents are relevant to the question.\n", + "\n", + " Args:\n", + " state (dict): The current graph state\n", + "\n", + " Returns:\n", + " state (dict): Updates documents key with only filtered relevant documents\n", + " \"\"\"\n", + "\n", + " print(\"---CHECK DOCUMENT RELEVANCE TO QUESTION---\")\n", + " question = state[\"question\"]\n", + " documents = state[\"documents\"]\n", + "\n", + " # Score each doc\n", + " filtered_docs = []\n", + " for d in documents:\n", + " score = retrieval_grader.invoke(\n", + " {\"question\": question, \"document\": d.page_content}\n", + " )\n", + " grade = score.binary_score\n", + " if grade == \"yes\":\n", + " print(\"---GRADE: DOCUMENT RELEVANT---\")\n", + " filtered_docs.append(d)\n", + " else:\n", + " print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n", + " continue\n", + " return {\"documents\": filtered_docs, \"question\": question}\n", + "\n", + "\n", + "def transform_query(state):\n", + " \"\"\"\n", + " Transform the query to produce a better question.\n", + "\n", + " Args:\n", + " state (dict): The current graph state\n", + "\n", + " Returns:\n", + " state (dict): Updates question key with a re-phrased question\n", + " \"\"\"\n", + "\n", + " print(\"---TRANSFORM QUERY---\")\n", + " question = state[\"question\"]\n", + " documents = state[\"documents\"]\n", + "\n", + " # Re-write question\n", + " better_question = question_rewriter.invoke({\"question\": question})\n", + " return {\"documents\": documents, \"question\": better_question}" + ] + }, + { + "cell_type": "markdown", + "id": "90f319e2-0f42-47bd-bba0-b732c05c5506", + "metadata": {}, + "source": [ + "#### Step-11: Define graph edges \n", + " The nodes defined above are connected to each other through functional edges, defined programatically. Based on the graph state the edges might pass the state information to one of the multiple different nodes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9de8ad05-5461-428f-bb06-48bb7328c02d", + "metadata": {}, + "outputs": [], + "source": [ + "### Edges\n", + "\n", + "\n", + "def decide_to_generate(state):\n", + " \"\"\"\n", + " Determines whether to generate an answer, or re-generate a question.\n", + "\n", + " Args:\n", + " state (dict): The current graph state\n", + "\n", + " Returns:\n", + " str: Binary decision for next node to call\n", + " \"\"\"\n", + "\n", + " print(\"---ASSESS GRADED DOCUMENTS---\")\n", + " state[\"question\"]\n", + " filtered_documents = state[\"documents\"]\n", + "\n", + " if not filtered_documents:\n", + " # All documents have been filtered check_relevance\n", + " # We will re-generate a new query\n", + " print(\n", + " \"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---\"\n", + " )\n", + " return \"transform_query\"\n", + " # We have relevant documents, so generate answer\n", + " print(\"---DECISION: GENERATE---\")\n", + " return \"generate\"\n", + " \n", + "def grade_generation_v_documents_and_question(state):\n", + " \"\"\"\n", + " Determines whether the generation is grounded in the document and answers question.\n", + "\n", + " Args:\n", + " state (dict): The current graph state\n", + "\n", + " Returns:\n", + " str: Decision for next node to call\n", + " \"\"\"\n", + "\n", + " print(\"---CHECK HALLUCINATIONS---\")\n", + " question = state[\"question\"]\n", + " documents = state[\"documents\"]\n", + " generation = state[\"generation\"]\n", + "\n", + " score = hallucination_grader.invoke(\n", + " {\"documents\": documents, \"generation\": generation}\n", + " )\n", + " grade = score.binary_score\n", + "\n", + " # Check hallucination\n", + " if grade == \"yes\":\n", + " print(\"---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---\")\n", + " # Check question-answering\n", + " print(\"---GRADE GENERATION vs QUESTION---\")\n", + " score = answer_grader.invoke({\"question\": question, \"generation\": generation})\n", + " grade = score.binary_score\n", + " if grade == \"yes\":\n", + " print(\"---DECISION: GENERATION ADDRESSES QUESTION---\")\n", + " return \"useful\"\n", + " print(\"---DECISION: GENERATION DOES NOT ADDRESS QUESTION---\")\n", + " return \"not useful\"\n", + " pprint(\"---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---\")\n", + " return \"not supported\"" + ] + }, + { + "cell_type": "markdown", + "id": "61cd5797-1782-4d78-a277-8196d13f3e1b", + "metadata": {}, + "source": [ + "#### Step-12: Build the graph\n", + "\n", + "We define the rules for how the nodes are connected to each other, we also use conditional edges, which can connect to different nodes based on the output of the functional edge" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e09ca9f-e36d-4ef4-a0d5-79fdbada9fe0", + "metadata": {}, + "outputs": [], + "source": [ + "from langgraph.graph import END, StateGraph, START\n", + "\n", + "workflow = StateGraph(GraphState)\n", + "\n", + "# Define the nodes\n", + "workflow.add_node(\"decompose\", decompose) #query decompostion\n", + "workflow.add_node(\"retrieve\", retrieve) # retrieve\n", + "workflow.add_node(\"rerank\", rerank) # rerank\n", + "workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n", + "workflow.add_node(\"generate\", generate) # generatae\n", + "workflow.add_node(\"transform_query\", transform_query) # transform_query\n", + "\n", + "# Build graph\n", + "workflow.add_edge(START, \"decompose\")\n", + "workflow.add_edge(\"decompose\", \"retrieve\")\n", + "workflow.add_edge(\"retrieve\", \"rerank\")\n", + "workflow.add_edge(\"rerank\", \"grade_documents\")\n", + "workflow.add_conditional_edges(\n", + " \"grade_documents\",\n", + " decide_to_generate,\n", + " {\n", + " \"transform_query\": \"transform_query\",\n", + " \"generate\": \"generate\",\n", + " },\n", + ")\n", + "workflow.add_edge(\"transform_query\", \"retrieve\")\n", + "workflow.add_conditional_edges(\n", + " \"generate\",\n", + " grade_generation_v_documents_and_question,\n", + " {\n", + " \"not supported\": \"generate\",\n", + " \"useful\": END,\n", + " \"not useful\": \"transform_query\",\n", + " },\n", + ")\n", + "\n", + "# Compile\n", + "app = workflow.compile()" + ] + }, + { + "cell_type": "markdown", + "id": "cfa86db1-f484-4d61-8033-fa03e0c665d0", + "metadata": {}, + "source": [ + "#### Step-13: Run the multi-agent RAG workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7aea1890-37f8-46d6-8d10-34d744f820fe", + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "# Run\n", + "inputs = {\"question\": question}\n", + "for output in app.stream(inputs):\n", + " for key, value in output.items():\n", + " # Node\n", + " pprint(f\"Node '{key}':\")\n", + " # Optional: print full state at each node\n", + " # pprint.pprint(value[\"keys\"], indent=2, width=80, depth=None)\n", + " pprint(\"\\n---\\n\")\n", + "\n", + "# Final generation\n", + "pprint(value[\"generation\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "394255a0-4c16-429a-941f-e5ca58ae969a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}