In [None]:
import os

os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()

## Tools

In [None]:
# Import things that are needed generically
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool

In [None]:
class SearchNextStep(BaseModel):
    summary_level: str = Field(description='the summary level under which the sub-tree will be explored.', default='summary_0')
    query: str = Field(description='a query for the information you expect to find in the sub-tree')

In [None]:
from typing import Any, Optional, Type


    
class SummaryTree(BaseTool):
    name = 'branch retrieval'
    description = ' '.join('''
        This tool organizes the document in a summary tree. 
        The leaf nodes are the chunks from the document and the non-leaf nodes are the summaries of their children. 
        Higher-level nodes contain more general but less reliable information. 
        In the initial call, 
        Given a query, if  and a summary level, the tool will return the relevant chunk and all its ancestors as a branch in the summary tree. provide the multi-granularity context. 
        This context is useful in connecting the current relevant node with the remaining parts in the document.
    '''.split())
    args_schema: Type[BaseModel] = SearchNextStep
    return_direct: bool = False
    
    def __init__()

In [None]:
from enum import Enum

In [None]:
from llama_index.core import TreeIndex

In [None]:
from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser

In [None]:
SemanticSplitterNodeParser()

# NavigateAgent

In [None]:
from typing import Dict
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
import sys
import seaborn as sb
sys.path.append('../..')


from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langgraph.graph import END, StateGraph

import os

os.environ["OPENAI_API_KEY"] = "EMPTY"


from src import *
from src.test_utils import *
from src.summary_tree import *
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

article = open('article.txt').read()
question = "Why didn't the skipper follow the new cook's advice about avoiding Vesta?"



In [None]:
class NavigateState(BaseModel):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
    """

    question: str
    queries: List[str] = []
    retriever: Literal['dpr', 'tree'] = 'dpr'
    retrieve_root: int = -1
    # generation: str
    documents: List[Document] = []
    aggretate_info: List[str] = []
    retrieve_cnts: int = 0
    answer:str = ''

class NavigateAgent:
    class Nodes:
        GENERATE_QUERY = 'generate_query'
        RETRIEVE_DOC = 'retrieve_doc'
        GRADE_DOC = 'grade_doc'
        ANALYZE_DOC = 'analyze_doc'
        GENERATE_ANSWER = 'generate_answer'
    
    def __init__(self, llm:BaseChatModel, retrievers:Dict[str, MyStructure], max_retrieve_turn:int=5) -> None:
        self.llm = llm
        self.retrievers = retrievers
        self.generate_query_chain = GenerateQuery(self.llm)
        self.grade_doc_chain = GradeDocument(self.llm)
        self.analyze_doc_chain = AnalyzeDocument(self.llm)
        self.eval_complete_info_chain = EvalCompleteInfo(self.llm)
        self.generate_answer_chain = GenerateAnswer(self.llm)
        self.max_retrieve_turn = max_retrieve_turn
    
    def generate_query(self, state:NavigateState):
        print('generate_query')
        if state.aggretate_info:
            result = self.generate_query_chain(questions=[state.question], contexts=['\n'.join(state.aggretate_info)])[0]
        else:
            result = self.generate_query_chain(questions=[state.question])[0]
        state.queries = result.queries
        return state
    
    def retrieve_doc(self, state:NavigateState):
        print('retrieve_doc')
        retriever = self.retrievers[state.retriever]
        state.documents = []
        document_ids = set()
        if state.retriever == 'dpr':
            for query in state.queries:
                for doc in retriever.vectorstore.similarity_search(query):
                    if doc.metadata['i'] not in document_ids:
                        document_ids.add(doc.metadata['i'])
                        state.documents.append(doc)
        state.retrieve_cnts += 1
        return state
    
    def grade_doc(self, state:NavigateState):
        print('grade_doc')
        documents = [doc.page_content for doc in state.documents]
        questions = [state.question] * len(documents)
        if state.aggretate_info:
            context = '\n'.join(state.aggretate_info)
            contexts = [context] * len(documents)
            batch_results = self.grade_doc_chain(documents=documents, questions=questions, contexts=contexts)
        else:
            batch_results = self.grade_doc_chain(documents=documents, questions=questions)
        state.documents = [doc for doc, result in zip(state.documents, batch_results) if 'yes' in result.binary_score.lower()]
        return state
    
    def check_non_empty_retrieval(self, state:NavigateState):
        print('check_non_empty_retrieval')
        return 'Not empty' if state.documents else 'Empty'
    
    def analyze_doc(self, state:NavigateState):
        print('analyze_doc')
        documents = [doc.page_content for doc in state.documents]
        questions = [state.question] * len(documents)
        if state.aggretate_info:
            context = '\n'.join(state.aggretate_info)
            contexts = [context] * len(documents)
            batch_results = self.analyze_doc_chain(documents=documents, questions=questions, contexts=contexts)
        else:
            batch_results = self.analyze_doc_chain(documents=documents, questions=questions)
        state.aggretate_info.extend(batch_results)
        return state
    
    def eval_complete_info(self, state:NavigateState):
        print('eval_complete_info')
        result = self.eval_complete_info_chain(contexts=['\n'.join(state.aggretate_info)], questions=[state.question])[0]
        if 'yes' in result.binary_score.lower() or state.retrieve_cnts >= self.max_retrieve_turn:
            return 'generate_answer'
        else:
            return 'update_query'
    
    def generate_answer(self, state:NavigateState):
        print('generate_answer')
        state.answer = self.generate_answer_chain(contexts=['\n'.join(state.aggretate_info)], questions=[state.question])[0]
        return state
    
    def create_workflow(self):
        workflow = StateGraph(NavigateState)
        for attr_name, attr_value in vars(self.Nodes).items():
            if not attr_name.startswith('_'):
                workflow.add_node(attr_value, getattr(self, attr_value))
        
        workflow.set_entry_point(self.Nodes.GENERATE_QUERY)
        workflow.add_edge(self.Nodes.GENERATE_QUERY, self.Nodes.RETRIEVE_DOC)
        workflow.add_edge(self.Nodes.RETRIEVE_DOC, self.Nodes.GRADE_DOC)
        workflow.add_conditional_edges(
            self.Nodes.GRADE_DOC,
            self.check_non_empty_retrieval,
            {
                'Not empty': self.Nodes.ANALYZE_DOC,
                'Empty': self.Nodes.GENERATE_QUERY,
            },
        )
        workflow.add_conditional_edges(
            self.Nodes.ANALYZE_DOC,
            self.eval_complete_info,
            {
                "update_query": self.Nodes.GENERATE_QUERY,
                "generate_answer": self.Nodes.GENERATE_ANSWER,
            },
        )
        workflow.add_edge(self.Nodes.GENERATE_ANSWER, END)
        app = workflow.compile()
        
        return app
    


In [None]:
f = Factory()
retrievers = f.build_corpus(article)

In [None]:
agent = NavigateAgent(f.llm, retrievers)
app = agent.create_workflow()
app.invoke(NavigateState(question=question))