In [None]:
import os
import re
import faiss
import json
import pandas as pd
from elasticsearch.helpers import bulk
from difflib import SequenceMatcher
from typing import List, Dict, Any


class SemanticHybridSearch:
    """
    A class that combines Elasticsearch for keyword-based lexical searches and Faiss for semantic searches.

    This class provides methods to load and search both Elasticsearch and Faiss indices,
    as well as a hybrid search method that combines results from both search types.

    Attributes:
        data (list): The dataset used for searching.
        es_client (Elasticsearch): Elasticsearch client for performing lexical searches.
        embedding_model: Model used for encoding queries into embeddings.
        vector_index (faiss.Index): Faiss index for semantic searches.
        elastic_index_name (str): Name of the Elasticsearch index.
    """

    def __init__(
        self,
        es_client,
        embedding_model,
        data: list,
        elastic_index_path: str,
        vector_index_path: str,
    ):
        """
        Initialize the SemanticHybridSearch class.

        Args:
            es_client (Elasticsearch): Elasticsearch client.
            embedding_model: Model for encoding queries into embeddings.
            data (list): Dataset used for searching.
            elastic_index_path (str): Path to the Elasticsearch index file.
            vector_index_path (str): Path to the Faiss vector index file.
        """
        self.data = data
        self.es_client = es_client
        self.embedding_model = embedding_model
        self.vector_index = self.load_vector_index(vector_index_path)
        self.elastic_index = self.load_elastic_index(elastic_index_path)

        self.elastic_index_name = ""

    def load_elastic_index(self, elastic_index_path: str):
        """
        Load the Elasticsearch index from a file.

        Args:
            elastic_index_path (str): Path to the Elasticsearch index file.
        """
        with open(elastic_index_path) as f:
            documents = json.load(f)
            self.elastic_index_name = os.path.basename(elastic_index_path)
            print(f"Loading Index {self.elastic_index_name}")

            actions = [
                {
                    "_index": self.elastic_index_name,
                    "_id": doc["_id"],
                    "_source": doc["_source"],
                }
                for doc in documents
            ]
            bulk(self.es_client, actions)

    def load_vector_index(self, vector_index_path: str):
        """
        Load the Faiss vector index from a file.

        Args:
            vector_index_path (str): Path to the Faiss vector index file.

        Returns:
            faiss.Index: Loaded Faiss index.
        """
        print(f"Loading Index {os.path.basename(vector_index_path)}")
        index = faiss.read_index(vector_index_path)
        return index

    def elastic_search(self, query: dict, top_k: int = 3) -> list:
        """
        Perform a keyword-based search using Elasticsearch.

        Args:
            query (dict): Elasticsearch query.
            top_k (int): Number of top results to return. Defaults to 3.

        Returns:
            list: Top k search results.
        """
        results = self.es_client.search(index=self.elastic_index_name, body=query)
        return [result["_source"] for result in results["hits"]["hits"][:top_k]]

    def semantic_search(self, query: str, top_k: int = 3) -> list:
        """
        Perform a semantic search using Faiss.

        Args:
            query (str): Search query.
            top_k (int): Number of top results to return. Defaults to 3.

        Returns:
            list: Top k search results.
        """
        embedding = self.embedding_model.encode([query]).astype("float32")
        distances, idx = self.vector_index.search(embedding, top_k)
        results = [self.data[i] for i in idx[0]]

        return results

    def hybrid_search(
        self,
        elastic_query: dict,
        semantic_query: str,
        top_k: tuple = (3, 3),
        clean_overlap: bool = True,
    ) -> list:
        """
        Perform a hybrid search combining results from Elasticsearch and Faiss.

        Args:
            elastic_query (dict): Elasticsearch query for lexical search.
            semantic_query (str): Query string for semantic search.
            top_k (tuple): Tuple containing the number of top results to return for (elastic, semantic) searches. Defaults to 3.
            clean_overlap (bool): Whether to remove overlap in email threads results. Defaults to True.

        Returns:
            list: Combined and deduplicated search results.
        """
        elastic_results = self.elastic_search(elastic_query, top_k[0])
        semantic_results = self.semantic_search(semantic_query, top_k[1])

        hybrid_concat = pd.concat(
            [pd.DataFrame(elastic_results), pd.DataFrame(semantic_results)],
            ignore_index=True,
        ).drop_duplicates()
        hybrid_results = hybrid_concat.to_dict(orient="records")

        if clean_overlap:
            return self._extract_unique_content(hybrid_results)
        return hybrid_results

    def _clean_text(self, text: str) -> str:
        """
        Remove extra whitespace and newlines from the given text.

        Args:
            text (str): The input text to be cleaned.

        Returns:
            str: The cleaned text with extra whitespace removed.
        """
        return re.sub(r"\s+", " ", text).strip()

    def _find_overlap(self, text1: str, text2: str) -> str:
        """
        Find the longest common substring between two texts.

        Args:
            text1 (str): The first text to compare.
            text2 (str): The second text to compare.

        Returns:
            str: The longest common substring, or an empty string if no overlap is found.
        """
        matcher = SequenceMatcher(None, text1, text2)
        match = matcher.find_longest_match(0, len(text1), 0, len(text2))
        return text1[match.a : match.a + match.size] if match.size > 0 else ""

    def _extract_unique_content(
        self, emails: List[Dict[str, Any]]
    ) -> List[Dict[str, Any]]:
        """
        Extract unique content from a list of email dictionaries by removing overlapping text.

        This function processes a list of email dictionaries, removing any overlapping content
        between emails to reduce redundancy. It preserves the original email structure and
        metadata while modifying only the 'Mail_Body' field.

        Args:
            emails (List[Dict[str, Any]]): A list of dictionaries, each representing an email
            keys for 'Origin', 'Subject', 'To', 'From', 'Cc', 'Bcc', 'Date', 'Attachment_Count',
            and 'Mail_Body'.

        Returns:
            List[Dict[str, Any]]: A list of dictionaries with the same structure as the input,
            but with overlapping content removed from the 'Mail_Body' field.

        Note:
            This function assumes that emails are ordered chronologically, with newer emails
            appearing later in the list.
        """
        unique_contents = []

        for i, email in enumerate(emails):
            current_email = self._clean_text(email["Mail_Body"])
            unique_content = current_email

            for j in range(i):
                previous_email = self._clean_text(emails[j]["Mail_Body"])
                overlap = self._find_overlap(previous_email, current_email)

                if len(overlap) > 10:
                    unique_content = unique_content.replace(overlap, "").strip()

            unique_contents.append(
                {
                    "Origin": email["Origin"],
                    "Subject": email["Subject"],
                    "To": email["To"],
                    "From": email["From"],
                    "Cc": email["Cc"],
                    "Bcc": email["Bcc"],
                    "Date": email["Date"],
                    "Attachment_Count": email["Attachment_Count"],
                    "Mail_Body": unique_content,
                }
            )

        return unique_contents


