In [1]:
import os
import torch
import networkx as nx
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
import re
import faiss
import pickle
import pandas as pd
import logging
import numpy as np
import secrets
from typing import List, Tuple, Dict, Union, Optional, Any
from transformers import pipeline, Pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_faiss_index_and_metadata(
    save_folder: str = "checkpoint",
    target_type: str = "node",
    metadata_output_path: str = "metadata.pkl"
) -> (faiss.Index, list):
    embedding_output_path = f"faiss_{target_type}_index.index"
    # Load FAISS index
    index_path = os.path.join(save_folder, embedding_output_path)
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"FAISS index file not found at '{index_path}'.")
    print(f"Loading FAISS index from '{index_path}'...")
    index = faiss.read_index(index_path)
    return index

In [3]:
def agent_answer(LLM, input_info, message_mode=False, do_sample=True, temperature = 0.5, max_new_tokens=20, pad_token_id=None):
    if not pad_token_id:
        pad_token_id = LLM.tokenizer.eos_token_id
        
    if message_mode:
        messages = input_info
    else:
        messages = [
            {"role": "user", "content": input_info},
        ]
    outputs = LLM(
        messages,
        max_new_tokens=max_new_tokens,
        do_sample = do_sample,
        temperature = temperature,
        pad_token_id=pad_token_id
    )
    res = outputs[0]["generated_text"][-1]
    torch.cuda.empty_cache()
    return res['content']

In [4]:


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    handlers=[
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def load_llm_model(
    model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct",
    task: str = "text-generation",
    torch_dtype: torch.dtype = torch.bfloat16,
    device_map: str = "auto"
) -> Pipeline:
    """
    Load a Language Model (LLM) using HuggingFace's pipeline.

    Args:
        model_name (str): The identifier of the model to load.
        task (str): The task for the pipeline (e.g., "text-generation").
        torch_dtype (torch.dtype): The data type for the model's tensors.
        device_map (str): Specifies how to allocate model layers on devices.

    Returns:
        Pipeline: An instance of HuggingFace's Pipeline configured for the specified task.
    """
    try:
        logger.info(f"Loading LLM model '{model_name}' for task '{task}' with dtype '{torch_dtype}' and device_map '{device_map}'.")
        llm = pipeline(
            task=task,
            model=model_name,
            torch_dtype=torch_dtype,
            device_map=device_map
        )
        logger.info("LLM model loaded successfully.")
        return llm
    except Exception as e:
        logger.error(f"Error loading LLM model '{model_name}': {e}")
        raise

def load_embedding_model(
    model_name: str = "stella_en_400M_v5",
    trust_remote_code: bool = True,
    device: Optional[str] = None
) -> SentenceTransformer:
    """
    Load a SentenceTransformer model.

    Args:
        model_name (str): The identifier of the SentenceTransformer model to load.
        trust_remote_code (bool): Whether to trust remote code from the model repository.
        device (Optional[str]): The device to load the model on ('cuda', 'cpu', etc.). If None, auto-detects.

    Returns:
        SentenceTransformer: An instance of SentenceTransformer loaded with the specified model.
    """
    try:
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            logger.info(f"CUDA availability detected. Using device '{device}' for embedding model.")
        else:
            logger.info(f"Using specified device '{device}' for embedding model.")

        logger.info(f"Loading embedding model '{model_name}' with trust_remote_code={trust_remote_code}.")
        embedding_model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code)
        embedding_model = embedding_model.to(device)
        logger.info("Embedding model loaded successfully.")
        return embedding_model
    except Exception as e:
        logger.error(f"Error loading embedding model '{model_name}': {e}")
        raise

In [5]:
def encode_sentence(
    model: SentenceTransformer,
    sentence: Union[str, List[str]],
    device: str = 'cuda'  # Default device
) -> np.ndarray:
    """
    Encode a sentence or a list of sentences into embeddings using the provided model.

    Args:
        model (SentenceTransformer): The sentence transformer model to use for encoding.
        sentence (Union[str, List[str]]): A single sentence or a list of sentences to encode.
        device (str, optional): The device to perform encoding on ('cpu' or 'cuda'). Defaults to 'cuda'.

    Returns:
        np.ndarray: The resulting embeddings as a NumPy array of type float32.

    Raises:
        ValueError: If the input sentence is neither a string nor a list of strings.
        RuntimeError: If encoding fails due to model or device issues.
    """
    try:
        # Validate input type
        if isinstance(sentence, str):
            sentences = [sentence]
        elif isinstance(sentence, list) and all(isinstance(s, str) for s in sentence):
            sentences = sentence
        else:
            raise ValueError("Input 'sentence' must be a string or a list of strings.")
        
        logger.info(f"Encoding {len(sentences)} sentence(s) on device '{device}'.")
        
        # Perform encoding
        embeddings = model.encode(
            sentences,
            device=device,
            convert_to_numpy=True,
            show_progress_bar=False  # Disable progress bar for cleaner logs
        ).astype(np.float32)
        torch.cuda.empty_cache()
        # Ensure consistent shape for single and multiple sentences
        if len(embeddings.shape) == 1:
            embeddings = embeddings.reshape(1, -1)
        logger.info(f"Encoding successful. Shape of embeddings: {embeddings.shape}.")
        return embeddings
    except Exception as e:
        logger.error(f"Failed to encode sentence(s): {e}")
        raise RuntimeError(f"Encoding failed: {e}") from e

