In [1]:
import os
from langchain_community.utilities import SQLDatabase

# OpenAI API 키 설정 (이 예제에서는 사용되지 않음)
api_key = os.getenv("OPENAI_API_KEY")

# SQLite 데이터베이스 경로
db_path = "/workspace/youngwoo/toyproject-datallm/modules/agent/tmp_dataset/Chinook.db"

# SQLite 데이터베이스 연결
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")

# 데이터베이스 방언 출력 (예: sqlite)
print(db.dialect)

# 사용 가능한 테이블 이름 출력
print(db.get_usable_table_names())

# SQL 쿼리 실행 및 결과 출력
result = db.run("SELECT * FROM Artist LIMIT 10;")
print(result)

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


동적 few-shot 프롬프트 사용<br>
에이전트 성능을 최적화하기 위해 도메인 관련 지식을 포함한 사용자 정의 프롬프트를 제공할 수 있습니다. 이 경우, 예제 선택기를 사용하여 사용자 입력에 따라 동적으로 few-shot 프롬프트를 구성할 것입니다. 이는 모델이 참조할 수 있는 관련 쿼리를 프롬프트에 삽입함으로써 더 나은 쿼리를 생성하는 데 도움이 됩니다.

In [2]:
examples = [
    {"input": "List all artists.", "query": "SELECT * FROM Artist;"},
    {
        "input": "Find all albums for the artist 'AC/DC'.",
        "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
    },
    {"input": "List all tracks in the 'Rock' genre.", "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');"},
    {"input": "Find the total duration of all tracks.", "query": "SELECT SUM(Milliseconds) FROM Track;"},
    {"input": "List all customers from Canada.", "query": "SELECT * FROM Customer WHERE Country = 'Canada';"},
    {"input": "How many tracks are there in the album with ID 5?", "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;"},
    {"input": "Find the total number of invoices.", "query": "SELECT COUNT(*) FROM Invoice;"},
    {"input": "List all tracks that are longer than 5 minutes.", "query": "SELECT * FROM Track WHERE Milliseconds > 300000;"},
    {"input": "Who are the top 5 customers by total purchase?", "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;"},
    {"input": "Which albums are from the year 2000?", "query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';"},
    {"input": "How many employees are there", "query": 'SELECT COUNT(*) FROM "Employee"'},
]

이제 예제 선택기를 만듭니다. 이는 실제 사용자 입력을 받아 몇 가지 예제를 선택하여 few-shot 프롬프트에 추가합니다. 우리는 SemanticSimilarityExampleSelector를 사용하여 설정한 임베딩 및 벡터 스토어를 사용하여 입력과 가장 유사한 예제를 찾습니다:

In [4]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

이제 FewShotPromptTemplate을 만듭니다. 여기서는 예제 선택기, 각 예제를 형식화할 예제 프롬프트, 그리고 형식화된 예제 앞뒤에 넣을 문자열 접두사와 접미사가 필요합니다:



In [5]:
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)

system_prefix = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.

Here are some examples of user inputs and their corresponding SQL queries:"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template(
        "User input: {input}\nSQL query: {query}"
    ),
    input_variables=["input", "dialect", "top_k"],
    prefix=system_prefix,
    suffix="",
)

전체 프롬프트는 human message 템플릿과 agent_scratchpad MessagesPlaceholder가 있는 chat 프롬프트여야 합니다. few-shot 프롬프트는 시스템 메시지로 사용됩니다:

In [6]:
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

# 예제 형식화된 프롬프트
prompt_val = full_prompt.invoke(
    {
        "input": "How many arists are there",
        "top_k": 5,
        "dialect": "SQLite",
        "agent_scratchpad": [],
    }
)
print(prompt_val.to_string())

System: You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't 

이제 사용자 정의 프롬프트로 에이전트를 생성할 수 있습니다:

In [10]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
)

# 시도해 봅시다:
agent.invoke({"input": "How many artists are there?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT COUNT(*) FROM Artist;'}`


[0m[36;1m[1;3m[(275,)][0m[32;1m[1;3mThere are 275 artists in the database.[0m

[1m> Finished chain.[0m


{'input': 'How many artists are there?',
 'output': 'There are 275 artists in the database.'}

고유 명사가 많은 열을 다루기<br>

주소, 노래 이름 또는 아티스트와 같은 고유 명사를 포함하는 열을 필터링하려면, 먼저 철자를 확인하여 데이터를 정확하게 필터링해야 합니다.<br>

이를 위해 데이터베이스에 존재하는 모든 고유 명사를 포함한 벡터 스토어를 생성할 수 있습니다. 그런 다음 사용자가 질문에 고유 명사를 포함할 때마다 에이전트가 해당 벡터 스토어를 쿼리하여 해당 단어의 정확한 철자를 찾도록 할 수 있습니다. 이렇게 하면 에이전트가 대상 쿼리를 작성하기 전에 사용자가 참조하는 엔티티를 정확하게 이해할 수 있습니다.<br>

먼저 각 엔티티에 대한 고유 값을 가져오는 함수를 정의합니다:<br>


In [11]:
import ast
import re

def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return list(set(res))

artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
albums[:5]

['Os Cães Ladram Mas A Caravana Não Pára',
 'War',
 'Mais Do Mesmo',
 "Up An' Atom",
 'Riot Act']

['Os Cães Ladram Mas A Caravana Não Pára',
 'War',
 'Mais Do Mesmo',
 "Up An' Atom",
 'Riot Act']

이제 사용자 정의 검색 도구를 생성하고 최종 에이전트를 만듭니다:


In [12]:
from langchain.agents.agent_toolkits import create_retriever_tool

vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings())
retriever = vector_db.as_retriever(search_kwargs={"k": 5})
description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
valid proper nouns. Use the noun most similar to the search."""
retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)


In [13]:
system = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool! 

You have access to the following tables: {table_names}

If the question does not seem related to the database, just return "I don't know" as the answer."""

prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{input}"), MessagesPlaceholder("agent_scratchpad")]
)
agent = create_sql_agent(
    llm=llm,
    db=db,
    extra_tools=[retriever_tool],
    prompt=prompt,
    agent_type="openai-tools",
    verbose=True,
)

agent.invoke({"input": "How many albums does alis in chain have?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `search_proper_nouns` with `{'query': 'alis in chain'}`


[0m[36;1m[1;3mAlice In Chains

Aisha Duo

Xis

Da Lama Ao Caos

A-Sides[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': "SELECT COUNT(*) AS album_count FROM Album WHERE ArtistId IN (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')"}`


[0m[36;1m[1;3m[(1,)][0m[32;1m[1;3mAlice In Chains has 1 album.[0m

[1m> Finished chain.[0m


{'input': 'How many albums does alis in chain have?',
 'output': 'Alice In Chains has 1 album.'}

보시다시피, 에이전트는 search_proper_nouns 도구를 사용하여 특정 아티스트에 대해 데이터베이스를 올바르게 쿼리하는 방법을 확인했습니다.