In [9]:
import uuid
import pandas as pd
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langgraph.graph import StateGraph, END
from langchain_huggingface import HuggingFacePipeline
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import get_buffer_string
from langchain_core.runnables import RunnableConfig
from langchain_core.chat_history import BaseChatMessageHistory
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoModelForQuestionAnswering, 
    AutoModelForSeq2SeqLM, 
    pipeline)
from typing import Dict, Any


df = pd.read_csv("./data/data.csv", encoding="utf8")
texts = df['text'].tolist()
docs = [Document(page_content=text) for text in texts]

embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

vectorstore = Chroma.from_documents(
    documents=docs,
    embedding=embedding_model,
    persist_directory="./MultiAgentChroma_db",
)


chats_by_ssetion_id = {}

def get_chat_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in chats_by_ssetion_id:
        chats_by_ssetion_id[session_id] = InMemoryChatMessageHistory()
    return chats_by_ssetion_id[session_id]

embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
retriever = Chroma (
    persist_directory="./MultiAgentChroma_db",
    embedding_function=embedding_model
).as_retriever(search_kwargs={"k":3})

qa_model_id = "monologg/koelectra-base-v3-finetuned-korquad"
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_id)
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_id)
qa_pipeline = pipeline(
    "question-answering",
    model=qa_model,
    tokenizer=qa_tokenizer,
    device=-1
)
qa_llim = HuggingFacePipeline(pipeline=qa_pipeline)

sum_model_id = "lcw99/t5-base-korean-text-summary"
sum_tokenizer = AutoTokenizer.from_pretrained(sum_model_id)
sum_model = AutoModelForSeq2SeqLM.from_pretrained(sum_model_id)
sum_pipeline = pipeline(
    "summarization",
    model=sum_model,
    tokenizer=sum_tokenizer,
    device=-1,
)
sum_llim = HuggingFacePipeline(pipeline=sum_pipeline)

gen_model_id = "skt/kogpt2-base-v2"
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_id)
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_id)
gen_tokenizer.model_max_length = 1024
gen_pipeline = pipeline(
    "text-generation",
    model=gen_model,
    tokenizer=gen_tokenizer,
    device=-1,
    max_new_tokens=300,
    do_sample=True)

gen_llim = HuggingFacePipeline(pipeline=gen_pipeline)

def rag_node(state: Dict[str, Any]):
    mcp = state["mcp"]
    payload = mcp["payload"]
    question = payload["question"]
    sesseion_id = payload["metadata"]["session_id"]

    chat_history = get_chat_history(sesseion_id)
    history_text = "\n".join([m.content for m in chat_history.messages])
    docs = retriever.invoke(question)
    top_docs = docs[:3]
    context = "\n".join([doc.page_content for doc in top_docs])
    full_context = f"{history_text}\n{context}" if history_text else context
    result = qa_pipeline(question=question, context=full_context)


    chat_history.add_user_message(question)
    chat_history.add_ai_message(result['answer'])

    return {
        "mcp": {
            "source": "rag_agent",
            "destination": mcp["source"],
            "intent":"answer",
            "payload": {
                "answer": result['answer'],
                "references": [doc.page_content for doc in top_docs],
                "metadata": {
                    "session_id": sesseion_id
                }
            }
        }
    }

def build_rag_agent():
    graph = StateGraph(dict)
    graph.set_entry_point("rag_node")
    graph.add_node("rag_node", rag_node)
    graph.set_finish_point("rag_node")
    return graph.compile()

rag_app = build_rag_agent()

def summarize_node(state: Dict[str, Any]):
    mcp = state["mcp"]
    payload = mcp["payload"]
    text = payload["question"]
    sesseion_id = payload["metadata"]["session_id"]
    result = sum_llim.invoke(text)
    
    return {
        "mcp": {
            "source": "summarize_agent",
            "destination": mcp["source"],
            "intent":"answer",
            "payload": {
                "answer": result,
                "metadata": {
                    "session_id": sesseion_id
                }
            }
        }
    }

def build_summarize_agent():
    graph = StateGraph(dict)
    graph.set_entry_point("summarize_node")
    graph.add_node("summarize_node", summarize_node)
    graph.set_finish_point("summarize_node")
    return graph.compile()

summarize_app = build_summarize_agent()

