# KG Index and VectorStore Index on Hotpot QA with LlamaIndex


This notebook demonstrates the following:

    1) Building a custom `KnowledgeGraphIndex` for extracted triples.
    2) How to build a customized `BaseRetriever` to retrieve KG triples and text documents jointly.
    3) Automated evaluation of KG RAG based QA using LLMs for evaluation on hotpot qa.
    
This notebook compares a `VectorIndexRetriever`, `KGTableRetriever`, and joint Retriever for RAG on the `hotpot_qa` dataset. The triples can be extracted following `hotpot_qa_extraction.ipynb` or by running `hotpot_qa_kgs.py`

In [1]:
import os, sys

open_ai_key = '...'
os.environ['OPENAI_API_KEY'] = open_ai_key

sys.path = ['/Users/walder2/kg_uq/'] + sys.path
path_to_data = '/Users/walder2/kg_uq/hotpot_qa_data'

from hotpot_qa_data.hotpot_data_load import load_hotpot_kgs

import numpy as np 

from llama_index import (
    SimpleDirectoryReader,
    ServiceContext,
    KnowledgeGraphIndex,
    VectorStoreIndex,
    get_response_synthesizer,
    QueryBundle,
    Response
)

from llama_index.llms import OpenAI
from llama_index.graph_stores import SimpleGraphStore
from llama_index.schema import NodeWithScore
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.node_parser import SentenceSplitter

# Retrievers
from llama_index.retrievers import (
    BaseRetriever,
    VectorIndexRetriever,
    KGTableRetriever,
)

#Evaluators
from llama_index.evaluation import CorrectnessEvaluator, BaseEvaluator
from typing import Dict, List, Tuple

import time
import asyncio 
import nest_asyncio


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
kg, query_answer = load_hotpot_kgs(path_to_data=path_to_data, query_answer=True)

### Look at the data

`doc_id` is the id for the group (question group id), `sub_idx` is the id of the subgraph extracted for context entry `j` for a particular question. Some of the text is messy, we can clean that up at a later time. 

Also note that the `file_path` is included. This helps track where the triples came from. It is important to know if you want to track down top-matching subgraphs. 

In [3]:
kg

Unnamed: 0,head,head_type,relation,tail,tail_type,file_path,doc_id,sub_idx
0,Radio City,place,isFirstPrivateFMStationIn,India,place,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,0,0
1,Radio City,place,wasStartedOn,3 July 2001,event,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,0,0
2,Radio City,place,broadcastsOn,91.1 megahertz,measurement,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,0,0
3,Radio City,place,broadcastsFrom,Mumbai,place,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,0,0
4,Radio City,place,broadcastsFrom,Bengaluru,place,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,0,0
...,...,...,...,...,...,...,...,...
1026,2014 Liqui Moly Bathurst 12 Hour,event,wasHeldOn,9 February 2014,date,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,9,9
1027,2014 Liqui Moly Bathurst 12 Hour,event,wasTheTwelfthRunningOf,Bathurst 12 Hour,event,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,9,9
1028,2014 Liqui Moly Bathurst 12 Hour,event,included,GT3 cars,thing,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,9,9
1029,2014 Liqui Moly Bathurst 12 Hour,event,included,GT4 cars,thing,/Users/walder2/kg_uq/hotpot_qa_data/txt_files/...,9,9


# Building a KnowledgeGraphIndex from extracted triples

Pick an LLM for use, default is gpt-3.5-turbo. `service_context` will help with chunking documents and determining with LLM to call. The path for `documents` should point to `'./hotpot_qa_data/txt_files'`. 

In [4]:
llm = OpenAI(temperature=0)
service_context = ServiceContext.from_defaults(llm=llm, chunk_size=512)
documents = SimpleDirectoryReader(path_to_data + '/txt_files').load_data()

Define the and empty `KnowledgeGraphIndex`. We will fill this store up with our extracted triples and a reference to the `Node` with contains the document the triples were extracted from.

In [5]:
kg_index = KnowledgeGraphIndex(
    [],
    service_context=service_context,
)