In [6]:
def num_tokens(text: str, tokenizer) -> int:
    """
    Return the number of tokens in a string using the specified tokenizer.

    Args:
        text (str): The input string.
        tokenizer: Tokenizer to use for encoding the text.

    Returns:
        int: Number of tokens in the string.
    """
    return len(tokenizer.encode(text, add_special_tokens=False))

def expand_query(query: str, reports: List[str], llm_config: Dict) -> Tuple[str, Dict[str, int]]:
    """
    Expand the query using a random community report template.

    Args:
        query (str): The original search query.
        reports (List[str]): List of report templates to base the expansion on.
        llm_config (Dict): Configuration for the LLM, including engine and settings.

    Returns:
        Tuple[str, Dict[str, int]]: Expanded query text and the token usage details.
    """
    # Validate inputs
    if not reports:
        log.warning("Reports list is empty. Returning the original query.")
        return query, {"llm_calls": 0, "prompt_tokens": 0, "output_tokens": 0}

    if "engine" not in llm_config or "settings" not in llm_config:
        log.error("Invalid LLM configuration. Missing 'engine' or 'settings'.")
        return query, {"llm_calls": 0, "prompt_tokens": 0, "output_tokens": 0}

    # Extract LLM engine and settings
    LLM_engine = llm_config["engine"]
    LLM_settings = llm_config["settings"]

    # Select a random template
    template = secrets.choice(reports)

    # Construct the prompt
    prompt = (
        f"Create a hypothetical answer to the following query: {query}\n\n"
        f"Format it to follow the structure of the template below:\n\n"
        f"{template}\n"
        "Ensure that the hypothetical answer does not reference new named entities "
        "that are not present in the original query."
    )

    # Generate the text using the LLM
    try:
        text = agent_answer(
            LLM_engine,
            prompt,
            do_sample=LLM_settings["do_sample"],
            max_new_tokens=LLM_settings["max_new_tokens"],
            temperature=LLM_settings["temperature"]
        )
    except Exception as e:
        log.error("LLM generation failed: %s", str(e))
        return query, {"llm_calls": 0, "prompt_tokens": 0, "output_tokens": 0}

    # Token usage tracking
    prompt_tokens = num_tokens(prompt, LLM_engine.tokenizer)
    output_tokens = num_tokens(text, LLM_engine.tokenizer)
    token_ct = {
        "llm_calls": 1,
        "prompt_tokens": prompt_tokens,
        "output_tokens": output_tokens,
    }

    # Handle empty LLM response
    if not text.strip():
        log.warning("Failed to generate expansion for query: %s", query)
        return query, token_ct

    return text, token_ct

In [7]:
def faiss_index_to_numpy(index):
    """
    Converts a Faiss IndexFlatL2 object to a NumPy array.

    Parameters:
    - index (faiss.IndexFlatL2): The Faiss index to convert.

    Returns:
    - numpy.ndarray: A 2D NumPy array where each row is a vector from the index.
    """
    # Ensure the index is an instance of IndexFlatL2
    if not isinstance(index, faiss.IndexFlatL2):
        raise TypeError("The provided index is not an instance of faiss.IndexFlatL2")

    # Get the dimensionality and total number of vectors
    d = index.d
    ntotal = index.ntotal

    # Pre-allocate a NumPy array to hold all vectors
    vectors = np.zeros((ntotal, d), dtype='float32')

    # Reconstruct each vector and store it in the NumPy array
    for i in range(ntotal):
        index.reconstruct(i, vectors[i])

    return vectors
def build_faiss_index(embeddings: np.ndarray, index_type: str = "IndexFlatIP") -> faiss.Index:
    """
    Build a FAISS index from document embeddings.

    Args:
        embeddings (np.ndarray): 2D array of normalized document embeddings.
        index_type (str, optional): Type of FAISS index. Defaults to "IndexFlatIP".

    Returns:
        faiss.Index: FAISS index.
    """
    try:
        dimension = embeddings.shape[1]
        if index_type == "IndexFlatIP":
            index = faiss.IndexFlatIP(dimension)
            logger.info("Using FAISS IndexFlatIP for inner product similarity.")
        elif index_type == "IndexFlatL2":
            index = faiss.IndexFlatL2(dimension)
            logger.info("Using FAISS IndexFlatL2 for L2 distance similarity.")
        else:
            raise ValueError(f"Unsupported index type: {index_type}")

        index.add(embeddings)
        logger.info(f"FAISS index built with {index.ntotal} vectors.")
        return index
    except Exception as e:
        logger.error(f"Error building FAISS index: {e}")
        raise

