<a href="https://colab.research.google.com/github/KaifAhmad1/code-test/blob/main/KG_Enhanced_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
!pip install -qU langchain langchain-community faiss-cpu kuzu pyvis sentence-transformers transformers torch plotly pandas scikit-learn networkx
!pip install --upgrade -q torch torchvision

In [18]:
import kuzu
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
import networkx as nx
import plotly.graph_objects as go
import plotly.express as px
from sklearn.manifold import TSNE
import numpy as np
import pandas as pd

In [19]:
from transformers import pipeline

In [20]:
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

In [21]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [22]:
import transformers
import torch

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

outputs = pipeline(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


KeyboardInterrupt: 

In [None]:
llm = HuggingFacePipeline(pipeline=pipeline)

In [23]:
# Initialize Kuzu DB
db = kuzu.Database("my_knowledge_graph")
conn = kuzu.Connection(db)

In [25]:
# Create schema for the graph
conn.execute("CREATE NODE TABLE Entity (name STRING, PRIMARY KEY (name))")
conn.execute("CREATE REL TABLE Relation (FROM Entity TO Entity, predicate STRING)")

<kuzu.query_result.QueryResult at 0x7c3afc1c68c0>

In [26]:
websites = [
    "https://neurons-lab.com/",
    "https://neurons-lab.com/about-us/",
    "https://www.crunchbase.com/organization/neurons-lab",
]

In [27]:
loader = WebBaseLoader(websites)
documents = loader.load()

In [28]:
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = text_splitter.split_documents(documents)



In [29]:
vectorstore = FAISS.from_documents(texts, embeddings)

In [30]:
# Knowledge extraction and graph population (same as before)
kg_triple_extract_template = """
Extract up to 5 knowledge triplets from the text below in the form (subject, predicate, object).
Text: {text}
Triplets:
"""
kg_triple_extract_prompt = PromptTemplate(
    input_variables=["text"],
    template=kg_triple_extract_template,
)

kg_triple_extract_chain = LLMChain(llm=llm, prompt=kg_triple_extract_prompt)

for text in texts:
    triplets = kg_triple_extract_chain.run(text.page_content)
    for triplet in triplets.split('\n'):
        if triplet.strip():
            try:
                subject, predicate, obj = eval(triplet.strip())
                conn.execute("INSERT INTO Entity (name) VALUES ($1) ON CONFLICT DO NOTHING", [subject])
                conn.execute("INSERT INTO Entity (name) VALUES ($1) ON CONFLICT DO NOTHING", [obj])
                conn.execute("INSERT INTO Relation VALUES ($1, $2, $3)", [subject, obj, predicate])
            except:
                print(f"Failed to process triplet: {triplet}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


ValueError: Input length of input_ids is 177, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.

In [None]:
# Function to retrieve graph data from Kuzu DB
def get_graph_data():
    nodes = conn.execute("MATCH (e:Entity) RETURN e.name").fetchall()
    edges = conn.execute("MATCH (e1:Entity)-[r:Relation]->(e2:Entity) RETURN e1.name, r.predicate, e2.name").fetchall()
    return nodes, edges

In [None]:
# Enhanced graph visualization using Plotly
def visualize_graph_plotly():
    nodes, edges = get_graph_data()
    G = nx.Graph()

    for node in nodes:
        G.add_node(node[0])

    for edge in edges:
        G.add_edge(edge[0], edge[2], label=edge[1])

    pos = nx.spring_layout(G)

    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    node_x = [pos[node][0] for node in G.nodes()]
    node_y = [pos[node][1] for node in G.nodes()]

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line_width=2))

    node_adjacencies = []
    node_text = []
    for node, adjacencies in G.adjacency():
        node_adjacencies.append(len(adjacencies))
        node_text.append(f'{node}<br># of connections: {len(adjacencies)}')

    node_trace.marker.color = node_adjacencies
    node_trace.text = node_text

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Knowledge Graph',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        annotations=[ dict(
                            text="",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002 ) ],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )

    fig.show()

In [None]:
# Embedding visualization
def visualize_embeddings():
    # Get embeddings
    doc_embeddings = [embeddings.embed_query(text.page_content) for text in texts]

    # Reduce dimensionality for visualization
    tsne = TSNE(n_components=3, random_state=42)
    vis_dims = tsne.fit_transform(doc_embeddings)

    # Create a DataFrame for Plotly
    df = pd.DataFrame(vis_dims, columns=['x', 'y', 'z'])
    df['text'] = [text.page_content[:100] + '...' for text in texts]  # Truncate text for readability

    # Create 3D scatter plot
    fig = px.scatter_3d(df, x='x', y='y', z='z', hover_data=['text'],
                        title='Document Embeddings Visualization')
    fig.show()

In [None]:
# Query function (same as before)
def query_graph(query):
    docs = vectorstore.similarity_search(query, k=2)
    context = "\n".join([doc.page_content for doc in docs])

    query_template = """
    Given the following context and question, provide a concise answer:
    Context: {context}
    Question: {question}
    Answer:
    """
    query_prompt = PromptTemplate(
        input_variables=["context", "question"],
        template=query_template,
    )
    query_chain = LLMChain(llm=llm, prompt=query_prompt)

    return query_chain.run({"context": context, "question": query})

In [None]:
# Main execution
if __name__ == "__main__":
    visualize_graph_plotly()
    visualize_embeddings()

    # Example query
    question = "How can Neurons Lab help with a fintech use case to solve fraud?"
    answer = query_graph(question)
    print(f"Question: {question}\nAnswer: {answer}")

    # Close the database connection
    conn.close()