In [1]:
def load_prompt(pth):
    with open(pth, 'r', encoding='utf-8') as file:
        content = file.read()
        return content    

def process_entity_extraction_prompt(content):
    DEFAULT_TUPLE_DELIMITER = "<|>"
    DEFAULT_RECORD_DELIMITER = "##"
    DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
    content = content.replace('{completion_delimiter}', DEFAULT_COMPLETION_DELIMITER)
    content = content.replace('{tuple_delimiter}', DEFAULT_TUPLE_DELIMITER)
    content = content.replace('{record_delimiter}', DEFAULT_RECORD_DELIMITER)
    return content
    
    
def load_document(folder_path):
    # Initialize a list to store the contents of each text file
    text_files_content = []
    # Loop through each file in the directory
    for filename in os.listdir(folder_path):
        # Check if the file is a text file
        if filename.endswith('.txt'):
            # Construct full file path
            file_path = os.path.join(folder_path, filename)
            # Open the file and read its contents
            with open(file_path, 'r', encoding='utf-8') as file:
                content = file.read()
                text_files_content.append(content)
    return text_files_content


def split_documents_into_chunks(documents, chunk_size=2000, overlap_size=400):
    """
    Splits documents into chunks of approximately chunk_size characters,
    ensuring each chunk ends at a sentence boundary with an overlap of overlap_size characters.
    
    :param documents: List of documents (strings) to be split.
    :param chunk_size: Approximate maximum size of each chunk in characters.
    :param overlap_size: Desired overlap between chunks in characters.
    :return: List of chunks.
    """
    chunks = []
    for document in documents:
        sentences = sent_tokenize(document)
        current_chunk = []
        current_length = 0
        i = 0
        while i < len(sentences):
            sentence = sentences[i]
            sentence_length = len(sentence) + 1  # +1 for space or punctuation
            if current_length + sentence_length > chunk_size:
                # Add the current chunk to chunks
                chunks.append(' '.join(current_chunk).strip())
                # Start new chunk with overlap
                overlap_sentences = []
                overlap_length = 0
                # Collect sentences for overlap from the end of current_chunk
                j = len(current_chunk) - 1
                while j >= 0 and overlap_length < overlap_size:
                    overlap_sentences.insert(0, current_chunk[j])
                    overlap_length += len(current_chunk[j]) + 1
                    j -= 1
                current_chunk = overlap_sentences
                current_length = overlap_length
            else:
                current_chunk.append(sentence)
                current_length += sentence_length
                i += 1
        if current_chunk:
            chunks.append(' '.join(current_chunk).strip())
    return chunks


def entity_type_extraction(LLM, prompt, document, domain_name):
    LLM_engine, LLM_setting = LLM
    prompt = prompt.replace('{domain}', domain_name)
    prompt = prompt.replace('{document}', document)

    messages=[
        {"role": "user", "content": prompt},
    ]    
    response = LLM_engine(messages, max_new_tokens=LLM_setting["entity_type_extraction_length"], 
                   pad_token_id=LLM_setting["pad_token_id"], do_sample = LLM_setting["do_sample"],
                    temperature=LLM_setting["temperature"])
    entities_types = response[0]['generated_text'][-1]['content']
    torch.cuda.empty_cache()
    return entities_types

In [2]:
import torch
from tqdm import tqdm

def extract_elements_from_chunks(LLM, chunks, extraction_prompt, max_gleanings=3):
    
    LLM_engine, LLM_setting = LLM
    CONTINUE_PROMPT = (
        "MANY entities and relationships were missed in the last extraction. "
        "Add them below using the same format:\n"
    )
    LOOP_PROMPT = (
        "It appears some entities and relationships may have still been missed. "
        "Answer YES | NO if there are still entities or relationships that need to be added.\n"
    )
    elements = []
    pbar = tqdm(enumerate(chunks), total=len(chunks))
    
    for index, chunk in pbar:
        pbar.set_description(f"Processing chunk {index+1} of {len(chunks)}")
        # Initialize the conversation with the initial extraction prompt
        current_prompt = extraction_prompt.replace('{input_text}', chunk)
        messages = [{"role": "user", "content": current_prompt}]
        try:
            with torch.no_grad():  # Prevent storing computation graph
                response = LLM_engine(messages, max_new_tokens=LLM_setting["entity_extraction_length"], 
                                      pad_token_id=LLM_setting["pad_token_id"], 
                                      do_sample = LLM_setting["do_sample"], temperature= LLM_setting["temperature"])
        except Exception as e:
            print(f"Error during initial extraction for chunk {index+1}: {e}")
            elements.append("")
            continue  # Skip to the next chunk
        
        # Extract the assistant's reply (adjust according to your LLM's response format)
        try:
            # Example for OpenAI-like response structure
            # entities_and_relations = response['choices'][0]['message']['content'].strip()
            # Adjust the above line based on your LLM's actual response format
            entities_and_relations = response[-1]['generated_text'][-1]['content'].strip()
        except (IndexError, KeyError) as e:
            print(f"Error parsing response for chunk {index+1}: {e}")
            elements.append("")
            continue
        
        results = entities_and_relations
        messages.append({"role": "assistant", "content": entities_and_relations})
        
        # Begin multi-glean checking loop
        for gleaning in range(max_gleanings):
            try:
                # Append CONTINUE_PROMPT to prompt for more entities
                messages.append({"role": "user", "content": CONTINUE_PROMPT})
                
                with torch.no_grad():
                    response = LLM_engine(messages, max_new_tokens=LLM_setting["entity_extraction_length"], 
                                      pad_token_id=LLM_setting["pad_token_id"], 
                                      do_sample = LLM_setting["do_sample"], temperature= LLM_setting["temperature"])
                # Extract new entities
                new_entities = response[-1]['generated_text'][-1]['content'].strip()
                results += '\n' + new_entities
                messages.append({"role": "assistant", "content": new_entities})
                
                # Check if this is the last iteration
                if gleaning >= max_gleanings - 1:
                    break
                
                # Append LOOP_PROMPT to check for remaining entities
                messages.append({"role": "user", "content": LOOP_PROMPT})
                
                with torch.no_grad():
                    response = LLM_engine(messages, max_new_tokens=LLM_setting["entity_condition_length"], 
                                      pad_token_id=LLM_setting["pad_token_id"], 
                                      do_sample = LLM_setting["do_sample"], temperature= LLM_setting["temperature"])               
                loop_response = response[-1]['generated_text'][-1]['content'].strip().upper()

                if "YES" not in loop_response:
                    break
                # Append the loop response
                messages.append({"role": "assistant", "content": loop_response})
            except Exception as e:
                print(f"Error during gleaning {gleaning+1} for chunk {index+1}: {e}")
                break  # Exit the gleaning loop on error
            
            finally:
                # Clear CUDA cache after each gleaning to free memory
                torch.cuda.empty_cache()
        
        # Append the final results for the chunk
        elements.append(results)
        
        # Clear messages list to free memory
        del messages
        torch.cuda.empty_cache()
    return elements