def normalize_embeddings(embeddings: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    # Avoid division by zero
    norms[norms == 0] = 1
    normalized_embeddings = embeddings / norms
    logger.info("Normalized embeddings to unit vectors.")
    return normalized_embeddings

def compute_cosine_similarity(
    query_embedding: np.ndarray,
    index: faiss.Index,
    top_k: int = 5
) -> Tuple[np.ndarray, np.ndarray]:
    try:
        # Normalize the query embedding
        query_norm = np.linalg.norm(query_embedding)
        if query_norm == 0:
            logger.warning("Query embedding has zero norm. Returning empty results.")
            return np.array([]), np.array([])

        normalized_query = query_embedding / query_norm
        normalized_query = normalized_query.astype(np.float32)
        distances, indices = index.search(normalized_query, top_k)
        logger.info(f"Retrieved top-{top_k} similar documents.")
        return distances, indices
    except Exception as e:
        logger.error(f"Error computing cosine similarity: {e}")
        raise

def build_context(query_embedding, embed_data, target_df, idx_col = 'community_key', top_k = 5):
    document_embeddings = faiss_index_to_numpy(embed_data)
    normalized_document_embeddings = normalize_embeddings(document_embeddings)
    faiss_index = build_faiss_index(normalized_document_embeddings, index_type="IndexFlatIP")
    distances, indices = compute_cosine_similarity(query_embedding, faiss_index, top_k=top_k)
    if indices.size == 0:
        logger.warning("No similar documents found.")
        return pd.DataFrame()

    # Validate indices
    max_index = len(target_df) - 1
    if np.any(indices > max_index) or np.any(indices < 0):
        raise ValueError("Some indices are out of bounds of the community dataframe.")

    selected_nodes = target_df[target_df[idx_col].isin(indices[0])].reset_index(drop=True)
    return selected_nodes

In [8]:
from prompt import drift_search_system_prompt

import re
import json

def extract_json(text):
    """
    Extracts the JSON content from a markdown code block.

    Args:
        text (str): The input text containing the JSON code block.

    Returns:
        str: The extracted JSON string.

    Raises:
        ValueError: If no JSON code block is found.
    """
    pattern = r"```json\s*(\{.*?\})\s*```"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1)
    else:
        raise ValueError("JSON code block not found.")

def clean_json(json_str):
    """
    Cleans the extracted JSON string to make it valid for parsing.
    Specifically handles multi-line strings with triple quotes.

    Args:
        json_str (str): The raw JSON string extracted from the code block.

    Returns:
        str: The cleaned JSON string.
    """
    # Replace triple quotes with escaped double quotes
    json_str = json_str.replace('"""', '"')

    # Use regex to find the value of "intermediate_answer" and escape necessary characters
    def replace_multiline(match):
        content = match.group(1)
        # Escape backslashes and double quotes
        content = content.replace('\\', '\\\\').replace('"', '\\"')
        # Replace newlines with escaped newline characters
        content = content.replace('\n', '\\n')
        return f'"intermediate_answer": "{content}"'

    json_str = re.sub(
        r'"intermediate_answer":\s*"([^"]*)"', 
        replace_multiline, 
        json_str, 
        flags=re.DOTALL
    )
    return json_str

def parse_json(cleaned_json_str):
    return json.loads(cleaned_json_str)

def obtain_task_json(text):
    raw_json = extract_json(text)
    cleaned_json = clean_json(raw_json)
    data = parse_json(cleaned_json)
    return data

def decompose_query(
        query: str, reports: pd.DataFrame, llm_config: dict
    ) -> tuple[dict, dict[str, int]]:
        """
        Decompose the query into subqueries based on the fetched global structures.

        Args:
            query (str): The original search query.
            reports (pd.DataFrame): DataFrame containing community reports.

        Returns
        -------
        tuple[dict, int, int]: Parsed response and the number of prompt and output tokens used.
        """
        LLM_engine = llm_config["engine"]
        LLM_settings = llm_config["settings"]
        community_reports = "\n\n".join(reports["content"].tolist())
        prompt = drift_search_system_prompt.DRIFT_PRIMER_PROMPT.format(
            query=query, community_reports=community_reports
        )
        text = agent_answer(
            LLM_engine,
            prompt,
            do_sample=LLM_settings["do_sample"],
            max_new_tokens=LLM_settings["max_new_tokens"],
            temperature=LLM_settings["temperature"]
        )
        token_ct = {
            "llm_calls": 1,
            "prompt_tokens": num_tokens(prompt, LLM_engine.tokenizer),
            "output_tokens": num_tokens(text, LLM_engine.tokenizer),
        }
        return text, token_ct

In [9]:
node_df = pd.read_csv('checkpoint/nodes_info.csv')
edge_df = pd.read_csv('checkpoint/edge_info.csv')
community_df = pd.read_csv('checkpoint/community_info.csv')

In [10]:
node_df['content'] = [i.replace("Here is a comprehensive summary of the data:\n\n", "") for i in node_df['content'].tolist()]

In [11]:
node_description_embed = load_faiss_index_and_metadata(target_type='node')
edge_description_embed = load_faiss_index_and_metadata(target_type='edge')
community_description_embed = load_faiss_index_and_metadata(target_type='graph')

Loading FAISS index from 'checkpoint\faiss_node_index.index'...
Loading FAISS index from 'checkpoint\faiss_edge_index.index'...
Loading FAISS index from 'checkpoint\faiss_graph_index.index'...


In [12]:
LLM = load_llm_model()
embedding_model = load_embedding_model()

2024-12-22 22:33:31,611 [INFO] __main__: Loading LLM model 'meta-llama/Meta-Llama-3-8B-Instruct' for task 'text-generation' with dtype 'torch.bfloat16' and device_map 'auto'.
2024-12-22 22:33:31,902 [INFO] accelerate.utils.modeling: We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 4/4 [00:06<00:00,  1.55s/it]
2024-12-22 22:33:38,463 [INFO] __main__: LLM model loaded successfully.
2024-12-22 22:33:38,464 [INFO] __main__: CUDA availability detected. Using device 'cuda' for embedding model.
2024-12-22 22:33:38,464 [INFO] __main__: Loading embedding model 'stella_en_400M_v5' with trust_remote_code=True.
2024-12-22 22:33:38,466 [INFO] sentence_transformers.SentenceTransformer: Use pytorch device_name: cuda
2024-12-22 22:33:38,467 [INFO] sentence_transformers.Sentenc

