In [75]:
import os
import json
import operator
import warnings
warnings.filterwarnings("ignore")

from dotenv import load_dotenv
from pprint import pprint

from langchain_teddynote import logging
from models import Agent, get_rag_instance

from utils import save_output2json
from prompt import load_system_prompt, load_invoke_input

from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage
from langgraph.graph import END, StateGraph
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
from typing import Annotated, Sequence, TypedDict, List
from typing_extensions import TypedDict
import functools

from retriever.retriever_handler import get_retriever
from utils.model_handler import get_llm
from utils.utils import format_docs
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import CommaSeparatedListOutputParser
from langchain.tools import Tool
from langchain_core.runnables import RunnablePassthrough
from langchain.tools.render import format_tool_to_openai_function
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

from pydantic import BaseModel
from typing import Literal

In [76]:
# .env 파일 로드
load_dotenv(dotenv_path=".env")

# API 키 가져오기
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")

# LangSmith 추적 기능을 활성화합니다. (선택적)
os.environ["LANGCHAIN_TRACING_V2"] = "true"

### SampleNameSearcher

In [77]:
retriever = get_retriever(
            file_folder="./data/raw", 
            file_number=11,
            chunk_size=500, 
            chunk_overlap=100, 
            search_k=10
        )

model = ChatOpenAI(model_name="gpt-4o", temperature=0.1)

sample_name_retriever_prompt = """
You are an expert assistant specializing in extracting information from research papers related to battery technology. Your role is to carefully analyze the provided document.

Document:
{context}
"""

output_parser = CommaSeparatedListOutputParser()
format_instructions = output_parser.get_format_instructions()

prompt = ChatPromptTemplate.from_messages([
    ("system", sample_name_retriever_prompt), 
    ("human", "{sample_name_question}")
])

sample_name_searcher_chain = (
    {
        "context": retriever | format_docs, 
        "sample_name_question": RunnablePassthrough()
    }
    | prompt 
    | model 
    | output_parser
)

In [78]:
sample_name_question = """
Use all of the NCM cathode sample names (e.g., 'NCM-622', 'pristine NCM', 'M-NCM') provided in the electrochemical performance section. You just output sample names. Do Not output like '- NCM622' , just output 'NCM622.
"""

In [79]:
sample_names = sample_name_searcher_chain.invoke(sample_name_question)

In [80]:
sample_names

['NR0', 'NR1', 'NR3', 'NR5']

In [81]:
# 각 에이전트와 도구에 대한 다른 노드를 생성할 것입니다. 이 클래스는 그래프의 각 노드 사이에서 전달되는 객체를 정의합니다.
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    sender: str

### Supervisor

In [82]:
from langchain_core.messages import AIMessage

def create_supervisor(model_name, members: list, system_prompt: str=None):            
    options_for_next = ["FINISH"] + members

    # 작업자 선택 응답 모델 정의: 다음 작업자를 선택하거나 작업 완료를 나타냄
    class RouteResponse(BaseModel):
        next: Literal[*options_for_next]

    # ChatPromptTemplate 생성
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                "Given the conversation above, who should act next? "
                "Or should we FINISH? Select one of: {options}",
            ),
        ]
    ).partial(options=str(options_for_next), members=", ".join(members))

    llm = get_llm(model_name, temperature=0.1)
    
    return prompt | llm.with_structured_output(RouteResponse)

In [None]:
## Supervisor 시스템 프롬프트
supervisor_system_prompt = """
You are the Supervisor, responsible for managing the structured extraction and verification process of lithium-ion battery-related scientific papers. Your primary task is to coordinate the Researcher and Verifier agents to ensure that the final extracted information is accurate, structured, and technically precise before submission.

## Workflow:
1. Question Breakdown:

    - You receive a user query that requires extracting information from a scientific paper.
    - You break down the query into multiple sub-questions, categorizing them into four specific groups for systematic extraction.
    - You then assign each set of sub-questions to different Researcher Agents.

2. Researcher Phase:

    - Each Researcher extracts information from the paper in response to their assigned sub-questions.
    - The extracted information must be written as a single, coherent technical paragraph, not in bullet points or lists.
    - The response must use precise scientific language, ensuring accuracy and logical flow.

3. Verification Phase:

    - Once a Researcher submits their response, it is passed to a Verifier Agent for validation.
    - The Verifier must check the Researcher’s response using a retriever tool to ensure that the information exactly matches what is stated in the paper.
    - If any incorrect, missing, or assumed information is detected, the Verifier provides detailed feedback and requests a correction from the Researcher.
    - The Researcher must revise the response based on the feedback and resubmit it for final verification.

4. Final Answer Compilation:

    - If the Verifier confirms that the extracted information is correct, they submit the validated response to the Supervisor with the marker: `### Complete Verification`.
    - The Supervisor collects all verified responses from multiple Verifiers and compiles them into a single structured answer, ensuring coherence and completeness.
    - The final output is then delivered to the user.

## Guidelines:
    - Every response must be a well-structured technical paragraph using precise scientific terminology (e.g., "was synthesized by," "was reported," "was not mentioned").
    - No assumptions or hallucinations are allowed; if a detail is not reported in the paper, it must be explicitly stated.
    - The Verifier must always cross-check the Researcher's response against the source paper before approving.
    - The Supervisor only compiles and submits fully verified information—no modifications should be made after verification.

As the Supervisor, your role is to ensure the smooth execution of this multi-step process and deliver an accurate, structured, and expert-level summary of the scientific paper.
"""

