# Knowledge Graph extraction from documents
Based on https://github.com/lamm-mit/GraphReasoning

This notebook creates a knowledge graph using functions in the GraphReasoning_Mod module which is adpated from https://github.com/lamm-mit/GraphReasoning to work with huggingface endpoints and for use with huggingface documents. 
- documents are scrapped using data.ipynb from pdfs, blogs, and youtube videos transcripts. 
- LLM model Mistral-Nemo-Instruct-2407
- Embedding model: dunzhang/stella_en_1.5B_v5 (chosen based on size and position on leaderboard DEC 2024)
- KG created without refinement loops in when identifying nodes. 
- KG uses simplify graph and additonal tools in GraphReasoning_Mod

- KG has metadata along edges and uses an updated generation prompt.
- KG does not keep source documents as nodes (could add later by referencing the pre-simplify.csv) 

_Note:Some file paths may have to be updated due to restructing of the repo_


## Setup

In [None]:
# %pip install -r requirements.txt -q

In [3]:
import sys
import os
root_path = "c:\\Users\\jonathan.kasprisin\\gitlab\\DNoK_GraphRAG"
os.chdir(root_path)
sys.path.append(root_path)

from GraphReasoning_Mod.graph_tools import *
from GraphReasoning_Mod.utils import *
from GraphReasoning_Mod.graph_generation import *
from GraphReasoning_Mod.graph_analysis import *


In [None]:
from langchain_huggingface import HuggingFaceEndpoint

#Initialize the model endpoint
HOST_URL_INF = ":8080" #Mistral-NeMo-Instruct-2407
MAX_NEW_TOKENS = 1012

TEMPERATURE = 0.2
TIMEOUT = 120
TOP_P = .9

llm = HuggingFaceEndpoint(
    endpoint_url=HOST_URL_INF,
    task="text-generation",
    max_new_tokens=MAX_NEW_TOKENS,
    do_sample=False,
    temperature = TEMPERATURE,
    timeout=TIMEOUT,
    top_p=TOP_P
)
#print(llm.invoke("What is HuggingFace?"))

In [2]:
from langchain_huggingface import HuggingFaceEmbeddings

model_name = "dunzhang/stella_en_1.5B_v5" #"BAAI/bge-small-en-v1.5" #dunzhang/stella_en_1.5B_v5
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
embd = HuggingFaceEmbeddings(
    model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)

## Import Documents

In [None]:
import pickle
import os

# load pickled documents
pickle_file_path = './data/storage/full_all_documents.pkl'
if os.path.exists(pickle_file_path):
    with open(pickle_file_path, 'rb') as f:
        all_pdf_docs, all_yt_docs, all_blog_docs = pickle.load(f)
else:
    print("Pickle file not found.")

#check if the documents are loaded
print("Number of PDF documents:", len(all_pdf_docs))
print("Number of YouTube documents:", len(all_yt_docs))
print("Number of blog documents:", len(all_blog_docs))


#standardize the metadata
all_pdf_docs, all_yt_docs, all_blog_docs = standardize_document_metadata(all_pdf_docs, all_yt_docs, all_blog_docs)

# Combine all documents into a single list
all_docs = all_pdf_docs+  all_yt_docs+  all_blog_docs

print(f"Total number of documents: {len(all_docs)}")

#print dictionary keys from metadata
print("Metadata keys for PDF documents:", all_pdf_docs[0].metadata.keys())
print("Metadata keys for yt documents:", all_yt_docs[0].metadata.keys())
print("Metadata keys for blog documents:", all_blog_docs[0].metadata.keys())


## Create create and save networkx graph

In [None]:
# # Initialize variables
# G_existing = None
# existing_node_embeddings = None
# failed_batches = []
# output_directory = 'data_no_refine3'  

# with open(f'{output_directory}/embeddings.pkl', 'rb') as f:
#     existing_node_embeddings = pickle.load(f)

# with open(f'{output_directory}/failed_batches.pkl', 'rb') as f:
#     failed_batches = pickle.load(f)

# graph_path=f'{output_directory}/final_augmented_graph.graphml'
# G_existing = nx.read_graphml(graph_path)
    


# # Process documents in chunks
# chunk_size = 1500
# batch_size = 5

# # Split all_docs into batches of size batch_size
# doc_batches = [all_docs[i:i + batch_size] for i in range(0, len(all_docs), batch_size)]

# doc_batches = doc_batches[453:]

# for batch_idx, doc_batch in tqdm(enumerate(doc_batches), total=len(doc_batches), desc="Processing batches..."):
#     try:
#         G_existing, existing_node_embeddings, res = add_new_subgraph_from_docs(
#             input_docs=doc_batch,
#             llm=llm,
#             embd=embd,
#             data_dir_output=f"./{output_directory}/",
#             verbatim=False,
#             size_threshold=10,
#             chunk_size=chunk_size,
#             do_Louvain_on_new_graph=True,
#             include_contextual_proximity=False,
#             repeat_refine=0,
#             similarity_threshold=1.0,
#             do_simplify_graph=False,
#             return_only_giant_component=False,
#             save_common_graph=False,
#             G_exisiting=G_existing,
#             graph_GraphML_exisiting=None,
#             existing_node_embeddings=existing_node_embeddings,
#             add_source_nodes=False
#         )

