In [1]:
import os
from dotenv import load_dotenv
from pprint import pprint

from langchain_teddynote import logging
from models import sample_name_searcher, get_rag_instance

from utils import load_config, save_output2json
from prompt import load_system_prompt, load_invoke_input

In [None]:
# .env 파일 로드
load_dotenv(dotenv_path=".env")

# API 키 가져오기
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")

# LangSmith 추적 기능을 활성화합니다. (선택적)
os.environ["LANGCHAIN_TRACING_V2"] = "true"

In [None]:
import json
import warnings
warnings.filterwarnings("ignore")

from langchain_core.messages import (
    BaseMessage,
    FunctionMessage,
    HumanMessage,
)
from langchain.tools.render import format_tool_to_openai_function
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.graph import END, StateGraph
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation
from langchain_core.output_parsers import JsonOutputParser

from langchain.tools import Tool

import operator
from typing import Annotated, Sequence, TypedDict
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict

import functools
from retriever import get_retriever


# 각 에이전트와 도구에 대한 다른 노드를 생성할 것입니다. 이 클래스는 그래프의 각 노드 사이에서 전달되는 객체를 정의합니다.
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    sender: str

class NewMultiAgentRAG:
    def __init__(
        self, 
        file_folder:str="./data/raw", 
        file_number:int=1, 
        chunk_size: int=500, 
        chunk_overlap: int=100, 
        search_k: int=10,       
        system_prompt:str = None, 
        model_name:str="gpt-4o",
        save_graph_png:bool=False,
    ):
        ## 파일 명 설정
        if file_number < 10:
            file_name = f"paper_00{file_number}"
        elif file_number < 100:
            file_name = f"paper_0{file_number}"
        else:
            file_name = f"paper_{file_number}"

        ## retriever 설정
        self.retriever = get_retriever(
            file_folder=file_folder, 
            file_name=file_name, 
            chunk_size=chunk_size, 
            chunk_overlap=chunk_overlap, 
            search_k=search_k
        )
        self.retriever_tool = Tool(
            name="retriever",
            func=self.retriever.get_relevant_documents,
            description="Retrieve relevant documents based on a query."
        )

        ## Coordinator Agent
        self.coordinator_system_prompt = system_prompt["coordinator_system_prompt"]

        ## researcher 시스템 프롬프트
        self.researcher_system_prompt = system_prompt["researcher_system_prompt"]

        ## verifier 시스템 프롬프트
        self.verifier_system_prompt = """You are a meticulous verifier agent specializing in the domain of battery technology.
Your primary task is to verify the accuracy of the Researcher's answers by using the search tool to cross-check the extracted information from research papers on batteries, formatted into JSON.  

Your responsibilities include validating the following:  

### Accuracy:  
Extracted values through documents retrieved via the search tool must be verified to ensure they match accurately.

### Completeness:  
Confirm that all fields in the JSON structure are either filled with accurate values from the battery-related sections of the PDF or marked as "None" if not mentioned in the document.  

If any field is missing or only partially extracted, explicitly state:  
- **Which fields are incomplete or missing**  
- **Whether the missing information exists in the PDF but was not extracted, or is genuinely absent**  
- **Suggestions for improvement (e.g., re-extraction, manual verification, or alternative sources if applicable)**  

### Consistency:  
Verify that the JSON structure, format, and data types adhere strictly to the required schema for battery-related research data.  

### Corrections:  
Identify and highlight any errors, including:  
- **Inaccurate values** (i.e., extracted values that do not match the PDF)  
- **Missing data** (i.e., fields left empty when information is available)  
- **Formatting inconsistencies** (i.e., data types or schema mismatches)  

For any issues found, provide a **clear and actionable correction**, including:  
- **The specific field in question**  
- **The nature of the issue (incorrect value, missing data, formatting error, etc.)**  
- **Suggestions or corrections to resolve the issue**  

### Handling Missing Data:  
If certain information is genuinely **not found** in the PDF, specify:  
- **Which fields could not be located**  
- **Confirmation that they are absent from the document**  
- **A recommendation to keep the field as `"None"` or any alternative solutions**  

### Final Output:  
If the JSON is entirely correct, confirm its validity and output the JSON structure exactly as provided.  
Include the phrase `### Final Output` before printing the JSON. This ensures the output is clearly marked and easy to locate.  

### Scope:  
Focus **exclusively** on battery-related content extracted from the PDF.  
Ignore any reference content or information outside the provided document.  
"""
        
        ## agent 및 node 생성
        self.model_name = model_name
        llm = ChatOpenAI(model=self.model_name, temperature=0.1)

        # Research agent and node
        self.research_agent = self.create_agent(
            llm,
            [self.retriever_tool],
            system_message=self.researcher_system_prompt,
        )
        self.research_node = functools.partial(self.agent_node, agent=self.research_agent, name="Researcher")

        # Data_Verifier
        self.verifier_agent = self.create_agent(
            llm,
            [self.retriever_tool],
            system_message=self.verifier_system_prompt,
        )
        self.verifier_node = functools.partial(self.agent_node, agent=self.verifier_agent, name="Data_Verifier")

        # Json_Processor
        self.json_processor_system_prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """You are a JSON Processor Agent. Your sole responsibility is to process the response generated by an LLM and ensure the accurate extraction of the JSON content within the response. Follow these instructions precisely:

### Instructions:
1. **Extract JSON Only**:
- Identify the ```json``` block within the provided response.
- Extract and output the content within the ```json``` block exactly as it appears.

2. **No Modifications**:
- Do not modify, add, or remove any part of the JSON content.
- Preserve the relevancerag structure, field names, and values without alteration.

3. **No Hallucination**:
- Do not interpret, infer, or generate additional content.

4. **Output Format**:
- Respond with the extracted JSON content only.
- Do not include any explanations, comments, or surrounding text.
- The output must be a clean, valid JSON.

### Your Role:
Ensure the integrity and consistency of the JSON data by strictly adhering to these instructions. Your output should always be concise and compliant with the above rules."""
                ),
                MessagesPlaceholder(variable_name="messages"),
            ]
        )

        self.json_processor_agent = self.json_processor_system_prompt | ChatOpenAI(model=self.model_name, temperature=0.1) | JsonOutputParser()
        self.json_processor_node = functools.partial(self.json_processor_agent_node, agent=self.json_processor_agent, name="Json_Processor")

        self.tools = [self.retriever_tool]
        self.tool_executor = ToolExecutor(self.tools)
        

        ## graph 구축
        workflow = StateGraph(AgentState)

        workflow.add_node("Researcher", self.research_node)
        workflow.add_node("Data_Verifier", self.verifier_node)
        workflow.add_node("call_tool", self.tool_node)
        workflow.add_node("Json_Processor", self.json_processor_node)

        workflow.add_edge("Json_Processor", END)
        workflow.add_conditional_edges(
            "Researcher",
            self.router,
            {"continue": "Data_Verifier", "call_tool": "call_tool"},
        )
        workflow.add_conditional_edges(
            "Data_Verifier",
            self.router,
            {"continue": "Researcher", "call_tool": "call_tool", "process_output": "Json_Processor"},
        )
        workflow.add_conditional_edges(
            "call_tool",
            lambda x: x["sender"],
            {
                "Researcher": "Researcher",
                "Data_Verifier": "Data_Verifier",
            },
        )

        workflow.set_entry_point("Researcher")
        self.graph = workflow.compile()   
        
        if save_graph_png:        
            self.graph.get_graph().draw_mermaid_png(output_file_path="./graph_img/multiagentrag_graph.png")


    def create_agent(self, llm, tools, system_message: str):
        # 에이전트를 생성합니다.
        functions = [format_tool_to_openai_function(t) for t in tools]

        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    "You are a helpful AI assistant, collaborating with other assistants."
                    " Use the provided tools to progress towards answering the question."                                        
                    " If you are unable to fully answer, that's OK, another assistant with different tools "
                    " will help where you left off. Execute what you can to make progress."
                    " If you or any of the other assistants have the final answer or deliverable,"
                    " You have access to the following tools: {tool_names}.\n{system_message}",
                ),
                MessagesPlaceholder(variable_name="messages"),
            ]
        )
        prompt = prompt.partial(system_message=system_message)
        prompt = prompt.partial(tool_names=", ".join(
            [tool.name for tool in tools]))
        return prompt | llm.bind_functions(functions)
    
    
    def agent_node(self, state, agent, name):
        result = agent.invoke(state)
        if isinstance(result, FunctionMessage):
            pass
        else:
            result = HumanMessage(**result.dict(exclude={"type", "name"}), name=name)
        return {
            "messages": [result],
            "sender": name,
        }


    def json_processor_agent_node(self, state, agent, name):
        result = agent.invoke(
            {
                "messages": [
                    HumanMessage(content=f"""Convert Final Output in the given response into a JSON format.: {state["messages"][-1].content}""")
                ]
            }
        )
        return {"messages": result, "name": name}


    def tool_node(self, state):
        # 그래프에서 도구를 실행하는 함수입니다.
        # 에이전트 액션을 입력받아 해당 도구를 호출하고 결과를 반환합니다.
        messages = state["messages"]
        
        # 계속 조건에 따라 마지막 메시지가 함수 호출을 포함하고 있음을 알 수 있습니다.
        first_message = messages[0]
        last_message = messages[-1]
        
        # ToolInvocation을 함수 호출로부터 구성합니다.
        tool_input = json.loads(last_message.additional_kwargs["function_call"]["arguments"])
        tool_name = last_message.additional_kwargs["function_call"]["name"]
        
        if tool_name == "retriever":
            base_query = tool_input.get("__arg1", "")  # 기존 query 가져오기
            refined_query = f"Context: {first_message.content} | Query: {base_query}"
            tool_input["__arg1"] = refined_query
        
        # 단일 인자 입력은 값으로 직접 전달할 수 있습니다.
        if len(tool_input) == 1 and "__arg1" in tool_input:
            tool_input = next(iter(tool_input.values()))
        
        action = ToolInvocation(
            tool=tool_name,
            tool_input=tool_input,
        )
        
        # 도구 실행자를 호출하고 응답을 받습니다.
        response = self.tool_executor.invoke(action)
        
        # 응답을 사용하여 FunctionMessage를 생성합니다.
        function_message = FunctionMessage(
            content=f"{tool_name} response: {str(response)}", name=action.tool
        )
        
        # 기존 리스트에 추가될 리스트를 반환합니다.
        return {"messages": [function_message]}
    
    
    def router(self, state):
        # 상태 정보를 기반으로 다음 단계를 결정하는 라우터 함수
        messages = state["messages"]
        last_message = messages[-1]
        if "function_call" in last_message.additional_kwargs:
            # 이전 에이전트가 도구를 호출함
            return "call_tool"
        if "Final Output" in last_message.content:
            # 어느 에이전트든 작업이 끝났다고 결정함
            return "process_output"
        return "continue"

