In [1]:
!pip install -qU langchain langgraph langchain-openai langchain-community langsmith python-dotenv

In [1]:
# =========================================================
# My SQL Agent Project
# - Chinook SQLite DB
# - LangChain + LangGraph
# =========================================================

import os
from pathlib import Path
import uuid
import requests

# (선택) .env 사용
try:
    from dotenv import load_dotenv
    load_dotenv()
except ImportError:
    pass

os.environ["OPENAI_API_KEY"] = "Key"

from langchain_openai import ChatOpenAI

MODEL_NAME = "gpt-4o"
# MODEL_NAME = "mistral-large-latest"
print("Using model:", MODEL_NAME)

Using model: gpt-4o


In [2]:
# =========================================================
# 1) SQLite DB (Chinook) 다운로드 & 로드
# =========================================================
db_path = Path("Chinook.db")

if not db_path.exists():
    url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
    r = requests.get(url)
    if r.status_code == 200:
        with open(db_path, "wb") as f:
            f.write(r.content)
        print("Chinook.db downloaded.")
    else:
        raise RuntimeError(f"Failed to download Chinook.db: {r.status_code}")
else:
    print("Chinook.db already exists.")

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print("DB dialect:", db.dialect)
print("Tables:", db.get_usable_table_names())

Chinook.db already exists.
DB dialect: sqlite
Tables: ['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [3]:
# =========================================================
# 2) ToolNode + Fallback 유틸
# =========================================================
from typing import Any, Annotated, Literal
from typing_extensions import TypedDict

from langchain_core.messages import ToolMessage, AIMessage, HumanMessage, AnyMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks

from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver

def handle_tool_error(state) -> dict:
    #Tool 실행 중 에러 발생 시, LLM에게 에러 내용 그대로 전달
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Here is the error: {repr(error)}\n\nPlease fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    #ToolNode + 예외 처리 fallback 래핑
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )



In [4]:
# =========================================================
# 3) SQL Toolkit & Custom Tool 정의
# =========================================================
from langchain_community.agent_toolkits import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model=MODEL_NAME))
tools = toolkit.get_tools()

list_tables_tool = next(t for t in tools if t.name == "sql_db_list_tables")
get_schema_tool = next(t for t in tools if t.name == "sql_db_schema")

print("list_tables_tool / get_schema_tool ready.")

from langchain_core.tools import tool

@tool
def db_query_tool(query: str) -> str:
    """
    Chinook SQLite DB에 대해 SELECT 쿼리를 실행하고 결과를 문자열로 반환, 실패 시 간단한 에러 메시지 반환
    """
    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result

print("db_query_tool ready.")



list_tables_tool / get_schema_tool ready.
db_query_tool ready.


In [5]:
# =========================================================
# 4) SQL Query Checker (LLM)
# =========================================================
from langchain_core.prompts import ChatPromptTemplate

query_check_system = """
You are a SQL expert with a strong attention to detail.
You are helping in a personal project that builds a SQL agent
over the Chinook SQLite database.

Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query.
If there are no mistakes, just reproduce the original query.

You will call the appropriate tool to execute the query after running this check.
"""

query_check_prompt = ChatPromptTemplate.from_messages(
    [("system", query_check_system), ("placeholder", "{messages}")]
)

# 여기서는 db_query_tool "만" tool로 바인딩 (정상)
query_check_llm = ChatOpenAI(
    model=MODEL_NAME,
    temperature=0,
).bind_tools([db_query_tool], tool_choice="db_query_tool")

# prompt | llm 체인
query_check = query_check_prompt | query_check_llm

print("query_check chain ready.")



query_check chain ready.


In [6]:
# =========================================================
# 5) 에이전트 State & 그래프 초기화
# =========================================================
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]

workflow = StateGraph(AgentState)

# ---------------------------------------------------------
# (1) 테이블 목록 가져오기
# ---------------------------------------------------------
def first_tool_call(state: AgentState) -> dict[str, list[AIMessage]]:
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "sql_db_list_tables",
                        "args": {},
                        "id": "initial_tool_call",
                    }
                ],
            )
        ]
    }

workflow.add_node("first_tool_call", first_tool_call)

