Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions interface/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,48 @@ def render_sequence(sequence: List[str]) -> str:
# 프리셋에서는 QUERY_MAKER 자동 포함
use_query_maker = True

# GET_TABLE_INFO 설정
st.subheader("GET_TABLE_INFO 설정")
_prev_cfg = st.session_state.get("graph_config", {})

_retriever_options = {
"기본": "벡터 검색 (기본)",
"Reranker": "Reranker 검색 (정확도 향상)",
}
_retriever_keys = list(_retriever_options.keys())
_retriever_default = _prev_cfg.get("retriever_name", "기본")
_retriever_index = (
_retriever_keys.index(_retriever_default)
if _retriever_default in _retriever_keys
else 0
)

retriever_name = st.selectbox(
"테이블 검색기",
options=_retriever_keys,
format_func=lambda x: _retriever_options[x],
index=_retriever_index,
)

top_n = st.slider(
"검색할 테이블 정보 개수",
min_value=1,
max_value=20,
value=int(_prev_cfg.get("top_n", 5)),
step=1,
)

_device_options = ["cpu", "cuda"]
_device_default = _prev_cfg.get("device", "cpu")
_device_index = (
_device_options.index(_device_default) if _device_default in _device_options else 0
)
device = st.selectbox(
"모델 실행 장치",
options=_device_options,
index=_device_index,
)


def build_sequence_with_qm(
preset: str, use_profile: bool, use_context: bool, use_qm: bool
Expand Down Expand Up @@ -166,6 +208,9 @@ def build_sequence_with_qm(
"use_profile": use_profile,
"use_context": use_context,
"use_query_maker": use_query_maker,
"retriever_name": retriever_name,
"top_n": top_n,
"device": device,
}

# 선택이 바뀌면 자동으로 세션 그래프 갱신
Expand All @@ -174,13 +219,20 @@ def build_sequence_with_qm(
_builder = build_state_graph(sequence)
st.session_state["graph"] = _builder.compile()
st.session_state["graph_config"] = config
# Lang2SQL 메인 UI에서 기본값으로 사용할 옵션 전달
st.session_state["default_retriever_name"] = retriever_name
st.session_state["default_top_n"] = top_n
st.session_state["default_device"] = device
st.info("그래프가 세션에 적용되었습니다.")

# 수동 새로고침 버튼
if st.button("세션 그래프 새로고침"):
_builder = build_state_graph(sequence)
st.session_state["graph"] = _builder.compile()
st.session_state["graph_config"] = config
st.session_state["default_retriever_name"] = retriever_name
st.session_state["default_top_n"] = top_n
st.session_state["default_device"] = device
st.success("세션 그래프가 새로고침되었습니다.")

with st.expander("현재 세션 그래프 설정"):
Expand Down
22 changes: 17 additions & 5 deletions interface/lang2sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,29 +302,41 @@ def should_show(_key: str) -> bool:
index=0,
)

_device_options = ["cpu", "cuda"]
_default_device = st.session_state.get("default_device", "cpu")
_device_index = (
_device_options.index(_default_device) if _default_device in _device_options else 0
)
device = st.selectbox(
"모델 실행 장치를 선택하세요:",
options=["cpu", "cuda"],
index=0,
options=_device_options,
index=_device_index,
)

retriever_options = {
"기본": "벡터 검색 (기본)",
"Reranker": "Reranker 검색 (정확도 향상)",
}

_retriever_keys = list(retriever_options.keys())
_default_retriever = st.session_state.get("default_retriever_name", "기본")
_retriever_index = (
_retriever_keys.index(_default_retriever)
if _default_retriever in _retriever_keys
else 0
)
user_retriever = st.selectbox(
"검색기 유형을 선택하세요:",
options=list(retriever_options.keys()),
options=_retriever_keys,
format_func=lambda x: retriever_options[x],
index=0,
index=_retriever_index,
)

user_top_n = st.slider(
"검색할 테이블 정보 개수:",
min_value=1,
max_value=20,
value=5,
value=int(st.session_state.get("default_top_n", 5)),
step=1,
help="검색할 테이블 정보의 개수를 설정합니다. 값이 클수록 더 많은 테이블 정보를 검색하지만 처리 시간이 길어질 수 있습니다.",
)
Expand Down
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
- PATCH는 1로 증가합니다.
"""

__version__ = "0.2.1"
__version__ = "0.2.2"