In [13]:
def extract_community_json(text: str) -> dict:
    """
    Extracts JSON content from a given text string and validates required fields.

    Args:
        text (str): The text containing JSON.

    Returns:
        dict: The extracted JSON as a dictionary if valid, else empty dict.
    """
    # Regex to find JSON within code blocks
    code_block_pattern = r"```json\s*(\{.*?\})\s*```"
    match = re.search(code_block_pattern, text, re.DOTALL)

    if match:
        json_str = match.group(1)
    else:
        # If no code block, attempt to find JSON directly
        json_start = text.find('{')
        json_end = text.rfind('}')
        if json_start == -1 or json_end == -1:
            raise ValueError("No JSON object found in the response.")
        json_str = text[json_start:json_end+1]

    # Parse the JSON string
    data = json.loads(json_str)
    return data

In [14]:
from collections import Counter
import math
class DriftAction:
    """
    Represent an action containing a query, answer, score, and follow-up actions.

    This class encapsulates action strings produced by the LLM in a structured way.
    """
    def __init__(
        self,
        query: str,
        answer: str | None = None,
        follow_ups: list["DriftAction"] | None = None,
    ):
        """
        Initialize the DriftAction with a query, optional answer, and follow-up actions.

        Args:
            query (str): The query for the action.
            answer (Optional[str]): The answer to the query, if available.
            follow_ups (Optional[list[DriftAction]]): A list of follow-up actions.
        """
        self.query = query
        self.answer: str | None = answer  # Corresponds to an 'intermediate_answer'
        self.score: float | None = None
        self.follow_ups: list[DriftAction] = (
            follow_ups if follow_ups is not None else []
        )
        
    @property
    def is_complete(self) -> bool:
        """Check if the action is complete (i.e., an answer is available)."""
        return self.answer is not None

    def get_sorted_frequency_dict(self, numbers):
        frequency = Counter(numbers)
        sorted_items = sorted(frequency.items(), key=lambda item: (-item[1], item[0]))
        sorted_frequency_dict = dict(sorted_items)
        return sorted_frequency_dict

    def locate_specific_row(self, edge_df, source_node, target_node, use_ids=False):
        if use_ids:
            mask = (edge_df['source_node_id'] == source_node) & (edge_df['target_node_id'] == target_node)
        else:
            mask = (edge_df['source_node'] == source_node) & (edge_df['target_node'] == target_node)

        matched_rows = edge_df[mask]
        return matched_rows
    
    def get_neighbors(self, edge_df, node_name):
        neighbors_as_source = edge_df[edge_df['source_node'] == node_name]['target_node']
        neighbors_as_target = edge_df[edge_df['target_node'] == node_name]['source_node']
        neighbors = pd.concat([neighbors_as_source, neighbors_as_target]).unique()
        return neighbors.tolist()
    
    
    def _get_header(self, attributes: list[str]) -> list[str]:
        header = ["id", "title"]
        attributes = [col for col in attributes if col not in header]
        attributes = [col for col in attributes if col != 'rating']
        header.extend(attributes)
        return header 
        
    def serialize(self, include_follow_ups: bool = True) -> dict[str, Any]:
        """
        Serialize the action to a dictionary.

        Args:
            include_follow_ups (bool): Whether to include follow-up actions in the serialization.

        Returns
        -------
        dict[str, Any]
            Serialized action as a dictionary.
        """
        data = {
            "query": self.query,
            "answer": self.answer,
            "score": self.score,
        }
        if include_follow_ups:
            data["follow_ups"] = [action.serialize() for action in self.follow_ups]
        return data
    
    
    def asearch(self,
                global_query: str,
                llm_config:dict,
                node_embed: Any,
                node_df: pd.DataFrame,
                edge_df: pd.DataFrame,
                community_df: pd.DataFrame,
                top_k: int = 10,
                neighbor_top_k: int = 5,
                scorer: Any = None):
        
        LLM_engine = llm_config["engine"]
        LLM_settings = llm_config["settings"]
        max_context_size = llm_config["max_context_size"]
        final_content = []
        cur_length = 0
        query_embed = encode_sentence(model=llm_config['embedding'], sentence=self.query)
        # obtain relevant entities
        selected_entities = build_context(query_embedding = query_embed, 
                                          embed_data = node_embed,target_df = node_df, 
                                          idx_col = 'node_id', top_k = top_k)
        
        # check these entity corresponding community and sort by number of appearance
        # DFS community selection
        community_rank = {}
        for level in range(1, 4):
            cur_level = selected_entities[f"Community_Level_{level}"]
            for idx, com_key in enumerate(cur_level):
                if not math.isnan(com_key):
                    if idx not in community_rank:
                        community_rank[idx] = [com_key]
                    else:
                        community_rank[idx].append(com_key)

        DFS_community_choice = [int(community_rank[key][-1]) for key in community_rank if community_rank[key]]
        sorted_community = self.get_sorted_frequency_dict(DFS_community_choice)
        
        # obtain releveant communities report to the entities and maintain within max_length
        community_reports_str = [
            content
            for key in sorted_community.keys()
            for content in community_df.loc[community_df['community_key'] == key, 'content'].astype(str).tolist()]
        
        community_reports = [extract_community_json(i) for i in community_reports_str]

        sorted_keys = list(sorted_community.keys())
        community_reports = [
            {**item, "id": sorted_keys[idx]}
            for idx, item in enumerate(community_reports)
        ]

        attributes = (
            list(community_reports[0].keys())
            if community_reports[0].keys()
            else []
        )

        header = self._get_header(attributes)
        concatenated_community_report = "\n".join(
            ".".join(str(report[attr]) for attr in header)
            for report in community_reports
        )
        head_line = ",".join(header)
        concatenated_community_report = "-----Report-----\n" + f"{head_line}\n" + concatenated_community_report
        report_len = num_tokens(concatenated_community_report, LLM_engine.tokenizer)
        if cur_length + report_len < max_context_size:
            final_content.append(concatenated_community_report)        
        
        
        # return the selected entity information
        concatenated_entity_report = '\n'.join(
            f"{row.node_id},{row.node_name},{row.content}"
            for row in selected_entities.itertuples(index=False)
        )
        
        node_len = num_tokens(concatenated_entity_report, LLM_engine.tokenizer)
        if cur_length + node_len < max_context_size:
            final_content.append(concatenated_entity_report)        
        
        concatenated_entity_report = "-----Entities-----\n" + "id,entity,description\n" +concatenated_entity_report
        entity_list = selected_entities["node_name"].tolist()
        neighbor_set = []
        neighbor_info = {}
        for entity_name in entity_list:
            cur_neighbor = self.get_neighbors(edge_df, entity_name)[:neighbor_top_k]
            for nei in cur_neighbor:
                if (entity_name, nei) not in neighbor_set and (nei, entity_name) not in neighbor_set:
                    neighbor_set.append((entity_name, nei))
                    target_info = self.locate_specific_row(edge_df, entity_name, nei)
                    if len(target_info) == 0:
                        target_info = self.locate_specific_row(edge_df, nei, entity_name)
                    neighbor_info[(entity_name, nei)] = target_info['content'].iloc[0]

        entity_len = num_tokens(concatenated_entity_report, LLM_engine.tokenizer)
        # return the selected entity relationship information
        concatenated_edge_report = '\n'.join(
            f"{edge[0]},{edge[1]},{neighbor_info[edge]}"
            for edge in neighbor_info)
        
        concatenated_edge_report = "-----Relationships-----\n" + "id,source,target,description\n" +concatenated_edge_report
        
        edge_len = num_tokens(concatenated_edge_report, LLM_engine.tokenizer)
        if cur_length + edge_len < max_context_size:
            final_content.append(concatenated_edge_report)
        
        context_result = "\n\n".join(final_content)
        # TODO, add text unit info
        
        search_prompt = drift_search_system_prompt.DRIFT_LOCAL_SYSTEM_PROMPT.format(
            context_data=context_result,
            response_type="multiple paragraphs",
            global_query=global_query,
        )
        search_messages = [
            {"role": "system", "content": search_prompt},
            {"role": "user", "content": self.query},
        ]        

        llm_response = agent_answer(
            LLM_engine,
            input_info=search_messages,
            message_mode=True,
            do_sample=LLM_settings["do_sample"],
            max_new_tokens=LLM_settings["max_new_tokens"],
            temperature=LLM_settings["temperature"]
        )

        parsed_output = parse_llm_output(llm_response)
        
        self.answer = parsed_output['sections']
        self.follow_ups = parsed_output['follow_up_questions']
        self.score = parsed_output['score']
        return self