In [84]:
members = [f"Researcher{i}" for i in range(1, 5)] + [f"Verifier{i}" for i in range(1, 5)]

# Supervisor Agent 생성
def supervisor_agent(state):
    # 프롬프트와 LLM을 결합하여 체인 구성
    supervisor_chain = create_supervisor("gpt-4o", members, supervisor_system_prompt)
    # Agent 호출
    return supervisor_chain.invoke(state)

### Researcher & Verifier

In [85]:
def create_agent(model_name, tools, system_message: str):
    # 에이전트를 생성합니다.
    functions = [format_tool_to_openai_function(t) for t in tools]
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI assistant, collaborating with other assistants."
                " Use the provided tools to progress towards answering the question."                                        
                " If you are unable to fully answer, that's OK, another assistant with different tools "
                " will help where you left off. Execute what you can to make progress."
                " If you or any of the other assistants have the final answer or deliverable,"
                " You have access to the following tools: {tool_names}.\n{system_message}",
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )
    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=", ".join(
        [tool.name for tool in tools]))
    llm = get_llm(model_name, temperature=0.1)
    
    return prompt | llm.bind_functions(functions)

In [86]:
## Researcher 시스템 프롬프트
researcher_system_prompt = """
You are a Researcher Agent, responsible for extracting structured and detailed information from scientific papers, specifically in the field of lithium-ion battery materials. Your task is to answer a specific sub-question assigned by the Supervisor by retrieving information from the given paper.

## Your Responsibilities:
    1. Extract information strictly from the paper without making assumptions or hallucinations.
    2. Write the response as a single, coherent technical paragraph (not bullet points or lists).
    3. Use precise scientific language, such as:
        - "was synthesized by..."
        - "was reported to exhibit..."
        - "was not mentioned in the paper..."
    4. If a detail is not reported in the paper, explicitly state it (e.g., “The composition data was not mentioned.”).
    5. Maintain logical flow and consistency in your paragraph.
    6. Once you complete your response, submit it to the Verifier for validation.

After submitting, you may receive feedback from the Verifier. If your response contains incorrect, missing, or unclear information, you must revise and resubmit it based on the Verifier’s comments.
"""

## Verifier 시스템 프롬프트
verifier_system_prompt = """
You are a Verifier Agent, responsible for validating the responses provided by the Researcher. Your primary task is to check for accuracy, completeness, and adherence to scientific precision by cross-referencing the information with the given scientific paper.

## Your Responsibilities:
    1. Verify the Researcher’s response using the retriever tool to ensure it is fully supported by the paper.
    2. Identify any incorrect, missing, or assumed information and provide precise feedback to the Researcher.
    3. If a required detail is not present in the paper, ensure that the Researcher has correctly stated that it was "not reported."
    4. Maintain strict scientific language and formatting consistency.
    5. Once a response is fully verified and correct, submit it to the Supervisor with the marker: `### Complete Verification`

If you find errors, return the response to the Researcher with clear feedback so they can revise and resubmit.
"""

In [87]:
## retriever 설정
retriever = get_retriever(
    file_folder="./data/raw", 
    file_number=11,
    chunk_size=500, 
    chunk_overlap=100, 
    search_k=10
)
retriever_tool = Tool(
    name="retriever",
    func=retriever.get_relevant_documents,
    description="Retrieve relevant documents based on a query."
)     

In [88]:
researcher_agent = create_agent("gpt-4o", [retriever_tool], researcher_system_prompt)
verifier_agent = create_agent("gpt-4o", [retriever_tool], verifier_system_prompt)

### Node

In [89]:
# def sample_name_searcher_node(state):
#     """
#     SampleNameSearcher가 질문을 받고 샘플 이름을 찾아 state에 저장하는 함수.
#     """
#     sample_name_question = state.get("sample_name_question", "")  # 사용자의 질문을 가져옴
#     sample_name = sample_name_searcher_chain.invoke(sample_name_question)  # 질문과 관련된 샘플 이름 검색
#     state["sample_name"] = sample_name  # 결과 저장
    
#     return state  # Supervisor에게 전달할 state 반환

