In [35]:
!pip install langchain langchain-google-genai google-generativeai



In [74]:
import re
from langchain_google_genai import GoogleGenerativeAI
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.prompts import PromptTemplate
import os

os.environ["GOOGLE_API_KEY"] = "AIzaSyCuUkkrRfVRGgtwud8z8oJp1GF2jnzYKqw"

db = SQLDatabase.from_uri("sqlite:///chinook.db")
llm = GoogleGenerativeAI(model="gemini-2.0-flash-thinking-exp", temperature=0)

prompt = PromptTemplate(
    input_variables=["input", "table_info", "dialect"],
    template="""
    [중요 지시사항]
    - SQL 쿼리만 생성하세요
    - 절대 마크다운, 코드 블록 사용 금지
    - 추가 설명 없이 순수 SQL만 출력

    데이터베이스 종류: {dialect}
    테이블 스키마: {table_info}

    질문: {input}
    SQL Query:
    """
)

def clean_sql_response(response):
    cleaned = re.sub(r'```sql', '', response, flags=re.IGNORECASE)
    cleaned = re.sub(r'`', '', cleaned)
    return cleaned.strip()

class CleanSQLChain(SQLDatabaseChain):
    def _call(self, inputs):
        result = super()._call(inputs)
        if 'result' in result:
            result['result'] = clean_sql_response(result['result'])
        return result

db_chain = CleanSQLChain.from_llm(
    llm=llm,
    db=db,
    prompt=prompt,
    verbose=True,
    use_query_checker=True,
    return_intermediate_steps=True,
    return_direct=True
)

try:
    response = db_chain.invoke({"query": "몇명의 직원이 있어?"})

    # SQL 추출 로직
    sql_generated = None
    for step in response['intermediate_steps']:
        if isinstance(step, dict):
            if 'sql_cmd' in step:
                sql_generated = step['sql_cmd']
                break
            elif 'sql' in step:  # 다른 키 확인
                sql_generated = step['sql']
                break
            elif 'query' in step:  # 추가 가능성
                sql_generated = step['query']
                break

    if sql_generated:
        cleaned_sql = clean_sql_response(sql_generated)
        print("\n✅ 생성된 SQL:", cleaned_sql)
    else:
        print("\n❌ SQL 쿼리를 찾을 수 없음")
        print("디버깅을 위한 intermediate_steps 내용:")
        print(response['intermediate_steps'])

    print("실행 결과:", response['result'])
except Exception as e:
    print(f"에러 발생: {str(e)}")



[1m> Entering new CleanSQLChain chain...[0m

[1m> Finished chain.[0m

✅ 생성된 SQL: SELECT COUNT(*) FROM employees
실행 결과: [(8,)]
