Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 80 additions & 53 deletions core/shared/router/intent_router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from langchain_core.messages import SystemMessage, AIMessage
import re
from typing import List, Optional

from core.shared.states.states import CustomsAgentState
from core.shared.utils.llm import get_llm
Expand All @@ -13,80 +14,106 @@
SESSION_CHECK_MESSAGE_COUNT
)

def intent_router(state: CustomsAgentState) -> CustomsAgentState:
"""사용자 쿼리의 의도를 분류합니다."""

# 현재 쿼리가 숫자 선택인지 확인
current_query = state["query"].strip()
is_number_selection = False

# 숫자 선택 패턴 확인 (1번, 2번, 3번, 1, 2, 3 등)
for pattern in NUMBER_SELECTION_PATTERNS:
if re.match(pattern, current_query):
is_number_selection = True
break

# 질문 형태 감지
is_question = any(re.search(pattern, current_query) for pattern in QUESTION_PATTERNS)

# 이전 대화에서 관세 예측 중인지 확인
messages = state.get("messages", [])
is_in_tariff_session = False


def _is_number_selection(query: str) -> bool:
"""숫자 선택 패턴인지 확인합니다."""
return any(re.match(pattern, query) for pattern in NUMBER_SELECTION_PATTERNS)


def _is_question(query: str) -> bool:
"""질문 형태인지 확인합니다."""
return any(re.search(pattern, query) for pattern in QUESTION_PATTERNS)


def _is_in_tariff_session(state: CustomsAgentState) -> bool:
"""관세 예측 세션 중인지 확인합니다."""
# 1. state의 intent 필드 확인 (가장 중요)
if state.get("intent") == "tariff_prediction":
is_in_tariff_session = True
return True

messages = state.get("messages", [])
if not messages:
return False

# 2. 최근 메시지들을 확인하여 관세 예측 세션 중인지 판단
for msg in messages[-SESSION_CHECK_MESSAGE_COUNT:]: # 최근 메시지 확인
recent_messages = messages[-SESSION_CHECK_MESSAGE_COUNT:]
for msg in recent_messages:
if hasattr(msg, 'content') and isinstance(msg.content, str):
content = msg.content.lower()
# 관세 예측 관련 키워드가 있거나 HS 코드 선택 메시지가 있으면 관세 예측 세션으로 판단
if any(keyword in content for keyword in TARIFF_SESSION_KEYWORDS):
is_in_tariff_session = True
break
return True

# 3. 마지막 메시지에서 세션 상태 확인
last_msg = messages[-1]
if hasattr(last_msg, 'content') and isinstance(last_msg.content, str):
content = last_msg.content.lower()
if 'tariff_prediction' in content or '관세 예측 세션' in content:
return True

return False


def _classify_with_llm(query: str) -> str:
"""LLM을 사용하여 의도를 분류합니다."""
try:
llm = get_llm()
result = llm.invoke([
SystemMessage(content=INTENT_CLASSIFICATION_PROMPT.format(query=query))
])

intent = str(result.content).strip()
return intent if intent in INTENT_TYPES else DEFAULT_INTENT

except Exception:
# LLM 분류 실패 시 기본값 사용
return DEFAULT_INTENT


def _add_classification_message(state: CustomsAgentState, intent: str, reason: str) -> None:
"""의도 분류 완료 메시지를 추가합니다."""
state["messages"].append(
AIMessage(content=f"의도 분류 완료: {intent} ({reason})")
)


def intent_router(state: CustomsAgentState) -> CustomsAgentState:
"""사용자 쿼리의 의도를 분류합니다."""

# 3. 이전 의도가 tariff_prediction이었는지도 확인
if messages and len(messages) > 0:
last_msg = messages[-1]
if hasattr(last_msg, 'content') and isinstance(last_msg.content, str):
if 'tariff_prediction' in last_msg.content or '관세 예측 세션' in last_msg.content:
is_in_tariff_session = True
current_query = state["query"].strip()
if not current_query:
state["intent"] = DEFAULT_INTENT
_add_classification_message(state, DEFAULT_INTENT, "빈 쿼리")
return state

# 세션 연속성 확인
is_in_tariff_session = _is_in_tariff_session(state)

# 관세 예측 세션 중이면 무조건 tariff_prediction으로 분류
if is_in_tariff_session:
state["intent"] = "tariff_prediction" # type: ignore
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (관세 예측 세션 연속성 유지)"))
state["intent"] = "tariff_prediction"
_add_classification_message(state, "tariff_prediction", "관세 예측 세션 연속성 유지")
return state

# 패턴 기반 분류
is_number_selection = _is_number_selection(current_query)
is_question = _is_question(current_query)

# 질문 형태이면서 관세 예측 세션이 아닌 경우 QnA로 분류
if is_question and not is_in_tariff_session:
state["intent"] = "qna" # type: ignore
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (질문 형태 감지)"))
if is_question:
state["intent"] = "qna"
_add_classification_message(state, "qna", "질문 형태 감지")
return state

# 숫자 선택이지만 관세 예측 세션이 아닌 경우에도 tariff_prediction으로 분류
# (HS 코드 직접 입력 등의 경우)
if is_number_selection:
state["intent"] = "tariff_prediction" # type: ignore
state["messages"].append(AIMessage(content=f"의도 분류 완료: {state['intent']} (숫자 선택 감지)"))
state["intent"] = "tariff_prediction"
_add_classification_message(state, "tariff_prediction", "숫자 선택 감지")
return state

# LLM 기반 의도 분류를 수행
llm = get_llm()

try:
result = llm.invoke([
SystemMessage(content=INTENT_CLASSIFICATION_PROMPT.format(query=state["query"]))
])

intent = str(result.content).strip()
if intent not in INTENT_TYPES:
intent = DEFAULT_INTENT # 기본값을 tariff_prediction으로 변경
except Exception as e:
# LLM 분류 실패 시 기본값 사용
intent = DEFAULT_INTENT # 오류 시에도 tariff_prediction으로 분류
intent = _classify_with_llm(current_query)
state["intent"] = intent
_add_classification_message(state, intent, "LLM 분류")

state["intent"] = intent # type: ignore
state["messages"].append(AIMessage(content=f"의도 분류 완료: {intent} (LLM 분류)"))
return state