# BYOKG RAG Demo
This notebook demonstrates a RAG (Retrieval Augmented Generation) system built on top of a Knowledge Graph. The system allows querying a knowledge graph using natural language questions and retrieving relevant information to generate answers.

1. **Graph Store**: Manages the knowledge graph data structure
2. **KG Linker**: Links natural language queries to graph entities and paths
3. **Entity Linker**: Matches entities from text to graph nodes
4. **Triplet Retriever**: Retrieves relevant triplets from the graph
5. **Path Retriever**: Finds paths between entities in the graph
6. **Query Engine**: Orchestrates all components to answer questions

#### Setup
If you haven't already, install the toolkit and dependencies in [README.md](../../byokg-rag/README.md).
Let's validate if the package is correctly installed.

In [None]:
import graphrag_toolkit.byokg_rag

In [None]:
!pip install https://github.com/awslabs/graphrag-toolkit/archive/refs/tags/v3.11.2.zip#subdirectory=byokg-rag

### Graph Store
The `LocalKGStore` class provides an interface to work with the knowledge graph. Here we
1. Initialize the graph store
2. Load data from a CSV file
3. Get basic statistics about the graph
4. Examine sample edges for a specific node (Wynton Marsalis)

In [None]:
from graphrag_toolkit.byokg_rag.graphstore import LocalKGStore

graph_store = LocalKGStore()
graph_store.read_from_csv('freebase_tiny_kg.csv')
# Print graph statistics
schema = graph_store.get_schema()
number_of_nodes = len(graph_store.nodes())
number_of_edges = len(graph_store.get_triplets())
print(f"The graph has {number_of_nodes} nodes and {number_of_edges} edges.")

# Let's also see neighbor edges of node "Wynton Marsalis"
import random
sample_triplets = graph_store.get_one_hop_edges(["Wynton Marsalis"])
sample_triplets = random.sample(list(sample_triplets["Wynton Marsalis"].items()), 3)
print("Some neighboring edges of node 'Wynton Marsalis' are: ", sample_triplets)


### Question Answering

We define a sample question and its ground truth answer to test our system. The question requires reasoning through multiple hops in the knowledge graph to find the answer.

In [None]:
question = "What genre of film is associated with the place where Wynton Marsalis was born?"
answer = "Backstage Musical"

### KG Linker
The `KGLinker` uses an LLM (Claude 3.5 Sonnet) to:
1. Extract entities from the question
2. Identify potential relationship paths in the graph
3. Generate initial responses based on its knowledge

In [None]:

from graphrag_toolkit.byokg_rag.graph_connectors import KGLinker
from graphrag_toolkit.byokg_rag.llm import BedrockGenerator



# Initialize llm
llm_generator = BedrockGenerator(
                model_name='us.anthropic.claude-3-5-sonnet-20240620-v1:0',
                region_name='us-west-2')

kg_linker = KGLinker(graph_store=graph_store, llm_generator=llm_generator)
response = kg_linker.generate_response(
                question=question,
                schema=schema,
                graph_context="Not provided. Use the above schema to understand the graph."
            )
response


In [None]:
artifacts = kg_linker.parse_response(response)
artifacts

### Entity Linking
The `EntityLinker` uses fuzzy string matching to
1. Match extracted entities to actual nodes in the graph
3. Link potential answers to graph nodes

In [None]:
from graphrag_toolkit.byokg_rag.indexing import FuzzyStringIndex
from graphrag_toolkit.byokg_rag.graph_retrievers import EntityLinker

# Add graph nodes text for string matching
string_index = FuzzyStringIndex()
string_index.add(graph_store.nodes())
retriever = string_index.as_entity_matcher()
entity_linker = EntityLinker(retriever=retriever)

linked_entities = entity_linker.link(artifacts["entity-extraction"], return_dict=False)
linked_answers = entity_linker.link(artifacts["draft-answer-generation"], return_dict=False)
linked_entities, linked_answers

### Triplet Retrieval
The `AgenticRetriever` uses an LLM to:
1. Navigate the graph starting from linked entities
2. Select relevant relations based on the question
3. Expand those relations and decide which relevant entities to explore next.
4. It returns the relevant (head->relation->tail) based on the question.