In [3]:
import html
import re
from typing import Any, Dict, Tuple, List, Optional
from collections.abc import Mapping
def clean_str(input: Any) -> str:
    """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
    # If we get non-string input, just give it back
    if not isinstance(input, str):
        return input

    result = html.unescape(input.strip())
    # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
    return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)


def _unpack_descriptions(data: Mapping) -> list[str]:
    value = data.get("description", None)
    return [] if value is None else value.split("\n")

def _unpack_source_ids(data: Mapping) -> list[str]:
    value = data.get("source_id", None)
    return [] if value is None else value.split(", ")

def collect_elements_relationship(elements):
    graph = nx.Graph()
    entitysss = {}
    for idx, group in enumerate(elements):
        chunk_id = str(idx)
        # Specific setting for Qwen2.5
        if "<##" in group:
            entities = group.split('<##')
        else:
            entities = group.split('##')
        entitysss[idx] = []
        for entity in entities:
            record = entity.strip()
            record = re.sub(r"^\(|\)$", "", record.strip())
            record = record.replace("< | >", "<|>")
            record_attributes = record.split('<|>')
            if "entity" in record_attributes[0] and len(record_attributes) >= 4:
                entity_name = clean_str(record_attributes[1].upper())
                entitysss[idx].append(entity_name)
                
                
                entity_type = clean_str(record_attributes[2].upper())
                entity_description = clean_str(record_attributes[3]) 
                if entity_name in graph.nodes():
                    node = graph.nodes[entity_name]
                    node["content"] = "\n".join(
                        list({
                            *_unpack_descriptions(node),
                            entity_description,
                        })
                    )
                    node["source_id"] = ", ".join(
                        list({
                            *_unpack_source_ids(node),
                            str(chunk_id),
                        })
                    )
                    node["type"] = (
                        entity_type if entity_type != "" else node["type"]
                    )
                else:
                    graph.add_node(
                        entity_name,
                        type=entity_type,
                        content=entity_description,
                        source_id=str(chunk_id),
                    )
            if ("relationship" in record_attributes[0] and len(record_attributes) >= 4):
                source_node = clean_str(record_attributes[1].upper())
                target_node = clean_str(record_attributes[2].upper())
                edge_description = clean_str(record_attributes[3])
                try:
                    weight = float(record_attributes[-1])
                except ValueError:
                    weight = 1.0
                if source_node not in graph.nodes():
                    graph.add_node(
                        source_node,
                        type="",
                        content="",
                        source_id=chunk_id,
                    )
                if target_node not in graph.nodes():
                    graph.add_node(
                        target_node,
                        type="",
                        content="",
                        source_id=chunk_id,
                    )
                    
                if graph.has_edge(source_node, target_node):
                    edge_data = graph.get_edge_data(source_node, target_node)
                    if edge_data is not None:
                        weight += edge_data["weight"]
                        edge_description = "\n".join(
                            list({
                                *_unpack_descriptions(edge_data),
                                edge_description,
                            })
                        )
                        chunk_id = ", ".join(
                            list({
                                *_unpack_source_ids(edge_data),
                                str(chunk_id),
                            })
                        )
                graph.add_edge(
                    source_node,
                    target_node,
                    weight=weight,
                    content=edge_description,
                    source_id=chunk_id,
                )                
    return graph, entitysss

In [4]:
import re
import nltk
from nltk.corpus import wordnet as wn
from nltk.stem import WordNetLemmatizer
from nltk.metrics import edit_distance
from difflib import SequenceMatcher
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

import numpy as np

# Ensure NLTK data is downloaded (only needed once)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)

lemmatizer = WordNetLemmatizer()

def description_similarity(desc1, desc2):
    """
    Computes cosine similarity between two descriptions using TF-IDF vectorization.
    Returns a score between 0 and 1.
    """
    # Additional normalization
    desc1 = desc1.lower().strip()
    desc2 = desc2.lower().strip()

    vectorizer = TfidfVectorizer(stop_words='english')
    vectors = vectorizer.fit_transform([desc1, desc2])
    cosine_sim = cosine_similarity(vectors[0:1], vectors[1:2])
    return float(cosine_sim[0][0])

def get_abbreviation(phrase):
    """
    Generates an abbreviation by taking the first letter of each word in the phrase.
    Lemmatizes the result to handle plural forms.
    """
    # Normalize phrase by removing extraneous punctuation and trimming spaces
    phrase = phrase.strip()
    words = re.split(r'[\s\-/]+', phrase)
    words = [w for w in words if w]  # remove empty strings
    abbreviation = ''.join(word[0].upper() for word in words if word)
    abbreviation_lemma = lemmatizer.lemmatize(abbreviation.lower())
    return abbreviation_lemma.upper()

def normalize_name(name):
    """
    Normalizes an entity name by:
    - Lowercasing
    - Removing non-alphanumeric characters
    - Tokenizing and lemmatizing each token
    - Optionally could remove stopwords if desired
    """
    # Remove non-alphanumeric characters
    name_clean = re.sub(r'\W+', ' ', name).lower().strip()
    tokens = nltk.word_tokenize(name_clean)
    lemmatized_tokens = [lemmatizer.lemmatize(token) for token in tokens]
    normalized_name = ' '.join(lemmatized_tokens)
    return normalized_name

def are_synonyms(word1, word2):
    """
    Checks if two words have overlapping synonym sets from WordNet.
    Returns True if synonym sets intersect, False otherwise.
    """
    synsets1 = wn.synsets(word1.lower())
    synsets2 = wn.synsets(word2.lower())
    if not synsets1 or not synsets2:
        return False

    # Get all lemmas for word1
    lemmas1 = set()
    for s in synsets1:
        for l in s.lemmas():
            lemmas1.add(l.name().lower())

    # Get all lemmas for word2
    lemmas2 = set()
    for s in synsets2:
        for l in s.lemmas():
            lemmas2.add(l.name().lower())

    # Check for intersection
    return not lemmas1.isdisjoint(lemmas2)

def string_similarity(name1, name2):
    """
    Computes similarity between two strings using:
    - Levenshtein-based similarity
    - Jaro-Winkler similarity (via difflib)
    Returns the average of both similarities.
    """
    name1 = name1.lower().strip()
    name2 = name2.lower().strip()

    if len(name1) == 0 or len(name2) == 0:
        return 0.0

    # Levenshtein Distance Similarity
    lev_distance = edit_distance(name1, name2)
    lev_similarity = 1 - lev_distance / max(len(name1), len(name2))

    # Jaro-Winkler Similarity
    jaro_similarity = SequenceMatcher(None, name1, name2).ratio()

    avg_similarity = (lev_similarity + jaro_similarity) / 2
    return avg_similarity

class NodeSimilarity:
    def __init__(self):
        pass  # No LLM initialization required

    def advanced_similarity(self, node1, node2, threshold=0.85):
        """
        Computes an advanced similarity score between two nodes represented as:
        node = [entity_name, entity_type, entity_description]

        Returns a binary decision (0 or 1) based on whether the computed similarity
        surpasses the given threshold. You can also return the continuous similarity
        score if desired.
        """
        entity_name1, entity_type1, entity_description1 = node1
        entity_name2, entity_type2, entity_description2 = node2

        # Normalize entity names
        name1_normalized = normalize_name(entity_name1)
        name2_normalized = normalize_name(entity_name2)

        # Compute abbreviations
        abbreviation1 = get_abbreviation(name1_normalized).lower().strip()
        abbreviation2 = get_abbreviation(name2_normalized).lower().strip()

        # Initial checks for abbreviation matches
        # Check if one name is the abbreviation of the other or vice versa
        if (name1_normalized == abbreviation2 or name1_normalized == abbreviation2+'s' or
            name2_normalized == abbreviation1 or name2_normalized == abbreviation1+'s'):
            return 1

        # Check synonyms
        if are_synonyms(name1_normalized, name2_normalized):
            return 1

        # Compute name string similarity
        name_sim = string_similarity(name1_normalized, name2_normalized)

        # Consider entity type: if types differ significantly, reduce effective name_sim
        # Example: If they are from completely different domains, we might trust name_sim less
        # Simple heuristic: If types are the same, trust name similarity more.
        if entity_type1 != entity_type2:
            name_sim = name_sim * 0.9  # slightly reduce similarity if types differ
        else:
            name_sim = name_sim * 1.1  # slightly boost if types are the same, ensure doesn't exceed 1 though
            if name_sim > 1:
                name_sim = 1

        # Description similarity
        desc_sim = description_similarity(entity_description1, entity_description2)

        # Combine similarities:
        # Here we can do a weighted combination. For example:
        # Weight name similarity more if entity types match
        combined_score = (name_sim * 0.6) + (desc_sim * 0.4)

        # If there's a strong name similarity or exact synonyms, we might return early as done above.
        # Otherwise, rely on combined score:
        if combined_score >= threshold:
            return 1

        return 0

In [5]:
import numpy as np
import networkx as nx
from tqdm import tqdm

def merge_node_attributes(graph, nodes_to_merge):
    """
    Merges attributes of a group of nodes into a single node attribute dictionary.
    Modify this function based on how you want to combine attributes.
    """
    # Example: If all nodes have 'type' and 'content', we can select the most common type,
    # and concatenate the content fields.
    types = [graph.nodes[n]['type'] for n in nodes_to_merge if 'type' in graph.nodes[n]]
    contents = [graph.nodes[n]['content'] for n in nodes_to_merge if 'content' in graph.nodes[n]]

    # Resolve 'type' by majority vote or first non-empty (custom logic here)
    if types:
        from collections import Counter
        type_counter = Counter(types)
        merged_type = type_counter.most_common(1)[0][0]
    else:
        merged_type = None

    # Concatenate all content (customize as needed)
    merged_content = " ".join(str(c) for c in contents if c)

    # You can add more logic for other attributes as needed

    merged_attr = {}
    if merged_type is not None:
        merged_attr['type'] = merged_type
    if merged_content:
        merged_attr['content'] = merged_content

    return merged_attr


def clean_graph(graph, similarity_model, SIMILARITY_THRESHOLD=0.85):
    # Compute pairwise similarities
    node_names = list(graph.nodes())
    num_nodes = len(node_names)
    similarity_matrix = np.zeros((num_nodes, num_nodes))

    for i in tqdm(range(num_nodes)):
        for j in range(i + 1, num_nodes):
            node_i = [
                node_names[i], 
                graph.nodes[node_names[i]].get('type', ''), 
                graph.nodes[node_names[i]].get('content', '')
            ]
            node_j = [
                node_names[j], 
                graph.nodes[node_names[j]].get('type', ''), 
                graph.nodes[node_names[j]].get('content', '')
            ]
            if graph.nodes[node_names[j]].get('type', '') != '' and graph.nodes[node_names[i]].get('type', '') != '':
                sim_score = similarity_model.advanced_similarity(node_i, node_j, threshold=SIMILARITY_THRESHOLD)
            else:
                sim_score = 0
            similarity_matrix[i, j] = sim_score
            similarity_matrix[j, i] = sim_score                

    # Identify pairs of nodes above threshold
    # We will treat nodes that are pairwise similar as edges in a similarity graph
    similarity_graph = nx.Graph()
    similarity_graph.add_nodes_from(node_names)

    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if similarity_matrix[i, j] >= SIMILARITY_THRESHOLD:
                # Add edge in similarity graph
                similarity_graph.add_edge(node_names[i], node_names[j])

    # Find connected components of similarity_graph - each component is a group of similar nodes
    components = list(nx.connected_components(similarity_graph))

    # If a component has only one node, no merge is needed
    # If multiple nodes are present, merge them into a single node
    # We'll create a new graph with merged nodes
    merged_graph = nx.Graph()
    # Keep track of old-to-new node mapping
    old_to_new = {}

    for comp in components:
        comp = list(comp)
        if len(comp) == 1:
            # Just copy the node as is
            n = comp[0]
            merged_graph.add_node(n, **graph.nodes[n])
            old_to_new[n] = n
        else:
            # Merge all nodes in the component into a single node
            merged_attr = merge_node_attributes(graph, comp)

            # Create a new representative node name - for example:
            # You could use a join of node names or pick the first node name
            # Here, let's pick the first node as representative
            representative = comp[0]
            merged_graph.add_node(representative, **merged_attr)

            # Update mapping
            for n in comp:
                old_to_new[n] = representative

    # Now, add edges to merged_graph:
    # For every original edge (u,v,data) in the old graph,
    # map u and v to their new representatives and add an edge if it doesn't exist.
    # If attributes need merging, handle similarly. For now, we assume no attribute conflict.
    for u, v, data in graph.edges(data=True):
        new_u = old_to_new[u]
        new_v = old_to_new[v]
        if new_u != new_v:
            # Combine edge attributes if needed, for simplicity just add if not present
            # If the edge already exists, you may want to merge attributes
            if merged_graph.has_edge(new_u, new_v):
                # If you need to merge edge attributes, do so here
                pass
            else:
                merged_graph.add_edge(new_u, new_v, **data)

    return merged_graph

In [6]:
import networkx as nx
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from pyvis.network import Network
from IPython.display import IFrame, display

def visualize(graph, filename='graph2.html'):
    # Step 1: Extract unique node types from the graph
    unique_types = set()
    for node, data in graph.nodes(data=True):
        node_type = data.get('type', 'default')
        unique_types.add(node_type)

    # Step 2: Assign colors to node types
    num_types = len(unique_types)
    cmap = cm.get_cmap('hsv', num_types)
    color_list = [mcolors.rgb2hex(cmap(i)) for i in range(cmap.N)]
    type_color_map = dict(zip(sorted(unique_types), color_list))

    # Determine if graph is directed
    directed = graph.is_directed()

    # Step 3: Create a pyvis Network
    net = Network(height='750px', width='100%', notebook=True, directed=directed)

    # Add nodes
    for node, data in graph.nodes(data=True):
        node_type = data.get('type', 'default')
        color = type_color_map.get(node_type, 'gray')
        net.add_node(
            node, 
            label=node, 
            title=data.get('content', ''), 
            color=color,
            **data
        )

    # Add edges
    for u, v, edata in graph.edges(data=True):
        # You can pass edge attributes as needed
        net.add_edge(u, v, **edata)

    # Customize the appearance
    net.repulsion(node_distance=200, central_gravity=0.3)
    net.toggle_physics(True)

    # Generate and display the network
    net.show(filename)

    # Display in Jupyter Notebook (if in such an environment)
    display(IFrame(filename, width='100%', height='750px'))

In [7]:
import json
import os

def save_meta_info(meta_info, file_path='meta_info.json'):
    """
    Save the meta_info dictionary to a JSON file.

    Args:
        meta_info (dict): The meta information to save.
        file_path (str): The path to the JSON file where data will be saved.
    """
    try:
        # Ensure the directory exists
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        
        # Open the file in write mode with UTF-8 encoding
        with open(file_path, 'w', encoding='utf-8') as f:
            # Serialize and write the dictionary to the file with indentation for readability
            json.dump(meta_info, f, ensure_ascii=False, indent=4)
        
        print(f"meta_info successfully saved to {file_path}")
    
    except TypeError as te:
        print("Serialization Error: Ensure all data in meta_info is JSON-serializable.")
        print(te)
    
    except Exception as e:
        print("An unexpected error occurred while saving meta_info:")
        print(e)
        
    
def load_meta_info(file_path='meta_info.json'):
    """
    Load the meta_info dictionary from a JSON file.

    Args:
        file_path (str): The path to the JSON file to load.

    Returns:
        dict: The loaded meta_info dictionary.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            meta_info = json.load(f)
        print(f"meta_info successfully loaded from {file_path}")
        return meta_info
    
    except FileNotFoundError:
        print(f"The file {file_path} does not exist.")
    
    except json.JSONDecodeError as jde:
        print("Error decoding JSON. Ensure the file is properly formatted.")
        print(jde)
    
    except Exception as e:
        print("An unexpected error occurred while loading meta_info:")
        print(e)

