In [None]:
%pip install --upgrade --quiet  langchain langchain-community langchain-openai python-dotenv

In [1]:
import os
from dotenv import load_dotenv, find_dotenv
# load the .env variables
load_dotenv(find_dotenv())

# Importing the openai api key to use openai models
openai_apikey = os.environ.get("OPENAI_API_KEY")

In [2]:
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase

# Loading the Chinook database
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
# Loading the openai model that consumes this db
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

# SQL Dialect Problem
SQL dialects are essentially different versions of the SQL language, each tweaked to suit the peculiarities of individual database systems (SGBD). Therefore, an effective error handling should take into account these differences.

In [6]:
from langchain.chains.sql_database.prompt import SQL_PROMPTS

# List of supported dialects
print("Different SQL dialects: ", list(SQL_PROMPTS))

# Our database dialect
print("Our current dialect: ", db.dialect)

Different SQL dialects:  ['crate', 'duckdb', 'googlesql', 'mssql', 'mysql', 'mariadb', 'oracle', 'postgresql', 'sqlite', 'clickhouse', 'prestodb']
Our current dialect:  sqlite


# Feed the schema context
In order to be able to write valid queries, the model needs to know in details our database. So to give it a relevant context, we will feed it with the table names, their schemas, and a sample of rows from each table. 

If the database schema is too large, we can choose only specific tables that are mostly used by users (extracted from their query inputs). In our case, we are going to feed the entire schema for effective error handling.  

In [5]:
context = db.get_context()["table_info"]
print(context)


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

# Prompt Template Customization

## Zero-shot prompting
In this section, we need to tell our model the specific tasks to do and feed it with the right information (sql dialect, database schema, some db samples..). In this prompting strategy, we do not provide any example (zero-shot) and let the model decide by itself.

In [26]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

system_template = """
You are an expert SQL assistant with deep knowledge of database schemas and SQL dialects. 
Given this database schema context: {context}
Your main tasks are to:
1. Analyze in-depth the user's {dialect} query
2. Detect all errors in the queries according to the provided database schema, including:
- Using proper tables, columns and Correct Data types for the columns
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Properly quoting identifiers
- Using the correct number of arguments for functions
3. Identify the position where each error is detected (row:column) its their type (Lexical, Syntax, Semantic, Logical, Referential Integrity or others ...)
4. Suggest corrections to fix the errors.
Please ensure to tailor your responses based on the provided database schema context. 
You have to list and correct all detected errors (not only one) with detailed response, if no error found then return "The input query is correct" as the answer.
Each error response should be formatted with these following details:
>> Error in error_position_row:error_position_column 
Type: ... Type of Error ...
Error: ... description ...
Correction: ... description ...
Corrected SQL query in {dialect}:
```sql
    ...
```
You have access to tools for interacting with the database.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
"""
prompt_template = ChatPromptTemplate.from_messages(
    [
        ("system", system_template), 
        ("user", "{input}"),
        MessagesPlaceholder("agent_scratchpad")
    ]
).partial(dialect=db.dialect)

