diff --git a/interface/core/result_renderer.py b/interface/core/result_renderer.py index e085735..51886b8 100644 --- a/interface/core/result_renderer.py +++ b/interface/core/result_renderer.py @@ -13,7 +13,7 @@ from infra.observability.token_usage import TokenUtils from utils.databases import DatabaseFactory from utils.llm.llm_response_parser import LLMResponseParser -from viz.display_chart import DisplayChart +from utils.visualization.display_chart import DisplayChart def display_result(res: dict) -> None: diff --git a/pyproject.toml b/pyproject.toml index 573955c..3147fd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,6 @@ packages = [ "interface", "engine", "infra", - "viz", "prompt", "utils", ] diff --git a/utils/llm/README.md b/utils/llm/README.md index 2113d38..f52b406 100644 --- a/utils/llm/README.md +++ b/utils/llm/README.md @@ -5,7 +5,7 @@ Lang2SQL 파이프라인에서 LLM, 검색(RAG), 그래프 워크플로우, DB ### Depth 0: 최상위 유틸리티 - (Moved) `engine/query_executor.py`: Lang2SQL 그래프 선택/컴파일/실행 진입점. -- (Moved) `viz/display_chart.py`: LLM 활용 Plotly 시각화 유틸. +- (Moved) `utils/visualization/display_chart.py`: LLM 활용 Plotly 시각화 유틸. - (Moved) `infra/monitoring/check_server.py`: GMS 헬스체크. - (Moved) `infra/db/connect_db.py`: ClickHouse 연결/실행. - (Moved) `infra/observability/token_usage.py`: LLM 메시지의 `usage_metadata` 합산 토큰 집계. @@ -89,7 +89,7 @@ sql = extract_sql_from_result(res) ``` ```python -from viz.display_chart import DisplayChart +from utils.visualization.display_chart import DisplayChart chart = DisplayChart(question="지난달 매출 추이", sql=sql, df_metadata=str(df.dtypes)) code = chart.generate_plotly_code() @@ -103,5 +103,3 @@ fig = chart.get_plotly_figure(code, df) - `display_chart.py` → OpenAI LLM(선택적)로 코드 생성 → Plotly 실행 - `connect_db.py` → ClickHouse 클라이언트로 SQL 실행 - `llm_response_parser.py` → 결과 파서 - - diff --git a/viz/display_chart.py b/utils/visualization/display_chart.py similarity index 63% rename from viz/display_chart.py rename to utils/visualization/display_chart.py index 47fd459..0a977d6 100644 --- a/viz/display_chart.py +++ b/utils/visualization/display_chart.py @@ -1,12 +1,20 @@ -import re -from langchain_openai import ChatOpenAI -from langchain_core.messages import HumanMessage, SystemMessage -import pandas as pd +""" +SQL 쿼리 결과를 Plotly로 시각화하는 모듈 + +이 모듈은 Lang2SQL 실행 결과를 다양한 형태의 차트로 시각화하는 기능을 제공합니다. +LLM을 활용하여 적절한 Plotly 코드를 생성하고 실행합니다. +""" + import os +import re +from typing import Optional +import pandas as pd import plotly import plotly.express as px import plotly.graph_objects as go +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI class DisplayChart: @@ -17,12 +25,29 @@ class DisplayChart: plotly코드를 출력하여 excute한 결과를 fig 객체로 반환합니다. """ - def __init__(self, question, sql, df_metadata): + def __init__(self, question: str, sql: str, df_metadata: str): + """ + DisplayChart 인스턴스를 초기화합니다. + + Args: + question (str): 사용자 질문 + sql (str): 실행된 SQL 쿼리 + df_metadata (str): 데이터프레임 메타데이터 + """ self.question = question self.sql = sql self.df_metadata = df_metadata - def llm_model_for_chart(self, message_log): + def llm_model_for_chart(self, message_log) -> Optional[str]: + """ + LLM 모델을 사용하여 차트 생성 코드를 생성합니다. + + Args: + message_log: LLM에 전달할 메시지 목록 + + Returns: + Optional[str]: 생성된 차트 코드 또는 None + """ provider = os.getenv("LLM_PROVIDER") if provider == "openai": llm = ChatOpenAI( @@ -31,18 +56,29 @@ def llm_model_for_chart(self, message_log): ) result = llm.invoke(message_log) return result + return None def _extract_python_code(self, markdown_string: str) -> str: + """ + 마크다운 문자열에서 Python 코드 블록을 추출합니다. + + Args: + markdown_string: 마크다운 형식의 문자열 + + Returns: + str: 추출된 Python 코드 + """ # Strip whitespace to avoid indentation errors in LLM-generated code - markdown_string = markdown_string.content.split("```")[1][6:].strip() + if hasattr(markdown_string, "content"): + markdown_string = markdown_string.content.split("```")[1][6:].strip() + else: + markdown_string = str(markdown_string) # Regex pattern to match Python code blocks - pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" # 여러 문자와 공백 뒤에 python이 나오고, 줄바꿈 이후의 모든 내용 + pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" # Find all matches in the markdown string - matches = re.findall( - pattern, markdown_string, re.IGNORECASE - ) # 대소문자 구분 안함 + matches = re.findall(pattern, markdown_string, re.IGNORECASE) # Extract the Python code from the matches python_code = [] @@ -55,13 +91,27 @@ def _extract_python_code(self, markdown_string: str) -> str: return python_code[0] - def _sanitize_plotly_code(self, raw_plotly_code): + def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: + """ + Plotly 코드에서 불필요한 부분을 제거합니다. + + Args: + raw_plotly_code: 원본 Plotly 코드 + + Returns: + str: 정리된 Plotly 코드 + """ # Remove the fig.show() statement from the plotly code plotly_code = raw_plotly_code.replace("fig.show()", "") - return plotly_code def generate_plotly_code(self) -> str: + """ + LLM을 사용하여 Plotly 코드를 생성합니다. + + Returns: + str: 생성된 Plotly 코드 + """ if self.question is not None: system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{self.question}'" else: @@ -82,20 +132,33 @@ def generate_plotly_code(self) -> str: ] plotly_code = self.llm_model_for_chart(message_log) + if plotly_code is None: + return "" return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) def get_plotly_figure( self, plotly_code: str, df: pd.DataFrame, dark_mode: bool = True - ) -> plotly.graph_objs.Figure: - + ) -> Optional[plotly.graph_objs.Figure]: + """ + Plotly 코드를 실행하여 Figure 객체를 생성합니다. + + Args: + plotly_code: 실행할 Plotly 코드 + df: 데이터프레임 + dark_mode: 다크 모드 사용 여부 + + Returns: + Optional[plotly.graph_objs.Figure]: 생성된 Figure 객체 또는 None + """ ldict = {"df": df, "px": px, "go": go} + fig = None + try: - exec(plotly_code, globals(), ldict) + exec(plotly_code, globals(), ldict) # noqa: S102 fig = ldict.get("fig", None) - except Exception as e: - + except Exception: # Inspect data types numeric_cols = df.select_dtypes(include=["number"]).columns.tolist() categorical_cols = df.select_dtypes( diff --git a/viz/__init__.py b/viz/__init__.py deleted file mode 100644 index e3a21c9..0000000 --- a/viz/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""시각화 계층 패키지"""