In [8]:
import os
from pathlib import Path

def load_and_process_documents(folder_path):
    """Load documents and split them into sentence-aware chunks."""
    documents = load_document(folder_path)
    chunks = split_documents_into_chunks(documents)
    return documents, chunks

def extract_meta_info(LLM, folder_path, domain_name='Source Code Vulnerability'):
    """Extract meta information from documents."""
    llm_sim_model = NodeSimilarity()
    extraction_prompt = load_prompt(EXTRACTION_PROMPT_PATH)
    processed_extraction_prompt = process_entity_extraction_prompt(extraction_prompt)
    type_extraction_prompt = load_prompt(TYPE_EXTRACTION_PROMPT_PATH)
    documents, chunks = load_and_process_documents(folder_path)

    # Extract entity types
    # Need more advanced techniques !!
    entity_type = entity_type_extraction(
        LLM,
        type_extraction_prompt,
        document='\n'.join(random.sample(chunks, 3)),
        domain_name=domain_name
    )
    # Update extraction prompt with entity types
    extraction_prompt = extraction_prompt.replace('{entity_types}', entity_type)
    
    print("Start elements extraction")
    # Extract elements from chunks
    elements = extract_elements_from_chunks(LLM, chunks, processed_extraction_prompt)
    # Collect relationships
    graph, _ = collect_elements_relationship(elements)
    
    # Clean the graph
    updated_graph = clean_graph(graph, llm_sim_model)
    
    # Compile meta information
    meta_info = { 
        "entity_type": entity_type,       # string
        "elements": elements,             # list
     }
    return meta_info, graph, updated_graph

In [9]:
import logging
import os
import networkx as nx
from pathlib import Path

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')

def load_indexing(folder_path, LLM, CHECKPOINT_PATH):
    # Ensure the checkpoint directory exists
    checkpoint_dir = CHECKPOINT_PATH.parent
    if not checkpoint_dir.exists():
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
    try:
        if CHECKPOINT_PATH.exists():
            # Load existing meta information and graphs
            logging.info(f"Checkpoint found at {CHECKPOINT_PATH}, loading existing meta information.")
            meta_info = load_meta_info(file_path=str(CHECKPOINT_PATH))
            
            org_graph_path = checkpoint_dir / 'org_graph.graphml'
            updated_graph_path = checkpoint_dir / 'update_graph.graphml'
            
            if not org_graph_path.exists() or not updated_graph_path.exists():
                raise FileNotFoundError("One or both of the graph checkpoint files are missing.")
            
            org_graph = nx.read_graphml(str(org_graph_path))
            updated_graph = nx.read_graphml(str(updated_graph_path))
            logging.info("Successfully loaded meta information and graphs from checkpoint.")
        else:
            # Extract and save new meta information if checkpoint doesn't exist
            logging.info("No checkpoint found. Extracting meta information from the provided folder and LLM.")
            meta_info, org_graph, updated_graph = extract_meta_info(LLM, folder_path)
            
            # Save the extracted information
            save_meta_info(meta_info, file_path=str(CHECKPOINT_PATH))
            nx.write_graphml(org_graph, str(checkpoint_dir / 'org_graph.graphml'))
            nx.write_graphml(updated_graph, str(checkpoint_dir / 'update_graph.graphml'))
            logging.info("Extracted and saved meta information and graphs to checkpoint.")
        
        return meta_info, org_graph, updated_graph

    except Exception as e:
        logging.error(f"An error occurred while loading or generating the meta information: {e}", exc_info=True)
        # Optionally, you can re-raise or return None
        raise

In [10]:
import networkx as nx
import html


def find_main_and_subgraphs(G: nx.Graph) -> Tuple[nx.Graph, List[nx.Graph]]:
    """识别主图和与主图无关的子图。"""
    # 获取所有连通组件，按大小降序排列
    connected_components = sorted(nx.connected_components(G), key=len, reverse=True)
    
    if not connected_components:
        return G, []
    
    # 主图是最大的连通组件
    main_component = connected_components[0]
    main_graph = G.subgraph(main_component).copy()
    
    # 其他子图
    subgraphs = [G.subgraph(c).copy() for c in connected_components[1:]]
    
    return main_graph, subgraphs

def run_leiden(
    graph: nx.Graph, args: Dict[str, Any]
) -> Tuple[Dict[int, Dict[int, List[str]]], Dict[int, int]]:
    """Run Leiden community detection on the given graph."""
    max_cluster_size = args.get("max_cluster_size", 10)
    use_lcc = args.get("use_lcc", True)
    node_id_to_community_map, community_hierarchy_map = _compute_leiden_communities(
        graph=graph,
        max_cluster_size=max_cluster_size,
        use_lcc=use_lcc,
        seed=args.get("seed", 0xDEADBEEF),
    )
    levels = args.get("levels")

    # If they don't pass in levels, use them all
    if levels is None:
        levels = sorted(node_id_to_community_map.keys())

    results_by_level: Dict[int, Dict[int, List[str]]] = {}
    for level in levels:
        result = {}
        results_by_level[level] = result
        for node_id, raw_community_id in node_id_to_community_map[level].items():
            community_id = raw_community_id
            if community_id not in result:
                result[community_id] = []
            result[community_id].append(node_id)
    return results_by_level, community_hierarchy_map


def _compute_leiden_communities(
    graph: nx.Graph,
    max_cluster_size: int,
    use_lcc: bool,
    seed=0xDEADBEEF,
) -> Tuple[Dict[int, Dict[str, int]], Dict[int, int]]:
    """Compute Leiden communities using hierarchical_leiden from graspologic."""
    from graspologic.partition import hierarchical_leiden

    if use_lcc:
        graph = stable_largest_connected_component(graph)
    

    community_mapping = hierarchical_leiden(
        graph, max_cluster_size=max_cluster_size, random_seed=seed
    )

    results: Dict[int, Dict[str, int]] = {}
    hierarchy: Dict[int, int] = {}
    for partition in community_mapping:
        results.setdefault(partition.level, {})
        results[partition.level][partition.node] = partition.cluster
        hierarchy[partition.cluster] = (
            partition.parent_cluster if partition.parent_cluster is not None else -1
        )
    return results, hierarchy


def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
    """Return the largest connected component of the graph in a stable manner."""
    from graspologic.utils import largest_connected_component

    graph = graph.copy()
    graph = largest_connected_component(graph)
