diff --git a/llm_utils/tools/datahub.py b/llm_utils/tools/datahub.py index 4991a79..f6b0b32 100644 --- a/llm_utils/tools/datahub.py +++ b/llm_utils/tools/datahub.py @@ -1,4 +1,5 @@ import os +import re from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any from langchain.schema import Document @@ -87,21 +88,41 @@ def _get_column_info( return column_info +def _extract_dataset_name_from_urn(urn: str) -> Optional[str]: + """URN 문자열에서 데이터셋 이름(예: delta.default.stg_gh_events)만 추출. + + 지원 패턴: + - dataset URN: urn:li:dataset:(urn:li:dataPlatform:dbt,delta.default.stg_gh_events,PROD) + - schemaField URN: urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:dbt,delta.default.stg_gh_events,PROD),event_id) + """ + match = re.search( + r"urn:li:dataset:\(urn:li:dataPlatform:[^,]+,([^,]+),[^)]+\)", urn + ) + if match: + return match.group(1) + return None + + def get_info_from_db(max_workers: int = 8) -> List[Document]: table_info = _get_table_info(max_workers=max_workers) fetcher = _get_fetcher() urns = list(fetcher.get_urns()) urn_table_mapping = {} + display_name_by_table = {} for urn in urns: - table_name = fetcher.get_table_name(urn) - if table_name: - urn_table_mapping[table_name] = urn - - def process_table_info(item: tuple[str, str]) -> str: - table_name, table_description = item + original_name = fetcher.get_table_name(urn) + if original_name: + urn_table_mapping[original_name] = urn + parsed_name = _extract_dataset_name_from_urn(urn) + if parsed_name: + display_name_by_table[original_name] = parsed_name + + def process_table_info(item: tuple[str, str, str]) -> str: + original_table_name, table_description, display_table_name = item + # 컬럼 조회는 기존 테이블 이름으로 수행 (urn_table_mapping과 일치) column_info = _get_column_info( - table_name, urn_table_mapping, max_workers=max_workers + original_table_name, urn_table_mapping, max_workers=max_workers ) column_info_str = "\n".join( [ @@ -109,10 +130,21 @@ def process_table_info(item: tuple[str, str]) -> str: for col in column_info ] ) - return f"{table_name}: {table_description}\nColumns:\n {column_info_str}" + used_name = display_table_name or original_table_name + return f"{used_name}: {table_description}\nColumns:\n {column_info_str}" + + # 표시용 이름을 세 번째 파라미터로 함께 전달 + items_with_display = [ + ( + name, + desc, + display_name_by_table.get(name, name), + ) + for name, desc in table_info.items() + ] table_info_str_list = parallel_process( - table_info.items(), + items_with_display, process_table_info, max_workers=max_workers, desc="컬럼 정보 수집 중",