# ---------------------------------------------------------
# (2) list_tables_tool / get_schema_tool / schema LLM
# ---------------------------------------------------------
workflow.add_node(
    "list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)

from langchain_openai import ChatOpenAI as OpenAIChat

schema_llm = OpenAIChat(
    model=MODEL_NAME, temperature=0
).bind_tools([get_schema_tool])

def model_get_schema(state: AgentState) -> dict[str, list[AIMessage]]:
    return {"messages": [schema_llm.invoke(state["messages"])]}

workflow.add_node("model_get_schema", model_get_schema)
workflow.add_node(
    "get_schema_tool", create_tool_node_with_fallback([get_schema_tool])
)



<langgraph.graph.state.StateGraph at 0x7cd44d9f5c40>

In [7]:
# =========================================================
# 6) 쿼리 생성/해석 LLM
# =========================================================
QUERY_GEN_INSTRUCTION = """
You are a SQL expert with a strong attention to detail.

You are helping with a personal project: a SQL agent over the Chinook
SQLite database (a digital media store).

You can:
- Define SQL queries,
- Analyze query results,
- Interpret the results and answer the user's question.

Read the messages below and identify:
- User question
- Table schemas (DDL)
- Query statements and query results (or errors), if any.

Rules:

1. If there is no query result yet that can answer the question,
   create a syntactically correct SQLite query to answer the user question.
   DO NOT run any DML statements (INSERT, UPDATE, DELETE, DROP, etc.).

2. If you create a query, respond ONLY with the query statement.
   Example: "SELECT * FROM Artist LIMIT 5;"

3. If a query was already executed but produced an error,
   respond with the same error message you found.
   Example: "Error: Artist table doesn't exist"

4. If a query was already executed successfully,
   interpret the result and answer the question with this pattern:
   Answer: <<question answer>>
"""

query_gen_prompt = ChatPromptTemplate.from_messages(
    [("system", QUERY_GEN_INSTRUCTION), ("placeholder", "{messages}")]
)

query_gen_llm = OpenAIChat(
    model=MODEL_NAME,
    temperature=0,
)

def query_gen_node(state: AgentState):
    """
    - 쿼리가 필요하면: SQL 문자열만 생성
    - 결과가 있으면: Answer: ... 형식으로 최종 답변 생성
    """
    message = query_gen_llm.invoke(
        query_gen_prompt.format_prompt(messages=state["messages"]).to_messages()
    )
    return {"messages": [message]}

workflow.add_node("query_gen", query_gen_node)



<langgraph.graph.state.StateGraph at 0x7cd44d9f5c40>

In [8]:
# =========================================================
# 7) 쿼리 체크 노드 & 실행 노드
# =========================================================
def model_check_query_node(state: AgentState) -> dict[str, list[AIMessage]]:
    # 마지막 메시지를 넣어서 query_check 체인 실행
    checked = query_check.invoke({"messages": [state["messages"][-1]]})
    return {"messages": [checked]}

workflow.add_node("correct_query", model_check_query_node)
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))



<langgraph.graph.state.StateGraph at 0x7cd44d9f5c40>

In [9]:
# =========================================================
# 8) 조건부 분기 로직
# =========================================================
def should_continue(state: AgentState) -> Literal[END, "correct_query", "query_gen"]:
    last = state["messages"][-1]
    content = getattr(last, "content", "")

    if isinstance(content, str):
        if content.startswith("Answer:"):
            return END
        if content.startswith("Error:"):
            # 에러 메시지가 오면 다시 쿼리 생성 시도
            return "query_gen"
        # 그 외는 correct_query로 보냄
        return "correct_query"
    else:
        # 리스트/기타 타입이면 일단 correct_query로
        return "correct_query"



In [10]:
# =========================================================
# 9) 엣지 연결 & 그래프 컴파일
# =========================================================
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")

workflow.add_conditional_edges("query_gen", should_continue)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")

checkpointer = MemorySaver()
app = workflow.compile(checkpointer=checkpointer)
print("LangGraph app compiled.")


LangGraph app compiled.


In [None]:
# =========================================================
# 10) 실행 함수 & 테스트
# =========================================================
from langchain_core.runnables import RunnableConfig

def run_graph(message: str, recursion_limit: int = 30, verbose: bool = True):

    config = RunnableConfig(
        recursion_limit=recursion_limit,
        configurable={"thread_id": str(uuid.uuid4())},
    )
    inputs = {"messages": [HumanMessage(content=message)]}
    result_state = app.invoke(inputs, config)
    msgs = result_state["messages"]
    last = msgs[-1]
    if verbose:
        print("---- Last Message ----")
        print(last)
        print("----------------------")
    return last.content