#     graph = normalize_node_names(graph)
    return _stabilize_graph(graph)


def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph:
    """Normalize node names."""
    node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()}  # type: ignore
    return nx.relabel_nodes(graph, node_mapping)


def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
    """Ensure an undirected graph with the same relationships will always be read the same way."""
    fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()

    sorted_nodes = graph.nodes(data=True)
    sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])

    fixed_graph.add_nodes_from(sorted_nodes)
    edges = list(graph.edges(data=True))
    # If the graph is undirected, we create the edges in a stable way, so we get the same results
    # for example:
    # A -> B
    # in graph theory is the same as
    # B -> A
    # in an undirected graph
    # however, this can lead to downstream issues because sometimes
    # consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A]
    # but they base some of their logic on the order of the nodes, so the order ends up being important
    # so we sort the nodes in the edge in a stable way, so that we always get the same order
    if not graph.is_directed():

        def _sort_source_target(edge):
            source, target, edge_data = edge
            if source > target:
                temp = source
                source = target
                target = temp
            return source, target, edge_data

        edges = [_sort_source_target(edge) for edge in edges]

    def _get_edge_key(source: Any, target: Any) -> str:
        return f"{source} -> {target}"

    edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))

    fixed_graph.add_edges_from(edges)
    return fixed_graph

def subgraph_collecting(results_by_level, gtype = 'major'):
    graph_dic = {}
    if gtype == 'major':
        for main_key in results_by_level:
            for sec_key in results_by_level[main_key]:
                graph_dic[sec_key] = results_by_level[main_key][sec_key]
                
    else:
        for idx, sg in enumerate(results_by_level):
            graph_dic[idx] = list(sg.nodes())
    return graph_dic

In [11]:
def generate_node_description(nodes_list, G):
    """
    Generate descriptive strings for nodes and edges based on a given subgraph defined by nodes_list.

    Parameters
    ----------
    nodes_list : list
        A list of nodes for which we want to generate descriptions.
    G : networkx.Graph
        The full graph or a subgraph containing the nodes and edges. Nodes are expected to have a
        'content' attribute, and edges are expected to have a 'description' attribute.
    
    Returns
    -------
    node_descriptions : str
        A string containing descriptions of each node in `nodes_list`.
    edge_descriptions : str
        A string containing descriptions of each edge between the nodes in `nodes_list`.
    """
    
    # Build node descriptions based on node attributes
    node_descriptions = "\n\n".join(
        f"{node}: {G.nodes[node].get('content', '')}" 
        for node in nodes_list
    ) + "\n\n"

    # Identify valid neighbors (within nodes_list) and build edge descriptions
    edge_set = set()
    edge_lines = []

    # Iterate over each node in nodes_list to find edges within this induced subgraph
    for node in nodes_list:
        for neighbor in G.neighbors(node):
            if neighbor in nodes_list:
                
                # Sort pair to handle undirected graphs (avoid duplicate edges)
                # If the graph is directed, you can skip sorting or adjust logic accordingly
                edge_nodes = tuple(sorted([node, neighbor]))
                
                if edge_nodes not in edge_set:
                    edge_set.add(edge_nodes)

                    # Attempt to retrieve the edge description from G
                    # For undirected graphs, G[node][neighbor] == G[neighbor][node]
                    edge_description = G[node][neighbor].get('content', '')
                    
                    # Only add an entry if there's a description
                    if edge_description:
                        edge_key = f"{node}<|>{neighbor}"
                        edge_lines.append(f"{edge_key}: {edge_description}")

    edge_descriptions = "\n\n".join(edge_lines).strip()
    return node_descriptions.strip(), edge_descriptions

In [12]:
def agent_answer(LLM, cur_testcase, do_sample=True, temperature = 0.5, max_new_tokens=20, pad_token_id=None):
    if not pad_token_id:
        pad_token_id = LLM.tokenizer.eos_token_id
    messages = [
        {"role": "user", "content": cur_testcase},
    ]
    outputs = LLM(
        messages,
        max_new_tokens=max_new_tokens,
        do_sample = do_sample,
        temperature = temperature,
        pad_token_id=pad_token_id
    )
    res = outputs[0]["generated_text"][-1]
    torch.cuda.empty_cache()
    return res['content']

In [13]:
def get_leaf_nodes(community_hierarchy_map):
    # All nodes in the hierarchy
    all_nodes = set(community_hierarchy_map.keys())
    
    # All nodes that are parents (excluding -1)
    parent_nodes = set(parent for parent in community_hierarchy_map.values() if parent != -1)
    
    # Leaf nodes are those that are not parents of any node
    leaf_nodes = all_nodes - parent_nodes
    return leaf_nodes

In [14]:
from prompt import graph_merging_finestage
def extract_val(text):
    if text.isdigit():
        return float(text)
    else:
        try:
            numbers = re.findall(r'\d+\.?\d*', text) 
            numbers = [float(num) if '.' in num else int(num) for num in numbers]
            return numbers[0]
        except:
            return 0

    
def subgraph_matching(major_community, minor_community, prompt, LLM, G):
    LLM_engine, LLM_setting = LLM
    evaluation_form = {}
    for skey, min_graph in tqdm(minor_community.items()):
        evaluation_form[skey] = {}
        min_nodes_des, min_edges_des = generate_node_description(min_graph, G)
        for tkey, maj_graph in major_community.items():
            maj_nodes_des, maj_edges_des = generate_node_description(maj_graph, G)
            cur_prompt = prompt.format(snodes_info=min_nodes_des, sedges_info=min_nodes_des,
                                      tnodes_info=maj_nodes_des, tedges_info=maj_edges_des)
            
            
            res = agent_answer(LLM_engine, cur_prompt, 
                               do_sample=LLM_setting["do_sample"], 
                               max_new_tokens=LLM_setting["subgraph_matching_length"],
                                temperature=LLM_setting["temperature"])

            evaluation_form[skey][tkey] = res
    return evaluation_form

In [15]:
from prompt import summarize_descriptions
def update_entity_description(updated_graph, prompt, LLM):
    LLM_engine, LLM_setting = LLM
    graph = updated_graph.copy()
    nodes_list = graph.nodes()
    for node in tqdm(nodes_list):
        node_info = graph.nodes[node]
        cur_content = node_info["content"]
        if cur_content:
            cur_prompt = prompt.format(description_list=cur_content, entity_name=node)
            updated_decription = agent_answer(LLM_engine, cur_prompt, 
                                  do_sample=LLM_setting["do_sample"], max_new_tokens=LLM_setting["update_entity_description"],
                                  temperature=LLM_setting["temperature"])
            node_info["content"] = updated_decription
    return graph

In [16]:
import json
import os
import networkx as nx
from transformers import pipeline
from tqdm import tqdm
import torch

def merge_discrete_nodes_into_major_graph(updated_graph, LLM):
    """
    如果本地已有 `merge_results.json` 则直接读取，不再重新进行计算和保存；
    否则正常进行计算、保存并返回结果。

    Parameters:
    - updated_graph (networkx.Graph): 待合并的图对象
    - LLM: 用于子图匹配的语言模型

    Returns:
    - A dictionary 包含:
        "concat_graph": 合并后的图
        "non_relevant_keys": 未被合并的键列表
        "G2G_finematch": 小社区到大社区匹配结果
        "major_community": 大社区结构
        "minor_community": 小社区结构
    """

    # 如果本地有结果文件，直接读取返回
    json_path = 'checkpoint/merge_results.json'
    graph_path = 'checkpoint/concat_graph.graphml'
    if os.path.exists(json_path) and os.path.exists(graph_path):
        with open(json_path, 'r') as f:
            loaded_data = json.load(f)
        # 重新读取图
        concat_graph = nx.read_graphml(graph_path)
        return {
            "concat_graph": concat_graph,
            "non_relevant_keys": loaded_data.get("non_relevant_keys", []),
            "G2G_finematch": loaded_data.get("G2G_finematch", {}),
            "major_community": loaded_data.get("major_community", {}),
            "minor_community": loaded_data.get("minor_community", {})
        }
    
    # update graph description
    print("Start Updating the Description")
    updated_graph = update_entity_description(updated_graph, prompt=summarize_descriptions.SUMMARIZE_PROMPT, LLM=LLM)
    NLI_pipe = pipeline("text-classification", model="sileod/deberta-v3-base-tasksource-nli")
    main_graph, minor_subgraphs = find_main_and_subgraphs(updated_graph)
    max_group = max(len(min_graph.nodes()) for min_graph in minor_subgraphs)
    seg_args = {
        "max_cluster_size": max_group,
        "use_lcc": True,
        "seed": 0xDEADBEEF
    }

    results_by_level, community_hierarchy_map = run_leiden(updated_graph, seg_args)
    leaf_nodes = get_leaf_nodes(community_hierarchy_map)
    major_community = subgraph_collecting(results_by_level, gtype='major')
    major_community_leaf = {key: major_community[key] for key in leaf_nodes}
    minor_community = subgraph_collecting(minor_subgraphs, gtype='minor')
    print("Start Graph Corase-Stage Merging")
    merge_prompt = graph_merging_finestage.Graph_Merging_PROMPT_FINE
    graph_merging_res = subgraph_matching(
        major_community_leaf, 
        minor_community, 
        merge_prompt, 
        LLM, 
        updated_graph
    )

    G2G_finematch = {}
    for mkey, score_list in graph_merging_res.items():
        filter_score = {key: extract_val(score) for key, score in score_list.items()}
        max_value = max(filter_score.values())
        max_index = [key for key, value in filter_score.items() if value == max_value]
        G2G_finematch[mkey] = max_index

    non_relevant_keys = []
    concat_graph = updated_graph.copy()
    print("Start Graph Fine-Stage Merging")
    for min_key, min_nodes in tqdm(minor_community.items(), desc="Merging Nodes"):
        related_major_group = G2G_finematch.get(min_key, [])
        maj_nodes = [leaf for key in related_major_group for leaf in major_community_leaf.get(key, [])]
        marker = 0
        
        for snode in min_nodes:
            for tnode in maj_nodes:
                snode_info = updated_graph.nodes[snode].get('content')
                tnode_info = updated_graph.nodes[tnode].get('content')
                
                if snode_info and tnode_info:
                    label_forward = NLI_pipe([{"text": snode_info, "text_pair": tnode_info}])[0]
                    label_backward = NLI_pipe([{"text": tnode_info, "text_pair": snode_info}])[0]
                    flabel = label_forward.get('label', '')
                    blabel = label_backward.get('label', '')

                    if 'entailment' in flabel.lower() or 'entailment' in blabel.lower():
                        marker = 1
                        concat_graph.add_edge(
                            snode,
                            tnode,
                            weight=1,
                            content="High relevant edges based on hypothesis",
                            source_id=""
                        )
        if marker == 0:
            non_relevant_keys.append(min_key)

    nx.write_graphml(concat_graph, graph_path)
    
    result_dict = {
        "concat_graph": concat_graph,
        "non_relevant_keys": non_relevant_keys,
        "G2G_finematch": G2G_finematch,
        "major_community": major_community,
        "minor_community": minor_community
    }

    serializable_dict = {
        "non_relevant_keys": non_relevant_keys,
        "G2G_finematch": G2G_finematch,
        "major_community": major_community,
        "minor_community": minor_community
    }

    with open(json_path, 'w') as f:
        json.dump(serializable_dict, f, indent=4)

    return result_dict

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
import pandas as pd
def obtain_node_edge_df(G):
    nodes_list = list(G.nodes())
    node_to_id = {node: i for i, node in enumerate(nodes_list)}

    # 2. Create node DataFrame with unique numerical IDs
    nodes_data = []
    for node in nodes_list:
        data = G.nodes[node]
        nodes_data.append({
            "node_id": node_to_id[node],
            "node_name": node,
            "type": data.get("type", None),
            "content": data.get("content", None)
        })

    nodes_df = pd.DataFrame(nodes_data)

    # 3. Create edge DataFrame with unique numerical IDs
    edges_data = []
    for edge_id, (u, v, data) in enumerate(G.edges(data=True)):
        edges_data.append({
            "edge_id": edge_id,
            "source_node_id": node_to_id[u],
            "target_node_id": node_to_id[v],
            "source_node": u,
            "target_node": v,
            "weight": data.get("weight", None),
            "content": data.get("content", None),
            "source_id": data.get("source_id", None)
        })

    edges_df = pd.DataFrame(edges_data)
    return nodes_df, edges_df

