Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions llm_utils/tools/datahub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from typing import List, Dict, Optional, TypeVar, Callable, Iterable, Any

from langchain.schema import Document
Expand Down Expand Up @@ -87,32 +88,63 @@ def _get_column_info(
return column_info


def _extract_dataset_name_from_urn(urn: str) -> Optional[str]:
"""URN 문자열에서 데이터셋 이름(예: delta.default.stg_gh_events)만 추출.

지원 패턴:
- dataset URN: urn:li:dataset:(urn:li:dataPlatform:dbt,delta.default.stg_gh_events,PROD)
- schemaField URN: urn:li:schemaField:(urn:li:dataset:(urn:li:dataPlatform:dbt,delta.default.stg_gh_events,PROD),event_id)
"""
match = re.search(
r"urn:li:dataset:\(urn:li:dataPlatform:[^,]+,([^,]+),[^)]+\)", urn
)
if match:
return match.group(1)
return None


def get_info_from_db(max_workers: int = 8) -> List[Document]:
table_info = _get_table_info(max_workers=max_workers)

fetcher = _get_fetcher()
urns = list(fetcher.get_urns())
urn_table_mapping = {}
display_name_by_table = {}
for urn in urns:
table_name = fetcher.get_table_name(urn)
if table_name:
urn_table_mapping[table_name] = urn

def process_table_info(item: tuple[str, str]) -> str:
table_name, table_description = item
original_name = fetcher.get_table_name(urn)
if original_name:
urn_table_mapping[original_name] = urn
parsed_name = _extract_dataset_name_from_urn(urn)
if parsed_name:
display_name_by_table[original_name] = parsed_name

def process_table_info(item: tuple[str, str, str]) -> str:
original_table_name, table_description, display_table_name = item
# 컬럼 조회는 기존 테이블 이름으로 수행 (urn_table_mapping과 일치)
column_info = _get_column_info(
table_name, urn_table_mapping, max_workers=max_workers
original_table_name, urn_table_mapping, max_workers=max_workers
)
column_info_str = "\n".join(
[
f"{col['column_name']}: {col['column_description']}"
for col in column_info
]
)
return f"{table_name}: {table_description}\nColumns:\n {column_info_str}"
used_name = display_table_name or original_table_name
return f"{used_name}: {table_description}\nColumns:\n {column_info_str}"

# 표시용 이름을 세 번째 파라미터로 함께 전달
items_with_display = [
(
name,
desc,
display_name_by_table.get(name, name),
)
for name, desc in table_info.items()
]

table_info_str_list = parallel_process(
table_info.items(),
items_with_display,
process_table_info,
max_workers=max_workers,
desc="컬럼 정보 수집 중",
Expand Down