#         print(f"Processed batch {batch_idx}, updated graph stats:", res)
#         with open(f'{output_directory}/embeddings.pkl', 'wb') as f:
#             pickle.dump(existing_node_embeddings, f)
#         with open(f'{output_directory}/failed_batches.pkl', 'wb') as f:
#             pickle.dump(failed_batches, f)

#     except Exception as e:
#         # Log the failed batch index
#         failed_batches.append(batch_idx)
#         print(f"Error processing batch {batch_idx} with batch size {batch_size}: {e}")
#         traceback.print_exc()

#         with open(f'{output_directory}/error_log.txt', 'a', encoding='utf-8') as f:
#             f.write(f"Batch {batch_idx} failed. Error: {e}\n")

# # Final graph statistics and saving
# print("Final graph statistics:", res if 'res' in locals() else "No successful batches")
# print("Failed batch indices:", failed_batches)

## Simplify

In [16]:
from sklearn.neighbors import NearestNeighbors
import tqdm as tqdm

def simplify_graph_v2(graph_in, node_embeddings, embd, llm=None, similarity_threshold=0.95, use_llm=False,
                   data_dir_output='./', graph_root='simple_graph', verbatim=False, max_tokens=2048, 
                   temperature=0.3, generate=None):
    """
    Simplifies a graph by merging similar nodes. Modified for large graphs and memory constraints
    """
    if verbatim:
            print("calculating KNN...")

    graph = graph_in.copy()
    
    nodes = list(node_embeddings.keys())
    embeddings_matrix = np.array([np.array(node_embeddings[node]).flatten() for node in nodes])

    assert len(nodes) == embeddings_matrix.shape[0], "simplify_graph: Number of nodes and embeddings do not match"
    
    #check if any node embeddings are None
    none_in_embeddings = False
    if None in embeddings_matrix:
        none_in_embeddings = True

        
    assert not none_in_embeddings, "simplify_graph: None in embeddings"

     # Using NearestNeighbors to reduce memory usage
    nn = NearestNeighbors(metric='cosine', n_neighbors=30, n_jobs=-1)  # n_neighbors > 1 to compare
    nn.fit(embeddings_matrix)
    distances, indices = nn.kneighbors(embeddings_matrix)


    node_mapping = {}
    nodes_to_recalculate = set()
    merged_nodes = set()  # Keep track of nodes that have been merged

    if verbatim:
        print("Start merge")

    for i, neighbors in tqdm.tqdm(enumerate(indices), total=len(nodes), desc="Merging nodes..."):
        node_i = nodes[i]
        for j, dist in zip(neighbors[1:], distances[i][1:]):  # Skip self (index 0 is itself)
            if dist <= (1 - similarity_threshold):
                node_j = nodes[j]
                if node_i != node_j and node_j not in merged_nodes and node_i not in merged_nodes:
                    try:
                        if graph.degree(node_i) >= graph.degree(node_j):
                            node_to_keep, node_to_merge = node_i, node_j
                        else:
                            node_to_keep, node_to_merge = node_j, node_i

                        # if verbatim:
                        #     print(f"Merging: {node_to_merge} --> {node_to_keep}")

                        node_mapping[node_to_merge] = node_to_keep
                        nodes_to_recalculate.add(node_to_keep)
                        merged_nodes.add(node_to_merge)
                    except Exception as e:
                        print(f"Error merging nodes {node_i} and {node_j}: {e}")
    if verbatim:
        print ("Now relabel. ")
    # Create the simplified graph by relabeling nodes. removes adds edges from to_merge to to_keep and removes to_merge.
    new_graph = nx.relabel_nodes(graph, node_mapping, copy=True)
    if verbatim:
        print ("New graph generated, nodes relabled. ")
    # Recalculate embeddings for nodes that have been merged or renamed.
    recalculated_embeddings = regenerate_node_embeddings(new_graph, nodes_to_recalculate, embd, verbatim=verbatim)
    if verbatim:
        print ("Relcaulated embeddings... ")
    # Update the embeddings dictionary with the recalculated embeddings.
    updated_embeddings = {**node_embeddings, **recalculated_embeddings}

    # Remove embeddings for nodes that no longer exist in the graph.
    for node in merged_nodes:
        updated_embeddings.pop(node, None)
    if verbatim:
        print ("Now save graph... ")

    # Save the simplified graph to a file.
    graph_path = f'{data_dir_output}/{graph_root}_graphML_simplified.graphml'
    nx.write_graphml(new_graph, graph_path)

    if verbatim:
        print(f"Graph simplified and saved to {graph_path}")

    return new_graph, updated_embeddings