In [15]:
from collections.abc import Callable
import random
class QueryState:
    """Manage the state of the query, including a graph of actions."""

    def __init__(self):
        self.graph = nx.MultiDiGraph()

    def add_action(self, action: DriftAction, metadata: dict[str, Any] | None = None):
        """Add an action to the graph with optional metadata."""
        self.graph.add_node(action, **(metadata or {}))
        
    def relate_actions(
        self, parent: DriftAction, child: DriftAction, weight: float = 1.0
    ):
        """Relate two actions in the graph."""
        self.graph.add_edge(parent, child, weight=weight)
        
    def add_all_follow_ups(
        self,
        action: DriftAction,
        follow_ups: list[DriftAction] | list[str],
        weight: float = 1.0,
    ):
        """Add all follow-up actions and links them to the given action."""
        if len(follow_ups) == 0:
            logger.warning("No follow-up actions for action: %s", action.query)

        for follow_up in follow_ups:
            if isinstance(follow_up, str):
                follow_up = DriftAction(query=follow_up)
            elif not isinstance(follow_up, DriftAction):
                logger.warning(
                    "Follow-up action is not a string, found type: %s", type(follow_up)
                )

            self.add_action(follow_up)
            self.relate_actions(action, follow_up, weight)
            
    def find_incomplete_actions(self) -> list[DriftAction]:
        """Find all unanswered actions in the graph."""
        return [node for node in self.graph.nodes if not node.is_complete]
    
    def rank_incomplete_actions(
        self, scorer: Callable[[DriftAction], float] | None = None
    ) -> list[DriftAction]:
        """Rank all unanswered actions in the graph if scorer available."""
        unanswered = self.find_incomplete_actions()
        if scorer:
            for node in unanswered:
                node.compute_score(scorer)
            return sorted(
                unanswered,
                key=lambda node: (
                    node.score if node.score is not None else float("-inf")
                ),
                reverse=True,
            )

        # shuffle the list if no scorer
        random.shuffle(unanswered)
        return list(unanswered)
    
    def serialize(
        self, include_context: bool = False
    ) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any], str]:
        """Serialize the graph to a dictionary, including nodes and edges."""
        # Create a mapping from nodes to unique IDs
        node_to_id = {node: idx for idx, node in enumerate(self.graph.nodes())}

        # Serialize nodes
        nodes: list[dict[str, Any]] = [
            {
                **node.serialize(include_follow_ups=False),
                "id": node_to_id[node],
                **self.graph.nodes[node],
            }
            for node in self.graph.nodes()
        ]

        # Serialize edges
        edges: list[dict[str, Any]] = [
            {
                "source": node_to_id[u],
                "target": node_to_id[v],
                "weight": edge_data.get("weight", 1.0),
            }
            for u, v, edge_data in self.graph.edges(data=True)
        ]
        return {"nodes": nodes, "edges": edges}