node_parser = SentenceSplitter()
nodes = node_parser.get_nodes_from_documents(documents)

file_to_node = {node.metadata['file_path']: k for k, node in enumerate(nodes)}

We fill up the `KnowledgeGraphIndex` object by passing in triples corresponding to the `Node` they were extracted from. 

In [6]:
for doc_id in kg['doc_id'].unique():
    idx = kg['doc_id'] == doc_id

    for sub_id in kg[idx]['sub_idx'].unique():
        tmp = kg[np.bitwise_and(idx, kg['sub_idx'] == sub_id)]
        for h, r, t, f in zip(tmp['head'], tmp['relation'], tmp['tail'], tmp['file_path']):
            kg_index.upsert_triplet_and_node((h, r, t), nodes[file_to_node[f]])
    
        

# Define JointReriever for KG triple and text indexing.

In [7]:
class JointRetriever(BaseRetriever):
    """Custom retriever that performs both Vector search and Knowledge Graph search"""

    def __init__(
        self,
        vector_retriever: VectorIndexRetriever,
        kg_retriever: KGTableRetriever,
        mode: str = "OR",
    ) -> None:
        """Init params."""

        self._vector_retriever = vector_retriever
        self._kg_retriever = kg_retriever
        if mode not in ("AND", "OR"):
            raise ValueError("Invalid mode.")
        self._mode = mode
        super().__init__()

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes given query."""

        vector_nodes = self._vector_retriever.retrieve(query_bundle)
        kg_nodes = self._kg_retriever.retrieve(query_bundle)

        vector_ids = {n.node.node_id for n in vector_nodes}
        kg_ids = {n.node.node_id for n in kg_nodes}

        combined_dict = {n.node.node_id: n for n in vector_nodes}
        combined_dict.update({n.node.node_id: n for n in kg_nodes})

        if self._mode == "AND":
            retrieve_ids = vector_ids.intersection(kg_ids)
        else:
            retrieve_ids = vector_ids.union(kg_ids)

        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
        return retrieve_nodes

Create a `VectorStoreIndex` object for RAG with just the text. 

In [8]:
vector_index = VectorStoreIndex.from_documents(documents)

Now we instantiate a retriever for the KQ, text, which is passed in the `CustomRetriever` object for joint retrieval.

In [9]:
# create custom retriever
vector_retriever = VectorIndexRetriever(index=vector_index)
kg_retriever = KGTableRetriever(
    index=kg_index, retriever_mode="keyword", include_text=False
)
joint_retriever = JointRetriever(vector_retriever, kg_retriever)

# create response synthesizer
response_synthesizer = get_response_synthesizer(
    service_context=service_context,
    response_mode="tree_summarize",
)

Create query engines for all three cases: text, KG, and joint.

In [10]:
joint_query_engine = RetrieverQueryEngine(
    retriever=joint_retriever,
    response_synthesizer=response_synthesizer,
)

vector_query_engine = vector_index.as_query_engine()

# only use triples from the KG
kg_keyword_query_engine = kg_index.as_query_engine(
    include_text=False,
    retriever_mode="keyword",
    response_mode="tree_summarize",
)

# Evaluation of responses with RAG

Define a `CorrectnessEvaluator` that checks correctness of response to the query (with answer supplied). There are other tools for out in the wild evaluation of responses. E.g. 

`ResponseSourceEvaluator` - uses an LLM to decide if the response is similar enough to the sources -- a good measure for hallunication detection.

`QueryResponseEvaluator` - uses an LLM to decide if a response is similar enough to the original query -- a good measure for checking if the query was answered.


I've defined a function which uses `CorrectnessEvaluator` to check if the response contains an answer suitable for the query, given the correct answer. We can actually write custom evaluators that fit specified guidelines. This will come in handy later when we want to extend the QA to self defined embeddings ect. 

In [11]:
nest_asyncio.apply()

class EvaluateResponse:
    def __init__(self, evaluator: BaseEvaluator, query_engine: RetrieverQueryEngine) -> None:
        self.eval = evaluator
        self.q = query_engine
    
    async def run_query(self, x: Dict[str, str]):
        try:
            return await self.q.aquery(x['query'])
        except:
            return Response(response="Error, query failed.")
        
    def evaluate(self, x: List[Dict[str, str]]):
            total_correct = 0
            all_results = []
            for batch_size in range(0, len(x), 5):
                batch_x = x[batch_size:batch_size+5]

                tasks = [self.run_query(y) for y in batch_x]
                responses = asyncio.run(asyncio.gather(*tasks))

                for y, res in zip(batch_x, responses):
                    eval_result = self.eval.evaluate(query=y['query'], reference=y['answer'], response=res.response)
                    total_correct += 1 if eval_result.passing else 0 
                    all_results.append(eval_result)
                
                time.sleep(1)
            return total_correct, all_results
                                                                        
                                                                             

Run the retrievers on the queries and check correctness

In [12]:
evaluator = CorrectnessEvaluator(service_context=service_context)
kg_tot, kg_res = EvaluateResponse(evaluator, kg_keyword_query_engine).evaluate(query_answer)
print('KG evaluation complete...')
time.sleep(1)
txt_tot, txt_res = EvaluateResponse(evaluator, vector_query_engine).evaluate(query_answer)
print('Text evaluation complete...')
time.sleep(1)
joint_tot, joint_res = EvaluateResponse(evaluator, joint_query_engine).evaluate(query_answer)
print('Joint evaluation complete...')

KG evaluation complete...
Text evaluation complete...
Joint evaluation complete...


Take a look at the results. The LLM will tell us if the retrieval is correct and some feedback on why the response was deemed correct or incorrect. 

In [13]:
for i, x in enumerate(query_answer):
    print('-------------\nQuery: %s\nAnswer %s\n' % (repr(x['query']), repr(x['answer']) ))

    print('KG (%s): %s\nFeedback: %s\n' % (repr(kg_res[i].passing), repr(kg_res[i].response), repr(kg_res[i].feedback)))
    
    print('Text (%s): %s\nFeedback: %s\n' % (repr(txt_res[i].passing), repr(txt_res[i].response), repr(txt_res[i].feedback)))
    
    print('Joint (%s): %s\nFeedback: %s\n---------------\n' % (repr(joint_res[i].passing), repr(joint_res[i].response), repr(joint_res[i].feedback)))


-------------
Query: "Which magazine was started first Arthur's Magazine or First for Women?"
Answer "Arthur's Magazine"

KG (True): "Arthur's Magazine was started first."
Feedback: 'The generated answer is relevant and fully correct. It provides the correct information and is concise. The only improvement could be to include the name of the other magazine mentioned in the query, "First for Women", to make the answer more complete.'

Text (True): "Arthur's Magazine was started first."
Feedback: "The generated answer is relevant and fully correct. It provides the correct information that Arthur's Magazine was started first. However, it could be improved by providing more context or additional details about the magazine."

Joint (True): "Arthur's Magazine was started before First for Women."
Feedback: "The generated answer is relevant and correct. It provides the correct information that Arthur's Magazine was started before First for Women. However, it could be improved by providing more

In [14]:
print(f"KG Correct: {kg_tot}, Text Correct: {txt_tot}, Joint Correct: {joint_tot}, Total: {len(query_answer)}")

KG Correct: 2, Text Correct: 5, Joint Correct: 6, Total: 10


#### Things to do...

Looking at the output above, it looks like some correct answers are being marked incorrectly based on the answer being "too verbose". We can play with this by changing the prompt for the evaluator. Here are some thoughts on things to try out: 

    1) Clean up the triples a bit, some of the text was messy and try this again.
    2) Look at correcting the prompt for the evaluators.
    3) Check if top documents (subgraphs) correspond to the `hotpot_qa` dataset suggestion for top context. 
    4) Use manually defined embeddings for the KGs (need to fit hetero gnn and pass as embedding method)