In [27]:
from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", prompt=prompt_template , verbose=True)
agent_executor.invoke(
    {
        "context" : db.get_context()["table_info"],
        "input": "SELECT * FROM employees WHERE (id = 1);"
    }
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query_checker` with `{'query': 'SELECT * FROM employees WHERE (id = 1);'}`


[0m[36;1m[1;3mSELECT * FROM employees WHERE id = 1;[0m[32;1m[1;3mError in 1:28 
Type: Semantic Error
Error: Table name 'employees' does not exist in the provided schema.
Correction: Replace 'employees' with 'Employee' to match the existing table name.
Corrected SQL query in sqlite:
```sql
SELECT * FROM Employee WHERE id = 1;
```[0m

[1m> Finished chain.[0m


{'context': '\nCREATE TABLE "Album" (\n\t"AlbumId" INTEGER NOT NULL, \n\t"Title" NVARCHAR(160) NOT NULL, \n\t"ArtistId" INTEGER NOT NULL, \n\tPRIMARY KEY ("AlbumId"), \n\tFOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")\n)\n\n/*\n3 rows from Album table:\nAlbumId\tTitle\tArtistId\n1\tFor Those About To Rock We Salute You\t1\n2\tBalls to the Wall\t2\n3\tRestless and Wild\t2\n*/\n\n\nCREATE TABLE "Artist" (\n\t"ArtistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("ArtistId")\n)\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3\tAerosmith\n*/\n\n\nCREATE TABLE "Customer" (\n\t"CustomerId" INTEGER NOT NULL, \n\t"FirstName" NVARCHAR(40) NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"Company" NVARCHAR(80), \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60) NOT NULL, \n\t"Support

# Prompt Template Customization

## Few-shot prompting
In this section, we need to tell our model the specific tasks to do and feed it with the right information (sql dialect, database schema, some db samples..). In this prompting strategy, we provide the model with a list of examples (few-shot) that can be either static (i.e already defined examples by the administrator) or dynamic (i.e by comparing the input query to the most similar errors --> RAG models). 

We will attempt in the following the dynamic approch : using a SemanticSimilarityExampleSelector, we will perform a semantic search using the embeddings and vector store we configure to find the examples most similar to the user's input.

In [None]:
%pip install faiss-cpu

In [33]:
from langchain_core.prompts import ChatPromptTemplate, FewShotPromptTemplate, MessagesPlaceholder, PromptTemplate, SystemMessagePromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings


# Examples based on the Chinook database schema with pre-processed flat structure
examples = [
    {
        "input": "SELEC * FROM artists;",
        "error_type": "Lexical Error",
        "error_position": "1:1",
        "error_message": "Misspelled keyword 'SELEC'.",
        "correction": "SELECT * FROM artists;"
    },
    {
        "input": "SELECT * artists;",
        "error_type": "Syntax Error",
        "error_position": "1:9",
        "error_message": "Missing 'FROM' keyword.",
        "correction": "SELECT * FROM artists;"
    },
    {
        "input": "SELECT non_existent_column FROM albums;",
        "error_type": "Semantic Error",
        "error_position": "1:8",
        "error_message": "Column 'non_existent_column' does not exist.",
        "correction": "SELECT AlbumId, Title FROM albums;"
    },
    {
        "input": "SELECT * FROM tracks WHERE TrackId = 1&;",
        "error_type": "Lexical Error",
        "error_position": "1:35",
        "error_message": "Unrecognized symbol '&'.",
        "correction": "SELECT * FROM tracks WHERE TrackId = 1;"
    },
    {
        "input": "SELECT * FROM customers WHERE (CustomerId = 1;",
        "error_type": "Syntax Error",
        "error_position": "1:39",
        "error_message": "Unmatched parentheses.",
        "correction": "SELECT * FROM customers WHERE (CustomerId = 1);"
    },
    {
        "input": "SELECT Name FROM artists GROUP BY ArtistId;",
        "error_type": "Semantic Error",
        "error_position": "1:1",
        "error_message": "Incorrect use of GROUP BY. 'Name' should be in the GROUP BY clause.",
        "correction": "SELECT Name, COUNT(*) FROM artists GROUP BY Name;"
    },
    {
        "input": "SELECT * FROM invoices WHERE InvoiceId = 'string_value';",
        "error_type": "Semantic Error",
        "error_position": "1:36",
        "error_message": "Incorrect data type for 'InvoiceId'. Expected integer.",
        "correction": "SELECT * FROM invoices WHERE InvoiceId = 1;"
    },
    {
        "input": "SELECT ArtistId FROM artists, albums WHERE ArtistId = 1;",
        "error_type": "Semantic Error",
        "error_position": "1:40",
        "error_message": "Ambiguous column 'ArtistId'. Specify the table name.",
        "correction": "SELECT artists.ArtistId FROM artists, albums WHERE artists.ArtistId = 1;"
    },
    {
        "input": "INSERT INTO invoices (CustomerId) VALUES (999);",
        "error_type": "Referential Integrity Error",
        "error_position": "1:1",
        "error_message": "Foreign key violation: 'CustomerId' does not exist in the referenced table.",
        "correction": "Ensure the foreign key value exists in the referenced table."
    },
    {
        "input": "SELECT * FROM employees WHERE EmployeeId > 10 AND EmployeeId < 5;",
        "error_type": "Logical Error",
        "error_position": "1:50",
        "error_message": "Logical error: No value can satisfy both conditions 'EmployeeId > 10' and 'EmployeeId < 5'.",
        "correction": "SELECT * FROM employees WHERE EmployeeId > 10 OR EmployeeId < 5;"
    },
]

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)

system_prefix = """
>> You are an SQL assistant with deep knowledge of database schemas and SQL dialects. \n
>> Your main tasks are to analyze user-provided SQL queries in {dialect}, detect any errors (lexical, syntax, semantic, referential integrity, logical), identify the position where each error is detected, and suggest corrections. 
Please ensure to tailor your responses based on the provided database schema context. 
Use this format for the output:\nError Type: ...\nError Position: row:column\nError Message: ...\nCorrection: ...\nFinal Corrected Query: ```sql\n...```\n
You have to list all detected errors (not only one) with detailed response ordered by their position, if no error found then return "The input query is correct" as the answer.
You have access to tools for interacting with the database.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
Here are some examples of user inputs and their corresponding SQL queries:
"""

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt = PromptTemplate.from_template(
        "Input:\n{input}\n\nOutput:\nError Type: {error_type}\nError Position: {error_position}\nError Message: {error_message}\nCorrection: {correction}"
    ),
    prefix=system_prefix,
    suffix=">> Now, given the following SQL query in {dialect}, analyze and correct it:\n\nUser Input:\n{input}\n\nUser Output:\n",
    input_variables=["input", "dialect", "context"],
)

full_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate(prompt=few_shot_prompt),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

# Example formatted prompt
prompt_val = full_prompt.invoke(
    {
        "input": "SELECT Name FROM artists GROUP BY ArtistId;",
        "dialect": db.dialect, # use the current dialect
        "agent_scratchpad": [],
    }
)
print(prompt_val.to_string())

System: 
>> You are an SQL assistant with deep knowledge of database schemas and SQL dialects. 

>> Your main tasks are to analyze user-provided SQL queries in sqlite, detect any errors (lexical, syntax, semantic, referential integrity, logical), identify the position where each error is detected, and suggest corrections. 
Please ensure to tailor your responses based on the provided database schema context. 
Use this format for the output:
Error Type: ...
Error Position: row:column
Error Message: ...
Correction: ...
Final Corrected Query: ```sql
...```

You have to list all detected errors (not only one) with detailed response ordered by their position, if no error found then return "The input query is correct" as the answer.
You have access to tools for interacting with the database.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
Here are some examples 

In [44]:
# Example of usage
from langchain_community.agent_toolkits import create_sql_agent

agent = create_sql_agent(
    llm=llm,
    db=db,
    prompt=full_prompt,
    verbose=True,
    agent_type="openai-tools",
)
res = agent.invoke({"input": "INSERT INTO orders (order_id, customer_id, order_date) VALUES (1001, 9999, '2023-01-01');"})
print(res)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_query_checker` with `{'query': "INSERT INTO orders (order_id, customer_id, order_date) VALUES (1001, 9999, '2023-01-01');"}`


[0m[36;1m[1;3mINSERT INTO orders (order_id, customer_id, order_date) VALUES (1001, 9999, '2023-01-01');[0m[32;1m[1;3mError Type: Referential Integrity Error
Error Position: 1:37
Error Message: Foreign key violation: 'customer_id' does not exist in the referenced table.
Correction: Ensure the foreign key value exists in the referenced table.
Final Corrected Query: ```sql
INSERT INTO orders (order_id, customer_id, order_date) VALUES (1001, 1, '2023-01-01');
```[0m

[1m> Finished chain.[0m
{'input': "INSERT INTO orders (order_id, customer_id, order_date) VALUES (1001, 9999, '2023-01-01');", 'output': "Error Type: Referential Integrity Error\nError Position: 1:37\nError Message: Foreign key violation: 'customer_id' does not exist in the referenced table.\nCorrection: Ensur