In [None]:
%pip install --upgrade --quiet  langchain langchain-community langchain-openai python-dotenv

In [52]:
import os
from dotenv import load_dotenv, find_dotenv
# load the .env variables
load_dotenv(find_dotenv())

# Importing the openai api key to use openai models
openai_apikey = os.environ.get("OPENAI_API_KEY")

In [53]:
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase

# Loading the Chinook database
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# Loading the openai model that consumes our previous db
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)

# SQL Dialect Problem
SQL dialects are essentially different versions of the SQL language, each tweaked to suit the peculiarities of individual database systems (SGBD). Therefore, an effective error handling should take into account these differences.

In [54]:
from langchain.chains.sql_database.prompt import SQL_PROMPTS

# List of supported dialects
print("Different SQL dialects: ", list(SQL_PROMPTS))

# Our database dialect
print("Our current dialect: ", db.dialect)

Different SQL dialects:  ['crate', 'duckdb', 'googlesql', 'mssql', 'mysql', 'mariadb', 'oracle', 'postgresql', 'sqlite', 'clickhouse', 'prestodb']
Our current dialect:  sqlite


# Feed the schema context
In order to be able to write valid queries, the model needs to know in details our database. So to give it a relevant context, we will feed it with the table names, their schemas, and a sample of rows from each table. 

If the database schema is too large, we can choose only specific tables that are mostly used by users (extracted from their query inputs). In our case, we are going to feed the entire schema for effective error handling.  

In [55]:
context = db.get_context()["table_info"]
print(context)


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

# Prompt Template Customization

## Zero-shot prompting
In this section, we need to tell our model the specific tasks to do and feed it with the right information (sql dialect, database schema, some db samples..). In this prompting strategy, we do not provide any example (zero-shot) and let the model decide by itself.

In [56]:
from langchain_core.prompts import ChatPromptTemplate

system_template = """
You are an expert SQL assistant with deep knowledge of database schemas and SQL dialects. 
Given this database schema context: {context}
Your main tasks are to:
1. Analyze in-depth the user's {dialect} query
2. Detect all errors in the queries according to the provided database schema, including:
- Using proper tables, columns and Correct Data types for the columns
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Properly quoting identifiers
- Using the correct number of arguments for functions
3. Identify the position where each error is detected (row:column) its their type (Lexical, Syntax, Semantic, Logical, Referential Integrity or others ...)
4. Suggest corrections to fix the errors.
Please ensure to tailor your responses based on the provided database schema context. 
If many errors exist, then list all details of each error in the same block then move to the next one, errors are ordered by the error's position ascending.
Each error block should be formatted as follow:
>> Error in error_position_row:error_position_column 
Type: ... Type of Error ...
Error: ... description ...
Correction: ... description ...
Corrected SQL query in {dialect}:
```sql
    ...
```
"""
prompt_template = ChatPromptTemplate.from_messages(
    [("system", system_template), ("user", "{query}")]
).partial(dialect=db.dialect)

In [58]:
from langchain_core.output_parsers import StrOutputParser

# The parser helps to extract the core text (string) from a model's response, discarding any extra information or metadata.  
parser = StrOutputParser()

# We create here a chain that goes from a customised prompt fed into the model and parse its reponse to the stdout
chain = prompt_template | llm | parser

# Testing the model with an input query, the context is our db schema
response = chain.invoke({
    "context" : db.get_context()["table_info"],
    "query": "SELECT * FROM employees WHERE (id = 1);"
})

print(response)

>> Error in 1:15 
Type: Semantic Error
Error: Table "employees" does not exist in the provided schema.
Correction: Use the correct table name "Employee" instead of "employees".

Corrected SQL query in sqlite:
```sql
SELECT * FROM Employee WHERE (EmployeeId = 1);
```


# Prompt Template Customization

## Few-shot prompting
In this section, we need to tell our model the specific tasks to do and feed it with the right information (sql dialect, database schema, some db samples..). In this prompting strategy, we provide the model with a list of examples (few-shot) that can be either static (i.e already defined examples by the administrator) or dynamic (i.e by comparing the input query to the most similar errors --> RAG models). We will attempt in the following the static approch.

In [76]:
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate

# Examples based on the Chinook database schema with pre-processed flat structure
flat_examples = [
    {
        "input": "SELEC * FROM artists;",
        "error_type": "Lexical Error",
        "error_position": "1:1",
        "error_message": "Misspelled keyword 'SELEC'.",
        "correction": "SELECT * FROM artists;"
    },
    {
        "input": "SELECT * artists;",
        "error_type": "Syntax Error",
        "error_position": "1:9",
        "error_message": "Missing 'FROM' keyword.",
        "correction": "SELECT * FROM artists;"
    },
    {
        "input": "SELECT non_existent_column FROM albums;",
        "error_type": "Semantic Error",
        "error_position": "1:8",
        "error_message": "Column 'non_existent_column' does not exist.",
        "correction": "SELECT AlbumId, Title FROM albums;"
    },
    {
        "input": "SELECT * FROM tracks WHERE TrackId = 1&;",
        "error_type": "Lexical Error",
        "error_position": "1:35",
        "error_message": "Unrecognized symbol '&'.",
        "correction": "SELECT * FROM tracks WHERE TrackId = 1;"
    },
    {
        "input": "SELECT * FROM customers WHERE (CustomerId = 1;",
        "error_type": "Syntax Error",
        "error_position": "1:39",
        "error_message": "Unmatched parentheses.",
        "correction": "SELECT * FROM customers WHERE (CustomerId = 1);"
    },
    {
        "input": "SELECT Name FROM artists GROUP BY ArtistId;",
        "error_type": "Semantic Error",
        "error_position": "1:1",
        "error_message": "Incorrect use of GROUP BY. 'Name' should be in the GROUP BY clause.",
        "correction": "SELECT Name, COUNT(*) FROM artists GROUP BY Name;"
    },
    {
        "input": "SELECT * FROM invoices WHERE InvoiceId = 'string_value';",
        "error_type": "Semantic Error",
        "error_position": "1:36",
        "error_message": "Incorrect data type for 'InvoiceId'. Expected integer.",
        "correction": "SELECT * FROM invoices WHERE InvoiceId = 1;"
    },
    {
        "input": "SELECT ArtistId FROM artists, albums WHERE ArtistId = 1;",
        "error_type": "Semantic Error",
        "error_position": "1:40",
        "error_message": "Ambiguous column 'ArtistId'. Specify the table name.",
        "correction": "SELECT artists.ArtistId FROM artists, albums WHERE artists.ArtistId = 1;"
    },
    {
        "input": "INSERT INTO invoices (CustomerId) VALUES (999);",
        "error_type": "Referential Integrity Error",
        "error_position": "1:1",
        "error_message": "Foreign key violation: 'CustomerId' does not exist in the referenced table.",
        "correction": "Ensure the foreign key value exists in the referenced table."
    },
    {
        "input": "SELECT * FROM employees WHERE EmployeeId > 10 AND EmployeeId < 5;",
        "error_type": "Logical Error",
        "error_position": "1:50",
        "error_message": "Logical error: No value can satisfy both conditions 'EmployeeId > 10' and 'EmployeeId < 5'.",
        "correction": "SELECT * FROM employees WHERE EmployeeId > 10 OR EmployeeId < 5;"
    },
]

example_prompt = PromptTemplate.from_template(
    "Input:\n{input}\n\nOutput:\nError Type: {error_type}\nError Position: {error_position}\nError Message: {error_message}\nCorrection: {correction}"
)

few_shot_prompt_template = FewShotPromptTemplate(
    examples=flat_examples,
    example_prompt=example_prompt,
    prefix=">> You are an SQL assistant with deep knowledge of database schemas and SQL dialects, specifically with this database schema context: {context}. \n\n>> Your main tasks are to analyze user-provided SQL queries in {dialect}, detect any errors (lexical, syntax, semantic, referential integrity, logical), identify the position where each error is detected, and suggest corrections. Use this format for the output:\nError Type: ...\nError Position: row:column\nError Message: ...\nCorrection: ...;\n\n>>Here are some examples:",
    suffix=">> Now, given the following SQL query in {dialect}, analyze and correct it:\n\nUser Input:\n{input}\n\nUser Output:\n",
    input_variables=["input", "dialect", "context"],
)


In [77]:
# Example of a what a formatted few-shot prompt means
input_query = "SELEC name FRM artists;"

formatted_prompt = few_shot_prompt_template.format(
    input = input_query,
    dialect = db.dialect, # use the current dialect
    context = db.get_context()["table_info"]
)
print(formatted_prompt)

>> You are an SQL assistant with deep knowledge of database schemas and SQL dialects, specifically with this database schema context: 
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 


In [78]:
# Example of usage

from langchain_core.output_parsers import StrOutputParser

# The parser helps to extract the core text (string) from a model's response, discarding any extra information or metadata.  
parser = StrOutputParser()

# We create here a chain that goes from a customised prompt fed into the model and parse its reponse to the stdout
chain = few_shot_prompt_template | llm | parser

# Testing the model with an input query, the context is our db schema
response = chain.invoke({
    "input" : "SELECT Name FROM artists GROUP BY ArtistId;",
    "dialect" : db.dialect, # use the current dialect
    "context" : db.get_context()["table_info"]
})

print(response)

Error Type: Semantic Error
Error Position: 1:1
Error Message: Incorrect use of GROUP BY. 'Name' should be in the GROUP BY clause.
Correction: SELECT Name, ArtistId FROM artists GROUP BY ArtistId;