In [18]:
def obtain_node_info(node_name, df):
    info = df.loc[df["node_name"]==node_name].iloc[0].to_dict()
    info['content'] = info['content'].replace("Here is a comprehensive summary of the data:\n\n", "")
    return info
    
def obtain_edge_info(source: any, target: any, df: pd.DataFrame):
    required_columns = {'source_node', 'target_node'}
    if not required_columns.issubset(df.columns):
        raise KeyError(f"DataFrame must include: {required_columns}")

    if source == target:
        filtered_df = df[(df['source_node'] == source) & (df['target_node'] == target)]
    else:
        filtered_df = df[
            ((df['source_node'] == source) & (df['target_node'] == target)) |
            ((df['source_node'] == target) & (df['target_node'] == source))
        ]
    if not filtered_df.empty:
        result = filtered_df.to_dict(orient='records')
        return result[0]
    else:
        return None

def estimiate_token_length(sequence_info, LLM):
    sequence_info = [str(i) for i in sequence_info]
    overall_info = ','.join(sequence_info)
    length = len(LLM.tokenizer.encode(overall_info, add_special_tokens=False))
    return length

In [19]:
import pandas as pd
from tqdm import tqdm

def generate_leaf_report(G, community_dict, leaf_nodes, length_limit, community_report_dict, prompt, LLM, nodes_df, edges_df):
    """
    Generate a report for each leaf community. This function:
    1. Extracts subgraphs corresponding to leaf communities.
    2. Prioritizes edges by sum of degrees of their endpoints.
    3. Iteratively adds node and edge info until the length limit is reached.
    4. Sends the compiled information to an LLM to generate a community report.

    Parameters
    ----------
    G : networkx.Graph
        The main graph.
    community_dict : dict
        A dictionary mapping community identifiers to the list of nodes in that community.
    leaf_nodes : list
        A list of leaf community identifiers.
    length_limit : int
        Maximum allowed token length for the context.
    community_report_dict : dict
        Dictionary to store the resulting reports.
    prompt : str
        Prompt template string, which will be formatted with the input text before sending to LLM.
    LLM : object
        Language model or API for generating text.
    nodes_df : pd.DataFrame
        DataFrame containing node information.
    edges_df : pd.DataFrame
        DataFrame containing edge information.

    Returns
    -------
    dict
        Updated community_report_dict with generated reports for the leaf communities.
    """
    LLM_engine, LLM_setting = LLM
    for community in tqdm(leaf_nodes):
        
        cur_nodes = community_dict[community]
        subG = G.subgraph(cur_nodes)

        # Build priority list based on node degrees
        priority_list = {}
        for source, target in subG.edges:
            edge_weight = G.degree(source) + G.degree(target)
            priority_list[(source, target)] = edge_weight

        # Sort edges by priority (descending)
        sorted_priority_list = sorted(priority_list.items(), key=lambda x: x[1], reverse=True)

        context_size = 0

        cur_node_info = []
        cur_edge_info = []
        existing_node_ids = set()
        existing_edge_ids = set()

        # Iterate over edges in priority order, adding nodes and edges if token length allows
        for edge, _ in sorted_priority_list:
            source, target = edge

            # Add source node info if not already added
            if context_size <= length_limit and source not in existing_node_ids:
                source_info = obtain_node_info(source, nodes_df)
                key_info = [source_info['node_id'], source_info['node_name'], source_info['content']]
                cur_len = estimiate_token_length(key_info, LLM_engine)
                if context_size + cur_len <= length_limit:
                    context_size += cur_len
                    cur_node_info.append(source_info)
                    existing_node_ids.add(source)

            # Add target node info if not already added
            if context_size <= length_limit and target not in existing_node_ids:
                target_info = obtain_node_info(target, nodes_df)
                key_info = [target_info['node_id'], target_info['node_name'], target_info['content']]
                cur_len = estimiate_token_length(key_info, LLM_engine)
                if context_size + cur_len <= length_limit:
                    context_size += cur_len
                    cur_node_info.append(target_info)
                    existing_node_ids.add(target)

            # Add edge info if not already added
            if context_size <= length_limit and edge not in existing_edge_ids:
                edge_info = obtain_edge_info(source, target, edges_df)
                key_info = [edge_info['edge_id'], edge_info['source_node'], edge_info['target_node'], edge_info['content']]
                cur_len = estimiate_token_length(key_info, LLM_engine)
                if context_size + cur_len <= length_limit:
                    context_size += cur_len
                    cur_edge_info.append(edge_info)
                    existing_edge_ids.add(edge)

        # Convert gathered info into CSV-like strings
        tmp_node_df = pd.DataFrame(cur_node_info)[["node_id", "node_name", "content"]]
        tmp_node_df = tmp_node_df.rename(columns={'node_id': 'id', 'node_name': 'entity', 'content': 'description'})
        node_content = tmp_node_df.to_csv(index=False, sep=',', encoding='utf-8').replace('"', '')

        tmp_edge_df = pd.DataFrame(cur_edge_info)[["edge_id", "source_node", "target_node", "content"]]
        tmp_edge_df = tmp_edge_df.rename(columns={"edge_id": "id", "source_node": "source",
                                                  "target_node": "target", "content": "description"})
        edge_content = tmp_edge_df.to_csv(index=False, sep=',', encoding='utf-8').replace('"', '')

        # Prepare input data for the LLM
        input_data = f"Entities\n{node_content}Relationships\n{edge_content}"

        # Generate the community report using the LLM
        report_text = agent_answer(LLM_engine, prompt.format(input_text=input_data), 
                                   do_sample= LLM_setting["do_sample"], max_new_tokens=LLM_setting["community_sum_length"],
                                  temperature=LLM_setting["temperature"])

        # Store the results in the community_report_dict
        community_report_dict[community] = {
            'node_info': node_content,
            'edge_info': edge_content,
            'len_info': len(node_content) + len(edge_content),
            'community_report': report_text,
            'community_report_len': estimiate_token_length([report_text], LLM_engine)
        }

    return community_report_dict