In [None]:
def main(
    data_folder:str="./data",
    file_num_list:list=[11],
    category_number:int=1, 
    chunk_size:int=500, 
    chunk_overlap:int=100, 
    search_k:int=10,       
    config_folder:str="./config",
    rag_method:str="multiagent-rag", 
    model_name:str="gpt-4o", 
    save_graph_png:bool=False, 
):
    category_names = ["CAM (Cathode Active Material)", "Electrode (half-cell)", "Morphological Properties", "Cathode Performance"]
    
    ## system_prompt 와 invoke_input 불러오기
    system_prompt = load_system_prompt(config_folder=config_folder, category_number=category_number, rag_method=rag_method)
    invoke_input = load_invoke_input(config_folder=config_folder, category_number=category_number, rag_method=rag_method)
    
    total_answer = []
    
    ## 각 논문에 대해 반복
    for file_number in file_num_list:
        print(f"#####    {file_number}번째 논문    #####")
        print(f"##       rag method     : {rag_method}")
        print(f"##       category name  : {category_names[category_number-1]}")
        
        ## graph 호출
        voltai_graph = get_rag_instance(
            rag_method=rag_method, 
            file_folder=f"{data_folder}/input_data/", 
            file_number=file_number, 
            chunk_size=chunk_size, 
            chunk_overlap=chunk_overlap, 
            search_k=search_k, 
            system_prompt=system_prompt,
            model_name=model_name, 
            save_graph_png=save_graph_png,
        ).graph
        
        ## 질문이 딕셔너리 형태일 경우와 아닌 경우를 처리
        if isinstance(invoke_input, dict):
            result = voltai_graph.invoke(**invoke_input)
        else:
            result = voltai_graph.invoke(*invoke_input)

        ## RAG method에 따른 결과 확인
        if result.get("answer"):
            temp_answer = result["answer"][0][category_names[category_number-1]]
        elif result.get("discussion"):
            temp_answer = result["discussion"][category_names[category_number-1]]
        elif result.get("messages"):
            temp_answer = result["messages"][-1][category_names[category_number-1]]
        
        print(f"##       print {file_number} result")
        print("------------------------------------")
        pprint(temp_answer, sort_dicts=False)
        
        ## json 저장하는 코드
        save_output2json(each_answer=temp_answer,file_num=file_number, rag_method=rag_method, category_number=category_number)
        
        total_answer.append(temp_answer)
        
    return total_answer

