In [None]:
%pip install numexpr

In [16]:
%pip install google-search-results


Note: you may need to restart the kernel to use updated packages.


In [38]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.chat_models import ChatOpenAI

from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType

from langchain.chains import LLMMathChain
from langchain.utilities import SerpAPIWrapper
from langchain.agents import initialize_agent, Tool
from langchain.prompts import MessagesPlaceholder
from langchain.memory import ConversationBufferMemory

from dotenv import load_dotenv
import time
from openai import RateLimitError
import os
import openai

In [39]:
# Load environment variables from the .env file
load_dotenv()

openai_api_key = os.getenv('OPENAI_API_KEY')

In [40]:
db = SQLDatabase.from_uri("sqlite:///Chinook.db")


In [41]:
db.run("select * from Artist limit 5")

"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains')]"

In [42]:
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0.0,
    openai_api_key=openai_api_key,
    verbose=True,
)

In [37]:
print(llm.model_name)

gpt-4o-mini


In [57]:
import re
class CleanSQLDatabaseChain(SQLDatabaseChain):
    def _strip_markdown(self, sql: str) -> str:
        """Remove triple backticks and 'sql' language hints."""
        return re.sub(r"```(?:sql)?|```", "", sql).strip()
    
    def _call(self, inputs: dict, run_manager=None):
        # Get the response from the LLM as usual
        response = super()._call(inputs, run_manager)
        
        # Clean the SQL query before execution
        if 'sql_cmd' in response:
            response['sql_cmd'] = self._strip_markdown(response['sql_cmd'])
        
        return response

In [58]:
#Connect SQL database chain
db_chain = CleanSQLDatabaseChain.from_llm(llm, db, verbose=True)

In [44]:

llm_math_chain = LLMMathChain.from_llm(llm=llm, verbose=True)


In [45]:
serpapi_api_key = os.getenv('SERPAPI_API_KEY')
search = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)

In [47]:
tools = [
    Tool(
        name="Search",
        func=search.run,
        description="useful for when you need to answer questions about current events. You should ask targeted questions",
    ),
    Tool(
        name="Calculator",
        func=llm_math_chain.run,
        description="useful for when you need to answer questions about math",
    ),
    Tool(
        name="FooBar-DB",
        func=db_chain.run,
        description="useful for when you need to answer questions about FooBar. Input should be in the form of a question containing full context",
    ),
]

In [48]:
memory = ConversationBufferMemory(memory_key="memory", return_messages=True)

In [49]:
agent_kwargs = {
    "extra_prompt_messages": [MessagesPlaceholder(variable_name="memory")],
}

In [50]:
agent = initialize_agent(
    tools,
    llm,
    agent=AgentType.OPENAI_FUNCTIONS,
    verbose=True,
    agent_kwargs=agent_kwargs,
    memory=memory,
)

In [51]:
# Define the handle_chat function

def handle_chat(query, retries=3, delay=5):
    for attempt in range(retries):
        try:
            response = agent.invoke({"input": query})
            return response
        except RateLimitError:
            print(f"[Retry {attempt+1}] Rate limit hit. Waiting {delay} seconds...")
            time.sleep(delay)
    raise Exception("Rate limit exceeded after multiple retries.")


In [52]:
ONE_MILLION = 1000000
def calculate_cost(query, model, response):
    prompt = query.strip()
    num_input_token = len(prompt.split())
    num_output_token = len(response.split())
    if model == 'gpt-4o-mini':
        return (0.6*num_input_token/ONE_MILLION) + (2.4*num_output_token/ONE_MILLION)
    elif model == 'gpt-4o':
        return (5*num_input_token/ONE_MILLION) + (20*num_output_token/ONE_MILLION)
    

In [None]:
# Example usage 01
query = 'How many artists are there in our database?'
query2 = 'Select * from Artist where ArtistId = 1'
response = handle_chat(query2)
print(response) 
cost = calculate_cost(query, llm.model_name ,response)
print("Cost of this prompt: " + cost )




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `FooBar-DB` with `Select * from Artist where ArtistId = 1`


[0m

[1m> Entering new SQLDatabaseChain chain...[0m
Select * from Artist where ArtistId = 1
SQLQuery:[32;1m[1;3mQuestion: Select * from Artist where ArtistId = 1  
SQLQuery: SELECT "ArtistId", "Name" FROM "Artist" WHERE "ArtistId" = 1[0m
SQLResult: [33;1m[1;3m[(1, 'AC/DC')][0m
Answer:[32;1m[1;3mQuestion: Select * from Artist where ArtistId = 1  
SQLQuery: SELECT "ArtistId", "Name" FROM "Artist" WHERE "ArtistId" = 1[0m
[1m> Finished chain.[0m
[38;5;200m[1;3mQuestion: Select * from Artist where ArtistId = 1  
SQLQuery: SELECT "ArtistId", "Name" FROM "Artist" WHERE "ArtistId" = 1[0m[32;1m[1;3mThe details for the artist with ArtistId = 1 are as follows:

- ArtistId: 1
- Name: "Artist A"

If you need more information or have any other queries, feel free to ask![0m

[1m> Finished chain.[0m
{'input': 'Select * from Artist where ArtistId = 1

UI For Chatbot

In [12]:
import gradio as gr
import streamlit as st



In [13]:
# Add a logo and title
logo_url = "/Users/pro/Documents/Documents - pro’s MacBook Pro - 1/Học Viện/Năm 4-HKII/Kỹ thuật theo dõi giám sát an toàn mạng/chatbot_sql/logo.avif"  # Replace with the actual URL of your logo
st.markdown(
    f"""
    <div style="display: flex; justify-content: space-between; align-items: center;">
        <h1>Chinook Tunes</h1>
        <img src="{logo_url}" style="height: 50px;">
    </div>
    """,
    unsafe_allow_html=True,
)



DeltaGenerator()

In [6]:
# Initialize chat history
if "history" not in st.session_state:
    st.session_state["history"] = []

2025-05-29 17:26:28.776 Session state does not function when running a script without `streamlit run`


In [7]:
# Function to handle message sending
def send_message():
    if st.session_state.user_input:
        user_message = st.session_state.user_input
        response = handle_chat(user_message)
        st.session_state["history"].append(("You", user_message))
        st.session_state["history"].append(("CT", response['output']))
        st.session_state.user_input = ""

# User input field
user_input = st.text_input(
    "Enter your message:", key="user_input", on_change=send_message
)



In [14]:
# Send button
send_button = st.button("Send")

if send_button and st.session_state.user_input:
    send_message()



In [15]:
# Display chat history
for idx, (user, message) in enumerate(reversed(st.session_state["history"])):
    if user == "You":
        st.markdown(f"<div style='text-align: right;'><b>You:</b> {message}</div>", unsafe_allow_html=True)
    else:
        st.markdown(f"<div style='text-align: left;'><b>CT:</b> {message}</div>", unsafe_allow_html=True)

# Horizontal line for separation
st.markdown("---")



DeltaGenerator()