diff --git a/interface/lang2sql.py b/interface/lang2sql.py index 973831a..199f3a7 100644 --- a/interface/lang2sql.py +++ b/interface/lang2sql.py @@ -7,10 +7,11 @@ import streamlit as st from langchain.chains.sql_database.prompt import SQL_PROMPTS -from langchain_core.messages import HumanMessage +from langchain_core.messages import AIMessage, HumanMessage from llm_utils.connect_db import ConnectDB from llm_utils.graph import builder +from llm_utils.llm_response_parser import LLMResponseParser DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리" SIDEBAR_OPTIONS = { @@ -51,18 +52,27 @@ def execute_query( device: str = "cpu", ) -> dict: """ - Lang2SQL 그래프를 실행하여 자연어 쿼리를 SQL 쿼리로 변환하고 결과를 반환합니다. + 자연어 쿼리를 SQL로 변환하고 실행 결과를 반환하는 Lang2SQL 그래프 인터페이스 함수입니다. + + 이 함수는 Lang2SQL 파이프라인(graph)을 세션 상태에서 가져오거나 새로 컴파일한 뒤, + 사용자의 자연어 질문을 SQL 쿼리로 변환하고 관련 메타데이터와 함께 결과를 반환합니다. + 내부적으로 LangChain의 `graph.invoke` 메서드를 호출합니다. Args: - query (str): 자연어로 작성된 사용자 쿼리. - database_env (str): 사용할 데이터베이스 환경 설정 이름. - retriever_name (str): 사용할 검색기 이름. - top_n (int): 검색할 테이블 정보의 개수. + query (str): 사용자가 입력한 자연어 기반 질문. + database_env (str): 사용할 데이터베이스 환경 이름 또는 키 (예: "dev", "prod"). + retriever_name (str, optional): 테이블 검색기 이름. 기본값은 "기본". + top_n (int, optional): 검색된 상위 테이블 수 제한. 기본값은 5. + device (str, optional): LLM 실행에 사용할 디바이스 ("cpu" 또는 "cuda"). 기본값은 "cpu". Returns: - dict: 변환된 SQL 쿼리 및 관련 메타데이터를 포함하는 결과 딕셔너리. + dict: 다음 정보를 포함한 Lang2SQL 실행 결과 딕셔너리: + - "generated_query": 생성된 SQL 쿼리 (`AIMessage`) + - "messages": 전체 LLM 응답 메시지 목록 + - "refined_input": AI가 재구성한 입력 질문 + - "searched_tables": 참조된 테이블 목록 등 추가 정보 """ - # 세션 상태에서 그래프 가져오기 + graph = st.session_state.get("graph") if graph is None: graph = builder.compile() @@ -102,22 +112,71 @@ def display_result( - 참조된 테이블 목록 - 쿼리 실행 결과 테이블 """ - total_tokens = summarize_total_tokens(res["messages"]) - - if st.session_state.get("show_total_token_usage", True): - st.write("총 토큰 사용량:", total_tokens) - if st.session_state.get("show_sql", True): - st.write("결과:", "\n\n```sql\n" + res["generated_query"].content + "\n```") - if st.session_state.get("show_result_description", True): - st.write("결과 설명:\n\n", res["messages"][-1].content) - if st.session_state.get("show_question_reinterpreted_by_ai", True): - st.write("AI가 재해석한 사용자 질문:\n", res["refined_input"].content) - if st.session_state.get("show_referenced_tables", True): - st.write("참고한 테이블 목록:", res["searched_tables"]) - if st.session_state.get("show_table", True): - sql = res["generated_query"] - df = database.run_sql(sql) - st.dataframe(df.head(10) if len(df) > 10 else df) + + def should_show(_key: str) -> bool: + st.markdown("---") + return st.session_state.get(_key, True) + + if should_show("show_total_token_usage"): + total_tokens = summarize_total_tokens(res["messages"]) + st.write("**총 토큰 사용량:**", total_tokens) + + if should_show("show_sql"): + generated_query = res.get("generated_query") + query_text = ( + generated_query.content + if isinstance(generated_query, AIMessage) + else str(generated_query) + ) + + try: + sql = LLMResponseParser.extract_sql(query_text) + st.markdown("**생성된 SQL 쿼리:**") + st.code(sql, language="sql") + except ValueError: + st.warning("SQL 블록을 추출할 수 없습니다.") + st.text(query_text) + + interpretation = LLMResponseParser.extract_interpretation(query_text) + if interpretation: + st.markdown("**결과 해석:**") + st.code(interpretation) + + if should_show("show_result_description"): + st.markdown("**결과 설명:**") + result_message = res["messages"][-1].content + + try: + sql = LLMResponseParser.extract_sql(result_message) + st.code(sql, language="sql") + except ValueError: + st.warning("SQL 블록을 추출할 수 없습니다.") + st.text(result_message) + + interpretation = LLMResponseParser.extract_interpretation(result_message) + if interpretation: + st.code(interpretation, language="plaintext") + + if should_show("show_question_reinterpreted_by_ai"): + st.markdown("**AI가 재해석한 사용자 질문:**") + st.code(res["refined_input"].content) + + if should_show("show_referenced_tables"): + st.markdown("**참고한 테이블 목록:**") + st.write(res.get("searched_tables", [])) + + if should_show("show_table"): + try: + sql_raw = ( + res["generated_query"].content + if isinstance(res["generated_query"], AIMessage) + else str(res["generated_query"]) + ) + sql = LLMResponseParser.extract_sql(sql_raw) + df = database.run_sql(sql) + st.dataframe(df.head(10) if len(df) > 10 else df) + except Exception as e: + st.error(f"쿼리 실행 중 오류 발생: {e}") db = ConnectDB() diff --git a/llm_utils/llm_response_parser.py b/llm_utils/llm_response_parser.py new file mode 100644 index 0000000..df94924 --- /dev/null +++ b/llm_utils/llm_response_parser.py @@ -0,0 +1,57 @@ +""" +LLM 응답 텍스트에서 특정 마크업 태그(``, `<해석>`)에 포함된 콘텐츠 블록을 추출하는 유틸리티 모듈입니다. + +이 모듈은 OpenAI, LangChain 등에서 생성된 LLM 응답 문자열에서 Markdown 코드 블록을 파싱하여, +SQL 쿼리 및 자연어 해석 설명을 분리하여 사용할 수 있도록 정적 메서드 형태의 API를 제공합니다. + +지원되는 태그: + - : SQL 코드 블록 (```sql ... ```) + - <해석>: 자연어 해석 블록 (```plaintext ... ```) +""" + +import re + + +class LLMResponseParser: + """ + LLM 응답 문자열에서 특정 태그(, <해석>)에 포함된 블록을 추출하는 유틸리티 클래스입니다. + + 주요 기능: + - 태그 내 ```sql ... ``` 블록에서 SQL 쿼리 추출 + - <해석> 태그 내 ```plaintext ... ``` 블록에서 자연어 해석 추출 + """ + + @staticmethod + def extract_sql(text: str) -> str: + """ + 태그 내부의 SQL 코드 블록만 추출합니다. + + Args: + text (str): 전체 LLM 응답 문자열. + + Returns: + str: SQL 쿼리 문자열 (```sql ... ``` 내부 텍스트). + + Raises: + ValueError: 태그 또는 SQL 코드 블록을 찾을 수 없는 경우. + """ + match = re.search(r"\s*```sql\n(.*?)```", text, re.DOTALL) + if match: + return match.group(1).strip() + raise ValueError("SQL 블록을 추출할 수 없습니다.") + + @staticmethod + def extract_interpretation(text: str) -> str: + """ + <해석> 태그 내부의 해석 설명 텍스트만 추출합니다. + + Args: + text (str): 전체 LLM 응답 문자열. + + Returns: + str: 해석 설명 텍스트. 블록이 존재하지 않으면 빈 문자열을 반환합니다. + """ + match = re.search(r"<해석>\s*```plaintext\n(.*?)```", text, re.DOTALL) + if match: + return match.group(1).strip() + return ""