print("\n=== Test 1 ===")
run_graph("Andrew Adams 직원의 인적 정보를 모두 조회해줘.", verbose=True)

print("\n=== Test 2 ===")
run_graph("2009년에 어느 국가의 고객이 가장 많이 지출했고, 얼마를 지출했는지 알려줘.", verbose=True)


In [None]:
# =========================================================
# 11) LangSmith Evaluator 를 활용한 SQL Agent 평가
# =========================================================
from langsmith import Client

# 클라이언트 초기화
client = Client()

# 평가용 데이터셋 생성
examples = [
    (
        "Which country's customers spent the most? And how much did they spend?",
        "The country whose customers spent the most is the USA, with a total spending of 523.06.",
    ),
    (
        "What was the most purchased track of 2013?",
        "The most purchased track of 2013 was Hot Girl.",
    ),
    (
        "How many albums does the artist Led Zeppelin have?",
        "Led Zeppelin has 14 albums",
    ),
    (
        "What is the total price for the album “Big Ones”?",
        "The total price for the album 'Big Ones' is 14.85",
    ),
    (
        "Which sales agent made the most in sales in 2009?",
        "Steve Johnson made the most sales in 2009",
    ),
]

dataset_name = "SQL Agent Response"

if not client.has_dataset(dataset_name=dataset_name):
    dataset = client.create_dataset(dataset_name=dataset_name)
    inputs, outputs = zip(
        *[({"input": text}, {"output": label}) for text, label in examples]
    )
    client.create_examples(inputs=inputs, outputs=outputs, dataset_id=dataset.id)
    print("LangSmith dataset created.")
else:
    print("LangSmith dataset already exists.")


In [None]:
# =========================================================
# 12) 에이전트의 SQL 쿼리 응답을 예측하기 위한 함수 정의
# =========================================================
from langchain_core.runnables import RunnableConfig
from langchain_core.messages import HumanMessage
import uuid

# 에이전트의 SQL 쿼리 응답을 예측하기 위한 함수 정의
def predict_sql_agent_answer(example: dict):
    """Use this for answer evaluation"""
    config = RunnableConfig(configurable={"thread_id": str(uuid.uuid4())})

    inputs = {
        "messages": [HumanMessage(content=example["input"])],
    }
    # 그래프를 실행하여 메시지 결과 조회
    state = app.invoke(inputs, config)
    answer = state["messages"][-1].content
    # 결과 반환
    return {"response": answer}


In [None]:
# =========================================================
# 13) LLM-as-judge 평가 프롬프트 및 평가자 정의
# =========================================================
from langchain import hub
from langchain_openai import ChatOpenAI

# Grade prompt
grade_prompt_answer_accuracy = hub.pull("langchain-ai/rag-answer-vs-reference")


# 답변 평가자 LLM-as-judge 정의
def answer_evaluator(run, example) -> dict:
    # input: 질문
    input_question = example.inputs["input"]
    # output: 참조 답변
    reference = example.outputs["output"]
    # 예측 답변
    prediction = run.outputs["response"]

    # LLM 평가자 초기화
    llm = ChatOpenAI(model=MODEL_NAME, temperature=0)
    answer_grader = grade_prompt_answer_accuracy | llm

    # 평가자 실행
    score = answer_grader.invoke(
        {
            "question": input_question,
            "correct_answer": reference,
            "student_answer": prediction,
        }
    )
    score = score["Score"]

    # 점수 반환
    return {"key": "answer_v_reference_score", "score": score}


In [None]:
# =========================================================
# 14) 평가 실행
# =========================================================
from langsmith.evaluation import evaluate

# 평가용 데이터셋 이름
dataset_name = "SQL Agent Response"

try:
    # 평가 진행
    experiment_results = evaluate(
        predict_sql_agent_answer,  # 평가에 사용할 예측 함수
        data=dataset_name,         # 평가용 데이터셋 이름
        evaluators=[answer_evaluator],  # 평가자 목록
        num_repetitions=3,         # 실험 반복 횟수
        experiment_prefix="sql-agent-eval",
        metadata={"version": "chinook db, sql-agent-eval: gpt-4o"},  # 실험 메타데이터
    )
    print("Evaluation finished.")
except Exception as e:
    print("Evaluation error:", e)