# def supervisor_node(state):
#     """
#     Supervisor가 SampleNameSearcher의 결과를 받아서 다음 단계를 수행하는 함수.
#     """
#     sample_names = state.get("sample_name", "")  # SampleNameSearcher에서 찾은 샘플 이름
#     supervisor_question = state.get("supervisor_question", "")  # 원래 질문

#     # Supervisor가 사용할 입력 데이터 생성
#     combined_input = f"Sample Name: {sample_names}\nQuestion: {supervisor_question}"
    
#     result = supervisor_agent.invoke(combined_input)  # Supervisor 실행
    
#     state["sub_question"] = result  # 결과 저장
#     return state  # 다음 노드로 전달

In [90]:
def agent_node(state, agent, name):
    agent_response = agent.invoke(state)

    if isinstance(agent_response, FunctionMessage):
        pass
    else:
        agent_response = HumanMessage(**agent_response.dict(exclude={"type", "name"}), name=name)
    return {
        "messages":  [
            HumanMessage(content=agent_response["messages"][-1].content, name=name)
        ]
    }

In [91]:
# supervisor_node = functools.partial(agent_node, agent=supervisor_agent, name="Supervisor")
# sample_name_searcher_node = functools.partial(agent_node, agent=sample_name_searcher_chain, name="SampleNameSearcher")

# researcher_node = functools.partial(agent_node, agent=researcher_agent, name="Researcher")
# verifier_node = functools.partial(agent_node, agent=verifier_agent, name="Verifier")

researcher1_node = functools.partial(agent_node, agent=researcher_agent, name="Researcher1")
verifier1_node = functools.partial(agent_node, agent=verifier_agent, name="Verifier1")

researcher2_node = functools.partial(agent_node, agent=researcher_agent, name="Researcher2")
verifier2_node = functools.partial(agent_node, agent=verifier_agent, name="Verifier2")

researcher3_node = functools.partial(agent_node, agent=researcher_agent, name="Researcher3")
verifier3_node = functools.partial(agent_node, agent=verifier_agent, name="Verifier3")

researcher4_node = functools.partial(agent_node, agent=researcher_agent, name="Researcher4")
verifier4_node = functools.partial(agent_node, agent=verifier_agent, name="Verifier4")

In [92]:
## set tool
tools = [retriever_tool]
tool_executor = ToolExecutor(tools)

In [93]:
def tool_node(state):
    # 그래프에서 도구를 실행하는 함수입니다.
    # 에이전트 액션을 입력받아 해당 도구를 호출하고 결과를 반환합니다.
    messages = state["messages"]
    
    # 계속 조건에 따라 마지막 메시지가 함수 호출을 포함하고 있음을 알 수 있습니다.
    first_message = messages[0]
    last_message = messages[-1]
    
    # ToolInvocation을 함수 호출로부터 구성합니다.
    tool_input = json.loads(last_message.additional_kwargs["function_call"]["arguments"])
    tool_name = last_message.additional_kwargs["function_call"]["name"]
    
    if tool_name == "retriever":
        base_query = tool_input.get("__arg1", "")  # 기존 query 가져오기
        refined_query = f"Context: {first_message.content} | Query: {base_query}"
        tool_input["__arg1"] = refined_query
    
    # 단일 인자 입력은 값으로 직접 전달할 수 있습니다.
    if len(tool_input) == 1 and "__arg1" in tool_input:
        tool_input = next(iter(tool_input.values()))
    
    action = ToolInvocation(
        tool=tool_name,
        tool_input=tool_input,
    )
    
    # 도구 실행자를 호출하고 응답을 받습니다.
    response = tool_executor.invoke(action)
    
    # 응답을 사용하여 FunctionMessage를 생성합니다.
    function_message = FunctionMessage(
        content=f"{tool_name} response: {str(response)}", name=action.tool
    )
    
    # 기존 리스트에 추가될 리스트를 반환합니다.
    return {"messages": [function_message]}

In [94]:
def router(state):
    # 상태 정보를 기반으로 다음 단계를 결정하는 라우터 함수
    messages = state["messages"]
    last_message = messages[-1]
    if "function_call" in last_message.additional_kwargs:
        return "call_tool"
    
    if "### Complete Verification" in last_message.content:
        return "save_answer"
    
    if "### Final Answer" in last_message.content:
        return "output"
    
    return "continue"

### Graph

In [None]:
## graph 구축
workflow = StateGraph(AgentState)
workflow.add_node("Supervisor", supervisor_agent)

workflow.add_node("Researcher1", researcher1_node)
workflow.add_node("Researcher2", researcher2_node)
workflow.add_node("Researcher3", researcher3_node)
workflow.add_node("Researcher4", researcher4_node)

