In [None]:
import os
import json
import requests as r
import mysql.connector as ms
import pandas as pd

from openai import AzureOpenAI

In [None]:
#OPEN AI Connection Setup
OPEN_AI_API_KEY = ""
OPEN_AI_API_ENDPOINT = ""
OPEN_AI_MODEL = "gpt-4o"
OPEN_AI_API_VERSION = "2024-05-01-preview"
SQL_SETUP_PROMPT = None

client = AzureOpenAI(
    azure_endpoint = OPEN_AI_API_ENDPOINT,
    api_key = OPEN_AI_API_KEY,
    api_version = OPEN_AI_API_VERSION
)

In [None]:
#Setup MY SQL Connection
MYSQL_SERVER = "localhost"
MYSQL_DB = "SalesDB"
MYSQL_USER = "root"
MYSQL_PASSWORD = "pass#123"

sqlconn = ms.connect(
    user= MYSQL_USER, 
    password= MYSQL_PASSWORD,
    host= MYSQL_SERVER,
    database = MYSQL_DB,
    use_pure=True
)

In [None]:
#Define Tool to execute sql statement/query

def exec_sql(sql_stmt: str):
    with sqlconn.cursor() as cur:
        cur.execute(sql_stmt)
        result = cur.fetchall()
        
        if len(result) > 0:
            columns = [desc[0] for desc in cur.description]
            result_dict = [dict(zip(columns, row)) for row in result]

            return result_dict
    return None


sql_tools = [
    {
        "type": "function",
        "function": {
            "name": "exec_sql",
            "description": "Execute sql query and return result dict object",
            "parameters": {
                "type": "object",
                "properties": {
                    "sql_stmt": {
                        "type": "string",
                        "decription": "sql query",
                    },
                },
                "reauired": ["sql_stmt"],
            }
        }
    }
]

In [None]:
#Train the LLM with metadata, provide table and column information. Define function for the same.

def sql_setup():
    metadata = dict()

    sql_stmt = f"SELECT t.table_name FROM information_schema.tables t WHERE t.table_schema = '{MYSQL_DB}';"

    tables = exec_sql(sql_stmt)

    for table in tables:
        sql_stmt = f"""
                    SELECT 
                    	c.column_name,
                    	c.data_type,
                    	c.column_comment
                    FROM information_schema.tables t
                    JOIN information_schema.columns c
                    	ON t.table_name = c.table_name
                    	AND t.table_schema = c.table_schema
                    WHERE t.table_schema = '{MYSQL_DB}'
                    	AND t.table_name = '{table['TABLE_NAME']}'
                """

        rows = exec_sql(sql_stmt)
        metadata[table['TABLE_NAME']] = rows

    return metadata


In [None]:
#Interact with LLM - prompt generation and invoke chat completion API to get response

def generate_prompt(sql_setup_prompt):

    prompt = f"""
    You are a SQL AI assistant that helps to write complex sql queries. Use below metadata information to build SQL quiries.Add schema 
    name {MYSQL_DB} to all table namesin generated query. You have accessto tool 'exec_sql', use this tool to execute query and fetch result. 
    Parse the result returned by the tool and convert it into tabular format. Limit your queries to provided dataset.

    sql engine: mysql

    {sql_setup_prompt}

    Output format -
    sql statement:
    result:
    """

    return prompt.replace(" " *2, "")



def get_llm_response(prompt, user_query, sql_tools):
    messages = [
        {
            "role": "system",
            "contnet": "you are sql expert"
        },
        {
            "role": "user",
            "content": user_query
        }
    ]
    print(messages)
    chat_completion = client.chat.completions.create(
        model = OPEN_AI_MODEL,
        messages = messages,
        temperature = 0,
        tools = sql_tools,
        tool_choice = "auto",
        top_p = 0.95,
        frequency_penalty = 0,
        presence_penalty = 0,
        timeout = 200
    )
    print(chat_completion)
    
    response_message = chat_completion.choices[0].message
    messages.append(response_message)

    if response_message.tool_calls:
        ##log the function call
        # print(f"Function call: {tool_call.function.name}")
        # print(f"Function call: {tool_call.function.arguments}")

        if tool_call.function.name == "exec_sql":
            function_args = json.loads(tool_call.function.arguments)

            result = exec_sql(
                sql_stmt = function_args.get("sql_stmt")
            )

            messages.append({
                "tool_call_id": tool_call.id,
                "role": "tool",
                "name": "sql_assistant",
                "content": str(result),
            })
    else:
        print("No tool call made by model.")


    final_response = client.chat.completions.create(
        model = OPEN_AI_MODEL,
        messages = messages
    )

    print(final_response.choices[0].message.content)


In [None]:
if not SQL_SETUP_PROMPT:
    SQL_SETUP_PROMPT = sql_setup()

prompt = generate_prompt(SQL_SETUP_PROMPT)

Time to test some user queries against the database.

In [None]:
user_query = "Find top 3 products sold"

result = get_llm_response(prompt, user_query, sql_tools)