In [None]:
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI

In [None]:
file_path = './pdfs/human-agent-collab-problem-solving.pdf'

docs = PyPDFLoader(file_path).load_and_split()

len(docs)

In [None]:
docs[0:5]

In [None]:
vectordb = Chroma.from_documents(
    documents=docs,
    collection_name="rag-chroma",
    embedding=OpenAIEmbeddings()
)

retriever = vectordb.as_retriever()

In [None]:
from langchain.tools.retriever import create_retriever_tool

retriever_tool = create_retriever_tool(
    retriever,
    'retrieve_info_from_paper',
    'Search and return information about a paper.')

tools = [retriever_tool]

In [None]:
from typing import Annotated, Sequence, TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages

In [None]:
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]



In [None]:
def agent(state: AgentState):
    messages = state['messages']
    model = ChatOpenAI(temperature=0, streaming=True, model='gpt-4o')
    model = model.bind_tools(tools)
    response = model.invoke(messages)
    
    return {'messages': [response]}


In [None]:
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode, tools_condition

# Graph initiliazation
graph = StateGraph(AgentState)

# Nodes
graph.add_node('agent', agent)
retrieve_node = ToolNode([retriever_tool])
graph.add_node('retrieve', retrieve_node)

# Edges
graph.add_edge(START, 'agent')
graph.add_conditional_edges('agent', tools_condition, {'tools': 'retrieve',END: END})
graph.add_edge('retrieve', 'agent')

compiled_graph = graph.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(compiled_graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [None]:
import pprint

inputs = {
    "messages": [
        ("user", "In this paper how do the authors set up the collaboration between the human and the LLMs?"),
    ]
}
for output in compiled_graph.stream(inputs):
    for key, value in output.items():
        pprint.pprint(f"Output from node '{key}':")
        pprint.pprint("---")
        pprint.pprint(value, indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

In [None]:
inputs = {
    "messages": [
        ("user", "What is the name of the framework in this paper that sets up the collaboration between the human and the LLMs?"),
    ]
}

output = compiled_graph.invoke(inputs)

output['messages'][-1].content