In [149]:
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings

embedding_function = OpenAIEmbeddings(model="text-embedding-3-large")

vector_store = Chroma(
    embedding_function=embedding_function,
    collection_name="income_tax_collection",
    persist_directory="income_tax_collection"
)

retriever = vector_store.as_retriever(search_kwargs={"k": 3})


In [150]:
from typing_extensions import List, TypedDict
from langchain_core.documents import Document
from langgraph.graph import StateGraph

class AgentState(TypedDict):
    query: str
    context: List[Document]
    answer: str


graph_builder = StateGraph(AgentState)

In [151]:
# retriever 노드
def retrieve(state: AgentState):
    query = state["query"]
    docs = retriever.invoke(query)
    return {"context": docs}

In [152]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o")


In [None]:
from langchain import hub

generate_prompt = hub.pull("rlm/rag-prompt")
relevence_doc_prompt = hub.pull("langchain-ai/rag-document-relevance")

In [154]:
# answer 노드
def generate(state: AgentState):
    context = state["context"]
    query = state["query"]
    rag_chain = generate_prompt | llm 
    response = rag_chain.invoke({"context": context, "question": query})
    return {"answer": response}


In [155]:
from typing import Literal

# 문서 관련 노드
def check_relevence_doc(state: AgentState) -> Literal["generate", "rewrite"]:
    query = state["query"]
    context = state["context"]

    relevence_chain  = relevence_doc_prompt | llm 
    response = relevence_chain.invoke({"documents": context, "question": query})
    print(f"dec relevence response: {response}")

    
    if response['Score'] == 1:
        return 'generate'
    return 'rewrite'


In [156]:
query = '연봉 5000만원 이상의 직장인의 소득세는?'

In [157]:
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

dictionary = ['사람과 관련된 표현 -> 거주자']

rewrite_prompt = PromptTemplate.from_template(
"""
사용자의 질문을 보고, 우리의 사전을 참고하여 사용자의 질문을 변경해주세요.
사전 : {dictionary}
사용자의 질문 : {{query}}
"""
)

def rewrite(state: AgentState):
    query = state["query"]
    rewrite_chain = rewrite_prompt | llm | StrOutputParser()
    response = rewrite_chain.invoke({"query": query})
    return {"query": response}

In [None]:
graph_builder.add_node("retrieve", retrieve)
graph_builder.add_node("generate", generate)
graph_builder.add_node("rewrite", rewrite)

In [None]:
# 그래프 흐름
# START - Retrieve - Check_relevence_doc - generate or rewrite - END

from langgraph.graph import START, END

# 노드 연결
graph_builder.add_edge(START, "retrieve")
graph_builder.add_conditional_edges("retrieve",check_relevence_doc)
graph_builder.add_edge("rewrite", "retrieve")
graph_builder.add_edge("generate", END)

In [160]:
# 그래프 컴파일
graph = graph_builder.compile()

In [None]:
from IPython.display import Image, display

# 그래프 시각화
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
initial_state = {"query": query}

graph.invoke(initial_state)