In [1]:
%pip install pyautogen==0.2.32 langchain==0.2.10 langchain_openai==0.1.17 langchain_community==0.2.9 load_dotenv==0.1.0



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import autogen
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_openai import AzureChatOpenAI
import os
from dotenv import load_dotenv

load_dotenv()

db = SQLDatabase.from_uri("sqlite:///kb-creditcard.db")
print (db.dialect)
print(db.get_usable_table_names())




  from .autonotebook import tqdm as notebook_tqdm


sqlite
['bill', 'bill_add', 'card', 'customer', 'customer_own_card', 'domestic_approve', 'issued_card', 'oversea_approve', 'payment', 'point']


In [3]:

llm = AzureChatOpenAI(
    azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
    azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)

In [4]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

chain = create_sql_query_chain(llm, db)

system = """Double check the user's {dialect} query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins

If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.

Output the final SQL query only.
Don't include markdown or code blocks. 
"""
prompt = ChatPromptTemplate.from_messages(
    [("system", system), ("human", "{query}")]
).partial(dialect=db.dialect)
validation_chain = prompt | llm | StrOutputParser()

full_chain = {"query": chain} | validation_chain



query = full_chain.invoke(
    {
        # "question": "고객들이 가장 많이 사용하는 카드명은 뭐야?"
        # "question": "신용카드로 가장 많은 금액을 사용한 업종은 뭐야?"
        "question": "50대 고객들이 사용하는 카드 중 가장 많이 사용하는 카드는 무엇인가요?"
    }
)

print(query)

db.run(query)
# result[0].get("contents")

SELECT "c"."card_name", COUNT(*) AS "usage_count"
FROM "customer" AS "cu"
JOIN "customer_own_card" AS "coc" ON "cu"."customer_id" = "coc"."customer_id"
JOIN "issued_card" AS "ic" ON "coc"."card_no" = "ic"."card_no"
JOIN "card" AS "c" ON "ic"."card_id" = "c"."card_id"
WHERE "cu"."age" BETWEEN 50 AND 59
GROUP BY "c"."card_name"
ORDER BY "usage_count" DESC
LIMIT 1;


"[('KB 국민 청춘대로카드', 10)]"

In [5]:

config_list = autogen.config_list_from_json(
    env_or_file="OAI_CONFIG_LIST.json",
)

def create_sql(query):
    return full_chain.invoke({"question": query})
    # result = db.run(q)
    # print(f"query: {q} \n result: {result}")
    # return result

llm_function_config = {
    "config_list": config_list, 
    "stream": True,
    # "seed": 42,
    "functions": [
        {
            "name": "create_sql",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "Create sql to execute on the database"
                    }
                },
                "required": ["query"]
            }
        }
    ]
}



In [19]:
import json
import re

import sqlite3

from textwrap import dedent
from typing import Callable, Dict, List, Optional, Literal, Optional, Union

from autogen import ConversableAgent
from autogen.coding import CodeExecutor, CodeExtractor, MarkdownCodeExtractor, CodeBlock, CodeResult
from autogen.runtime_logging import log_new_agent, logging_enabled

class SqlQueryCodeExecutor(CodeExecutor):
    def __init__(self, db_path, **kwargs):
        self.db_path = db_path

    @property
    def code_extractor(self) -> CodeExtractor:
        return MarkdownCodeExtractor()

    def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult:
        logs_all = ""
        for idx, code_block in enumerate(code_blocks, start=1):
            lang, code = code_block.language, code_block.code
            lang = lang.lower()

            if lang != "sql":
                logs_all += "\n" + f"Skipping execution: language not specified (code block #{idx})."
                continue

            try:
                with sqlite3.connect(self.db_path) as cnn:
                    cnn.row_factory = sqlite3.Row
                    cur = cnn.cursor()
                    cur.execute(code)
                    cres = [dict(row) for row in cur.fetchall()]
                    res_json = json.dumps(cres, default=str)
            except Exception as e:
                exitcode = -1
                logs_all += f"\nError: {str(e)}"
                break
            exitcode = 0
            logs_all += res_json
        return CodeResult(exit_code=exitcode, output=logs_all)

    def restart(self) -> None:
        pass 


