In [1]:
import os
import warnings
warnings.filterwarnings('ignore')

os.environ['OPENAI_API_KEY'] = "Your ChatGPT API KEY"
os.environ['TAVILY_API_KEY'] = "Your TAVILY API KEY"

In [2]:
from langchain_openai import ChatOpenAI
from typing import TypedDict, Annotated
from pydantic import BaseModel, Field

### Profiling LLM

In [3]:
from langchain_core.messages import SystemMessage

class Profile(BaseModel):
    character_name: str
    universe: str
    requirements: str
    user_name: str

profile_system_prompt = '''Your role is to become a character who engages in conversation with the user.
To do this, you should collect the following information from the user:

- What the character's name is
- What universe(세계관, 영화, 게임 등) does the character belong to
- What the user's requirements are
- Whtr the user's name is

If you cannot determine this information, ask the user directly to clarify — preferably using a bullet-point or structured format. Do not make assumptions.
Once you have all the necessary information, confirm it with the user one more time, and then call the relevant tool.'''



llm  = ChatOpenAI(model="gpt-4o-mini")
profiling_llm = llm.bind_tools([Profile])

### Web Search Tool

In [4]:
from langchain_community.tools.tavily_search import TavilySearchResults

web_search_tool = TavilySearchResults(max_results=3)


### Profile Search Query Generation Chain

In [5]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

In [6]:
profile_search_system_prompt = """Your role is to investigate character information based on the details provided by the user.
Given a character's name and universe, generate the most effective and natural web search query to gather information about the character's background, personality, and dialogues.

You MUST output the search query only."""

profile_web_search_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", profile_search_system_prompt),
        ("user", "Character profile:\n{profile}")
    ]
)

profile_web_search_chain = profile_web_search_prompt | llm | StrOutputParser()

### Define Vector DB and Retriever

In [7]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document

In [8]:


embd = OpenAIEmbeddings()

vectorstore = Chroma(
    collection_name="rag-chroma",
    embedding_function=embd
)

doc_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

  vectorstore = Chroma(


### Define Retrieval Evaluation Chain

In [9]:

class EvalDocuments(BaseModel):
    """
    Document evaluation class.
    The decision attribute can have a value of 'yes' or 'no' indicating the relevance of the document to the message.
    """
    decision: str = Field(description="Documents are relevant to message. This attribute can have a value of 'yes' or 'no'")
    

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
retrieval_llm_evaluator = llm.with_structured_output(EvalDocuments)

retrieval_eval_system = """You are an evaluator responsible for assessing whether a retrieved document is relevant to the user's message.
The user is having a conversation with a character. If the document contains related keywords or is semantically connected to the user's message, evaluate it as relevant.
Your goal is to filter out incorrectly retrieved documents.
If the user's message is related to the document, output 'yes'; otherwise, output 'no'."""

retrieval_eval_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", retrieval_eval_system),
        ("user", "Character Profile and User Name : {profile}\n\nRetrieved Document: {document}\n\nMessage: {message}")
    ]
)

retrieval_evaluator = retrieval_eval_prompt | retrieval_llm_evaluator

### Define Web Search Query Generation Chain

In [10]:
web_query_gen_system_prompt = """You are role-playing with user and acting as a given character. To respond to the user's messages, you need to perform web searches.
Your role is to generate an appropriate web search query to obtain the knowledge necessary to answer the user's messages.
Output the web search query — do not output anything else."""

web_query_gen_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", web_query_gen_system_prompt),
        ("user", "Character Profile and User Name : {profile}\n\n  User's Message: {message}")
    ]
)

llm = ChatOpenAI(model="gpt-4o-mini")
web_search_query_chain = web_query_gen_prompt | llm | StrOutputParser()

### Define Response Generation Chain

In [11]:
from langchain import hub

# rag_prompt = hub.pull("rlm/rag-prompt")

llm = ChatOpenAI(model="gpt-4o-mini")

rag_system_prompt = """You are a character engaged in a roleplay with the user.
Respond to the user's messages based on the pieces of retrieved context provided.
If the context is not needed, you may answer without using it.
If you're asked something you don't know, simply say you don't know.
Pay close attention to the context of the conversation provided by the user, and respond in a way that stays true to the profile of the character you are roleplaying.
Always follow up your response with an appropriate question to keep the conversation going."""

