<a href="https://colab.research.google.com/github/KaifAhmad1/code-test/blob/main/RAG_with_Knowledge_Graph_OpenSource.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **RAG with Knowledge Graph**
- This notebook demonstrates how to build a knowledge graph using website data, visualize it, and answer complex queries that cannot be handled by traditional naive RAG.
- It includes optimizations for cybersecurity data, utilizing the ReLiK model for entity detection and relation extraction, and integrating with Kuzu DB for efficient querying.

**Installations**

In [None]:
!pip install -q langchain langchain-experimental langchain-core langchain-community langchain-groq pandas networkx
!pip install -q mpi4py pyvis ampligraph transformers relik kuzu pykeen torch

**Web Data Loading**
- This function loads documents from a list of URLs.
- Loading data from websites is the first step in any data processing pipeline. It ensures that we have the raw data needed for further analysis.

In [None]:
from langchain_community.document_loaders import WebBaseLoader

def load_data_from_websites(urls):
    """Load documents from a list of URLs."""
    web_base_loader = WebBaseLoader(urls)
    documents = web_base_loader.load()
    print(f"Loaded {len(documents)} documents.")
    return documents

# List of websites to load data from
websites = [
    "https://www.scmagazine.com/home/security-news/",
    "https://thehackernews.com/",
    "https://www.securityweek.com/",
    "https://www.darkreading.com/",
    "https://krebsonsecurity.com/",
    "https://www.bleepingcomputer.com/",
    "https://threatpost.com/",
    "https://www.cyberscoop.com/",
    "https://www.infosecurity-magazine.com/",
    "https://www.zdnet.com/topic/security/",
]

documents = load_data_from_websites(websites)

Loaded 10 documents.


**Split Documents into Chunks**
- This function splits the loaded documents into smaller chunks for easier processing.
- Splitting documents into chunks helps in managing large texts and ensures that each chunk can be processed independently.
- This function is used to break down long documents into manageable pieces for entity detection and triplet extraction.

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

def chunk_data(documents, chunk_size=500, chunk_overlap=50):
    """Split documents into smaller chunks."""
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        is_separator_regex=False,
    )
    chunks = text_splitter.split_documents(documents)
    print(f"Number of chunks created: {len(chunks)}")
    return chunks

chunks = chunk_data(documents)

Number of chunks created: 365


**Create a DataFrame of Chunks**
- This function converts the list of chunks into a pandas DataFrame.
- Converting chunks into a DataFrame allows for easier data manipulation and analysis using pandas.
- This function is used to create a structured format for the chunks, which will be used in subsequent steps.

In [None]:
import pandas as pd

def create_chunks_dataframe(chunks):
    """Create a DataFrame from document chunks."""
    data = {'content': [chunk.page_content for chunk in chunks]}
    return pd.DataFrame(data)

chunks_df = create_chunks_dataframe(chunks)
chunks_df.head()

Unnamed: 0,content
0,404: Not FoundCISO StoriesTopicsEventsPodcasts...
1,in any form without prior authorization.\n ...
2,The Hacker News | #1 Trusted Cybersecurity New...
3,Contact/Tip Us\n\n\n\nReach out to get featur...
4,"In 2023, no fewer than 94 percent of businesse..."


**Detect and Classify Entities**
- This function uses a Named Entity Recognition (NER) pipeline to detect entities in the text chunks.
- Detecting entities helps in identifying key concepts and entities in the text, which is crucial for understanding the content.
- This function is used to extract entities from the text chunks and create a DataFrame of detected entities.

In [None]:
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from relik import Relik
from langchain.schema.document import Document
import torch
import pandas as pd