def generate_parent_report(G, community_dict, children_map, length_limit, community_report_dict, prompt, LLM):
    """
    Generate a report for parent communities based on their child communities' reports.
    If the aggregated length exceeds the limit, some child community reports are replaced
    by their summarized forms.

    Parameters
    ----------
    G : networkx.Graph
        The main graph.
    community_dict : dict
        Dictionary mapping community IDs to node lists.
    children_map : dict
        Dictionary mapping a parent community to its child communities.
    length_limit : int
        Maximum allowed token length.
    community_report_dict : dict
        Dictionary containing previously generated community reports for child communities.
    prompt : str
        Prompt template for LLM that will be formatted with the input data.
    LLM : object
        The language model or API used for generating text.

    Returns
    -------
    dict
        Updated community_report_dict with parent community reports.
    """

    # Process parents from bottom-up
    LLM_engine, LLM_setting = LLM
    for community in tqdm(list(children_map.keys())[::-1]):
        child_communities = children_map[community]

        # If no child communities, skip processing
        if not child_communities:
            continue

        # Extract the reports for all child communities
        parent_info = {key: community_report_dict[key] for key in child_communities if key in community_report_dict}

        # Calculate the total length from all children
        overall_len_info = {key: parent_info[key]['len_info'] for key in parent_info}
        cur_length = sum(overall_len_info.values())

        # Lists to hold final node and edge CSV contents
        content_node_list = []
        content_edge_list = []
        # List to hold summarized (shortened) community reports when length exceeds limit
        pre_report_list = []

        if cur_length > length_limit:
            # Sort children by length of their content in descending order
            sorted_children = sorted(overall_len_info.items(), key=lambda x: x[1], reverse=True)

            # Identify children to convert to summaries until we fall below the length limit
            convert_keys = []
            iter_idx = 0
            while cur_length > length_limit and iter_idx < len(sorted_children):
                cur_key = sorted_children[iter_idx][0]
                # Add current key to conversion list
                convert_keys.append(cur_key)
                # Calculate new length after "converting" current key to a summary
                # Converted keys contribute only their report length (since we'll use their summary instead of node/edge info)
                converted_len = sum(parent_info[k]['community_report_len'] for k in convert_keys)
                remaining_keys = [k for k in overall_len_info.keys() if k not in convert_keys]
                remaining_len = sum(overall_len_info[k] for k in remaining_keys)
                cur_length = converted_len + remaining_len
                iter_idx += 1
            # Now that some keys might be converted:
            original_keys = [k for k in overall_len_info.keys() if k not in convert_keys]

            # Add summaries of converted keys (replace their node/edge details with their report text)
            for key in convert_keys:
                pre_report_list.append(parent_info[key]['community_report'])

            # Add the node and edge content of the original (non-converted) keys
            for key in original_keys:
                content_node_list.append(parent_info[key]['node_info'])
                content_edge_list.append(parent_info[key]['edge_info'])
        else:
            # If length is already within the limit, use all child info as-is
            for key in parent_info:
                content_node_list.append(parent_info[key]['node_info'])
                content_edge_list.append(parent_info[key]['edge_info'])

        def merge_csv_contents(csv_list):
            """
            Merge multiple CSV strings into one, removing duplicate headers.
            Assumes the first line of each CSV is a header line.
            """
            final_content = []
            for idx, csv_content in enumerate(csv_list):
                lines = csv_content.strip().split('\n')
                # Keep header only from the first CSV
                if idx == 0:
                    final_content.extend(lines)
                else:
                    # Skip the header line for subsequent CSVs
                    if len(lines) > 1:
                        final_content.extend(lines[1:])
                    # If only a header line was present, ignore it
            return "\n".join(final_content)

        content_node_final = merge_csv_contents(content_node_list)
        content_edge_final = merge_csv_contents(content_edge_list)

        # Prepare input data for LLM
        if pre_report_list:
            # If there are summaries, prepend them as "Relevant Community"
            community_content = "\n".join(pre_report_list)
            input_data = (
                f"Relevant Community\n{community_content}\n"
                f"Entities\n{content_node_final}\n"
                f"Relationships\n{content_edge_final}"
            )
        else:
            input_data = f"Entities\n{content_node_final}\nRelationships\n{content_edge_final}"

        # Generate the parent community report
        report_text = agent_answer(LLM_engine, prompt.format(input_text=input_data), 
                                   do_sample= LLM_setting["do_sample"], max_new_tokens=LLM_setting["community_sum_length"],
                                  temperature=LLM_setting["temperature"])
        if pre_report_list:
            community_report_dict[community] = {
                'node_info': content_node_final,
                'edge_info': content_edge_final,
                'len_info': len(content_node_final) + len(content_edge_final),
                'contained_community':community_content,
                'community_report': report_text,
                'community_report_len': estimiate_token_length([report_text], LLM_engine)
            }            
            
        else:
            # Store the parent community report
            community_report_dict[community] = {
                'node_info': content_node_final,
                'edge_info': content_edge_final,
                'len_info': len(content_node_final) + len(content_edge_final),
                'community_report': report_text,
                'community_report_len': estimiate_token_length([report_text], LLM_engine)
            }

    return community_report_dict

In [20]:
from collections import defaultdict
import logging
from prompt import community_report

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def generate_community_report(
    concat_info: Dict[str, Any],
    results_by_level: Dict[Any, Dict[Any, Any]],
    community_hierarchy_map: Dict[Any, Any],
    nodes_df: pd.DataFrame,
    edges_df: pd.DataFrame,
    LLM: Any,
    LLM_args: Dict[str, Any],
    length_limit: int
) -> Dict[Any, Any]:
    """
    Generate a comprehensive community report based on hierarchical community data.

    Args:
        concat_info (Dict[str, Any]): Contains concatenated graph information.
        results_by_level (Dict[Any, Dict[Any, Any]]): Community results organized by levels.
        community_hierarchy_map (Dict[Any, Any]): Mapping of child to parent community IDs.
        nodes_df (Any): DataFrame containing node information.
        edges_df (Any): DataFrame containing edge information.
        LLM (Any): Language model instance for report generation.
        LLM_args (Dict[str, Any]): Arguments for the language model.
        length_limit (int): Maximum length for the generated reports.

    Returns:
        Dict[Any, Any]: Dictionary containing reports for each community.
    """
    
    def is_leaf(node: Any, children_map: Dict[Any, list]) -> bool:
        return not children_map.get(node)

    try:
        # Build community_dict from results_by_level
        community_dict = {
            community_id: details
            for level in results_by_level
            for community_id, details in results_by_level[level].items()
        }
        logger.debug(f"Constructed community_dict with {len(community_dict)} communities.")

        # Build children_map and identify root nodes
        children_map = defaultdict(list)
        root_nodes = []
        for child, parent in community_hierarchy_map.items():
            if parent == -1:
                root_nodes.append(child)
            else:
                children_map[parent].append(child)
        logger.debug(f"Identified {len(root_nodes)} root nodes.")

        # Identify all leaf communities
        leaf_nodes = [node for node in community_hierarchy_map if is_leaf(node, children_map)]
        logger.info(f"Found {len(leaf_nodes)} leaf nodes.")

        # Extract the main graph
        G = concat_info.get('concat_graph')
        if G is None:
            logger.error("concat_graph not found in concat_info.")
            raise ValueError("concat_graph is required in concat_info.")

        # Initialize community_report_dict
        community_report_dict = {}

        # Generate reports for leaf communities
        logger.info("Generating leaf communities report.")
        community_report_dict = generate_leaf_report(
            G=G,
            community_dict=community_dict,
            leaf_nodes=leaf_nodes,
            length_limit=length_limit,
            community_report_dict=community_report_dict,
            prompt=community_report.COMMUNITY_REPORT_PROMPT,
            LLM=(LLM, LLM_args),
            nodes_df=nodes_df,
            edges_df=edges_df
        )

        # Generate reports for parent communities
        logger.info("Generating parent communities report.")
        community_report_dict = generate_parent_report(
            G=G, 
            community_dict=community_dict, 
            children_map=children_map, 
            length_limit=length_limit, 
            community_report_dict=community_report_dict,
            prompt=community_report.COMMUNITY_REPORT_PROMPT,
            LLM=(LLM, LLM_args)
        )

        logger.info("Community report generation completed successfully.")
        return community_report_dict

    except Exception as e:
        logger.exception("An error occurred while generating the community report.")
        raise e

In [21]:
def load_community(
    path: str = 'checkpoint',
    concat_info: Dict[str, Any]= None,
    nodes_df: Any = None,          
    edges_df: Any = None,          
    llm_config: Optional[Dict[str, Any]] = None,
    seg_args: Optional[Dict[str, Any]] = None,
    meta_info_filename: str = 'community_meta_info.json',
    length_limit: int = 1000        
) -> Dict[str, Any]:
    
    LLM = llm_config.get("engine")
    LLM_args = llm_config.get("settings", {})    
    """
    Load community meta information from a file if it exists; otherwise, perform community segmentation,
    generate reports, and save the meta information.

    Args:
        path (str, optional): Directory path to check for the meta info file. Defaults to 'checkpoint'.
        concat_info (Dict[str, Any]): Contains concatenated graph information.
        results_by_level (Optional[Dict[Any, Dict[Any, Any]]], optional): 
            Community results organized by levels. Defaults to None.
        community_hierarchy_map (Optional[Dict[Any, Any]], optional): 
            Mapping of child to parent community IDs. Defaults to None.
        nodes_df (Any, optional): DataFrame containing node information. Defaults to None.
        edges_df (Any, optional): DataFrame containing edge information. Defaults to None.
        LLM (Any, optional): Language model instance for report generation. Defaults to None.
        LLM_args (Optional[Dict[str, Any]], optional): Arguments for the language model. Defaults to None.
        seg_args (Optional[Dict[str, Any]], optional): Arguments for the Leiden algorithm. Defaults to None.
        meta_info_filename (str, optional): Filename for the meta info. Defaults to 'meta_info.json'.
        length_limit (int, optional): Maximum length for the generated reports. Defaults to 1000.

    Returns:
        Dict[str, Any]: Dictionary containing community meta information.
    """
    try:
        meta_info_file = os.path.join(path, meta_info_filename)
        logger.debug(f"Constructed meta_info_file path: {meta_info_file}")

        if os.path.exists(meta_info_file):
            logger.info(f"Loading community meta info from {meta_info_file}.")
            community_meta_info = load_meta_info(file_path=meta_info_file)
            logger.info("Community meta info loaded successfully.")
        else:
            logger.info("Meta info file not found. Running community segmentation using Leiden algorithm.")
            
            # Validate presence of 'concat_graph' in concat_info
            if 'concat_graph' not in concat_info:
                logger.error("'concat_graph' key is missing from concat_info.")
                raise KeyError("'concat_graph' must be present in concat_info.")
            
            # Run Leiden algorithm to get community results
            results_by_level, community_hierarchy_map = run_leiden(
                concat_info['concat_graph'], 
                seg_args
            )
            logger.info("Community segmentation completed successfully.")

            # Generate community reports
            logger.info("Generating community reports.")
            community_report_dict = generate_community_report(
                concat_info=concat_info,
                results_by_level=results_by_level,
                community_hierarchy_map=community_hierarchy_map,
                nodes_df=nodes_df,
                edges_df=edges_df,
                LLM=LLM,
                LLM_args=LLM_args,
                length_limit=length_limit
            )
            logger.info("Community reports generated successfully.")

            # Compile meta information
            community_meta_info = {
                "results_by_level": results_by_level, 
                "community_hierarchy_map": community_hierarchy_map,
                "community_report_dict": community_report_dict,
            }
            logger.debug(f"Compiled community_meta_info: {community_meta_info.keys()}")

            # Ensure the directory exists
            os.makedirs(path, exist_ok=True)
            logger.debug(f"Ensured the directory '{path}' exists.")

            logger.info(f"Saving community meta info to {meta_info_file}.")
            save_meta_info(community_meta_info, file_path=meta_info_file)
            logger.info("Community meta info saved successfully.")

        return community_meta_info

    except FileNotFoundError as fnf_error:
        logger.exception(f"File not found error: {fnf_error}")
        raise
    except KeyError as key_error:
        logger.exception(f"Key error: {key_error}")
        raise
    except Exception as e:
        logger.exception("An unexpected error occurred while loading or generating community meta information.")
        raise e