rag_user_prompt = """Character Profile and User Name : {profile}

Context: {context}

Conversations: {messages}

Answer:"""

rag_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", rag_system_prompt),
        ("user", rag_user_prompt)
    ]
)

rag_chatbot_chain = rag_prompt | llm | StrOutputParser()

### Define Responsee Evaluation Chain

In [12]:
from langchain_core.messages import AIMessage, HumanMessage

class EvalResponse(BaseModel):
    """
    Response evaluation class.
    The decision attribute can have a value of 'yes' or 'no' indicating the relevance of the response to the conversation.
    """
    decision: str = Field(description="Response is relevant to message. This attribute can have a value of 'yes' or 'no'")
    
    
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
response_llm_evaluator = llm.with_structured_output(EvalResponse)

response_eval_system = """You are a character engaged in a roleplay with the user.
Your role is to evaluate whether a response is appropriate to the conversation based on the character's profile and context.
Determine whether your response is appropriate to the conversation.
If your response properly addresses the issue or question in the conversation, output 'yes'; otherwise, output 'no'."""



response_eval_user_prompt = """Character Profile and User Name : {profile}
Context: {context}

Conversations: {messages}

Your Response: {response}

Decision:"""

response_eval_prompt = ChatPromptTemplate(
    [
        ("system", response_eval_system),
        ("user", response_eval_user_prompt)
    ]
)

response_eval_chain = response_eval_prompt | response_llm_evaluator

### Define Query Rewriter Chain

In [13]:
llm = ChatOpenAI(model="gpt-4o-mini")

# Prompt
system = """You are a question rewriter that refines input queries to enhance their effectiveness for vector store retrieval or web searching by capturing their underlying semantic intent."""

re_write_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("user","Here is the original question:\n{question}\n\nPlease rewrite it to improve clarity and optimize it for retrieval."),
    ]
)

question_rewriter = re_write_prompt | llm

### Define Graph State

In [14]:
from typing import Annotated, List
from langgraph.graph.message import add_messages
from langgraph.graph import START, END

class ConversationState(TypedDict):
    messages: Annotated[list, add_messages]
    documents: List
    web_query: Annotated[list, add_messages]
    web_search_flag: bool
    generation: str
    retry_flag: bool
    retries:int
    profile: dict

### Define Node and Routing Functions

In [15]:
def get_profile_messages(messages):
    return [SystemMessage(content=profile_system_prompt)] + messages


def profiling(state):
    print("--- Profiling Node ---")
    messages = get_profile_messages(state["messages"])
    response = profiling_llm.invoke(messages)
    return {"messages": [response]}


def route_message(state):
    print("--- Routing Node ---")
    profile = state.get("profile")
    
    if profile: # Profile 수집 완료 상태, 대화 시작
        print("\t-- Route Message To Retriever --")
        return "retriever"
    else: # 정보 수집 노드로 라우팅
        print("\t-- Route To Profiling --")
        return "profiling"


def check_profiling(state):
    print("--- Check Profiling Node ---")
    messages = state["messages"]
    
    if isinstance(messages[-1], AIMessage) and messages[-1].tool_calls: # Tool Call 발생했을 경우 Profile 수집 완료, "profile_web_search"로 라우팅
        print("\t-- Route Message To Profile Web Search --")
        return "profile_web_search"
    else:
        print("\t-- Profile Information Insufficient --")
        return "insufficient"
        
        
    
def profile_web_search(state):
    print("--- Profile Web Search Node ---")
    messages = state["messages"]
    profile = messages[-1].tool_calls[0]["args"]
    
    # 쿼리 재작성 카운터 초기화
    retries = 0
    
    # 프로필 웹 검색 시작
    
    profile_search_query = profile_web_search_chain.invoke({"profile": profile})
    saerch_results = web_search_tool.invoke(profile_search_query)

    # 수집 문서 Chunking
    docs = [Document(page_content=result["content"]) for result in saerch_results]
    splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=550, chunk_overlap=50
    )
    doc_splits = splitter.split_documents(docs)
    
    # 벡터 DB에 저장
    try:
        vectorstore.add_documents(doc_splits)
    except:
        pass
    
    # tool_calls 메시지 삭제
    messages.pop()
    messages.append(AIMessage(content="정보 수집 완료. 대화를 시작해주세요!"))
    
    return {"messages": messages, "profile": profile, "web_query": profile_search_query, "retries": retries}