def extract_entities_with_relik(chunks_df, max_retries=3):
    """Extract entities from document chunks using the ReLiK entity extraction model."""
    session = requests.Session()
    retry = Retry(total=max_retries, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
    adapter = HTTPAdapter(max_retries=retry)
    session.mount('http://', adapter)
    session.mount('https://', adapter)

    relik = Relik.from_pretrained(
        "relik-ie/relik-cie-tiny",
        device="cuda" if torch.cuda.is_available() else "cpu",
        precision="bf16" if torch.cuda.is_available() else "fp32",
        skip_metadata=True,  # don't load index metadata to keep low memory requirements
    )

    documents = [Document(page_content=chunk) for chunk in chunks_df['content']]
    entities = []

    for doc in documents:
        print(f"Processing document: {doc.page_content[:50]}...")  # Print the first 50 characters of the document
        if not doc.page_content.strip():
            print("Warning: Empty document content")
            continue

        # Use the ReLiK model to extract entities
        try:
            relik_output = relik(doc.page_content)
            for span in relik_output.spans:
                entities.append({
                    'Entity': span.text,
                    'Type': span.label
                })
        except Exception as e:
            print(f"Error processing document: {e}")

    entities_df = pd.DataFrame(entities)
    return entities_df

entities_df = extract_entities_with_relik(chunks_df)
entities_df.head()

                ___              __         
               /\_ \      __    /\ \        
 _ __     __   \//\ \    /\_\   \ \ \/'\    
/\`'__\ /'__`\   \ \ \   \/\ \   \ \ , <    
\ \ \/ /\  __/    \_\ \_  \ \ \   \ \ \\`\  
 \ \_\ \ \____\   /\____\  \ \_\   \ \_\ \_\
  \/_/  \/____/   \/____/   \/_/    \/_/\/_/
                                            
                                            

[2024-08-13 05:01:20,517] [INFO] [relik.inference.annotator.from_pretrained:700] [PID:561] [RANK:0] Loading Relik from relik-ie/relik-cie-tiny[39m
[2024-08-13 05:01:20,521] [INFO] [relik.inference.annotator.from_pretrained:701] [PID:561] [RANK:0] {
    '_target_': 'relik.inference.annotator.Relik',
    'index': {
        'span': {
            '_target_': 'relik.retriever.indexers.inmemory.InMemoryDocumentIndex.from_pretrained',
            'name_or_path': 'relik-ie/index-e5-small-v2-wikipedia-matryoshka',
        },
        'triplet': {
            '_target_': 'relik.retriever.indexers.

Unnamed: 0,Entity,Type
0,2024,--NME--
1,CyberRisk Alliance,--NME--
2,this website,--NME--
3,CyberRisk Alliance,--NME--
4,Hacker News,Hacker News


**Initializing LLM for Knowledge Extraction**
- Initialize the Mistral LLM using Groq for knowledge extraction.

In [None]:
from langchain_groq import ChatGroq
GROQ_API_KEY = "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"

# Initialize the Mistral LLM using Groq
llm = ChatGroq(
    temperature=0,
    model="mixtral-8x7b-32768",
    api_key=GROQ_API_KEY
)

**Extract Triplets using ReLiK**
- Use the ReLiK model for relation extraction.
- Extracting triplets helps in understanding the relationships between entities, which is essential for building a knowledge graph.
- This function is used to extract triplets from the text chunks and create a list of triplets.

In [None]:
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import torch
from relik import Relik

def extract_triplets_relik(chunks, max_retries=3):
    """Extract triplets using the ReLiK model."""
    session = requests.Session()
    retry = Retry(total=max_retries, backoff_factor=1, status_forcelist=[500, 502, 503, 504])
    adapter = HTTPAdapter(max_retries=retry)
    session.mount('http://', adapter)
    session.mount('https://', adapter)

    relik = Relik.from_pretrained(
        "relik-ie/relik-relation-extraction-small-wikipedia",
        device="cuda" if torch.cuda.is_available() else "cpu",
        precision="bf16" if torch.cuda.is_available() else "fp32",
        skip_metadata=True,
    )
    print("ReLiK model loaded successfully")

    all_triplets = []
    for chunk_id, chunk in enumerate(chunks):
        if not chunk.page_content or not isinstance(chunk.page_content, str):
            print(f"Skipping invalid chunk: {chunk}")
            continue

        print(f"Processing chunk with length: {len(chunk.page_content)}")

        try:
            relik_output = relik(chunk.page_content)
            if relik_output is None:
                print(f"ReLiK output is None for chunk with length: {len(chunk.page_content)}")
                continue

            for triplet in relik_output.triplets:
                all_triplets.append({
                    'chunk_id': chunk_id,
                    'Subject': triplet.subject.text,
                    'Predicate': triplet.label,
                    'Object': triplet.object.text
                })
        except IndexError as e:
            print(f"IndexError encountered: {e}")
            # No need to check the length of windows if relik_output is None
            if relik_output is not None:
                try:
                    if hasattr(relik_output, 'windows'):
                        print(f"Length of windows: {len(relik_output.windows)}")
                    if hasattr(relik_output, 'windows_candidates'):
                        print(f"Length of windows_candidates: {len(relik_output.windows_candidates)}")
                except Exception as inner_e:
                    print(f"Inner exception while accessing relik_output attributes: {inner_e}")
        except Exception as e:
            print(f"Unexpected error encountered: {e}")

    return all_triplets

triplets = extract_triplets_relik(chunks)
print(f"Number of triplets extracted: {len(triplets)}")

                ___              __         
               /\_ \      __    /\ \        
 _ __     __   \//\ \    /\_\   \ \ \/'\    
/\`'__\ /'__`\   \ \ \   \/\ \   \ \ , <    
\ \ \/ /\  __/    \_\ \_  \ \ \   \ \ \\`\  
 \ \_\ \ \____\   /\____\  \ \_\   \ \_\ \_\
  \/_/  \/____/   \/____/   \/_/    \/_/\/_/
                                            
                                            





config.yaml:   0%|          | 0.00/619 [00:00<?, ?B/s]

[2024-08-13 05:17:09,138] [INFO] [relik.inference.annotator.from_pretrained:700] [PID:561] [RANK:0] Loading Relik from relik-ie/relik-relation-extraction-small-wikipedia[39m
[2024-08-13 05:17:09,142] [INFO] [relik.inference.annotator.from_pretrained:701] [PID:561] [RANK:0] {
    '_target_': 'relik.inference.annotator.Relik',
    'index': {
        'triplet': {
            '_target_': 'relik.retriever.indexers.inmemory.InMemoryDocumentIndex.from_pretrained',
            'name_or_path': 'relik-ie/encoder-e5-small-v2-wikipedia-relations-index',
        },
    },
    'metadata_fields': [],
    'reader': {
        '_target_': 'relik.reader.pytorch_modules.triplet.RelikReaderForTripletExtraction',
        'transformer_model': 'relik-ie/relik-reader-deberta-v3-small-re-wikipedia',
    },
    'retriever': {
        'triplet': {
            '_target_': 'relik.retriever.pytorch_modules.model.GoldenRetriever',
            'question_encoder': 'relik-ie/encoder-e5-small-v2-wikipedia-relations',
  

config.json:   0%|          | 0.00/895 [00:00<?, ?B/s]

configuration_relik.py:   0%|          | 0.00/1.70k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/586M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/6.94k [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/8.65M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

[2024-08-13 05:17:26,868] [INFO] [relik.inference.utils.load_reader:383] [PID:561] [RANK:0] Moving reader to `cuda`.[39m
[2024-08-13 05:17:26,874] [INFO] [relik.inference.utils.load_reader:386] [PID:561] [RANK:0] Setting precision of reader to `torch.bfloat16`.[39m
ReLiK model loaded successfully
Processing chunk with length: 461




[36m[2024-08-13 05:17:27,837] [DEBUG] [relik.reader.data.relik_reader_re_data.__iter__:399] [PID:561] [RANK:0] Dataset finished: 1 number of elements processed[39m
Processing chunk with length: 149
[36m[2024-08-13 05:17:28,658] [DEBUG] [relik.reader.data.relik_reader_re_data.__iter__:399] [PID:561] [RANK:0] Dataset finished: 1 number of elements processed[39m
Processing chunk with length: 421
[36m[2024-08-13 05:17:29,459] [DEBUG] [relik.reader.data.relik_reader_re_data.__iter__:399] [PID:561] [RANK:0] Dataset finished: 2 number of elements processed[39m
Processing chunk with length: 440
[36m[2024-08-13 05:17:30,288] [DEBUG] [relik.reader.data.relik_reader_re_data.__iter__:399] [PID:561] [RANK:0] Dataset finished: 3 number of elements processed[39m
Processing chunk with length: 496
[36m[2024-08-13 05:17:31,501] [DEBUG] [relik.reader.data.relik_reader_re_data.__iter__:399] [PID:561] [RANK:0] Dataset finished: 2 number of elements processed[39m
Processing chunk with length: 496


**Analyze Relationships**
- This function converts the list of triplets into a DataFrame.
- Converting triplets into a DataFrame allows for easier analysis and manipulation of the relationships between entities.
- This function is used to create a structured format for the triplets, which will be used in subsequent steps.

In [33]:
def extract_relationships(triplets):
    """Extract relationships from the triplets."""
    relationships = []
    for triplet in triplets:
        relationships.append((triplet["Subject"], triplet["Predicate"], triplet["Object"]))
    return relationships

# Convert the list of relationships into a DataFrame
relationships = extract_relationships(triplets)
triplets_df = pd.DataFrame(relationships, columns=["subject", "predicate", "object"])
triplets_df.head()

TypeError: list indices must be integers or slices, not str

**Calculate Contextual Proximity**
- This function calculates the contextual proximity between nodes by merging the DataFrame with itself and counting the occurrences of node pairs within the same chunk.
- Contextual proximity helps in understanding the relationships between entities that appear together in the same context.
- This function is used to create a DataFrame of contextual proximity relationships between entities.


In [None]:
import numpy as np

def calculate_contextual_proximity(df):
    """Calculate contextual proximity between nodes."""
    long_format_df = pd.melt(
        df, id_vars=["chunk_id"], value_vars=["subject", "object"], value_name="node"
    )
    long_format_df.drop(columns=["variable"], inplace=True)
    merged_df = pd.merge(long_format_df, long_format_df, on="chunk_id", suffixes=("_1", "_2"))
    self_loops_index = merged_df[merged_df["node_1"] == merged_df["node_2"]].index
    merged_df = merged_df.drop(index=self_loops_index).reset_index(drop=True)

    # Convert chunk_id to string before joining
    merged_df['chunk_id'] = merged_df['chunk_id'].astype(str)

    grouped_df = (
        merged_df.groupby(["node_1", "node_2"])
        .agg({"chunk_id": [",".join, "count"]})
        .reset_index()
    )
    grouped_df.columns = ["node_1", "node_2", "chunk_id", "count"]
    grouped_df.replace("", np.nan, inplace=True)
    grouped_df.dropna(subset=["node_1", "node_2"], inplace=True)
    grouped_df = grouped_df[grouped_df["count"] != 1]
    grouped_df["edge"] = "contextual proximity"
    return grouped_df

# Calculate contextual proximity
contextual_proximity_df = calculate_contextual_proximity(triplets_df)
contextual_proximity_df.head()

Unnamed: 0,node_1,node_2,chunk_id,count,edge
0,.top,August 2024,137137,2,contextual proximity
1,.top,Chinese,137137137137137137140140140140140140,12,contextual proximity
2,.top,ICANN,138138143143144144144,7,contextual proximity
3,.top,Internet Corporation for Assigned Names and Nu...,138138,2,contextual proximity
4,.top,Jiangsu,143143143143,4,contextual proximity


**Merge DataFrames**
- This function merges the concepts DataFrame with the contextual proximity DataFrame and aggregates the data.
- Merging the DataFrames allows for a comprehensive view of the relationships between entities.
- This function is used to create a merged DataFrame that will be used to create the graph.

In [27]:
def merge_dataframes(concepts_df, contextual_proximity_df):
    """Merge the concepts DataFrame with the contextual proximity DataFrame."""
    merged_df = pd.concat([concepts_df, contextual_proximity_df], axis=0, ignore_index=True, sort=False)
    return merged_df
# Merge DataFrames
merged_df = merge_dataframes(triplets_df, contextual_proximity_df)
merged_df.head()

Unnamed: 0,chunk_id,subject,predicate,object,node_1,node_2,count,edge
0,1,this website,digital rights management system,CyberRisk Alliance,,,,
1,3,RSS Feeds,has use,Social Media,,,,
2,3,RSS Feeds,has use,Social Media,,,,
3,3,RSS Feeds,has use,Social Media,,,,
4,3,Email Alerts,has use,Social Media,,,,


**Create NetworkX Graph**
- This function creates a NetworkX graph from the merged DataFrame, with nodes and edges representing the relationships between entities.
- Creating a graph allows for visualizing and analyzing the relationships between entities.
- This function is used to create a graph that will be used for further analysis and visualization.

In [31]:
import networkx as nx
from pyvis.network import Network
import torch
from pykeen.triples import TriplesFactory
from pykeen.models import RotatE
from pykeen.training import SLCWATrainingLoop
from pykeen.evaluation import RankBasedEvaluator

def create_network_graph(triplets_df):
    """Create a network graph using NetworkX and Pyvis, with centrality metrics and PyKEEN RotatE embeddings."""
    G = nx.MultiDiGraph()

    for _, row in triplets_df.iterrows():
        subject = str(row['subject'])
        obj = str(row['object'])
        predicate = str(row['predicate'])

        # Add nodes and edges
        G.add_node(subject, label=subject, title=subject)
        G.add_node(obj, label=obj, title=obj)
        G.add_edge(subject, obj, label=predicate, title=predicate)

    # Calculate centrality metrics
    degree_centrality = nx.degree_centrality(G)
    betweenness_centrality = nx.betweenness_centrality(G)
    pagerank = nx.pagerank(G)

    # Add centrality metrics as node attributes
    nx.set_node_attributes(G, degree_centrality, 'degree_centrality')
    nx.set_node_attributes(G, betweenness_centrality, 'betweenness_centrality')
    nx.set_node_attributes(G, pagerank, 'pagerank')

    # Pyvis Visualization
    net = Network(notebook=True, height="750px", width="100%", bgcolor="#222222", font_color="white")
    net.from_nx(G)

    # Customize node appearance based on centrality metrics
    for node in net.nodes:
        node['size'] = 10 + (degree_centrality.get(node['id'], 0) * 50)
        node['color'] = f"rgb({int(255 * betweenness_centrality.get(node['id'], 0))}, {int(255 * pagerank.get(node['id'], 0))}, 100)"

    net.set_options("""
    var options = {
      "nodes": {
        "shape": "dot",
        "font": {
          "size": 12
        }
      },
      "edges": {
        "width": 2,
        "color": {
          "inherit": true
        }
      },
      "physics": {
        "forceAtlas2Based": {
          "gravitationalConstant": -50,
          "centralGravity": 0.01,
          "springLength": 230,
          "springConstant": 0.18
        },
        "maxVelocity": 50,
        "solver": "forceAtlas2Based",
        "timestep": 0.22,
        "stabilization": {
          "iterations": 150
        }
      }
    }
    """)

    return net, G


# Create TriplesFactory for PyKEEN
triplets = merged_df[['subject', 'predicate', 'object']].values.tolist()
tf = TriplesFactory.from_labeled_triples(triplets)

# PyKEEN Training
model = RotatE(triples_factory=tf, embedding_dim=200, random_seed=42)
training_loop = SLCWATrainingLoop(model=model, triples_factory=tf)
losses = training_loop.train(num_epochs=50, batch_size=1024, use_gpu=True)

# Visualize the network graph
net, G = create_network_graph(merged_df)
net.show("network_graph.html")

TypeError: list indices must be integers or slices, not tuple

**PyKeen Training**

In [30]:
# Visualize the network graph
net, G = create_network_graph(merged_df)
net.show("network_graph.html")



AssertionError: 

**Metrics**
- Evaluation of the model


In [25]:
# Print evaluation metrics
print("Model Evaluation Metrics:")
print(metrics)

# Print top 5 nodes for each centrality metric
print("\nTop 5 nodes by Degree Centrality:")
print(sorted(nx.degree_centrality(G).items(), key=lambda x: x[1], reverse=True)[:5])

print("\nTop 5 nodes by Betweenness Centrality:")
print(sorted(nx.betweenness_centrality(G).items(), key=lambda x: x[1], reverse=True)[:5])

print("\nTop 5 nodes by PageRank:")
print(sorted(nx.pagerank(G).items(), key=lambda x: x[1], reverse=True)[:5])

Model Evaluation Metrics:
<pykeen.evaluation.rank_based_evaluator.RankBasedMetricResults object at 0x7f9c6221e260>

Top 5 nodes by Degree Centrality:
[('Windows', 0.1832797427652733), ('Microsoft', 0.13504823151125403), ('Google', 0.10610932475884244), ('U.S.', 0.08681672025723472), ('Android', 0.0594855305466238)]

Top 5 nodes by Betweenness Centrality:
[('U.S.', 0.003906674744085621), ('Google', 0.0018808477147635543), ('Ukraine', 0.0016180727071262511), ('Microsoft', 0.001518399428367274), ('Windows', 0.0010627501540405218)]

Top 5 nodes by PageRank:
[('U.S.', 0.08645919507541587), ('Russia', 0.04550903466068782), ('Ukraine', 0.03517356585833704), ('Microsoft', 0.03042955350789955), ('Google', 0.021278035108809445)]


In [26]:
# Function to get entity embeddings
def get_entity_embeddings(model, entity_name):
    entity_id = model.triples_factory.entity_to_id[entity_name]
    return model.entity_embeddings(torch.tensor([entity_id])).detach().numpy()

# Print embeddings for a few sample entities
sample_entities = list(G.nodes())[:5]  # Get first 5 entities
print("\nSample Entity Embeddings:")
for entity in sample_entities:
    print(f"{entity}: {get_entity_embeddings(model, entity)}")


Sample Entity Embeddings:


AttributeError: 'RotatE' object has no attribute 'triples_factory'

**Set up Kuzu DB and Create the Schema**
- We'll set up the Kuzu DB and create the schema:


In [None]:
import kuzu

def setup_kuzu_db(db_name):
    """Set up Kuzu DB and create the schema."""
    db = kuzu.Database(db_name)
    conn = kuzu.Connection(db)

    conn.execute("CREATE NODE TABLE Movie (name STRING, PRIMARY KEY(name))")
    conn.execute("CREATE NODE TABLE Person (name STRING, birthDate STRING, PRIMARY KEY(name))")
    conn.execute("CREATE REL TABLE ActedIn (FROM Person TO Movie)")

    # Create additional tables and relationships
    conn.execute("CREATE NODE TABLE Incident (name STRING, date STRING, type STRING, PRIMARY KEY(name))")
    conn.execute("CREATE REL TABLE InvolvedIn (FROM Person TO Incident)")
    conn.execute("CREATE REL TABLE Targeted (FROM Incident TO Movie)")

    return conn

# Set up Kuzu DB
conn = setup_kuzu_db("test_db")

**Insert Data into Kuzu DB**
- We'll insert data into the Kuzu DB

In [None]:
def insert_data_into_kuzu(conn, triplets_df):
    """Insert data into Kuzu DB."""
    for _, row in triplets_df.iterrows():
        subject, predicate, obj = row['subject'], row['predicate'], row['object']
        if predicate == 'ActedIn':
            conn.execute(f"CREATE (:Person {{name: '{subject}'}})-[:ActedIn]->(:Movie {{name: '{obj}'}})")
        elif predicate == 'InvolvedIn':
            conn.execute(f"CREATE (:Person {{name: '{subject}'}})-[:InvolvedIn]->(:Incident {{name: '{obj}'}})")
        elif predicate == 'Targeted':
            conn.execute(f"CREATE (:Incident {{name: '{subject}'}})-[:Targeted]->(:Movie {{name: '{obj}'}})")

# Insert data into Kuzu DB
insert_data_into_kuzu(conn, triplets_df)

**Create KuzuQAChain and Query the Graph**
- We'll create a KuzuQAChain for querying the graph


In [None]:
from langchain.chains import KuzuQAChain
from langchain_community.graphs import KuzuGraph

def create_kuzu_qa_chain(db, api_key):
    """Create KuzuQAChain for querying the graph."""
    graph = KuzuGraph(db)
    chain = KuzuQAChain.from_llm(llm, graph=graph, verbose=True)
    return chain

# Create KuzuQAChain
chain = create_kuzu_qa_chain(db, GROQ_API_KEY)

# Process queries
queries = [
    "List all details on BFSI security incidents in India.",
    "List all ransomware attacks targeting the healthcare industry in the last 7 days.",
    "Provide recent incidents related to Lockbit Ransomware gang.",
    "Provide recent incidents related to BlackBasta Ransomware."
]

# Query the graph
for query in queries:
    print(f"Query: {query}")
    result = chain.run(query)
    print(result)
    print("\n")

**Refresh Schema Information**
- We'll refresh the schema information



In [None]:
def refresh_schema(graph):
    """Refresh the schema information needed to generate Cypher statements."""
    graph.refresh_schema()
    print("Schema information refreshed.")

# Refresh schema
refresh_schema(graph)

**Add Indexing**
- We'll create indexes on the Kuzu DB

In [None]:
def create_indexes(conn):
    """Create indexes on the Kuzu DB."""
    safe_execute(conn, "CREATE INDEX ON Movie(name)")
    safe_execute(conn, "CREATE INDEX ON Person(name)")
    safe_execute(conn, "CREATE INDEX ON Incident(name)")

# Create indexes
create_indexes(conn)

**Add More Complex Queries**
- We'll run more complex queries using the KuzuQAChain

In [None]:
def run_complex_queries(chain):
    """Run more complex queries using the KuzuQAChain."""
    complex_queries = [
        "Which actors were involved in incidents related to Lockbit Ransomware?",
        "List all movies targeted by ransomware attacks in the last month.",
        "Who is the oldest actor involved in any incident?"
    ]

    for query in complex_queries:
        print(f"Query: {query}")
        result = chain.run(query)
        print(result)
        print("\n")

# Run complex queries
run_complex_queries(chain)