class UserProxyAgentForSql(ConversableAgent):
    DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS = {
        "ALWAYS": dedent(\
            """An attentive HUMAN user who can answer questions about the task, and can perform tasks such as running sql query
               or inputting command line commands at a Linux terminal and reporting back the execution results."""),
        "TERMINATE": dedent(\
            """A user that can run sql query or input command line commands at a Linux terminal and report back the execution results."""),
        "NEVER": dedent(\
            """An sql bot that performs no other action than running sql query (provided to it's quoted in sql query blocks)."""),
    }

    def __init__(
        self,
        name: str,
        db_path: str,
        is_termination_msg: Optional[Callable[[Dict], bool]] = None,
        max_consecutive_auto_reply: Optional[int] = None,
        human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
        default_auto_reply: Union[str, Dict] = "",
        description: Optional[str] = None,
    ):
        sql_executor = SqlQueryCodeExecutor(db_path=db_path)
        super().__init__(
            name=name,
            is_termination_msg=is_termination_msg,
            max_consecutive_auto_reply=max_consecutive_auto_reply,
            human_input_mode=human_input_mode,
            code_execution_config={"executor": sql_executor},
            default_auto_reply=default_auto_reply,
            description=(
                description if description is not None else self.DEFAULT_USER_PROXY_AGENT_DESCRIPTIONS[human_input_mode]
            ),
        )

        if logging_enabled():
            log_new_agent(self, locals())

    def run_code(self, code, **kwargs):
        return -1, "Not supported", None

    def execute_code_blocks(self, code_blocks):
        return -1, "Not supported"    

In [20]:
user_proxy = UserProxyAgentForSql(
    name="User_proxy",
    db_path = "kb-creditcard.db",
    human_input_mode="NEVER",
    max_consecutive_auto_reply=3,
    is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
)

sql_agent = autogen.AssistantAgent(
  name="chatbot",
  llm_config=llm_function_config,
  system_message="""
    질문 받은 것을 그대로 create_sql Function의 query파라미터에 넣어 SQL을 생성해주세요.
  """,
  
  human_input_mode="NEVER"
)

user_proxy.register_function(
    function_map={
        "create_sql": create_sql
    }
)

sql_agent.register_function(
    function_map={
        "create_sql": create_sql
    }
)
user_proxy.initiate_chat(sql_agent, message="50대 고객들이 사용하는 카드 중 가장 많이 보유한 카드이름은 무엇인가요?.")


[33mUser_proxy[0m (to chatbot):

50대 고객들이 사용하는 카드 중 가장 많이 보유한 카드이름은 무엇인가요?.

--------------------------------------------------------------------------------


[33mchatbot[0m (to User_proxy):


[32m***** Suggested function call: create_sql *****[0m
Arguments: 
{
  "query": "SELECT card_name, COUNT(*) as count FROM cards WHERE age >= 50 GROUP BY card_name ORDER BY count DESC LIMIT 1;"
}
[32m***********************************************[0m

--------------------------------------------------------------------------------
[35m
>>>>>>>> EXECUTING FUNCTION create_sql...[0m
[33mUser_proxy[0m (to chatbot):

[32m***** Response from calling function (create_sql) *****[0m
SELECT card.card_name, COUNT(*) as count 
FROM customer
JOIN customer_own_card ON customer.customer_id = customer_own_card.customer_id
JOIN issued_card ON customer_own_card.card_no = issued_card.card_no
JOIN card ON issued_card.card_id = card.card_id
WHERE customer.age >= 50
GROUP BY card.card_name
ORDER BY count DESC
LIMIT 1;
[32m*******************************************************[0m

--------------------------------------------------------------------------------


ChatResult(chat_id=None, chat_history=[{'content': '50대 고객들이 사용하는 카드 중 가장 많이 보유한 카드이름은 무엇인가요?.', 'role': 'assistant'}, {'content': '', 'function_call': {'arguments': '{\n  "query": "SELECT card_name, COUNT(*) as count FROM cards WHERE age >= 50 GROUP BY card_name ORDER BY count DESC LIMIT 1;"\n}', 'name': 'create_sql'}, 'role': 'assistant'}, {'content': 'SELECT card.card_name, COUNT(*) as count \nFROM customer\nJOIN customer_own_card ON customer.customer_id = customer_own_card.customer_id\nJOIN issued_card ON customer_own_card.card_no = issued_card.card_no\nJOIN card ON issued_card.card_id = card.card_id\nWHERE customer.age >= 50\nGROUP BY card.card_name\nORDER BY count DESC\nLIMIT 1;', 'name': 'create_sql', 'role': 'function'}, {'content': 'Here is the SQL query to find out the most popular card among 50s customers:\n\n```sql\nSELECT card.card_name, COUNT(*) as count \nFROM customer\nJOIN customer_own_card ON customer.customer_id = customer_own_card.customer_id\nJOIN issued_card ON cu