In [7]:
import pickle

output_directory = 'data_no_refine3'  

with open(f'{output_directory}/embeddings.pkl', 'rb') as f:
    existing_node_embeddings = pickle.load(f)


graph_path=f'{output_directory}/final_augmented_graph.graphml'
G_existing = nx.read_graphml(graph_path)

nodes = list(existing_node_embeddings.keys())
embeddings_matrix = np.array([np.array(existing_node_embeddings[node]).flatten() for node in nodes])

print(f"Len nodes: {len(nodes)} \nlen embeddings: {len(embeddings_matrix)}") 

#check if any node embeddings are None
if None in embeddings_matrix:
    print( "simplify_graph: None in embeddings")



Len nodes: 77039 
len embeddings: 77039


In [None]:
#gen

In [17]:
import tqdm as tqdm
try:

    #simplify to .95 threshold
    simplified_graph, simplified_embeddings = simplify_graph_v2(G_existing, existing_node_embeddings, embd, similarity_threshold=0.95, graph_root='0.95threshold', data_dir_output=output_directory, verbatim=True)
    
    with open(f'{output_directory}/0.95threshold_embeddings.pkl', 'wb') as f:
            pickle.dump(simplified_embeddings, f)

    res= graph_statistics_and_plots_for_large_graphs(
            simplified_graph, data_dir=output_directory,
            include_centrality=False, make_graph_plot=False,
            root='simple_graph')

    print("simple graph statistics:", res ) 
except Exception as e:
    print(f"Error simplifying graph: {e}")
    traceback.print_exc()





calculating KNN...
Start merge


Merging nodes...: 100%|██████████| 77039/77039 [00:00<00:00, 82362.08it/s]


Now relabel. 
New graph generated, nodes relabled. 


Recalculating embeddings: 100%|██████████| 14974/14974 [1:41:08<00:00,  2.47it/s]


Relcaulated embeddings... 
Now save graph... 
Graph simplified and saved to data_no_refine3/0.95threshold_graphML_simplified.graphml
simple graph statistics: ({'Number of Nodes': 37709, 'Number of Edges': 58504, 'Average Degree': 3.102919727386035, 'Density': 8.228810139455911e-05, 'Connected Components': 3171, 'Number of Communities': 3228}, False)


In [None]:
output_directory = "./data/generated_graphs/GR_no_refine3/"
with open(f'{output_directory}/0.95threshold_embeddings.pkl', 'rb') as f:
    simplified_embeddings = pickle.load(f)

simplified_graph = nx.read_graphml(f'{output_directory}/0.95threshold_graphML_simplified.graphml')

#simplify to .85 threshold
simplified_graph, simplified_embeddings = simplify_graph(simplified_graph, simplified_embeddings, embd, similarity_threshold=0.85, graph_root='0.85threshold', data_dir_output=output_directory)

print(f"-->number of embeddings: {len(simplified_embeddings)}")
simplified_embeddings = update_node_embeddings(simplified_graph, simplified_embeddings, embd, remove_embeddings_for_nodes_no_longer_in_graph=True)

print(f"-->number of embeddings after removed nodes: {len(simplified_embeddings)}")
with open(f'{output_directory}/0.85threshold_embeddings.pkl', 'wb') as f:
            pickle.dump(simplified_embeddings, f)

res= graph_statistics_and_plots_for_large_graphs(
            simplified_graph, data_dir=output_directory,
            include_centrality=False, make_graph_plot=False,
            root='simple_graph')

print("simple graph statistics:", res ) 

simple graph statistics: ({'Number of Nodes': 18933, 'Number of Edges': 33153, 'Average Degree': 3.5021391221676437, 'Density': 0.00018498516385842193, 'Connected Components': 510, 'Number of Communities': 639}, False)


In [None]:
#simplify to .75 threshold
simplified_graph, simplified_embeddings = simplify_graph(simplified_graph, simplified_embeddings, embd, similarity_threshold=0.75, graph_root='0.75threshold', data_dir_output=output_directory)

print(f"-->number of embeddings: {len(simplified_embeddings)}")
simplified_embeddings = update_node_embeddings(simplified_graph, simplified_embeddings, embd, remove_embeddings_for_nodes_no_longer_in_graph=True)

print(f"-->number of embeddings after removed nodes: {len(simplified_embeddings)}")

with open(f'{output_directory}/0.75threshold_embeddings.pkl', 'wb') as f:
            pickle.dump(simplified_embeddings, f)

res= graph_statistics_and_plots_for_large_graphs(
            simplified_graph, data_dir=output_directory,
            include_centrality=False, make_graph_plot=False,
            root='simple_graph')

print("simple graph statistics:", res ) 

simple graph statistics: ({'Number of Nodes': 13596, 'Number of Edges': 24305, 'Average Degree': 3.5753162694910268, 'Density': 0.00026298758878198064, 'Connected Components': 199, 'Number of Communities': 280}, False)