In [None]:
file_num_list = [39]

### Multiagent RAG

In [None]:
multiagent_rag_c1_answer = main(file_num_list=file_num_list, category_number=1, rag_method="multiagent-rag")

##       ./config/multiagent-rag/c1-system-prompt.yaml를 불러왔습니다.
##       ./config/multiagent-rag/c1-question.yaml를 불러왔습니다.
#####    11번째 논문    #####
##       rag method     : multiagent-rag
##       category name  : CAM (Cathode Active Material)
##       paper_011 retriever를 생성했습니다.
##          - chunk_size    :500
##          - chunk_overlap :100
##          - retrieve_k    :10
##       print 11 result
------------------------------------
{'Stoichiometry information': {'LiNi1/3Co1/3Mn1/3O2': {'Li ratio': 1.0,
                                                       'Ni ratio': 0.33,
                                                       'Co ratio': 0.33,
                                                       'Mn ratio': 0.33,
                                                       'O ratio': 2.0}},
 'Commercial NCM used': {'LiNi1/3Co1/3Mn1/3O2': 'no'},
 'Lithium source': 'LiNO3',
 'Synthesis method': 'solution combustion',
 'Crystallization method': 'calcination',
 'Crystallization final

In [None]:
# multiagent_rag_c2_answer = main(file_num_list=file_num_list, category_number=2, rag_method="multiagent-rag")

##       ./config/multiagent-rag/c2-system-prompt.yaml를 불러왔습니다.
##       ./config/multiagent-rag/c2-question.yaml를 불러왔습니다.
#####    11번째 논문    #####
##       rag method     : multiagent-rag
##       category name  : Electrode (half-cell)
##       paper_011 retriever를 생성했습니다.
##          - chunk_size    :500
##          - chunk_overlap :100
##          - retrieve_k    :10
##       print 11 result
------------------------------------
{'Active material to Conductive additive to Binder ratio': '87:3:10',
 'Electrolyte': [{'Salt': 'LiPF6',
                  'Concentration': '1M',
                  'Solvent': 'EC:DMC',
                  'Solvent ratio': '1:1'}],
 'Additive': 'RGO, 5%',
 'Loading density (mass loading of NCM)': 'None',
 'Additional treatment for electrode': 'None'}


In [5]:
multiagent_rag_c3_answer = main(file_num_list=file_num_list, category_number=3, rag_method="multiagent-rag")

## ./config/multiagent-rag/c3-system-prompt.yaml를 불러왔습니다.
## ./config/multiagent-rag/c3-question.yaml를 불러왔습니다.
#####    11번째 논문    #####
##       rag method     : multiagent-rag
##       category name  : Morphological Properties
##       paper_011 retriever를 생성했습니다.
##          - chunk_size    :500
##          - chunk_overlap :100
##          - retrieve_k    :10
##       print 11 result
------------------------------------
{'ParticleSize': {'NCM': '200-300 nm'},
 'ParticleShape': {'NCM': 'faceted morphology, regular polyhedrons with smooth '
                          'surfaces'},
 'ParticleDistribution': {'NCM': 'uniform distribution with no significant '
                                 'agglomeration'},
 'CoatingLayerCharacteristics': {'NCMeRGO': 'RGO sheets wrapped around NCM '
                                            'nanoparticles forming a '
                                            'core-shell-like structure'},
 'CrystalStructureAndLatticeCharacteristics': {'NCM': 'hexagona

In [7]:
multiagent_rag_c4_answer = main(file_num_list=file_num_list, category_number=4, rag_method="multiagent-rag")

## ./config/multiagent-rag/c4-system-prompt.yaml를 불러왔습니다.
## ./config/multiagent-rag/c4-question.yaml를 불러왔습니다.
#####    11번째 논문    #####
##       rag method     : multiagent-rag
##       category name  : Cathode Performance
##       paper_011 retriever를 생성했습니다.
##          - chunk_size    :500
##          - chunk_overlap :100
##          - retrieve_k    :10
##       print 11 result
------------------------------------
{'NR0': [{'Voltage range': '2.5 - 4.3',
          'Temperature': 25,
          'C-rate and Specific capacity': [{'C-rate': 0.1, 'Capacity': 155.3},
                                           {'C-rate': 0.2, 'Capacity': 'None'},
                                           {'C-rate': 0.5, 'Capacity': 'None'},
                                           {'C-rate': 1.0, 'Capacity': 123},
                                           {'C-rate': 2.0, 'Capacity': 'None'},
                                           {'C-rate': 4.0, 'Capacity': 'None'},
                                 