In [None]:
import os

os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()

## Tools

In [None]:
# Import things that are needed generically
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, StructuredTool, tool

In [None]:
class SearchNextStep(BaseModel):
    summary_level: str = Field(description='the summary level under which the sub-tree will be explored.', default='summary_0')
    query: str = Field(description='a query for the information you expect to find in the sub-tree')

In [None]:
from typing import Any, Optional, Type


    
class SummaryTree(BaseTool):
    name = 'branch retrieval'
    description = ' '.join('''
        This tool organizes the document in a summary tree. 
        The leaf nodes are the chunks from the document and the non-leaf nodes are the summaries of their children. 
        Higher-level nodes contain more general but less reliable information. 
        In the initial call, 
        Given a query, if  and a summary level, the tool will return the relevant chunk and all its ancestors as a branch in the summary tree. provide the multi-granularity context. 
        This context is useful in connecting the current relevant node with the remaining parts in the document.
    '''.split())
    args_schema: Type[BaseModel] = SearchNextStep
    return_direct: bool = False
    
    def __init__()

In [None]:
from enum import Enum

In [None]:
from llama_index.core import TreeIndex

In [None]:
from llama_index.core.node_parser import SentenceSplitter, SemanticSplitterNodeParser

In [None]:
SemanticSplitterNodeParser()

# NavigateAgent

In [None]:
import sys
sys.path.append('../..')

from src.summary_tree import *
from tqdm.notebook import tqdm

from langsmith import Client
from langsmith.schemas import Run
from uuid import UUID
import pickle

client = Client()

In [None]:
project_map = defaultdict(list)
trace2runs: Dict[UUID, Dict[int, List[Run]]] = {}
for project in tqdm(client.list_projects(), total=38):
    traces = list(client.list_runs(project_name=project.name, is_root=True))
    trace_ids = [t.trace_id for t in traces]
    if 'tree' in project.name:
        project_map['tree'].extend(trace_ids)
    elif 'dpr' in project.name:
        project_map['dpr'].extend(trace_ids)
    for trace in traces:
        runs = [d for d in client.list_runs(run_ids=trace.child_run_ids) if 'langgraph_node' in d.extra['metadata']][::-1]
        step2runs = defaultdict(list)
        for run in runs:
            step2runs[run.extra['metadata']['langgraph_step']].append({'metadata': run.extra['metadata'], 'inputs': run.inputs, 'outputs': run.outputs})
        trace2runs[trace.trace_id] = step2runs

with open('result.pickle', 'wb') as f_out:
    pickle.dump(trace2runs, f_out)
    
with open('project_map.pickle', 'wb') as f_out:
    pickle.dump(project_map, f_out)

In [None]:
with open('result.pickle', 'rb') as f_in:
    trace2runs = pickle.load(f_in)
    
with open('project_map.pickle', 'rb') as f_in:
    project_map = pickle.load(f_in)

In [None]:
def get_steps(step2runs:Dict[int, Any], node:str):
    return [step for step, runs in step2runs.items() if runs and runs[0]['metadata']['langgraph_node'] == node]

In [None]:
trace_keys = [trace_key for trace_key, step2runs in trace2runs.items() if get_steps(step2runs, NavigateAgent.Nodes.REFORM_QUERY) and get_steps(step2runs, NavigateAgent.Nodes.GENERATE_ANSWER)]

In [None]:
len(trace_keys)

In [None]:
scores = defaultdict(list)
proposes = defaultdict(list)
for trace_key in trace_keys:
    step2runs = trace2runs[trace_key]
    answer_steps = get_steps(step2runs, NavigateAgent.Nodes.GENERATE_ANSWER)
    reform_steps = get_steps(step2runs, NavigateAgent.Nodes.REFORM_QUERY)
    propose_num = 0
    accept_num = 0
    temp_proposes = []
    for s in reform_steps:
        propose_num += len(step2runs[s+1][0]['outputs']['output']['new_document_ids'])
        if len(step2runs[s+1][0]['outputs']['output']['new_document_ids']):
            temp_proposes.append(len(step2runs[s+1][0]['outputs']['output']['new_document_ids']))
        accept_num += len(step2runs[s+2][0]['outputs']['output']['new_document_ids'])
        if len(step2runs[s+2][0]['outputs']['output']['new_document_ids']) == 0:
            break
    
    if propose_num > 0:
        if trace_key in project_map['dpr']:
            scores['dpr'].append(accept_num * 1. / propose_num)
            proposes['dpr'].extend(temp_proposes)
        if trace_key in project_map['tree']:
            scores['tree'].append(accept_num * 1. / propose_num)
            proposes['tree'].extend(temp_proposes)

In [None]:
np.mean(scores['dpr'])

In [None]:
np.mean(proposes['dpr'])

In [None]:
len(scores['dpr'])

In [None]:
np.mean(scores['tree'])

In [None]:
np.mean(proposes['tree'])

In [None]:
len(scores['tree'])

In [None]:
propose_num

In [None]:
accept_num

In [None]:
grades = [run['outputs']['output']['score'] for grade_step in grade_steps for run in step2runs[grade_step] if 'output' in run['outputs'] and 'score' in run['outputs']['output']]

In [None]:
grades

In [None]:
grade_steps

In [None]:
[run['outputs'] for run in step2runs[grade_steps[1]]]

In [None]:
[run['outputs'] for run in step2runs[retrieve_steps[1]]]

In [None]:
get_steps(step2runs, NavigateAgent.Nodes.REFORM_QUERY)

In [None]:
step2runs[4][0]

In [None]:
get_steps(step2runs, NavigateAgent.Nodes.GENERATE_ANSWER)

In [None]:
[run.inputs for run in step2runs[5]]

In [None]:
step2runs[6][0].extra['metadata']['langgraph_node']

In [None]:
step = 4
print(step2runs[step][0].extra['metadata']['langgraph_node'])
print(step2runs[step][0].inputs['input'])
print(step2runs[step][0].outputs['output'])

In [None]:
c[3].extra

In [None]:
c[1].outputs

In [None]:
import os
os.environ["OPENAI_API_KEY"] = "EMPTY"
f = Factory()

In [None]:
dataset = QualityDataset(None, split='dev')

In [None]:
test_id = 19
article = dataset.get_article(dataset.data[test_id])
questions, answers = dataset.get_questions_and_answers(dataset.data[test_id])

In [None]:
dpr_retriever, tree_retriever, documents = f.build_corpus(article, dpr_file=os.path.join(dataset.data_dir, f'dpr_{test_id}.json'), tree_file=os.path.join(dataset.data_dir, f'tree_{test_id}.json'))


In [None]:
tree_retriever.retrieve_children(tree_retriever.docs[14])