def retriever(state):
    print("--- Retriever Node ---")
    user_message = state["messages"][-1]

    retrieved_docs = doc_retriever.invoke(user_message.content)
    state["documents"] = retrieved_docs

    return state
    

def evaluate_documents(state):
    print("--- Evaluate Document Node ---")
    user_message = state["messages"][-1]
    documents = state["documents"]
    profile = state["profile"]
    
    filtered_docs = []
    web_search_flag = False
    
    for doc in documents:
        evaluation = retrieval_evaluator.invoke(
            {"profile":profile, "document":doc.page_content, "message":user_message}
        )
        if evaluation.decision == "yes":
            print("\t-- Document Is Relevant To Message--")
            filtered_docs.append(doc)
        else:
            print("\t-- Document Is Not Relevant To Message--")
            web_search_flag = True
    
    state["documents"] = filtered_docs
    state["web_search_flag"] = web_search_flag
    
    return state


def decide_generation(state):
    print("--- Decide Generation ---")
    web_search_flag = state["web_search_flag"]
    
    if web_search_flag:
        print("\t-- Route To Web Search --")
        return "web_search"
    else:
        print("\t-- Route To Generate ---")
        return "generate"


def web_search(state):
    print("--- Web Search Node ---")
    messages = state["messages"]
    user_message = messages[-1]
    documents = state["documents"]
    profile = state["profile"]
    web_query = state.get("web_query", [])
    
    web_search_query = web_search_query_chain.invoke({"profile": profile, "message": user_message})
    search_results = web_search_tool.invoke(web_search_query)
    
    web_query.append(web_search_query)
    state["web_query"] = web_query
    
    docs = [Document(page_content=result["content"]) for result in search_results]
    splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=550, chunk_overlap=50
    )
    doc_splits = splitter.split_documents(docs)
    
    try:
        vectorstore.add_documents(doc_splits)
        documents.extend(doc_splits)
        state["documents"] = documents
    except:
        pass
    
    return state

def generate(state):
    print("--- Generate ---")
    messages = state["messages"]
    profile = state["profile"]
    context = state["documents"]
    
    generation = rag_chatbot_chain.invoke({"profile": profile, "context": context, "messages": messages})
    state["generation"] = generation
    
    return state

def evaluate_generation(state):
    print("--- Evaluate Generation ---")
    messages = state["messages"]
    generation = state["generation"]
    documents = state["documents"]
    profile = state["profile"]
    retries = state["retries"]
    
    retry_flag = False
    if retries >= 3: # max retry 도달한 경우 바로 답변(무한루프 방지)
        print("\t-- Reached The Maximum Number Of Retries. --")
        messages.append(HumanMessage(content=generation))
        state["messages"] = messages
        state["retry_flag"] = retry_flag
        state["retries"] = 0
        return state
    else: # max retry 도달하지 않은 경우 Response 평가
        evaluation = response_eval_chain.invoke({"profile": profile, "context": documents, "messages": messages, "response": generation})
        if evaluation.decision == "yes":
            print("\t-- Response Addresses The Conversation --")
            messages.append(HumanMessage(content=generation))
            state["messages"] = messages
            state["retry_flag"] = retry_flag
            state["retries"] = 0
            return state
        else:
            print("\t-- Response Does Not Address The Conversation --")
            retry_flag = True
            state["retry_flag"] = retry_flag
            state["retries"] += 1
            del state["generation"]
            return state
        
        
def decide_reponse(state):
    print("--- Decide To Response ---")
    retry_flag = state['retry_flag']
    
    if not retry_flag:
        print("\t-- Response --")
        return "response"
    else:
        print("\t-- Rewrite query --")
        return "rewrite_query"
    