In [22]:
import logging
from typing import Tuple, Dict, Any
from prompt import missing_prediction

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def flat_dict(target_dict: Dict[str, Any]) -> str:
    """
    Flattens a dictionary into a string with each key-value pair on a new line.

    Args:
        target_dict (Dict[str, Any]): The dictionary to flatten.

    Returns:
        str: The flattened string.
    """
    return "\n".join(f"{key}: {value}" for key, value in target_dict.items())


def obtain_neighbor_info(
    graph: nx.Graph, 
    node: str, 
    max_neighbor: Optional[int] = 8
) -> Tuple[str, str]:
    """
    Retrieves and formats neighbor node information and edge information for a given node.

    Args:
        graph: The graph data structure.
        node (str): The node identifier.

    Returns:
        Tuple[str, str]: A tuple containing flattened neighbor node information and edge information.
    """
    try:
        neighbors = list(graph.neighbors(node))
    except Exception as e:
        logger.error(f"Error retrieving neighbors for node '{node}': {e}")
        return "", ""

    if max_neighbor is not None:
        # Sort neighbors by degree in descending order and select top `max_neighbor`
        neighbors_sorted = sorted(neighbors, key=lambda n: graph.degree(n), reverse=True)
        selected_neighbors = neighbors_sorted[:max_neighbor]
        neighbors = selected_neighbors
        
    neighbor_info = {}
    edge_info = {}
    for nei_node in neighbors:
        neighbor_info[nei_node] = graph.nodes[nei_node]
        edge_data = graph.get_edge_data(node, nei_node)
        if edge_data and 'content' in edge_data:
            edge_info[f"{node}-{nei_node}"] = edge_data['content']
        else:
            logger.warning(f"Missing 'content' for edge '{node}-{nei_node}'.")
            edge_info[f"{node}-{nei_node}"] = "No content available."

    return flat_dict(neighbor_info), flat_dict(edge_info)


def fill_missing_node(graph, all_node_types: str, llm_config: Dict[str, Any]) -> Dict[str, str]:
    """
    Fills in missing node types and content by generating them using an LLM.

    Args:
        graph: The graph data structure.
        all_node_types (str): A string listing all possible node types.
        llm_config (Dict[str, Any]): Configuration for the LLM, including engine and settings.

    Returns:
        Dict[str, str]: A dictionary mapping node identifiers to their filled content.
    """
    llm_engine = llm_config.get("engine")
    llm_settings = llm_config.get("settings", {})
    filled_nodes = {}
    overall_missing_nodes = [node for node in graph.nodes() if graph.nodes[node].get('content') == ""]
    sidx = 1
    for node in graph.nodes():
        node_info = graph.nodes[node]
        node_type = node_info.get('type')
        node_content = node_info.get('content')

        if not node_type or not node_content:
            neighbor_nodes_info, neighbor_edges_info = obtain_neighbor_info(graph, node)

            try:
                prompt = missing_prediction.NODE_MISSING_PROMPT.format(
                    cur_node=node,
                    neighbor_Node=neighbor_nodes_info,
                    edge_info=neighbor_edges_info,
                    all_node_type=all_node_types
                )
            except KeyError as e:
                logger.error(f"Missing placeholder in NODE_MISSING_PROMPT: {e}")
                continue

            entity_text = agent_answer(
                LLM=llm_engine,
                cur_testcase=prompt,
                do_sample=llm_settings.get("do_sample", False),
                max_new_tokens=llm_settings.get("max_new_tokens", 150),
                temperature=llm_settings.get("temperature", 0)
            )
            if entity_text:
                filled_nodes[node] = entity_text
                logger.info(f"{sidx}/{len(overall_missing_nodes)}:Filled node '{node}' with content.")
                sidx += 1
            else:
                logger.warning(f"No content returned for node '{node}'.")

    return filled_nodes


def fill_missing_edge(graph, llm_config: Dict[str, Any]) -> Dict[str, str]:
    """
    Fills in missing edge descriptions by generating them using an LLM.

    Args:
        graph: The graph data structure.
        llm_config (Dict[str, Any]): Configuration for the LLM, including engine and settings.

    Returns:
        Dict[str, str]: A dictionary mapping edge identifiers to their filled content.
    """
    llm_engine = llm_config.get("engine")
    llm_settings = llm_config.get("settings", {})
    filled_edges = {}
    sidx = 1
    overall_missing_edges = [edge for edge in graph.edges() if graph.get_edge_data(edge[0], edge[1]).get('content') == "High relevant edges based on hypothesis"]
    
    for edge in graph.edges():
        source, target = edge
        edge_data = graph.get_edge_data(source, target)

        # Check if the edge content indicates it needs to be filled
        if edge_data.get('content') == "High relevant edges based on hypothesis":
            source_node = graph.nodes[source]
            target_node = graph.nodes[target]
            source_nodes_info, source_edges_info = obtain_neighbor_info(graph, source, max_neighbor=llm_settings["filling_max_neighbor"])
            target_nodes_info, target_edges_info = obtain_neighbor_info(graph, target, max_neighbor=llm_settings["filling_max_neighbor"])
            try:
                prompt = missing_prediction.EDGE_MISSING_PROMPT.format(
                    source=source,
                    target=target,
                    source_info=source_node,
                    target_info=target_node,
                    source_neighbor_info=source_nodes_info,
                    source_edge_info=source_edges_info,
                    target_neighbor_info=target_nodes_info,
                    target_edge_info=target_edges_info
                )
            except KeyError as e:
                logger.error(f"Missing placeholder in EDGE_MISSING_PROMPT: {e}")
                continue

            edge_text = agent_answer(
                LLM=llm_engine,
                cur_testcase=prompt,
                do_sample=llm_settings.get("do_sample", False),
                max_new_tokens=llm_settings.get("max_new_tokens", 150),
                temperature=llm_settings.get("temperature", 0)
            )
            if edge_text:
                # need to improve this spliter
                filled_edges[f"{source}-{target}"] = edge_text
                logger.info(f"{sidx}/{len(overall_missing_edges)}:Filled edge '{source}-{target}' with content.")
                sidx+=1
            else:
                logger.warning(f"No content returned for edge '{source}-{target}'.")

    return filled_edges

In [23]:
def extract_json(text: str) -> dict:
    """
    Extracts JSON content from a given text string and validates required fields.

    Args:
        text (str): The text containing JSON.

    Returns:
        dict: The extracted JSON as a dictionary if valid, else empty dict.
    """
    try:
        # Regex to find JSON within code blocks
        code_block_pattern = r"```json\s*(\{.*?\})\s*```"
        match = re.search(code_block_pattern, text, re.DOTALL)
        
        if match:
            json_str = match.group(1)
        else:
            # If no code block, attempt to find JSON directly
            json_start = text.find('{')
            json_end = text.rfind('}')
            if json_start == -1 or json_end == -1:
                raise ValueError("No JSON object found in the response.")
            json_str = text[json_start:json_end+1]
        
        # Parse the JSON string
        data = json.loads(json_str)
        
        # Validate required fields
        required_fields = ['Type', 'Content'] if 'Type' in json_str else ['Content']
        for field in required_fields:
            if field not in data:
                logger.warning(f"Missing '{field}' in the extracted JSON.")
                data[field] = "N/A"  # or handle as needed
        
        return data
    except json.JSONDecodeError as e:
        logger.error(f"JSON decoding failed: {e}")
        return {}
    except Exception as e:
        logger.error(f"Error extracting JSON: {e}")
        return {}