In [16]:
import re

def parse_llm_output(text: str) -> dict:
    """
    Parse the LLM output into a structured dictionary.
    """
    # This dictionary will hold all parsed content
    parsed_output = {
        "title": None,
        "sections": {},        # key: section heading, value: content
        "references": None,
        "score": None,
        "follow_up_questions": []
    }

    # Regex patterns
    # 1) Headings: lines that start and end with double asterisks (**Heading**)
    heading_pattern = re.compile(r"\*\*(.+?)\*\*")
    # 2) Score pattern: "**Score:** some_number"
    score_pattern = re.compile(r"\*\*Score:\*\*\s*(\d+)")
    # 3) Follow-up questions heading
    followup_pattern = re.compile(r"\*\*Follow-up Questions:\*\*")

    # Split the text into lines
    lines = text.splitlines()

    # A helper to track which heading we are currently in
    current_heading = None
    # Buffer to hold text lines for the current heading
    section_lines = []

    # We’ll manually detect references, score, and follow-up questions as we go
    in_followup_section = False

    for line in lines:
        line_stripped = line.strip()

        # Check if we've matched a heading
        heading_match = heading_pattern.match(line_stripped)
        # Check for follow-up questions heading
        followup_match = followup_pattern.match(line_stripped)
        # Check for score
        score_match = score_pattern.match(line_stripped)

        # If line is empty, just skip it
        if not line_stripped:
            continue

        # If it's the "Score:" line
        if score_match:
            parsed_output["score"] = int(score_match.group(1))
            continue

        # If it's the "Follow-up Questions:" heading
        if followup_match:
            # Save the current heading's content (if any) before clearing
            if current_heading and section_lines:
                parsed_output["sections"][current_heading] = "\n".join(section_lines).strip()
                section_lines = []
            current_heading = "Follow-up Questions"
            in_followup_section = True
            continue

        # If it’s any other heading
        if heading_match and not followup_match:
            # Save the previous heading’s content in the dictionary
            if current_heading and section_lines:
                parsed_output["sections"][current_heading] = "\n".join(section_lines).strip()
                section_lines = []

            # The new heading
            current_heading = heading_match.group(1).strip()
            # Check if it's 'References'
            if current_heading.lower() == "references":
                in_followup_section = False  # references is separate
            else:
                in_followup_section = False
            continue

        # If we're in the follow-up questions section, each line is a potential question
        if in_followup_section:
            # Typically, follow-up questions are numbered lines. Let's grab them.
            # You can refine this logic if your format changes.
            parsed_output["follow_up_questions"].append(line_stripped)
        else:
            # Otherwise, this line belongs to the current heading's content
            section_lines.append(line_stripped)

    # After the loop, we still might have leftover content for the last heading
    if current_heading and section_lines and current_heading != "Follow-up Questions":
        content = "\n".join(section_lines).strip()
        # If the heading is 'References', store them separately
        if current_heading.lower() == "references":
            parsed_output["references"] = content
        else:
            parsed_output["sections"][current_heading] = content

    # The first heading might logically be the "title" if the structure calls for that.
    # If you consider the very first heading in the text to be the title, you can do:
    all_headings = list(parsed_output["sections"].keys())
    if all_headings:
        parsed_output["title"] = all_headings[0]
        first_heading_content = parsed_output["sections"].pop(all_headings[0])
        parsed_output["sections"]["intro"] = first_heading_content

    return parsed_output

In [17]:

from concurrent.futures import ThreadPoolExecutor, as_completed