In [15]:
import json
from typing import Dict, List, Annotated
from typing_extensions import TypedDict
from langchain_core.messages import HumanMessage, AIMessage
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolExecutor

class State(TypedDict):
    accusation: str
    queries:Dict[Dict, str]
    search_results: List[Dict]
    extracted_info: Dict
    analysis: Dict
    search_count: int

class DummyLLM:
    def invoke(self, code: str):
        pass

class InvestigationAgent:
    def __init__(self):
        self.llm = DummyLLM()
        self.search_tool = SemanticHybridSearch()
        self.workflow = self._create_workflow()

    def _create_workflow(self) -> StateGraph:
        workflow = StateGraph(State)
        workflow.add_node("initial_query", self.initial_query_generation)
        workflow.add_node("search", self.perform_search)
        workflow.add_node("extract_info", self.information_extraction)
        workflow.add_node("analyze", self.evidence_analysis)
        workflow.add_node("refine_query", self.refine_query)

        workflow.add_edge("initial_query", "search")
        workflow.add_edge("search", "extract_info")
        workflow.add_edge("extract_info", "analyze")
        workflow.add_conditional_edges("analyze", self.should_continue_search, {"end": END, "refine": "refine_query"})
        workflow.add_edge("refine_query", "search")
        workflow.set_entry_point("initial_query")

        return workflow.compile()

    def initial_query_generation(self, state: State) -> Dict:
        prompt = PromptTemplate.from_template(self._initial_query_prompt())
        human_message = HumanMessage(content=prompt.format(accusation=state['accusation']))
        ai_message = self.llm.invoke([human_message])
        queries = json.loads(ai_message.content)
        return {"queries": queries}

    def perform_search(self, state: Dict) -> Dict:
        results = self.search_tool.search(state['queries']['elastic'], state['queries']['semantic'])
        return {"search_results": results}

    def information_extraction(self, state: State) -> Dict:
        prompt = PromptTemplate.from_template(self._information_extraction_prompt())
        human_message = HumanMessage(content=prompt.format(
            accusation=state['accusation'],
            results=json.dumps(state['search_results'])
        ))
        ai_message = self.llm.invoke([human_message])
        extracted_info = json.loads(ai_message.content)
        return {"extracted_info": extracted_info}

    def evidence_analysis(self, state: State) -> Dict:
        prompt = PromptTemplate.from_template(self._analyze_evidence_prompt())
        human_message = HumanMessage(content=prompt.format(
            accusation=state['accusation'],
            info=json.dumps(state['extracted_info']),
            summary=state.get('summary', 'None')
        ))
        ai_message = self.llm.invoke([human_message])
        analysis = json.loads(ai_message.content)
        return {"analysis": analysis}

    def refine_query(self, state: State) -> Dict:
        prompt = PromptTemplate.from_template(self._refine_search_prompt())
        human_message = HumanMessage(content=prompt.format(
            elastic_query=json.dumps(state['queries']['elastic']),
            semantic_query=state['queries']['semantic'],
            info=json.dumps(state['extracted_info']),
            areas=json.dumps(state['analysis']['areas_for_further_investigation']),
            accusation=state['accusation']
        ))
        ai_message = self.llm.invoke([human_message])
        refined_queries = json.loads(ai_message.content)
        return {"queries": refined_queries}

    def should_continue_search(self, state: State) -> str:
        if state['search_count'] >= 2:
            return "end"
        if state['analysis']['sufficiency']['conclusion'] == "sufficient":
            return "end"
        if state['search_count'] > 0 and not self._significant_difference(state['previous_analysis'], state['analysis']):
            return "end"
        return "refine"

    def _significant_difference(self, prev_analysis: Dict, current_analysis: Dict) -> bool:
        # Implement logic to compare previous and current analysis
        # Return True if there's a significant difference, False otherwise
        pass

    def run_investigation(self, accusation: str) -> Dict:
        inputs = {
            "accusation": accusation,
            "search_count": 0,
            "previous_analysis": None
        }
        
        for output in self.workflow.stream(inputs):
            if "search_count" in output:
                output["search_count"] += 1
            if "analysis" in output:
                output["previous_analysis"] = output["analysis"]
            print(f"Step: {output['__node__']}")
            print(f"Output: {json.dumps(output, indent=2)}")
            print("---")
        
        return output

    @staticmethod
    def _initial_query_prompt() -> str:
        return """Task: Generate initial search queries for the following accusation, suitable for use with the SemanticHybridSearch tool. Accusation: {accusation} Response Format: Provide the response in JSON format with the following keys: elastic: Contains the Elasticsearch query in JSON format. semantic: Contains the semantic search query as a string. Guidelines: Unionized Search Approach: - Combine Elasticsearch and semantic search capabilities effectively. For example: Use Elasticsearch to filter specific fields (e.g., recipients, senders). Use semantic search to refine or specify the context within filtered results. - If only one type of search is required, leave the other key empty (e.g., {} for elastic or "" for semantic). Data Schema: { "Subject": "Subject of mail", "To": "All Recipients", "From": "Name of sender", "Cc": "All CC", "Bcc": "All BCC", "Date": "Date in datetime format", "Attachment_Count": "Number of attachments", "Mail_Body": "Content of the mail in plain text format" } Elasticsearch Query: - Focus on key terms and concepts relevant to the accusation. - Use appropriate Elasticsearch query DSL structures (e.g., bool, must, should, match, term). - Consider field-specific searches (e.g., subject, body, from, to) and apply boosts where necessary. - Ensure queries are broad enough to capture relevant information but specific enough to exclude irrelevant results. Semantic Search Query: - Use natural language to describe the context and meaning of the accusation. - Incorporate synonyms, related terms, and broader concepts to capture nuances beyond simple keywords. Efficiency and Contextual Relevance: - Adapt search strategies based on the unique aspects of each accusation. - Ensure objectivity and avoid bias in query generation. - Clearly distinguish between facts, inferences, and speculations. Output Example: { "elastic": { // Elasticsearch query here }, "semantic": "Semantic search string here" } Do not provide a preamble or an explanation, the output should strictly be in JSON format with no comments"""

    @staticmethod
    def _refine_search_prompt() -> str:
        return """Task: Refine the search queries based on the current queries and extracted information to uncover more details about the accusation. Provide refined queries for both Elasticsearch and semantic search. Current Elasticsearch Query: {elastic_query} Current Semantic Query: {semantic_query} Extracted Info Summary: {info} Areas for Further Investigation: {areas} Accusation: {accusation} Guidelines: Unionized Search Approach: - Combine Elasticsearch and semantic search capabilities effectively. For example: Use Elasticsearch to filter specific fields (e.g., recipients, senders). Use semantic search to refine or specify the context within filtered results. - If only one type of search is required, leave the other key empty (e.g., {} for elastic or "" for semantic). Data Schema: { "Subject": "Subject of mail", "To": "All Recipients", "From": "Name of sender", "Cc": "All CC", "Bcc": "All BCC", "Date": "Date in datetime format", "Attachment_Count": "Number of attachments", "Mail_Body": "Content of the mail in plain text format" } Your refined queries should: - Build upon the insights gained from the extracted information. - Focus on areas where evidence is lacking or inconclusive. - Include any new relevant terms or concepts discovered in the previous search. - Be more specific than the initial queries, targeting the most promising areas for further investigation. - Utilize Elasticsearch-specific features for the lexical query and natural language for the semantic query. Refined Search Queries: { "elastic": { // Elasticsearch query here }, "semantic": "Semantic search string here" } Do not provide a preamble or an explanation, the output should strictly be in JSON format with no comments"""

    @staticmethod
    def _information_extraction_prompt() -> str:
        return """Task: Extract relevant information from the hybrid search results related to the following accusation: Accusation: {accusation} Hybrid Search Results: {results} Analyze the results, which combine Elasticsearch and Faiss search outcomes. Each result contains fields like "Subject", "To", "From", "Cc", "Bcc", "Date", "Attachment_Count", and "Mail_Body". Provide the following information in JSON format: { "accused_suspects": [], "incident_details": { "events": [ { "details": "", "description": "", "date": "", "uid":"", } ] }, "other_parties": { "name": { "relationship": "", "role": "", "uid":"uid", } }, "summary": "" } Ensure all relevant information is included within this structure. Omit any explanations or additional text outside the JSON."""

    @staticmethod
    def _analyze_evidence_prompt() -> str:
        return """Task: Analyze the extracted information and determine if it provides sufficient evidence for the accusation. If not, suggest areas for further investigation. Accusation: {accusation} Extracted Information: {info} Summary of Previous Information: {summary} Provide your analysis in the following JSON format: { "credibility_and_reliability": { "events_analysis": [ { "event": "Description of the event", "credibility_score": "Score from 0-100", "reasoning": "Explanation for the credibility score", "uid": "The uid of the source where event is mentioned" } ], "relationships_analysis": [ { "entity1": "Name of first entity", "entity2": "Name of second entity", "relationship": "Description of relationship", "credibility_impact": "How this relationship affects credibility", "uid": "The uid of the source where entities are mentioned" } ], "overall_credibility_assessment": "Summary of overall credibility" }, "sufficiency": { "conclusion": "One of: sufficient, partial, insufficient", "confidence_score": "Score from 0-100", "conclusion_statement": "Detailed explanation of the sufficiency conclusion", "refrences":... ["List of the uids referenced"] }, "areas_for_further_investigation": [ "List of specific areas or questions needing further investigation" ] } Ensure all relevant analysis is included within this structure. Omit any explanations or additional text outside the JSON."""

In [24]:
from IPython.display import display
x = InvestigationAgent()

In [26]:
print(x.workflow.get_graph().draw_ascii())

              +-----------+                
              | __start__ |                
              +-----------+                
                    *                      
                    *                      
                    *                      
            +---------------+              
            | initial_query |              
            +---------------+              
                    *                      
                    *                      
                    *                      
               +--------+                  
               | search |                  
               +--------+                  
             ***         ***               
            *               *              
          **                 ***           
+--------------+                *          
| extract_info |                *          
+--------------+                *          
        *                       *          
        *                       