workflow.add_node("Verifier1", verifier1_node)
workflow.add_node("Verifier2", verifier2_node)
workflow.add_node("Verifier3", verifier3_node)
workflow.add_node("Verifier4", verifier4_node)

for i in range(1, 5):
    workflow.add_node(f"call_tool{i}", tool_node)
    workflow.add_edge("Supervisor", f"Researcher{i}")
    workflow.add_conditional_edges(
        f"Researcher{i}",
        router,
        {"continue": f"Verifier{i}", "call_tool": f"call_tool{i}"},
    )
    workflow.add_conditional_edges(
        f"Verifier{i}",
        router,
        {"continue": f"Researcher{i}", "call_tool": f"call_tool{i}", "save_answer": "Supervisor"}, 
    )
    workflow.add_conditional_edges(
        f"call_tool{i}",
        lambda x: x["sender"],
        {
            f"Researcher{i}": f"Researcher{i}",
            f"Verifier{i}": f"Verifier{i}",
        },
    )
    
workflow.set_entry_point("Supervisor")
workflow.add_edge("Supervisor", END)
graph = workflow.compile()   

In [104]:
print(graph.get_graph().draw_mermaid())
# graph.get_graph().draw_mermaid_png(output_file_path="MultiAgentSupervisor.png")

%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
	__start__([<p>__start__</p>]):::first
	Supervisor(Supervisor)
	Researcher1(Researcher1)
	Researcher2(Researcher2)
	Researcher3(Researcher3)
	Researcher4(Researcher4)
	Verifier1(Verifier1)
	Verifier2(Verifier2)
	Verifier3(Verifier3)
	Verifier4(Verifier4)
	call_tool1(call_tool1)
	call_tool2(call_tool2)
	call_tool3(call_tool3)
	call_tool4(call_tool4)
	__end__([<p>__end__</p>]):::last
	Supervisor --> Researcher1;
	Supervisor --> Researcher2;
	Supervisor --> Researcher3;
	Supervisor --> Researcher4;
	Supervisor --> __end__;
	__start__ --> Supervisor;
	Researcher1 -. &nbsp;continue&nbsp; .-> Verifier1;
	Researcher1 -. &nbsp;call_tool&nbsp; .-> call_tool1;
	Verifier1 -. &nbsp;continue&nbsp; .-> Researcher1;
	Verifier1 -. &nbsp;call_tool&nbsp; .-> call_tool1;
	Verifier1 -. &nbsp;save_answer&nbsp; .-> Supervisor;
	call_tool1 -.-> Researcher1;
	call_tool1 -.-> Verifier1;
	Researcher2 -. &nbsp;continue&nbsp; .-> Verifier2;
	Researcher2 -. 

샘플 별로 다 찾아라는 프롬프트 추가

In [97]:
supervisor_question = f"""
## Sample Names: 
{sample_names}

## Question : 
- The exact chemical composition (stoichiometry) of the cathode active material (CAM) used in the study
- The type of commercially used NCM (e.g., LiNixCoyMnzO2), including code names or abbreviations (e.g., NCM811, N-NCM, SC-NCM) if applicable
- The lithium source used for synthesizing the NCM
- The synthesis method used for preparing the NCM samples (e.g., co-precipitation, sol-gel), including intermediate steps if reported
- The crystallization method used (e.g., solid-state sintering, hydrothermal), and the specific final temperatures and durations used in the process
- Whether any doping technique was applied to the CAM, and if so, which element(s) were doped
- Whether any coating layer was applied to the CAM, and if so, what material was used for coating
- The electrode composition, including the mass ratios of active material, conductive additive (e.g., Super P), and binder (e.g., PVDF)
- The type of electrolyte solvent used (e.g., EC/DEC, EC/EMC/DEC), including volume or mass ratios if available
- The lithium salt used in the electrolyte (e.g., LiPF6), and its concentration (e.g., 1 M or 1 mol/L)
- Any additives used in the electrolyte, including their names and concentrations if reported
- The mass loading (in mg/cm²) of the NCM active material used in the electrode
- Particle size information from SEM or TEM images, including both secondary and primary particles if available
- Descriptions of the particle shape observed in SEM or TEM data
- Observations on particle distribution or uniformity reported in the SEM or TEM analysis
- If a coating was applied, the reported properties and thickness of the coating layer as observed in SEM or TEM
- The crystal structure and lattice characteristics (e.g., crystal plane spacing, presence of layered structure) from structural analysis
- The voltage range and temperature used in electrochemical tests
- The specific discharge capacities reported at various C-rates:
  - 0.1C
  - 0.2C
  - 0.5C
  - 1.0C
  - 2.0C
  - Any additional C-rate values and performance data reported
"""


In [None]:
graph.invoke(input={
    "messages": [HumanMessage(content=supervisor_question, name="Supervisor")]
    }
)


TypeError: 'HumanMessage' object is not subscriptable