def run_drift_loop(llm_config, query_state, query, node_description_embed, node_df, edge_df, community_df):
    """
    Executes the DRIFT loop based on the provided configuration and state.

    Args:
        llm_config (dict): Configuration dictionary containing 'query_epoch' and 'drift_k_followups'.
        query_state: An object managing the state of queries and actions.
        query (str): The global query string.
        node_description_embed: Embedding for node descriptions.
        node_df (DataFrame): DataFrame containing node information.
        edge_df (DataFrame): DataFrame containing edge information.
        community_df (DataFrame): DataFrame containing community information.
    """
    # Configure the logger
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    if not logger.handlers:
        # Prevent adding multiple handlers in interactive environments
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    max_epochs = llm_config.get('query_epoch', 0)
    drift_k_followups = llm_config.get('drift_k_followups', 0)

    if max_epochs <= 0:
        logger.warning("No epochs to run. Exiting DRIFT loop.")
        return

    for epoch in range(1, max_epochs + 1):
        actions = query_state.rank_incomplete_actions()
        if not actions:
            logger.info("No more actions to take. Exiting DRIFT loop.")
            break

        # Select top k follow-up actions
        selected_actions = actions[:drift_k_followups]
        if not selected_actions:
            logger.info("No actions selected after applying drift_k_followups. Exiting DRIFT loop.")
            break

        logger.debug(f"Epoch {epoch}: Processing {len(selected_actions)} actions.")

        # Function to process a single action with error handling
        def process_action(action):
            try:
                return action.asearch(
                    global_query=query,
                    llm_config=llm_config,
                    node_embed=node_description_embed,
                    node_df=node_df,
                    edge_df=edge_df,
                    community_df=community_df
                )
            except Exception as e:
                logger.error(f"Error processing action {action}: {e}")
                return None

        results = []
        # Use ThreadPoolExecutor to process actions in parallel if asearch is I/O bound
        with ThreadPoolExecutor(max_workers=drift_k_followups) as executor:
            future_to_action = {executor.submit(process_action, action): action for action in selected_actions}
            for future in as_completed(future_to_action):
                action_result = future.result()
                if action_result:
                    results.append(action_result)

        if not results:
            logger.warning(f"Epoch {epoch}: No successful action results. Continuing to next epoch.")
            continue

        # Update the query state with new actions and their follow-ups
        for action in results:
            query_state.add_action(action)
            follow_ups = getattr(action, 'follow_ups', None)
            if follow_ups:
                query_state.add_all_follow_ups(action, follow_ups)
            else:
                logger.warning(f"Action {action} has no 'follow_ups' attribute.")

        logger.info(f"Finished Epoch {epoch}/{max_epochs} Stage")

    logger.info("DRIFT loop completed.")
    return query_state

In [18]:
LLM_args = {"do_sample":False, "do_sample":0, 
            "entity_type_extraction_length":100,
            "pad_token_id":LLM.tokenizer.eos_token_id,
            "entity_extraction_length":2000,
            "entity_condition_length":10,
            "update_entity_description":500,
            "subgraph_matching_length": 10,
            "community_sum_length":1000,
            "max_new_tokens":2000,
            "temperature":0, "filling_max_neighbor":8}

llm_config = {
    "engine": LLM,
    "settings": LLM_args,
    "embedding": embedding_model,
    "query_epoch":2,
    "drift_k_followups":3,
    "max_context_size":6000
}

In [102]:
import networkx as nx
import math

def build_graph(response_state):
    G = nx.DiGraph()

    for node_info in response_state.get('nodes', []):
        node_id = node_info.get('id')
        if node_id is None:
            continue

        answer_value = node_info.get('answer', '')
        if answer_value:
            # Use `or 0.0` to replace None with 0.0, then convert to float
            score_value = node_info.get('score') or 0.0
            score_value = float(score_value)

            G.add_node(
                node_id,
                query=node_info.get('query'),
                answer=answer_value,
                score=score_value
            )

    for edge_info in response_state.get('edges', []):
        src = edge_info.get('source')
        tgt = edge_info.get('target')
        if not src or not tgt:
            continue
        if src in G and tgt in G:
            weight = edge_info.get('weight', 0.0)
            G.add_edge(src, tgt, weight=weight)

    return G


def find_roots_and_leaves(G):
    roots = [n for n, deg in G.in_degree() if deg == 0]
    leaves = [n for n, deg in G.out_degree() if deg == 0]
    return roots, leaves


def best_path_by_node_score(G):
    roots, leaves = find_roots_and_leaves(G)
    best_path = None
    best_score = -math.inf

    for root in roots:
        for leaf in leaves:
            for path in nx.all_simple_paths(G, root, leaf):
                # Convert each node’s 'score' to float in case of any leftover weirdness
                path_score = sum(float(G.nodes[n].get('score', 0.0)) for n in path)
                if path_score > best_score:
                    best_score = path_score
                    best_path = path

    return best_path, best_score

In [109]:
def generate_answer(query, graph, llm_config, use_all=True):
    """
    Generates an answer for a given user query based on relevant information from a graph.

    :param query:        The user query (string).
    :param graph:        A graph structure containing relevant information in its node attributes.
    :param llm_config:   A dictionary containing LLM engine and settings.
    :param use_all:      Boolean flag indicating whether to use all nodes or a best path subset.
    :return:             A string containing the final answer from the LLM.
    """

    # Prompt template
    prompt_template = (
        "You are provided with the following user query:\n"
        "\"{query}\"\n\n"
        "Along with relevant information:\n"
        "{content}\n\n"
        "Please analyze the query and the provided information to generate a clear, concise, "
        "and accurate final answer. Please only generate the answer!"
    )

    # Extract engine and settings from configuration
    LLM_engine = llm_config["engine"]
    LLM_settings = llm_config["settings"]

    # Determine which nodes to use
    if use_all:
        nodes_to_use = list(graph.nodes())
    else:
        best_path, best_score = best_path_by_node_score(graph)
        nodes_to_use = list(best_path)

    # Aggregate relevant info from each node
    content_snippets = []
    for node in nodes_to_use:
        cur_info = graph.nodes[node]  # Retrieve node attributes
        content_snippets.append(f"Query: {cur_info['query']}\nAnswer: {cur_info['answer']}\n")

    # Construct final content for the LLM prompt
    final_output = "".join(content_snippets)

    # Format the full prompt
    cur_prompt = prompt_template.format(query=query, content=final_output)

    # Call your LLM agent with the constructed prompt
    llm_response = agent_answer(
        LLM_engine,
        input_info=cur_prompt,  # Pass the prompt here
        message_mode=False,
        do_sample=LLM_settings["do_sample"],
        max_new_tokens=LLM_settings["max_new_tokens"],
        temperature=LLM_settings["temperature"]
    )
    return llm_response

