diff --git a/experimental/README.md b/experimental/README.md index e7e5279f4..3d271ed06 100644 --- a/experimental/README.md +++ b/experimental/README.md @@ -43,6 +43,10 @@ Experimental examples are sample code and deployments for RAG pipelines that are This example is able to ingest PDFs, PowerPoint slides, Word and other documents with complex data formats including text, images, slides and tables. It allows users to ask questions through a text interface and optionally with an image query, and it can respond with text and reference images, slides and tables in its response, along with source links and downloads. +* [NVIDIA Knowledge Graph RAG](./knowledge_graph_rag) + + This example implements a GPU-accelerated pipeline for creating and querying knowledge graphs using Retrieval-Augmented Generation (RAG). The approach leverages NVIDIA's AI technologies and RAPIDS ecosystem to process large-scale datasets efficiently. It allows users to interact through a chat interface and also visualize the corresponding knowledge graph, and perform evaluations against synthetic data generated with NVIDIA's Nemotron-4 340B model. + * [Run RAG-LLM in Azure Machine Learning](./AzureML) This example shows the configuration changes to using Docker containers and local GPUs that are required diff --git a/experimental/knowledge_graph_rag/README.md b/experimental/knowledge_graph_rag/README.md new file mode 100644 index 000000000..2ec888009 --- /dev/null +++ b/experimental/knowledge_graph_rag/README.md @@ -0,0 +1,170 @@ +# Knowledge Graphs for RAG with NVIDIA AI Foundation Models and Endpoints + +This repository implements a GPU-accelerated pipeline for creating and querying knowledge graphs using Retrieval-Augmented Generation (RAG). Our approach leverages NVIDIA's AI technologies and RAPIDS ecosystem to process large-scale datasets efficiently. + +## Overview + +This project demonstrates: +- Creation of knowledge graphs from various document sources +- Provides a simple script to download research papers from Arxiv for a given topic +- GPU-accelerated graph processing and analysis using NVIDIA's RAPIDS Graph Analytics library (cuGraph): https://github.com/rapidsai/cugraph +- Hybrid semantic search combining keyword and dense vector approaches +- Integration of knowledge graphs into RAG workflows +- Visualization of the knowledge graph through [Gephi-Lite](https://github.com/gephi/gephi-lite), an open-source web app for visualization of large graphs +- Comprehensive evaluation metrics using NVIDIA's Nemotron-4 340B model for synthetic data generation and reward scoring + +## Technologies Used + +- **Frontend**: Streamlit +- **Graph Representation and Optimization**: cuGraph (RAPIDS), NetworkX +- **Vector Database**: Milvus +- **LLM Models**: + - NVIDIA AI Playground hosted models for graph creation and querying, providing numerous instruct-fine-tuned options + - NVIDIA AI Playground hosted Nemotron-4 340B model for synthetic data generation and evaluation reward scoring + +## Architecture Diagram + +Here is how the ingestion system is designed, by leveraging a high throughput hosted LLM deployment which can process multiple document chunks in parallel. The LLM can optionally be fine-tuned for triple extraction, thereby requiring a shorter prompt and enabling greater accuracy and optimized inference. + +```mermaid +graph TD + A[Document Collection] --> B{Document Splitter} + B --> |Chunk 1| C1[LLM Stream 1] + B --> |Chunk 2| C2[LLM Stream 2] + B --> |Chunk 3| C3[LLM Stream 3] + B --> |...| C4[...] + B --> |Chunk N| C5[LLM Stream N] + C1 --> D[Response Parser
and Aggregator] + C2 --> D + C3 --> D + C4 --> D + C5 --> D + D --> E[GraphML Generator] + E --> F[Single GraphML File] +``` + +Here's how the inference system is designed, incorporating both hybrid dense-vector search and sparse keyword-based search, reranking, and Knowledge Graph for multi-hop search: + +```mermaid +graph LR + E(User Query) --> A(FRONTEND
Chat UI
Streamlit) + A --Dense-Sparse
Retrieval--> C(Milvus Vector DB) + A --Multi-hop
Search--> F(Knowledge Graph
with cuGraph) + C --Hybrid
Chunks--> X(Reranker) + X -- Augmented
Prompt--> B((Hosted LLM API
NVIDIA AI Playground)) + F -- Graph Context
Triples--> B + B --> D(Streaming
Chat Response) +``` + +This architecture shows how the user query is processed through both the Milvus Vector DB for traditional retrieval and the Knowledge Graph with cuGraph for multi-hop search. The results from both are then used to augment the prompt sent to the NVIDIA AI Playground backend. + +## Setup Steps + +Follow these steps to get the chatbot up and running in less than 5 minutes: + +### 1. Clone this repository to a Linux machine + +```bash +git clone https://github.com/NVIDIA/GenerativeAIExamples/ && cd GenerativeAIExamples/experimental/knowledge_graph_rag +``` + +### 2. Get an NVIDIA AI Playground API Key + +```bash +export NVIDIA_API_KEY="nvapi-*******************" +``` + +If you don't have an API key, follow [these instructions](https://github.com/NVIDIA/GenerativeAIExamples/blob/main/docs/api-catalog.md#get-an-api-key-for-the-accessing-models-on-the-api-catalog) to sign up for an NVIDIA AI Foundation developer account and obtain access. + +### 3. Create a Python virtual environment and activate it + +```bash +cd knowledge_graph_rag +pip install virtualenv +python3 -m virtualenv venv +source venv/bin/activate +``` + +### 4. Install the required packages + +```bash +pip install -r requirements.txt +``` + +### 5. Setup a hosted Milvus vector database + +Follow the instructions [here](https://milvus.io/docs/install_standalone-docker.md) to deploy a hosted Milvus instance for the vector database backend. Note that it must be Milvus 2.4 or better to support [hybrid search](https://milvus.io/docs/multi-vector-search.md). We do not support disabling this feature for previous versions of Milvus as of now. + +### 5. Launch the Streamlit frontend + +```bash +streamlit run app.py +``` + +Open the URL in your browser to access the UI and chatbot! + +### 6. Upload Docs and Train Model + +Upload your own documents to a folder, or use an existing folder for the knowledge graph creation. Note that the implementation currently focuses on text from PDFs only. It can be extended to other text file formats using the Unstructured.io data loader in LangChain. + +## Pipeline Components + +1. **Data Ingestion**: + - ArXiv paper downloader + - Arbitrary document folder ingestion +2. **Knowledge Graph Creation**: + - Uses the API Catalog models through the LangChain NVIDIA AI Endpoints interface +3. **Graph Representation**: cuGraph + RAPIDS + NetworkX +4. **Semantic Search**: Milvus 2.4.x for hybrid (keyword + dense vector) search +5. **RAG Integration**: Custom workflow incorporating knowledge graph retrieval +6. **Evaluation**: Comparison of different RAG approaches using Nemotron-4 340B model + +## Evaluation Metrics + +We've implemented comprehensive evaluation metrics using NVIDIA's Nemotron-4 340B model, which is designed for synthetic data generation and reward scoring. Our evaluation compares different RAG approaches across five key attributes: + +1. **Helpfulness**: Overall helpfulness of the response to the prompt. +2. **Correctness**: Inclusion of all pertinent facts without errors. +3. **Coherence**: Consistency and clarity of expression. +4. **Complexity**: Intellectual depth required to write the response. +5. **Verbosity**: Amount of detail included in the response, relative to what is asked for in the prompt. + +## Evaluation Results + +We compared four RAG approaches on a small representative dataset using the NeMoTron-340B reward model: + +![Evaluation Results](viz.png) + +Key takeaways: +- Graph RAG significantly outperforms traditional Text RAG. +- Combined Text and Graph RAG shows promise but doesn't consistently beat the ground truth yet. This may be due to the way we structure the augmented prompt for the LLM and needs more experimentation. +- Our approach improves on verbosity and coherence compared to ground truth. + +While we're not beating long-context ground truth across the board, these results show the potential of integrating knowledge graphs into RAG systems. We're particularly excited about the improvements in verbosity and coherence. Next steps include refining how we combine text and graph retrieval to get the best of both worlds. + +## Component Swapping + +All components are designed to be swappable. Here are some options: + +- **Frontend**: The current Streamlit implementation can be replaced with other web frameworks. +- **Retrieval**: The embedding model and reranker model being used for semantic search can be swapped to use other models for higher performance. The number of entities retrieved prior to reranking can also be changed. The chunk size for documents can be changed. +- **Vector DB**: While we use Milvus, it can be replaced with options like ChromaDB, Pinecone, FAISS, etc. Milvus is designed to be highly performant and scale on GPU infrastructure. +- **Backend**: + - Cloud Hosted: Currently uses NVIDIA AI Playground APIs, but can be deployed in a private DGX Cloud or AWS/Azure/GCP with NVIDIA GPUs and LLMs. + - On-Prem/Locally Hosted: Smaller models like Llama2-7B or Mistral-7B can be run locally with appropriate hardware. Fine-tuning can also be done for the purpose of a specific model designed for triple extraction for a given use-case. + +## Future Work + +- Dynamic information incorporation into knowledge graphs (continuous update of knowledge graphs) +- Further refinement of evaluation metrics and combined semantic-graphRAG pipeline +- Investigating the impact of different graph structures and queries on RAG performance (single/multi-hop retrieval, BFS/DFS, etc) +- Expanding support for various document types and formats (multimodal RAG with knowledge graphs) +- Fine-tuning the Nemotron-4 340B model for domain-specific evaluations + +## Contributing + +Please create a merge request to this repository, our team appreciates any and all contributions that add features! We will review and get back as soon as possible. + +## Acknowledgements + +This project utilizes NVIDIA's AI technologies, including the Nemotron-4 340B model, and the RAPIDS ecosystem. We thank the open-source community for their invaluable contributions to the tools and libraries used in this project. \ No newline at end of file diff --git a/experimental/knowledge_graph_rag/app.py b/experimental/knowledge_graph_rag/app.py new file mode 100644 index 000000000..e19edc817 --- /dev/null +++ b/experimental/knowledge_graph_rag/app.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import streamlit as st +from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex +from utils.preprocessor import extract_triples +from llama_index.core import ServiceContext +import multiprocessing +import pandas as pd +import networkx as nx +from utils.lc_graph import process_documents, save_triples_to_csvs +from vectorstore.search import SearchHandler +from langchain_nvidia_ai_endpoints import ChatNVIDIA + +def load_data(input_dir, num_workers): + reader = SimpleDirectoryReader(input_dir=input_dir) + documents = reader.load_data(num_workers=num_workers) + return documents + +def has_pdf_files(directory): + for file in os.listdir(directory): + if file.endswith(".pdf"): + return True + return False + +st.title("Knowledge Graph RAG") + +st.subheader("Load Data from Files") + +# Variable for documents +if 'documents' not in st.session_state: + st.session_state['documents'] = None + +models = ChatNVIDIA.get_available_models() +available_models = [model.id for model in models if model.model_type=="chat" and "instruct" in model.id] +with st.sidebar: + llm = st.selectbox("Choose an LLM", available_models, index=available_models.index("mistralai/mixtral-8x7b-instruct-v0.1")) + st.write("You selected: ", llm) + llm = ChatNVIDIA(model=llm) + +def app(): + # Get the current working directory + cwd = os.getcwd() + + # Get a list of visible directories in the current working directory + directories = [d for d in os.listdir(cwd) if os.path.isdir(os.path.join(cwd, d)) and not d.startswith('.') and '__' not in d] + + # Create a dropdown menu for directory selection + selected_dir = st.selectbox("Select a directory:", directories, index=0) + + # Construct the full path of the selected directory + directory = os.path.join(cwd, selected_dir) + + if st.button("Process Documents"): + # Check if the selected directory has PDF files + res = has_pdf_files(directory) + if not res: + st.error("No PDF files found in directory! Only PDF files and text extraction are supported for now.") + st.stop() + documents, results = process_documents(directory, llm) + print(documents) + st.write(documents) + search_handler = SearchHandler("hybrid_demo3", use_bge_m3=True, use_reranker=True) + search_handler.insert_data(documents) + st.write(f"Processing complete. Total triples extracted: {len(results)}") + + with st.spinner("Saving triples to CSV files with Pandas..."): + # write the resulting entities to a CSV, relations to a CSV and all triples with IDs to a CSV + save_triples_to_csvs(results) + + with st.spinner("Loading the CSVs into dataframes..."): + # Load the triples from the CSV file + triples_df = pd.read_csv("triples.csv") + # Load the entities and relations DataFrames + entities_df = pd.read_csv("entities.csv") + relations_df = pd.read_csv("relations.csv") + + # with st.spinner("Creating the knowledge graph from these triples..."): + # Create a mapping from IDs to entity names and relation names + entity_name_map = entities_df.set_index("entity_id")["entity_name"].to_dict() + relation_name_map = relations_df.set_index("relation_id")["relation_name"].to_dict() + + # Create the graph from the triples DataFrame + G = nx.from_pandas_edgelist( + triples_df, + source="entity_id_1", + target="entity_id_2", + edge_attr="relation_id", + create_using=nx.DiGraph, + ) + + with st.spinner("Relabeling node integers to strings for future retrieval..."): + # Relabel the nodes with the actual entity names + G = nx.relabel_nodes(G, entity_name_map) + + # Relabel the edges with the actual relation names + edge_attributes = nx.get_edge_attributes(G, "relation_id") + + # Update the edges with the new relation names + new_edge_attributes = { + (u, v): relation_name_map[edge_attributes[(u, v)]] + for u, v in G.edges() + if edge_attributes[(u, v)] in relation_name_map + } + + nx.set_edge_attributes(G, new_edge_attributes, "relation") + + with st.spinner("Saving the graph to a GraphML file for further visualization and retrieval..."): + try: + nx.write_graphml(G, "knowledge_graph.graphml") + + # Verify by reading it back + G_loaded = nx.read_graphml("knowledge_graph.graphml") + if nx.is_directed(G_loaded): + st.success("GraphML file is valid and successfully loaded.") + else: + st.error("GraphML file is invalid.") + except Exception as e: + st.error(f"Error saving or loading GraphML file: {e}") + return + + st.success("Done!") + +if __name__ == "__main__": + app() diff --git a/experimental/knowledge_graph_rag/pages/chat.py b/experimental/knowledge_graph_rag/pages/chat.py new file mode 100644 index 000000000..8e0225361 --- /dev/null +++ b/experimental/knowledge_graph_rag/pages/chat.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from langchain.chains import GraphQAChain +from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities +from langchain_nvidia_ai_endpoints import ChatNVIDIA + +import streamlit as st +import json +import networkx as nx +st.set_page_config(layout = "wide") + +from langchain_community.callbacks.streamlit import StreamlitCallbackHandler + +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate + +from vectorstore.search import SearchHandler + +G = nx.read_graphml("knowledge_graph.graphml") +graph = NetworkxEntityGraph(G) + +models = ChatNVIDIA.get_available_models() +available_models = [model.id for model in models if model.model_type=="chat" and "instruct" in model.id] + +with st.sidebar: + llm = st.selectbox("Choose an LLM", available_models, index=available_models.index("mistralai/mixtral-8x7b-instruct-v0.1")) + st.write("You selected: ", llm) + llm = ChatNVIDIA(model=llm) + +st.subheader("Chat with your knowledge graph!") + +if "messages" not in st.session_state: + st.session_state.messages = [] + +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + +with st.sidebar: + use_kg = st.toggle("Use knowledge graph") + +user_input = st.chat_input("Can you tell me how research helps users to solve problems?") + +graph_chain = GraphQAChain.from_llm(llm = llm, graph=graph, verbose=True) + +prompt_template = ChatPromptTemplate.from_messages( + [("system", "You are a helpful AI assistant named Envie. You will reply to questions only based on the context that you are provided. If something is out of context, you will refrain from replying and politely decline to respond to the user."), ("user", "{input}")] +) + +chain = prompt_template | llm | StrOutputParser() +search_handler = SearchHandler("hybrid_demo3", use_bge_m3=True, use_reranker=True) + +if user_input: + st.session_state.messages.append({"role": "user", "content": user_input}) + with st.chat_message("user"): + st.markdown(user_input) + + with st.chat_message("assistant"): + if use_kg: + entity_string = llm.invoke("""Return a JSON with a single key 'entities' and list of entities within this user query. Each element in your list MUST BE part of the user's query. Do not provide any explanation. If the returned list is not parseable in Python, you will be heavily penalized. For example, input: 'What is the difference between Apple and Google?' output: ['Apple', 'Google']. Always follow this output format. Here's the user query: """ + user_input) + try: + entities = json.loads(entity_string.content)['entities'] + with st.expander("Extracted triples"): + st.code(entities) + res = search_handler.search_and_rerank(user_input, k=5) + with st.expander("Retrieved and Reranked Sparse-Dense Hybrid Search"): + st.write(res) + context = "Here are the relevant passages from the knowledge base: \n\n" + "\n".join(item.text for item in res) + all_triplets = [] + for entity in entities: + all_triplets.extend(graph_chain.graph.get_entity_knowledge(entity, depth=2)) + context += "\n\nHere are the relationships from the knowledge graph: " + "\n".join(all_triplets) + with st.expander("All triplets"): + st.code(context) + except Exception as e: + st.write("Faced exception: ", e) + context = "No graph triples were available to extract from the knowledge graph. Always provide a disclaimer if you know the answer to the user's question, since it is not grounded in the knowledge you are provided from the graph." + message_placeholder = st.empty() + full_response = "" + + for response in chain.stream("Context: " + context + "\n\nUser query: " + user_input): + full_response += response + message_placeholder.markdown(full_response + "▌") + else: + message_placeholder = st.empty() + full_response = "" + for response in chain.stream(user_input): + full_response += response + message_placeholder.markdown(full_response + "▌") + message_placeholder.markdown(full_response) + + st.session_state.messages.append({"role": "assistant", "content": full_response}) diff --git a/experimental/knowledge_graph_rag/pages/evaluation.py b/experimental/knowledge_graph_rag/pages/evaluation.py new file mode 100644 index 000000000..60b63031e --- /dev/null +++ b/experimental/knowledge_graph_rag/pages/evaluation.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import streamlit as st +from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex +from utils.preprocessor import generate_qa_pair +from llama_index.core import ServiceContext +import multiprocessing +import pandas as pd +import networkx as nx +from utils.lc_graph import process_documents, save_triples_to_csvs +from vectorstore.search import SearchHandler +from langchain_nvidia_ai_endpoints import ChatNVIDIA +import random +import pandas as pd +import time +import json +from langchain_core.output_parsers import StrOutputParser +from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities +from langchain_core.prompts import ChatPromptTemplate + +from vectorstore.search import SearchHandler + +from concurrent.futures import ThreadPoolExecutor, as_completed + +from openai import OpenAI +reward_client = OpenAI( + base_url = "https://integrate.api.nvidia.com/v1", + api_key = os.environ["NVIDIA_API_KEY"] +) + +def get_reward_scores(question, answer): + completion = reward_client.chat.completions.create( + model="nvidia/nemotron-4-340b-reward", + messages=[{"role": "user", "content":question}, {"role":"assistant", "content":answer}] + ) + try: + content = completion.choices[0].message[0].content + res = content.split(",") + content_dict = {} + for item in res: + name, val = item.split(":") + content_dict[name] = float(val) + return content_dict + except: + return None + +def process_question(question, answer): + with ThreadPoolExecutor(max_workers=3) as executor: + future_text = executor.submit(get_text_RAG_response, question) + future_graph = executor.submit(get_graph_RAG_response, question) + future_combined = executor.submit(get_combined_RAG_response, question) + + text_RAG_response = future_text.result() + graph_RAG_response = future_graph.result() + combined_RAG_response = future_combined.result() + + return { + "question": question, + "gt_answer": answer, + "textRAG_answer": text_RAG_response, + "graphRAG_answer": graph_RAG_response, + "combined_answer": combined_RAG_response + } + +prompt_template = ChatPromptTemplate.from_messages( + [("system", "You are a helpful AI assistant named Envie. You will reply to questions only based on the context that you are provided. If something is out of context, you will refrain from replying and politely decline to respond to the user."), ("user", "{input}")] +) + +def load_data(input_dir, num_workers): + reader = SimpleDirectoryReader(input_dir=input_dir) + documents = reader.load_data(num_workers=num_workers) + return documents + +def has_pdf_files(directory): + for file in os.listdir(directory): + if file.endswith(".pdf"): + return True + return False + +def get_text_RAG_response(question): + chain = prompt_template | llm | StrOutputParser() + + search_handler = SearchHandler("hybrid_demo3", use_bge_m3=True, use_reranker=True) + res = search_handler.search_and_rerank(question, k=5) + context = "Here are the relevant passages from the knowledge base: \n\n" + "\n".join(item.text for item in res) + answer = chain.invoke("Context: " + context + "\n\nUser query: " + question) + return answer + +def get_graph_RAG_response(question): + chain = prompt_template | llm | StrOutputParser() + entity_string = llm.invoke("""Return a JSON with a single key 'entities' and list of entities within this user query. Each element in your list MUST BE part of the user's query. Do not provide any explanation. If the returned list is not parseable in Python, you will be heavily penalized. For example, input: 'What is the difference between Apple and Google?' output: ['Apple', 'Google']. Always follow this output format. Here's the user query: """ + question) + G = nx.read_graphml("knowledge_graph.graphml") + graph = NetworkxEntityGraph(G) + + try: + entities = json.loads(entity_string.content)['entities'] + context = "" + all_triplets = [] + for entity in entities: + all_triplets.extend(graph.get_entity_knowledge(entity, depth=2)) + context = "Here are the relationships from the knowledge graph: " + "\n".join(all_triplets) + except: + context = "No graph triples were available to extract from the knowledge graph. Always provide a disclaimer if you know the answer to the user's question, since it is not grounded in the knowledge you are provided from the graph." + answer = chain.invoke("Context: " + context + "\n\nUser query: " + question) + return answer + +def get_combined_RAG_response(question): + chain = prompt_template | llm | StrOutputParser() + entity_string = llm.invoke("""Return a JSON with a single key 'entities' and list of entities within this user query. Each element in your list MUST BE part of the user's query. Do not provide any explanation. If the returned list is not parseable in Python, you will be heavily penalized. For example, input: 'What is the difference between Apple and Google?' output: ['Apple', 'Google']. Always follow this output format. Here's the user query: """ + question) + G = nx.read_graphml("knowledge_graph.graphml") + graph = NetworkxEntityGraph(G) + + try: + entities = json.loads(entity_string.content)['entities'] + search_handler = SearchHandler("hybrid_demo3", use_bge_m3=True, use_reranker=True) + res = search_handler.search_and_rerank(question, k=5) + context = "Here are the relevant passages from the knowledge base: \n\n" + "\n".join(item.text for item in res) + all_triplets = [] + for entity in entities: + all_triplets.extend(graph.get_entity_knowledge(entity, depth=2)) + context += "\n\nHere are the relationships from the knowledge graph: " + "\n".join(all_triplets) + except Exception as e: + context = "No graph triples were available to extract from the knowledge graph. Always provide a disclaimer if you know the answer to the user's question, since it is not grounded in the knowledge you are provided from the graph." + answer = chain.invoke("Context: " + context + "\n\nUser query: " + question) + return answer + +st.title("Evaluations") + +st.subheader("Create synthetic Q&A pairs from large document chunks") + +# Variable for documents +if 'documents' not in st.session_state: + st.session_state['documents'] = None + +with st.sidebar: + llm_selectbox = st.selectbox("Choose an LLM", ["nvidia/nemotron-4-340b-instruct", "mistralai/mixtral-8x7b-instruct-v0.1", "meta/llama3-70b-instruct"], index=0) + st.write("You selected: ", llm_selectbox) + llm = ChatNVIDIA(model=llm_selectbox) + + num_data = st.slider("How many Q&A pairs to generate?", 10, 100, 50, step=10) + +def app(): + # Get the current working directory + cwd = os.getcwd() + + # Get a list of visible directories in the current working directory + directories = [d for d in os.listdir(cwd) if os.path.isdir(os.path.join(cwd, d)) and not d.startswith('.') and '__' not in d] + + # Create a dropdown menu for directory selection + selected_dir = st.selectbox("Select a directory:", directories, index=0) + + # Construct the full path of the selected directory + directory = os.path.join(cwd, selected_dir) + if st.button("Process Documents"): + # Check if the selected directory has PDF files + res = has_pdf_files(directory) + if not res: + st.error("No PDF files found in directory! Only PDF files and text extraction are supported for now.") + st.stop() + documents, results = process_documents(directory, llm, triplets=False, chunk_size=2000, chunk_overlap=200) + st.session_state["documents"] = documents + st.success("Finished splitting documents!") + + json_list = [] + if st.session_state["documents"] is not None: + if st.button("Create Q&A pairs"): + qa_docs = random.sample(st.session_state["documents"], num_data) + for doc in qa_docs: + res = generate_qa_pair(doc, llm) + st.write(res) + if res: + json_list.append(res) + + if len(json_list) > 0: + df = pd.DataFrame(json_list) + df.to_csv('qa_data.csv', index=False) + + if os.path.exists("qa_data.csv"): + with st.expander("Load Q&A data and run evaluations of text vs graph vs text+graph RAG"): + if st.button("Run"): + df_csv = pd.read_csv("qa_data.csv") + questions_list = df_csv["question"].tolist() + answers_list = df_csv["answer"].tolist() + + # Create an empty DataFrame to store results + results_df = pd.DataFrame(columns=[ + "question", "gt_answer", "textRAG_answer", + "graphRAG_answer", "combined_answer" + ]) + + # Create a placeholder for the DataFrame + df_placeholder = st.empty() + + # Create a progress bar + progress_bar = st.progress(0) + + # Process questions + for i, (question, answer) in enumerate(zip(questions_list, answers_list)): + result = process_question(question, answer) + + # Add new row to results_df + new_row = pd.DataFrame([result]) + results_df = pd.concat([results_df, new_row], ignore_index=True) + + # Update the displayed DataFrame + df_placeholder.dataframe(results_df) + + # Update progress bar + progress_bar.progress((i + 1) / len(questions_list)) + + # Optionally, save the combined results to a new CSV file + results_df.to_csv("combined_results.csv", index=False) + st.success("Combined results saved to 'combined_results.csv'") + + if os.path.exists("combined_results.csv"): + with st.expander("Run comparative evals for saved Q&A data"): + if st.button("Run scoring"): + combined_results = pd.read_csv("combined_results.csv") + + # Initialize new columns for scores + score_columns = ['gt', 'textRAG', 'graphRAG', 'combinedRAG'] + metrics = ['helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity'] + + for row in range(len(combined_results)): + res_gt = get_reward_scores(combined_results["question"][row], combined_results["gt_answer"][row]) + + res_textRAG = get_reward_scores(combined_results["question"][row], combined_results["textRAG_answer"][row]) + + res_graphRAG = get_reward_scores(combined_results["question"][row], combined_results["graphRAG_answer"][row]) + + res_combinedRAG = get_reward_scores(combined_results["question"][row], combined_results["combined_answer"][row]) + + # Add scores to the DataFrame + for score_type, res in zip(score_columns, [res_gt, res_textRAG, res_graphRAG, res_combinedRAG]): + for metric in metrics: + combined_results.at[row, f'{score_type}_{metric}'] = res[metric] + + # Display progress + if row % 10 == 0: # Update every 10 rows + st.write(f"Processed {row + 1} out of {len(combined_results)} rows") + + # Save the updated DataFrame + combined_results.to_csv("combined_results_with_scores.csv", index=False) + st.success("Evaluation complete. Results saved to 'combined_results_with_scores.csv'") + + # Display the first few rows of the updated DataFrame + st.write("First few rows of the updated data:") + st.dataframe(combined_results.head()) + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/experimental/knowledge_graph_rag/pages/visualization.py b/experimental/knowledge_graph_rag/pages/visualization.py new file mode 100644 index 000000000..633c4b9d8 --- /dev/null +++ b/experimental/knowledge_graph_rag/pages/visualization.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import streamlit as st +import streamlit.components.v1 as components + +st.set_page_config(layout="wide") + +def app(): + st.title("Visualize the Knowledge Graph!") + st.subheader("Load a knowledge graph GraphML file from your system.") + st.write("If you used the previous step, it will be saved on your system as ```knowledge_graph.graphml```") + + components.iframe( + src="https://gephi.org/gephi-lite/", + height=800, + scrolling=True, + ) + +app() \ No newline at end of file diff --git a/experimental/knowledge_graph_rag/requirements.txt b/experimental/knowledge_graph_rag/requirements.txt new file mode 100644 index 000000000..82a9a1410 --- /dev/null +++ b/experimental/knowledge_graph_rag/requirements.txt @@ -0,0 +1,14 @@ +arxiv==2.1.0 +langchain==0.2.6 +langchain_community==0.2.6 +langchain_core==0.2.10 +langchain_nvidia_ai_endpoints==0.1.2 +llama_index==0.10.50 +networkx==3.2.1 +numpy==1.24.1 +pandas==2.2.2 +pymilvus==2.4.3 +Requests==2.32.3 +streamlit==1.30.0 +unstructured[all-docs] +tqdm==4.66.1 diff --git a/experimental/knowledge_graph_rag/utils/download_papers.py b/experimental/knowledge_graph_rag/utils/download_papers.py new file mode 100644 index 000000000..5030f3259 --- /dev/null +++ b/experimental/knowledge_graph_rag/utils/download_papers.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Import necessary libraries +import argparse +import arxiv +import requests +import os +from datetime import datetime, timedelta +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor +import time + +# Define a function to download a paper +def download_paper(result, download_dir, max_retries=3, retry_delay=5): + # Get the URL to download the paper + pdf_url = result.entry_id + '.pdf' + pdf_url = pdf_url.split("/")[-1] + pdf_url = f"https://arxiv.org/pdf/{pdf_url}" + retries = 0 + + # Try downloading the paper with retries + while retries < max_retries: + try: + pdf_response = requests.get(pdf_url) + print(pdf_response, pdf_url) + if pdf_response.status_code == 200: + # Save the paper to the download directory + pdf_filename = result.title.replace(' ', '_') + '.pdf' + pdf_filepath = os.path.join(download_dir, pdf_filename) + with open(pdf_filepath, 'wb') as f: + f.write(pdf_response.content) + print(f"Downloaded: {result.entry_id} - {result.title}") + return # Exit the function if download is successful + except requests.exceptions.RequestException as e: + print(f"Error downloading {result.entry_id}: {e}") + + # Retry logic + retries += 1 + print(f"Retrying download for {result.entry_id} in {retry_delay} seconds...") + time.sleep(retry_delay) + + print(f"Failed to download: {result.entry_id} - {result.title} (after {max_retries} retries)") + +# Define a function to download papers based on search criteria +def download_papers(search_terms, start_date, end_date, max_results=10, download_dir='papers', num_threads=4, max_retries=3, retry_delay=5): + # Create the search query based on search terms and dates + search_query = f"({search_terms}) AND submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]" + + search = arxiv.Search( + query=search_query, + max_results=max_results, + sort_by=arxiv.SortCriterion.SubmittedDate, + ) + + # Create the download directory if it doesn't exist + os.makedirs(download_dir, exist_ok=True) + + # Use a thread pool to download papers in parallel + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + for result in tqdm(search.results(), total=max_results, unit='paper'): + # Submit download tasks to the executor + future = executor.submit(download_paper, result, download_dir, max_retries, retry_delay) + futures.append(future) + +# Main function to parse arguments and execute the download +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Download research papers from arXiv.org') + parser.add_argument('-s', '--search-terms', required=True, help='A comma-separated list of search terms') + parser.add_argument('-sd', '--start-date', help='Start date in the format YYYY-MM-DD (default: 10 years ago)') + parser.add_argument('-ed', '--end-date', help='End date in the format YYYY-MM-DD (default: today)') + parser.add_argument('-n', '--max-results', type=int, default=10, help='Maximum number of papers to download (default: 10)') + parser.add_argument('-d', '--download-dir', default='papers', help='Directory to save the downloaded papers (default: papers)') + parser.add_argument('-t', '--num-threads', type=int, default=4, help='Number of threads to use for parallel downloads (default: 4)') + parser.add_argument('-r', '--max-retries', type=int, default=3, help='Maximum number of retries for each download (default: 3)') + parser.add_argument('-rd', '--retry-delay', type=int, default=5, help='Delay in seconds between retries (default: 5)') + + args = parser.parse_args() + + # Handle the start date + if args.start_date: + try: + start_date = datetime.strptime(args.start_date, '%Y-%m-%d') + except ValueError: + print(f"Invalid start date format. Please use YYYY-MM-DD. Provided: {args.start_date}") + exit(1) + else: + start_date = datetime.now() - timedelta(days=365 * 10) # Default to 10 years ago + + # Handle the end date + if args.end_date: + try: + end_date = datetime.strptime(args.end_date, '%Y-%m-%d') + except ValueError: + print(f"Invalid end date format. Please use YYYY-MM-DD. Provided: {args.end_date}") + exit(1) + else: + end_date = datetime.now() # Default to today + + # Call the download_papers function with the provided arguments + download_papers(args.search_terms, start_date, end_date, args.max_results, args.download_dir, args.num_threads, args.max_retries, args.retry_delay) + diff --git a/experimental/knowledge_graph_rag/utils/lc_graph.py b/experimental/knowledge_graph_rag/utils/lc_graph.py new file mode 100644 index 000000000..5a5a81ac2 --- /dev/null +++ b/experimental/knowledge_graph_rag/utils/lc_graph.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from langchain_nvidia_ai_endpoints import ChatNVIDIA +import concurrent.futures +from preprocessor import extract_triples +from tqdm import tqdm +from langchain_community.document_loaders import DirectoryLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter +import multiprocessing +import csv +import streamlit as st + +# function to process a single document (will run many of these processes in parallel) +def process_document(doc, llm): + try: + return extract_triples(doc, llm) + except Exception as e: + print(f"Error processing document: {e}") + return [] + +def process_documents(directory, llm, triplets=True, chunk_size=500, chunk_overlap=100): + with st.spinner("Loading and splitting documents"): + loader = DirectoryLoader(directory) + raw_docs = loader.load() + text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + documents = text_splitter.split_documents(raw_docs) + st.write("Loaded docs, len(docs): " + str(len(documents))) + + if not triplets: + return documents, [] + + multiprocessing.set_start_method('fork', force=True) + + progress_bar = st.progress(0) # Initialize the progress bar + progress_text = st.empty() # Create a placeholder for the progress text + + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [executor.submit(process_document, doc, llm) for doc in documents] + results = [] + total_futures = len(futures) + completed_futures = 0 + + for future in concurrent.futures.as_completed(futures): + try: + result = future.result() + results.extend(result) + except Exception as e: + print(f"Error collecting result: {e}") + + completed_futures += 1 + progress = completed_futures / total_futures + progress_bar.progress(progress) # Update the progress bar + progress_text.text(f"Processing: {completed_futures}/{total_futures}") # Update the progress text + + print("Processing complete. Total triples extracted:", len(results)) + return documents, results + +import pandas as pd + +def save_triples_to_csvs(triples): + # Create the triples DataFrame + triples_df = pd.DataFrame(triples, columns=['subject', 'subject_type', 'relation', 'object', 'object_type']) + + # Create the relations DataFrame + relations_df = pd.DataFrame({'relation_id': range(len(triples_df['relation'].unique())), 'relation_name': triples_df['relation'].unique()}) + + # Get unique entities (subjects and objects) from triples_df + entities = pd.concat([triples_df['subject'], triples_df['object']]).unique() + + entities_df = pd.DataFrame({ + 'entity_name': entities, + 'entity_type': [ + triples_df.loc[triples_df['subject'] == entity, 'subject_type'].iloc[0] + if entity in triples_df['subject'].values + else triples_df.loc[triples_df['object'] == entity, 'object_type'].dropna().iloc[0] + if not triples_df.loc[triples_df['object'] == entity, 'object_type'].empty + else 'Unknown' + for entity in entities + ] + }) + entities_df = entities_df.reset_index().rename(columns={'index': 'entity_id'}) + + # Merge triples_df with entities_df for subject + triples_with_ids = triples_df.merge(entities_df[['entity_id', 'entity_name']], left_on='subject', right_on='entity_name', how='left') + triples_with_ids = triples_with_ids.rename(columns={'entity_id': 'entity_id_1'}).drop(columns=['entity_name', 'subject', 'subject_type']) + + # Merge triples_with_ids with entities_df for object + triples_with_ids = triples_with_ids.merge(entities_df[['entity_id', 'entity_name']], left_on='object', right_on='entity_name', how='left') + triples_with_ids = triples_with_ids.rename(columns={'entity_id': 'entity_id_2'}).drop(columns=['entity_name', 'object', 'object_type']) + + # Merge triples_with_ids with relations_df to get relation IDs + triples_with_ids = triples_with_ids.merge(relations_df, left_on='relation', right_on='relation_name', how='left').drop(columns=['relation', 'relation_name']) + + # Select necessary columns and ensure correct data types + result_df = triples_with_ids[['entity_id_1', 'relation_id', 'entity_id_2']].astype({'entity_id_1': int, 'relation_id': int, 'entity_id_2': int}) + + # Write DataFrames to CSV files + entities_df.to_csv('entities.csv', index=False) + relations_df.to_csv('relations.csv', index=False) + result_df.to_csv('triples.csv', index=False) + +if __name__ == "__main__": + llm = ChatNVIDIA(model="ai-mixtral-8x7b-instruct") + results = process_documents("papers/", llm) + + # write the resulting entities to a CSV, relations to a CSV and all triples with IDs to a CSV + save_triples_to_csvs(results) + + # load the CSV triples, entities and relations into pandas objects (accelerated by cuDF/cuGraph) + import pandas as pd + import networkx as nx + + # Load the triples from the CSV file + triples_df = pd.read_csv("triples.csv", header=None, names=["Entity1_ID", "relation", "Entity2_ID"]) + + # Load the entities and relations DataFrames + entity_df = pd.read_csv("entities.csv", header=None, names=["ID", "Entity"]) + relations_df = pd.read_csv("relations.csv", header=None, names=["ID", "relation"]) + + # Create a mapping from IDs to entity names and relation names + entity_name_map = entity_df.set_index("ID")["Entity"].to_dict() + relation_name_map = relations_df.set_index("ID")["relation"].to_dict() + + # Create the graph from the triples DataFrame using accelerated networkX-cuGraph integration + G = nx.from_pandas_edgelist( + triples_df, + source="Entity1_ID", + target="Entity2_ID", + edge_attr="relation", + create_using=nx.DiGraph, + ) + + # Relabel the nodes with the actual entity names + G = nx.relabel_nodes(G, entity_name_map) + + # Relabel the edges with the actual relation names + edge_attributes = nx.get_edge_attributes(G, "relation") + nx.set_edge_attributes(G, {(u, v): relation_name_map[edge_attributes[(u, v)]] for u, v in G.edges()}, "relation") + + # Save the graph to a GraphML file so it can be visualized in Gephi Lite + nx.write_graphml(G, "knowledge_graph.graphml") + + # Query the graph using LangChain + from langchain.chains import GraphQAChain + from langchain.indexes.graph import NetworkxEntityGraph + graph = NetworkxEntityGraph(G) + # print(graph.get_triples()) + + # llm.invoke("hello") + chain = GraphQAChain.from_llm(llm = llm, graph=graph, verbose=True) + res = chain.run("explain how URDFormer and vision transformer is related") + print(res) diff --git a/experimental/knowledge_graph_rag/utils/preprocessor.py b/experimental/knowledge_graph_rag/utils/preprocessor.py new file mode 100644 index 000000000..3861c3067 --- /dev/null +++ b/experimental/knowledge_graph_rag/utils/preprocessor.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import getpass +import os +import json +import ast +from langchain_nvidia_ai_endpoints import ChatNVIDIA + +if not os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"): + nvapi_key = getpass.getpass("Enter your NVIDIA API key: ") + assert nvapi_key.startswith("nvapi-"), f"{nvapi_key[:5]}... is not a valid key" + os.environ["NVIDIA_API_KEY"] = nvapi_key + +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate + +def process_response(triplets_str): + triplets_list = ast.literal_eval(triplets_str) + json_triplets = [] + + for triplet in triplets_list: + try: + subject, subject_type, relation, object, object_type = triplet + json_triplet = { + "subject": subject, + "subject_type": subject_type, + "relation": relation, + "object": object, + "object_type": object_type + } + json_triplets.append(json_triplet) + except ValueError: + # Skip the malformed triplet and continue with the next one + continue + + return json_triplets + +def extract_triples(text, llm): + prompt = ChatPromptTemplate.from_messages( + [("system", """Note that the entities should not be generic, numerical, or temporal (like dates or percentages). Entities must be classified into the following categories: +- ORG: Organizations other than government or regulatory bodies +- ORG/GOV: Government bodies (e.g., "United States Government") +- ORG/REG: Regulatory bodies (e.g., "Food and Drug Administration") +- PERSON: Individuals (e.g., "Marie Curie") +- GPE: Geopolitical entities such as countries, cities, etc. (e.g., "Germany") +- INSTITUTION: Academic or research institutions (e.g., "Harvard University") +- PRODUCT: Products or services (e.g., "CRISPR technology") +- EVENT: Specific and Material Events (e.g., "Nobel Prize", "COVID-19 pandemic") +- FIELD: Academic fields or disciplines (e.g., "Quantum Physics") +- METRIC: Research metrics or indicators (e.g., "Impact Factor"), numerical values like "10%" is not a METRIC; +- TOOL: Research tools or methods (e.g., "Gene Sequencing", "Surveys") +- CONCEPT: Abstract ideas or notions or themes (e.g., "Quantum Entanglement", "Climate Change") + +The relationships 'r' between these entities must be represented by one of the following relation verbs set: Has, Announce, Operate_In, Introduce, Produce, Control, Participates_In, Impact, Positive_Impact_On, Negative_Impact_On, Relate_To, Is_Member_Of, Invests_In, Raise, Decrease. + +Remember to conduct entity disambiguation, consolidating different phrases or acronyms that refer to the same entity (for instance, "MIT" and "Massachusetts Institute of Technology" should be unified as "MIT"). Simplify each entity of the triplet to be less than four words. However, always make sure it is a sensible entity name and not a single letter or NAN value. + +From this text, your output Must be in python lis tof tuple with each tuple made up of ['h', 'type', 'r', 'o', 'type'], each element of the tuple is the string, where the relationship 'r' must be in the given relation verbs set above. Only output the list. As an Example, consider the following news excerpt: + Input :'Apple Inc. is set to introduce the new iPhone 14 in the technology sector this month. The product's release is likely to positively impact Apple's stock value.' + OUTPUT : ``` + [('Apple Inc.', 'COMP', 'Introduce', 'iPhone 14', 'PRODUCT'), + ('Apple Inc.', 'COMP', 'Operate_In', 'Technology Sector', 'SECTOR'), + ('iPhone 14', 'PRODUCT', 'Positive_Impact_On', 'Apple's Stock Value', 'FIN_INSTRUMENT')] + ``` + The output structure must not be anything apart from above OUTPUT structure. NEVER REPLY WITH any element as NAN. Just leave out the triple if you think it's not worth including or does not have an object. Do not provide ANY additional explanations, if it's not a Python parseable list of tuples, you will be penalized severely. Make the best possible decisions given the context."""), ("user", "{input}")]) + chain = prompt | llm | StrOutputParser() + response = chain.invoke({"input": text}) + print(response) + return process_response(response) + +def generate_qa_pair(text, llm): + prompt = ChatPromptTemplate.from_messages( + [("system", """You are a synthetic data generation model responsible for creating high quality question and answer pairs from text content provided to you. Given the paragraph as an input, create one high quality and highly complex question answer pair. The question should require a large portion of the context and multi-step advanced reasoning to answer. Make sure it is something a human may ask while reading this document. The answer should be highly detailed and comprehensive. Your output should be in a json format of one question answer pair. Restrict the question to the context information provided. Do not print anything else. The output MUST be JSON parseable."""), ("user", "{input}")]) + # llm = ChatNVIDIA(model="nvidia/nemotron-4-340b-instruct") + chain = prompt | llm | StrOutputParser() + response = chain.invoke({"input": text}) + print(response) + try: + parsed_response = json.loads(response) + return parsed_response + except: + return None \ No newline at end of file diff --git a/experimental/knowledge_graph_rag/vectorstore/search.py b/experimental/knowledge_graph_rag/vectorstore/search.py new file mode 100644 index 000000000..a0b860ef1 --- /dev/null +++ b/experimental/knowledge_graph_rag/vectorstore/search.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string +import numpy as np +from pymilvus import ( + utility, + FieldSchema, CollectionSchema, DataType, + Collection, AnnSearchRequest, RRFRanker, connections, +) + +class SearchHandler: + def __init__(self, collection_name, use_bge_m3=True, use_reranker=True): + self.use_bge_m3 = use_bge_m3 + self.use_reranker = use_reranker + self.collection_name = collection_name + self.dense_dim = None + self.ef = None + self.setup_embeddings() + self.setup_milvus_collection() + + def setup_embeddings(self): + if self.use_bge_m3: + from pymilvus.model.hybrid import BGEM3EmbeddingFunction + self.ef = BGEM3EmbeddingFunction(use_fp16=False, device="cpu") + self.dense_dim = self.ef.dim["dense"] + else: + self.dense_dim = 768 + self.ef = self.random_embedding + + def random_embedding(self, texts): + rng = np.random.default_rng() + return { + "dense": np.random.rand(len(texts), self.dense_dim), + "sparse": [{d: rng.random() for d in random.sample(range(1000), random.randint(20, 30))} for _ in texts], + } + + def setup_milvus_collection(self): + connections.connect("default", host="localhost", port="19530") + + fields = [ + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100), + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048), + FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR), + FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=self.dense_dim), + ] + schema = CollectionSchema(fields, "") + self.collection = Collection(self.collection_name, schema, consistency_level="Strong") + + sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"} + self.collection.create_index("sparse_vector", sparse_index) + dense_index = {"index_type": "FLAT", "metric_type": "IP"} + self.collection.create_index("dense_vector", dense_index) + self.collection.load() + + def insert_data(self, docs): + doc_page_content = [doc.page_content for doc in docs] + docs_embeddings = self.ef(doc_page_content) + entities = [doc_page_content, docs_embeddings["sparse"], docs_embeddings["dense"]] + self.collection.insert(entities) + self.collection.flush() + + def search_and_rerank(self, query, k=2): + query_embeddings = self.ef([query]) + + sparse_search_params = {"metric_type": "IP"} + sparse_req = AnnSearchRequest(query_embeddings["sparse"], "sparse_vector", sparse_search_params, limit=k) + dense_search_params = {"metric_type": "IP"} + dense_req = AnnSearchRequest(query_embeddings["dense"], "dense_vector", dense_search_params, limit=k) + + res = self.collection.hybrid_search([sparse_req, dense_req], rerank=RRFRanker(), limit=k, output_fields=['text']) + res = res[0] + + if self.use_reranker: + result_texts = [hit.fields["text"] for hit in res] + from pymilvus.model.reranker import BGERerankFunction + bge_rf = BGERerankFunction(device='cpu') + results = bge_rf(query, result_texts, top_k=k) + return results + else: + return res \ No newline at end of file diff --git a/experimental/knowledge_graph_rag/viz.png b/experimental/knowledge_graph_rag/viz.png new file mode 100644 index 000000000..6c18123f3 Binary files /dev/null and b/experimental/knowledge_graph_rag/viz.png differ