def rewrite_query(state):
    print("--- Rewrite User Question ---")
    messages = state["messages"]
    user_message = messages[-1]
    
    new_user_message = question_rewriter.invoke({"question": user_message})
    messages[-1] = new_user_message
    state["messages"] = messages
    
    return state

### Define Graph

In [16]:
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver

PersonaGraph = StateGraph(ConversationState)

PersonaGraph.add_node(profiling)
PersonaGraph.add_node(profile_web_search)
PersonaGraph.add_node(retriever)
PersonaGraph.add_node(evaluate_documents)
PersonaGraph.add_node(web_search)
PersonaGraph.add_node(generate)
PersonaGraph.add_node(evaluate_generation)
PersonaGraph.add_node(rewrite_query)

PersonaGraph.add_conditional_edges(
    START,
    route_message,
    {
        "retriever": "retriever",
        "profiling": "profiling",
    }
)



PersonaGraph.add_conditional_edges(
    "profiling",
    check_profiling,
    {
        "profile_web_search": "profile_web_search",
        "insufficient": END
    }
)

PersonaGraph.add_edge("profile_web_search", END)

PersonaGraph.add_edge("retriever", "evaluate_documents")

PersonaGraph.add_conditional_edges(
    "evaluate_documents",
    decide_generation,
    {
        "generate": "generate",
        "web_search": "web_search",
    }
)


PersonaGraph.add_edge("web_search", "generate")
PersonaGraph.add_edge("generate", "evaluate_generation")

PersonaGraph.add_conditional_edges(
    "evaluate_generation",
    decide_reponse,
    {
        "response": END,
        "rewrite_query": "rewrite_query"
    }
)

PersonaGraph.add_edge("rewrite_query", "retriever")

memory = MemorySaver()
app = PersonaGraph.compile(checkpointer=memory)

### Graph Visualization

In [None]:
from IPython.display import Image, display
display(Image(app.get_graph(xray=True).draw_mermaid_png()))

### Test

In [18]:
config = {"configurable": {"thread_id": "12"}}
while True:
    user_input = input("User: ")
    print("User:", user_input)
    if user_input.lower() in ["quit", "exit", "q"]:
        print("Goodbye!")
        break
    final_response = None
    for event in app.stream({"messages": [HumanMessage(content=user_input)]}, config):
        for value in event.values():
            final_response = value["messages"][-1].content
            
    print("Assistant:", final_response)

User: 안녕
--- Routing Node ---
	-- Route To Profiling --
--- Profiling Node ---
--- Check Profiling Node ---
	-- Profile Information Insufficient --
Assistant: 안녕하세요! 어떤 캐릭터와 이야기를 하고 싶은지 말씀해 주시면 도와드릴게요. 다음 정보를 제공해 주실 수 있나요?

- 캐릭터의 이름은 무엇인가요?
- 그 캐릭터는 어떤 세계관(영화, 게임 등)에 속하나요?
- 어떤 요구사항이 있나요?
- 당신의 이름은 무엇인가요? 

이 정보를 주시면 더욱 원활하게 도와드릴 수 있습니다!
User: 지우, 포켓몬스터, 요구사항은 딱히 없어. 내이름은 여행자야.
--- Routing Node ---
	-- Route To Profiling --
--- Profiling Node ---
--- Check Profiling Node ---
	-- Profile Information Insufficient --
Assistant: 감사합니다, 여행자님! 이제 정보를 정리해볼게요:

- 캐릭터 이름: 지우
- 세계관: 포켓몬스터
- 요구사항: 딱히 없음
- 당신의 이름: 여행자

이게 맞나요? 확인해 주시면, 다음 단계로 진행할게요!
User: 시작하자!
--- Routing Node ---
	-- Route To Profiling --
--- Profiling Node ---
--- Check Profiling Node ---
	-- Route Message To Profile Web Search --
--- Profile Web Search Node ---
Assistant: 정보 수집 완료. 대화를 시작해주세요!
User: 지우야 너의 최애 포켓몬은 누구야?
--- Routing Node ---
	-- Route Message To Retriever --
--- Retriever Node ---
--- Evaluate Document Node ---