In [None]:
from graphrag_toolkit.byokg_rag.graph_retrievers import AgenticRetriever
from graphrag_toolkit.byokg_rag.graph_retrievers import GTraversal, TripletGVerbalizer
graph_traversal = GTraversal(graph_store)
graph_verbalizer = TripletGVerbalizer()
triplet_retriever = AgenticRetriever(
    llm_generator=llm_generator, 
    graph_traversal=graph_traversal,
    graph_verbalizer=graph_verbalizer)

In [None]:
triplet_context = triplet_retriever.retrieve(query=question, source_nodes=linked_entities)
triplet_context

### Path Retrieval
The `PathRetriever` uses the identified metapaths and candidate answers to:
1. Retrieve actual paths in the graph following the metapath
2. Retrieve shortest paths connecting question entities and candidate answers (if any) 
3. Verbalize the paths for context

In [None]:
from graphrag_toolkit.byokg_rag.graph_retrievers import PathRetriever
from graphrag_toolkit.byokg_rag.graph_retrievers import GTraversal, PathVerbalizer
graph_traversal = GTraversal(graph_store)
path_verbalizer = PathVerbalizer()
path_retriever = PathRetriever(
    graph_traversal=graph_traversal,
    path_verbalizer=path_verbalizer)

metapaths = [[component.strip() for component in path.split("->")] for path in artifacts["path-extraction"]]
shortened_paths = []
for path in metapaths:
    if len(path) > 1:
        shortened_paths.append(path[:1])
for path in metapaths:
    if len(path) > 2:
        shortened_paths.append(path[:2])
metapaths += shortened_paths
path_context = path_retriever.retrieve(linked_entities, metapaths, linked_answers)
path_context

### Retrieval Evaluation

We evaluate whether the answer is found in the retrieved context.

In [None]:
context = list(set(triplet_context + path_context))
print(f"Success! Ground-truth answer `{answer}` retrieved!") if answer in '\n'.join(context) else print("Failure..")


### BYOKG RAG Pipeline

The `ByoKGQueryEngine` combines all components to:
1. Process natural language questions
2. Retrieve relevant context from the graph
3. Generate answers based on the retrieved information

In [None]:
from graphrag_toolkit.byokg_rag.byokg_query_engine import ByoKGQueryEngine
byokg_query_engine = ByoKGQueryEngine(
    graph_store=graph_store,
    kg_linker=kg_linker,
    triplet_retriever=triplet_retriever,
    path_retriever=path_retriever,
    entity_linker=entity_linker
)

In [None]:
retrieved_context = byokg_query_engine.query(question)
answers, response = byokg_query_engine.generate_response(question, "\n".join(retrieved_context))

print("Retrieved context: ", "\n".join(retrieved_context))
print("Generated answers: ", answers)
print(f"Success! Ground-truth answer `{answer}` retrieved!") if answer in '\n'.join(answers) else print("Failure..")


### Testing with Another Question

In [None]:
question="What are all airports in the city where New York Times is circulated?"
retrieved_context = byokg_query_engine.query(question)


In [None]:
print("Retrieved context: ", "\n".join(retrieved_context))
answers, response = byokg_query_engine.generate_response(question, "\n".join(retrieved_context))
print("Generated response: ", response)


### Testing with Scoring-based Retriever

In [None]:
from graphrag_toolkit.byokg_rag.graph_retrievers import GraphScoringRetriever
from graphrag_toolkit.byokg_rag.graph_retrievers import GTraversal, TripletGVerbalizer
from graphrag_toolkit.byokg_rag.graph_retrievers import LocalGReranker
graph_traversal = GTraversal(graph_store)
graph_verbalizer = TripletGVerbalizer()

device="cuda" #change to device="cpu" if no gpu available
graph_reranker = LocalGReranker(model_name='BAAI/bge-reranker-v2-m3', topk=10, device=device)
triplet_retriever = GraphScoringRetriever(
    graph_traversal=graph_traversal,
    graph_verbalizer=graph_verbalizer,
    graph_reranker=graph_reranker
    )

In [None]:
triplet_context = triplet_retriever.retrieve(query="What genre of film is associated with the place where Wynton Marsalis was born?",
                                            source_nodes=["Wynton Marsalis"])
print("Retrieved context: ", "\n".join(triplet_context))
answer = "Backstage Musical"
print(f"Success! Ground-truth answer `{answer}` retrieved!") if answer in '\n'.join(triplet_context) else print("Failure..")