In [24]:
def fill_missing_graph(
    graph: nx.Graph,
    llm_config: Dict[str, Any],
    folder: str = "checkpoint",
    save_path: str = "missing_meta_info.json"
) -> nx.Graph:
    """
    Fills missing node and edge information in the graph using an LLM and caches the results.

    Args:
        graph (nx.Graph): The original graph.
        llm_config (Dict[str, Any]): Configuration for the LLM, including engine and settings.
        folder (str, optional): Directory to save/load the meta information. Defaults to "checkpoint".
        save_path (str, optional): Filename for the meta information. Defaults to "missing_meta_info.json".

    Returns:
        nx.Graph: The updated graph with missing node and edge information filled.
    """
    # Create a copy of the graph to avoid mutating the original
    updated_graph = graph.copy()
    
    # Ensure the folder exists
    os.makedirs(folder, exist_ok=True)
    
    # Build the full file path
    loc = os.path.join(folder, save_path)
    
    # Initialize meta_info
    meta_info = {}
    
    if os.path.exists(loc):
        meta_info = load_meta_info(file_path=loc)
    else:
        # Fill missing nodes and edges
        filled_node_info = fill_missing_node(
            graph=updated_graph,
            all_node_types=llm_config.get('entity_type', ''),
            llm_config=llm_config
        )
        filled_edge_info = fill_missing_edge(
            graph=updated_graph,
            llm_config=llm_config
        )
        meta_info = {
            'missing_node': filled_node_info,
            'missing_edge': filled_edge_info
        }
        save_meta_info(meta_info, file_path=loc)


    # Extract filled node and edge information from meta_info
    filled_node_info = meta_info.get('missing_node', {})
    filled_edge_info = meta_info.get('missing_edge', {})
    
    # Update nodes with filled information
    for node, node_json in filled_node_info.items():
        node_content = extract_json(node_json)
        if node_content:
            types = node_content.get('Type')
            content = node_content.get('Content')
            if types:
                updated_graph.nodes[node]['type'] = types
            if content:
                updated_graph.nodes[node]['content'] = content
        else:
            # Fallback if JSON extraction failed
            updated_graph.nodes[node]['content'] = filled_node_info[node]
            logger.warning(f"Failed to extract JSON for node '{node}'. Updated 'content' with raw data.")
    
    # Update edges with filled information
    for edge_key, edge_json in filled_edge_info.items():

        source, target = (edge_key.split('"-"')[0], edge_key.split('"-"')[1])
        source += '"'
        target = '"' + target
        if updated_graph.has_edge(source, target):
            edge_content = extract_json(edge_json)
            if edge_content and 'Content' in edge_content:
                
                updated_graph.edges[source, target]['content'] = edge_content['Content']
            else:
                updated_graph.edges[source, target]['content'] = edge_json
                logger.warning(f"Failed to extract 'Content' for edge '{source}-{target}'. Updated 'content' with raw data.")
        else:
            logger.warning(f"Edge '{source}-{target}' does not exist in the graph.")

    return updated_graph

In [25]:
def locate_community(level_dict, node_name):
    for cid in level_dict:
        if node_name in level_dict[cid]:
            return cid

def load_community_info(level_info, nodes_list):
    layer_com = []
    for node_name in nodes_list:
        cid = locate_community(level_info, node_name)
        if cid:
            layer_com.append(cid)
        else:
            layer_com.append('None')
    return layer_com

def extract_community_json(text: str) -> dict:
    """
    Extracts JSON content from a given text string and validates required fields.

    Args:
        text (str): The text containing JSON.

    Returns:
        dict: The extracted JSON as a dictionary if valid, else empty dict.
    """
    try:
        # Regex to find JSON within code blocks
        code_block_pattern = r"```json\s*(\{.*?\})\s*```"
        match = re.search(code_block_pattern, text, re.DOTALL)
        
        if match:
            json_str = match.group(1)
        else:
            # If no code block, attempt to find JSON directly
            json_start = text.find('{')
            json_end = text.rfind('}')
            if json_start == -1 or json_end == -1:
                raise ValueError("No JSON object found in the response.")
            json_str = text[json_start:json_end+1]
        
        # Parse the JSON string
        data = json.loads(json_str)
        return data
    except json.JSONDecodeError as e:
        logger.error(f"JSON decoding failed: {e}")
        return text
    except Exception as e:
        logger.error(f"Error extracting JSON: {e}")
        return text
    
    
def check_level(key, community_levels):
    for level in community_levels:
        if key in community_levels[level]:
            return level
        
def process_communities(nodes_df, community_meta_info):
    """
    Processes community information and updates the nodes DataFrame with community levels.
    Also creates a DataFrame containing community details.

    Parameters:
    - nodes_df (pd.DataFrame): DataFrame containing node information with a 'node_name' column.
    - community_meta_info (dict): Dictionary containing community metadata.

    Returns:
    - tuple: Updated nodes_df and community_df DataFrames.
    """
    
    # Extract node names as a list for processing
    nodes_list = nodes_df["node_name"].tolist()

    # Define the community levels to process (e.g., Level 1 and Level 2)
    community_levels = community_meta_info.get('results_by_level', {})
    
    # Iterate over each level to load and assign community information
    for level_str, level_data in community_levels.items():
        try:
            # Convert level to integer for naming (e.g., '0' -> 1)
            level_int = int(level_str) + 1
            community_column = f'Community_Level_{level_int}'
            
            # Load community information using the provided function
            community_info = load_community_info(level_data, nodes_list)
            # Assign the community information to the corresponding column in nodes_df
            nodes_df[community_column] = community_info
        except (ValueError, KeyError) as e:
            # Handle potential errors in level conversion or data loading
            print(f"Error processing level {level_str}: {e}")
            nodes_df[community_column] = None  # Assign NaN or a default value
            

    # Process community keys sorted numerically
    community_report_dict = community_meta_info.get('community_report_dict', {})
    community_hierarchy_map = community_meta_info.get('community_hierarchy_map', {})
    
    # Sort community keys as integers
    try:
        sorted_keys = sorted(community_report_dict.keys(), key=lambda k: int(k))
    except ValueError as e:
        print(f"Error sorting community keys: {e}")
        sorted_keys = list(community_report_dict.keys())  # Fallback to unsorted keys

    # Prepare data for the community DataFrame
    community_data = {
        "community_key": [],
        "content": [],
        "parent": [],
        "level":[]
    }

    for key in sorted_keys:
        try:
            # Convert key to integer
            key_int = int(key)
            cur_level = check_level(key, community_levels)
            # Extract and process community report
            community_report = community_report_dict[key].get('community_report', {})
            content = extract_community_json(community_report)
            
            # Extract parent information from hierarchy map
            parent = community_hierarchy_map.get(key, None)
            
            # Append data to the lists
            community_data["community_key"].append(key_int)
            # have parsing error in query stage, therefore using org form
            community_data["content"].append(community_report)
            community_data["parent"].append(parent)
            community_data["level"].append(cur_level)
        except (ValueError, KeyError, TypeError) as e:
            # Handle missing or malformed data
            print(f"Error processing community key {key}: {e}")
            community_data["community_key"].append(None)
            community_data["content"].append(None)
            community_data["parent"].append(None)

    # Create the community DataFrame
    community_df = pd.DataFrame(community_data)

    return nodes_df, community_df

In [26]:
import os
import torch
import networkx as nx
from tqdm import tqdm
from transformers import pipeline
import nltk
from sentence_transformers import SentenceTransformer, util
import torch
import re
import random
from nltk.tokenize import sent_tokenize
nltk.download('punkt')
LLM = pipeline("text-generation", "meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\hughj\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
2024-12-23 19:42:09,609 [INFO] We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:12<00:00,  3.15s/it]


In [27]:
# fill the missing edge/node?
# Missing node and edge description

# leidan settings 
seg_args = {
    "max_cluster_size": 10,
    "use_lcc": True,
    "seed": 0xDEADBEEF
}
# Set a length limit for input token context.
length_limit = 7000
LLM_args = {"do_sample":False, "do_sample":0, 
            "entity_type_extraction_length":100,
            "pad_token_id":LLM.tokenizer.eos_token_id,
            "entity_extraction_length":2000,
            "entity_condition_length":10,
            "update_entity_description":500,
            "subgraph_matching_length": 10,
            "community_sum_length":1000,
            "max_new_tokens":600,
            "temperature":0, "filling_max_neighbor":5}

llm_config = {
    "engine": LLM,
    "settings": LLM_args,
    "chunk_size": 2000,
    "overlap_size":400 # overlap tokens between chunks
}

In [28]:
CHECKPOINT_PATH = Path("checkpoint/meta_info.json")
EXTRACTION_PROMPT_PATH = Path('entity_extraction_prompt.txt')
TYPE_EXTRACTION_PROMPT_PATH = Path('entity_type_extraction_prompt.txt')
folder_path = 'input'
meta_info, org_graph, updated_graph = load_indexing(folder_path, (LLM, LLM_args), CHECKPOINT_PATH=CHECKPOINT_PATH)
llm_config['entity_type'] = meta_info['entity_type']
# Improvement by merging discreate node
concat_info = merge_discrete_nodes_into_major_graph(updated_graph, (LLM, LLM_args))

2024-12-23 19:42:22,627 [INFO] Checkpoint found at checkpoint\meta_info.json, loading existing meta information.
2024-12-23 19:42:22,767 [INFO] Successfully loaded meta information and graphs from checkpoint.


meta_info successfully loaded from checkpoint\meta_info.json


In [None]:
# full-fill the missing entry/relationship based on neighbor nodes
# Node name is not accurate need re-fine
new_graph = fill_missing_graph(concat_info['concat_graph'], llm_config)
concat_info['concat_graph'] = new_graph

In [None]:
# obtain node dataframe
nodes_df, edges_df = obtain_node_edge_df(concat_info['concat_graph'])

In [None]:
community_meta_info = load_community(
    concat_info = concat_info,
    nodes_df = nodes_df,
    edges_df = edges_df, 
    llm_config = llm_config,
    seg_args = seg_args,
    length_limit = length_limit)

In [None]:
updated_nodes_df, community_df = process_communities(nodes_df, community_meta_info)

In [None]:
updated_nodes_df.to_csv('checkpoint/nodes_info.csv', index=False)
community_df.to_csv('checkpoint/community_info.csv', index=False)
edges_df.to_csv('checkpoint/edge_info.csv', index=False)