In [1]:
# Install necessary packages
%pip install langchain_ollama langchain langchain_community pyodbc mssql langchain_huggingface faiss-cpu

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import sqlite3
import uuid
import gradio as gr
from langchain_ollama import ChatOllama
from langchain_community.utilities import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool, InfoSQLDatabaseTool, ListSQLDatabaseTool, QuerySQLCheckerTool
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate, ChatPromptTemplate, SystemMessagePromptTemplate

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Set up environment variables for LangChain API and other configurations
langsmith_api_key = os.environ.get("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")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Local SQL Agent"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = "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"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [5]:
# Initialize the LLM model
llm = ChatOllama(model="llama3.1:8b-instruct-q4_0")

In [6]:
# Step 1: SQLite Setup and SQL Script Execution
def setup_sqlite_db(db_path='chinook.db', sql_script_path='chinook.sql'):
    conn = sqlite3.connect(db_path)
    with open(sql_script_path, 'r') as file:
        sql_script = file.read()
    cursor = conn.cursor()
    cursor.executescript(sql_script)
    conn.commit()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    print("Tables in the database:", tables)
    return conn

conn = setup_sqlite_db()

Tables in the database: [('Album',), ('Artist',), ('Customer',), ('Employee',), ('Genre',), ('Invoice',), ('InvoiceLine',), ('MediaType',), ('Playlist',), ('PlaylistTrack',), ('Track',)]


In [7]:
# Step 2: Create the SQLDatabase object for LangChain tools
database_path = 'chinook.db'
connection_url = f"sqlite:///{database_path}"
db = SQLDatabase.from_uri(connection_url, sample_rows_in_table_info=3)

In [8]:
# Example Queries
examples = [
    { "input": "List all artists.", "query": "SELECT * FROM Artist;" },
    { "input": "Find all albums for the artist 'AC/DC'.", "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');" },
    { "input": "List all tracks in the 'Rock' genre.", "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');" },
    { "input": "Find the total duration of all tracks.", "query": "SELECT SUM(Milliseconds) FROM Track;" },
    { "input": "List all customers from Canada.", "query": "SELECT * FROM Customer WHERE Country = 'Canada';" },
    { "input": "How many tracks are there in the album with ID 5?", "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;" },
    { "input": "Find the total number of Albums.", "query": "SELECT COUNT(DISTINCT AlbumId) FROM Album;" },
    { "input": "List all tracks that are longer than 5 minutes.", "query": "SELECT * FROM Track WHERE Milliseconds > 300000;" },
    { "input": "Who are the top 5 customers by total purchase?", "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;" },
    { "input": "How many employees are there?", "query": "SELECT COUNT(*) FROM Employee;" },
]
print(len(examples))

10


In [9]:
# Set up Embeddings and Example Selector
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embeddings,
    FAISS,
    k=2,
    input_keys=["input"],
)



In [10]:
# Define System Prompt Template
system_prefix = """You are an agent designed to interact with the local SQLite database.
Given an input question, create a syntactically correct SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.

You have access to the following tools for interacting with the database:

{tools}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of {tool_names}
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.
If you see you are repeating yourself, just provide final answer and exit.

Here are some examples of user inputs and their corresponding SQL queries:"""


In [11]:
# Dynamic Few-Shot Prompt Template
dynamic_few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=PromptTemplate.from_template("User input: {input}\nSQL query: {query}"),
    input_variables=["input"],
    prefix=system_prefix,
    suffix=""
)

# Full Chat Prompt Template
full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=dynamic_few_shot_prompt),
        ("human", "{input}"),
        ("system", "{agent_scratchpad}"),
    ]
)

In [12]:
# Step 3: Create Tools Using the SQLDatabase Object
tools = [
    QuerySQLDataBaseTool(db=db),
    InfoSQLDatabaseTool(db=db),
    ListSQLDatabaseTool(db=db),
    QuerySQLCheckerTool(db=db, llm=llm)
]

# Utility Functions for SQL Query Handling
def check_sql_query(query, max_retries=3):
    for i in range(max_retries):
        try:
            print(f"Attempt {i+1}: Checking SQL query...")
            action_output = QuerySQLCheckerTool(db=db, llm=llm).run(query)
            if action_output:
                print("Query check successful.")
                return action_output
            else:
                print("Query check failed. Retrying...")
        except Exception as e:
            print(f"Error on attempt {i+1}: {e}")
    print("Max retries reached. Exiting with error.")
    return None

def execute_sql_query(query):
    try:
        print("Executing SQL query directly...")
        query_tool = QuerySQLDataBaseTool(db=db)
        result = query_tool.run(query)
        print("Query executed. Result:", result)
        return result
    except Exception as e:
        print(f"Error during SQL query execution: {e}")
        return None

def extract_sql_query_from_prompt(prompt_val):
    for message in prompt_val.messages:
        if "SQL query:" in message.content:
            start_index = message.content.index("SQL query:") + len("SQL query:")
            sql_query = message.content[start_index:].strip()
            return sql_query
    return None

In [14]:
# Step 4: Process the Prompt and Generate Response
def process_prompt(input_question):
    print("Starting query processing...")
    prompt_val = full_prompt.invoke(
        {
            "input": input_question,
            "tool_names": [tool.name for tool in tools],
            "tools": [tool.name + " - " + tool.description.strip() for tool in tools],
            "agent_scratchpad": [],
        }
    )
    
    print(prompt_val)

    generated_query = extract_sql_query_from_prompt(prompt_val)
    if generated_query:
        checked_query = check_sql_query(generated_query)
        if checked_query:
            result = execute_sql_query(checked_query)
            if result:
                final_answer = result
                print("Final Answer:", final_answer)
                return final_answer
            else:
                return "Execution failed after validation."
        else:
            return "Query check failed after multiple attempts. Aborting execution."
    else:
        return "No valid SQL query generated by the prompt."
    
# Format the Output
def format_output(data):
    if not isinstance(data, list):
        return str(data)
    formatted_rows = ["   ".join([str(item) if item is not None else '' for item in row]) for row in data]
    return "\n".join(formatted_rows)

In [16]:
# Gradio Interface Setup
def respond(message):
    """
    Processes user input and returns a response formatted for the Gradio chatbot interface.

    Args:
        message (str): The user's message input.

    Returns:
        list: A list of tuples with the user input and the system response.
    """
    session_id = uuid.uuid4().hex
    print("Session ID: ", session_id)
    print(f"Received user input: {message}")
    final_answer = process_prompt(message)
    formatted_answer = format_output(final_answer)
    return [(message, formatted_answer)]

# Launch Gradio Interface
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Your Question", placeholder="Ask your database a question...")
    clear = gr.ClearButton([msg, chatbot])

    # When the user submits a message, it is processed and the response is displayed
    msg.submit(respond, [msg], chatbot)

demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7882
Running on public URL: https://9e708f587b3a60c871.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Session ID:  00af11a9565942a3b05587375cc55ee9
Received user input: Find the total number of Albums.
Starting query processing...
messages=[SystemMessage(content='You are an agent designed to interact with the local SQLite database.\nGiven an input question, create a syntactically correct SQL query to run, then look at the results of the query and return the answer.\nUnless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\nYou can order the results by a relevant column to return the most interesting examples in the database.\nNever query for all the columns from a specific table, only ask for the relevant columns given the question.\n\nYou have access to the following tools for interacting with the database:\n\n[\'sql_db_query - Execute a SQL query against the database and get back the result..\\n    If the query is not correct, an error message will be returned.\\n    If an error is returned, rewrite the query, check th