def rephrase_node(state: Dict[str, Any]):
    mcp = state["mcp"]
    payload = mcp["payload"]
    question = payload["question"]
    sesseion_id = payload["metadata"]["session_id"]
    result = gen_llim.invoke(question)
    
    return {
        "mcp": {
            "source": "rephrase_agent",
            "destination": mcp["source"],
            "intent":"answer",
            "payload": {
                "answer": result,
                "metadata": {
                    "session_id": sesseion_id
                }
            }
        }
    }

def build_rephrase_agent():
    graph = StateGraph(dict)
    graph.set_entry_point("rephrase_node")
    graph.add_node("rephrase_node", rephrase_node)
    graph.set_finish_point("rephrase_node")
    return graph.compile()
rephrase_app = build_rephrase_agent()

def classify_intent(question:str) -> str:
    if "요약" in question:
        return "summarize"
    elif "정중" in question or "공손" in question or "예의" in question:
        return "rephrase"
    else:
        return "get_answer"
    
def supervisor_node(state: Dict[str, Any]):
    question = input("질문을 입력해라:").strip()

    sesseion_id = state.get("session_id", str(uuid.uuid4()))
    intent = classify_intent(question)

    dest_map = {
        "get_answer": "rag_agent",
        "summarize": "summarize_agent",
        "rephrase": "rephrase_agent"
    }
    mcp = {
        "source": "supervisor",
        "destination": dest_map[intent],
        "intent":"request",
        "payload": {
            "question": question,
            "metadata": {
                "session_id": sesseion_id
            }
        }
    }
    return {"mcp": mcp, "session_id": sesseion_id}

def get_answer(state: Dict[str, Any]):
    mcp = state["mcp"]
    print("\n [Agent answering...]")
    print("Answer:", mcp["payload"]["answer"])

    if "references" in mcp["payload"]:
        print("\n [References:]")
        for ref in mcp["payload"]["references"]:
            print("-", ref)
                  
    return state

def route_mcp(state):
    return state["mcp"]["destination"]
def ask_continue(state: Dict[str, Any]):
    user_input = input("\n계속 질문하시겠습니까? (y/n): ").strip()
    state["continue"] = user_input.lower().startswith('y')
    return state
def should_continue(state: Dict[str, Any]):
    return "supervisor" if state.get("continue") else END

def build_supervisor_graph():
    graph = StateGraph(dict)
    graph.set_entry_point("supervisor_node")
    graph.add_node("supervisor_node", supervisor_node)
    
    graph.add_node("rag_agent", lambda state:rag_app.invoke(state))
    graph.add_node("summarize_agent", lambda state:summarize_app.invoke(state))
    graph.add_node("rephrase_agent", lambda state:rephrase_app.invoke(state))
    
    graph.add_node("get_answer", get_answer)
    graph.add_node("ask_continue", ask_continue)

    graph.add_conditional_edges("supervisor_node", route_mcp, {
        "rag_agent": "rag_agent",
        "summarize_agent": "summarize_agent",
        "rephrase_agent": "rephrase_agent"
    })
    graph.add_edge("rag_agent", "get_answer")
    graph.add_edge("summarize_agent", "get_answer")
    graph.add_edge("rephrase_agent", "get_answer")
    graph.add_edge("get_answer", "ask_continue")

    graph.add_conditional_edges("ask_continue", should_continue, {
        "supervisor": "supervisor_node",
        END: END
    })
    
    return graph.compile()

supervisor_app = build_supervisor_graph()

supervisor_app.invoke({})



Device set to use cpu
Device set to use cpu
Device set to use cpu



 [Agent answering...]
Answer: 6개월간

 [References:]
- MES 시스템에 기록된 공정 로그는 6개월간 보관됩니다.
- MES 시스템에 기록된 공정 로그는 6개월간 보관됩니다.
- MES 시스템에 기록된 공정 로그는 6개월간 보관됩니다.


{'mcp': {'source': 'rag_agent',
  'destination': 'supervisor',
  'intent': 'answer',
  'payload': {'answer': '6개월간',
   'references': ['MES 시스템에 기록된 공정 로그는 6개월간 보관됩니다.',
    'MES 시스템에 기록된 공정 로그는 6개월간 보관됩니다.',
    'MES 시스템에 기록된 공정 로그는 6개월간 보관됩니다.'],
   'metadata': {'session_id': '6396df29-9b7f-4a01-a74b-00de82e44a4c'}}},
 'continue': False}