In [19]:
query = "How to apply transformers for medical data classficaiton tasks?"
reports = list(community_df['content'])
augmented_query, token_ct = expand_query(query, reports=reports, llm_config=llm_config)
query_embed = encode_sentence(model=llm_config['embedding'], sentence=augmented_query)
selected_context = build_context(query_embedding = query_embed, 
                                 embed_data = community_description_embed, 
                                 target_df = community_df, idx_col = 'community_key', top_k = 5)
# add self-correction is tasks is not in json format
tasks, token_ct = decompose_query(query, reports=selected_context, llm_config=llm_config)
parsed_output = obtain_task_json(tasks)

2024-12-22 22:33:54,669 [INFO] __main__: Encoding 1 sentence(s) on device 'cuda'.
2024-12-22 22:33:54,730 [INFO] __main__: Encoding successful. Shape of embeddings: (1, 8192).
2024-12-22 22:33:54,733 [INFO] __main__: Normalized embeddings to unit vectors.
2024-12-22 22:33:54,733 [INFO] __main__: Using FAISS IndexFlatIP for inner product similarity.
2024-12-22 22:33:54,734 [INFO] __main__: FAISS index built with 56 vectors.
2024-12-22 22:33:54,739 [INFO] __main__: Retrieved top-5 similar documents.


In [20]:
init_action = DriftAction(query=query, 
                     follow_ups=parsed_output.get("follow_up_queries", []),
                     answer=parsed_output.get("intermediate_answer"))
init_action.score = parsed_output.get("score")

In [21]:
query_state = QueryState()
query_state.add_action(init_action)
query_state.add_all_follow_ups(init_action, init_action.follow_ups)

In [22]:
updated_query_state = run_drift_loop(llm_config, query_state, query, node_description_embed, node_df, edge_df, community_df)

2024-12-22 22:34:10,952 - INFO - Encoding 1 sentence(s) on device 'cuda'.
2024-12-22 22:34:10,952 [INFO] __main__: Encoding 1 sentence(s) on device 'cuda'.
2024-12-22 22:34:10,952 - INFO - Encoding 1 sentence(s) on device 'cuda'.
2024-12-22 22:34:10,954 - INFO - Encoding 1 sentence(s) on device 'cuda'.
2024-12-22 22:34:10,952 [INFO] __main__: Encoding 1 sentence(s) on device 'cuda'.
2024-12-22 22:34:10,954 [INFO] __main__: Encoding 1 sentence(s) on device 'cuda'.
2024-12-22 22:34:11,009 - INFO - Encoding successful. Shape of embeddings: (1, 8192).
2024-12-22 22:34:11,009 [INFO] __main__: Encoding successful. Shape of embeddings: (1, 8192).
2024-12-22 22:34:11,010 - INFO - Encoding successful. Shape of embeddings: (1, 8192).
2024-12-22 22:34:11,010 [INFO] __main__: Encoding successful. Shape of embeddings: (1, 8192).
2024-12-22 22:34:11,018 - INFO - Encoding successful. Shape of embeddings: (1, 8192).
2024-12-22 22:34:11,018 [INFO] __main__: Encoding successful. Shape of embeddings: (1,

In [23]:
response_state = query_state.serialize()

In [110]:
answer_graph = build_graph(response_state)

In [111]:
final_answer1 = generate_answer(query, answer_graph, llm_config, use_all=True)

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


In [112]:
final_answer2 = generate_answer(query, answer_graph, llm_config, use_all=False)

In [113]:
print(final_answer1)

**Applying Transformers for Medical Data Classification Tasks**

Transformers have shown great promise in various machine learning applications, including natural language processing, computer vision, and bioinformatics. In the context of medical data classification tasks, transformers can be used to develop accurate and efficient models for disease diagnosis, treatment planning, and patient outcome prediction.

Transformers can be applied to medical data classification tasks in several ways, including text classification, image classification, and time series analysis. They offer several advantages, such as improved accuracy, efficient processing, and flexibility. However, there are still several challenges and future directions to explore, including data quality, domain adaptation, and explainability.

Transformers can be used to improve the interpretability of medical data classification models by analyzing attention mechanisms and using explainable AI (XAI) techniques. They can als

In [114]:
print(final_answer2)

Transformers can be applied to medical data classification tasks, particularly those involving sequential data such as medical text, time-series data, or genomic sequences. They can also be used to classify graph-structured data, such as medical images or protein-protein interaction networks. For medical data that involves both sequential and graph-structured data, a hybrid approach can be taken, combining sequential and graph transformers to leverage the strengths of both architectures. However, transformers may not be well-suited for large-scale medical datasets or complex relationships between variables, and future directions include developing more efficient and scalable architectures and incorporating domain-specific knowledge.
