# **Packages' Import**

In [1]:
import os
import re
import ast
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain.schema import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_ollama import OllamaLLM
from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_ollama import OllamaEmbeddings

# **Set environement variables**

In [2]:
os.environ["GOOGLE_API_KEY"] = "AIzaSyA47Wh2L_klptRMuclUdwUre47TAAUtXjs"
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = "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"

In [3]:
db_user = "postgres"
db_password = "postgres"
db_host = "localhost"
db_name = "Mydatabase"

In [4]:
server = "ASUS"
database = "BikesDB"
username = "admin"
password = "admin"

In [5]:
import pyodbc
connection_string = f"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={server};DATABASE={database};UID={username};PWD={password};Trusted_Connection=yes"
#db = SQLDatabase.FROM_uri(f"mssql+pyodbc:///?odbc_connect={connection_string}")
conn = pyodbc.connect(connection_string)
db = conn.cursor()

# NL2SQL Langchain

In [6]:
# Initiate db connection
#connection_string = f"DRIVER={{ODBC Driver 17 for SQL Server}};SERVER={server};DATABASE={database};UID={username};PWD={password};Trusted_Connection=yes"
db = SQLDatabase.from_uri(f"postgresql+psycopg2://{db_user}:{db_password}@{db_host}/{db_name}")
#db = SQLDatabase.from_uri(f"mssql+pyodbc:///?odbc_connect={connection_string}")

In [7]:
# Database details
db_type = db.dialect
table_names = db.get_usable_table_names()
db_info = db.get_table_info()

In [8]:
print(db_info)


CREATE TABLE customers (
	customernumber INTEGER NOT NULL, 
	customername VARCHAR(50) NOT NULL, 
	contactlastname VARCHAR(50) NOT NULL, 
	contactfirstname VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	addressline1 VARCHAR(50) NOT NULL, 
	addressline2 VARCHAR(50) DEFAULT NULL::character varying, 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50) DEFAULT NULL::character varying, 
	postalcode VARCHAR(15) DEFAULT NULL::character varying, 
	country VARCHAR(50) NOT NULL, 
	salesrepemployeenumber INTEGER, 
	creditlimit NUMERIC(10, 2) DEFAULT NULL::numeric, 
	CONSTRAINT customers_pkey PRIMARY KEY (customernumber), 
	CONSTRAINT customers_salesrepemployeenumber_fkey FOREIGN KEY(salesrepemployeenumber) REFERENCES employees (employeenumber)
)

/*
3 rows from customers table:
customernumber	customername	contactlastname	contactfirstname	phone	addressline1	addressline2	city	state	postalcode	country	salesrepemployeenumber	creditlimit
103	Atelier graphique	Schmitt	Carine 	40.32.2555	54, rue Roya

In [9]:
# Query to get tables and columns information
query = """
    SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE
    FROM INFORMATION_SCHEMA.COLUMNS
    ORDER BY TABLE_NAME, ORDINAL_POSITION;
    """

query = """
    SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE
    FROM INFORMATION_SCHEMA.COLUMNS
    WHERE TABLE_SCHEMA = 'public'
    ORDER BY TABLE_NAME, ORDINAL_POSITION;
   """
# Execute the query to get metadata
execute_query = QuerySQLDatabaseTool(db=db)
metadata = execute_query.invoke(query)

# Safely convert metadata string to list of tuples
metadata = ast.literal_eval(metadata)
print(metadata)


[('customers', 'customernumber', 'integer'), ('customers', 'customername', 'character varying'), ('customers', 'contactlastname', 'character varying'), ('customers', 'contactfirstname', 'character varying'), ('customers', 'phone', 'character varying'), ('customers', 'addressline1', 'character varying'), ('customers', 'addressline2', 'character varying'), ('customers', 'city', 'character varying'), ('customers', 'state', 'character varying'), ('customers', 'postalcode', 'character varying'), ('customers', 'country', 'character varying'), ('customers', 'salesrepemployeenumber', 'integer'), ('customers', 'creditlimit', 'numeric'), ('employees', 'employeenumber', 'integer'), ('employees', 'lastname', 'character varying'), ('employees', 'firstname', 'character varying'), ('employees', 'extension', 'character varying'), ('employees', 'email', 'character varying'), ('employees', 'officecode', 'character varying'), ('employees', 'reportsto', 'integer'), ('employees', 'jobtitle', 'character var

In [10]:
# Format the metadata into a structured dictionary
tables_info = {}
for row in metadata:

    # Map each row to a dictionary
    table_name = row[0]
    column_name = row[1]
    data_type = row[2]

    # Add to the dictionary
    if table_name not in tables_info:
        tables_info[table_name] = []

    column_info = {
        "name": column_name,
        "type": data_type
    }
    tables_info[table_name].append(column_info)

# Output the JSON result
print(tables_info)

{'customers': [{'name': 'customernumber', 'type': 'integer'}, {'name': 'customername', 'type': 'character varying'}, {'name': 'contactlastname', 'type': 'character varying'}, {'name': 'contactfirstname', 'type': 'character varying'}, {'name': 'phone', 'type': 'character varying'}, {'name': 'addressline1', 'type': 'character varying'}, {'name': 'addressline2', 'type': 'character varying'}, {'name': 'city', 'type': 'character varying'}, {'name': 'state', 'type': 'character varying'}, {'name': 'postalcode', 'type': 'character varying'}, {'name': 'country', 'type': 'character varying'}, {'name': 'salesrepemployeenumber', 'type': 'integer'}, {'name': 'creditlimit', 'type': 'numeric'}], 'employees': [{'name': 'employeenumber', 'type': 'integer'}, {'name': 'lastname', 'type': 'character varying'}, {'name': 'firstname', 'type': 'character varying'}, {'name': 'extension', 'type': 'character varying'}, {'name': 'email', 'type': 'character varying'}, {'name': 'officecode', 'type': 'character vary

In [11]:
# Define system prompt
system_prompt = PromptTemplate.from_template(
    """
    You are a {db_type} expert. Given an input question, first create a syntactically correct {db_type} query.

    Pay attention to use only the column names you can see in the tables below. Also, pay attention to which column is in which table.

    Use CURRENT_DATE function in the returned SQL Query if the question involves "today".

    Generate the SQL query without trying to run it or making assumptions about missing data.

    Only use the following tables information in JSON format, with each entry containing the table name as the key and a list of its column names along with their types as values :
    {tables_info}

    Question: {question}

    In case you couldn't generate a correct SQL query ask the user for more details or to reformulate his question using the language of the user's question.

    Respond by only giving the generated SQL query with no extra explanations or details in this format <SQL>SQLQUERY</> and don't try to run it or expect results!

    """
)

In [12]:
# Load Model
llm = OllamaLLM(model="mistral")
# Generate query chain
generate_query = (
    RunnablePassthrough.assign(db_type=lambda x: db_type, tables_info=lambda x: tables_info)
    | system_prompt
    | llm
)

# Invoke query
query_result = generate_query.invoke({"question": "Quel est le prix du 'Trek Conduit+ - 2016' ?"})
print(query_result)

 <SQL>SELECT products.msrp FROM products WHERE products.productname = 'Trek Conduit+ - 2016';</SQL>


In [13]:
# Function to extract query
def extract_sql_query(text: str) -> str:
    query = re.sub(r'<.*?>', '', text)
    return query

In [14]:
query = extract_sql_query(query_result)
execute_query = QuerySQLDatabaseTool(db=db)
execute_query.invoke(query)

''

In [15]:
chain = generate_query | StrOutputParser() | extract_sql_query| execute_query
chain.invoke({"question": "what is the name of the most expensive product ?"})

"[('1952 Alpine Renault 1300',)]"

In [16]:
def verify_sql_query(sql_query: str) -> str:
    """
    Verify if the SQL query contains restricted commands (UPDATE, DELETE, INSERT, DROP).
    If it does, return a permission error message.

    :param sql_query: The SQL query to verify
    :return: A response message if restricted commands are found, otherwise an empty string
    """
    restricted_commands = ["UPDATE", "DELETE", "INSERT", "DROP"]
    if re.search(r"\b(" + "|".join(restricted_commands) + r")\b", sql_query, re.IGNORECASE):

        message = "You don't have permission to execute such an action on the selected database."

        return message

    return sql_query

In [17]:
answer_prompt = PromptTemplate.from_template(
    """
    Given the following user question, corresponding SQL query, and SQL result, answer the user question.

    This is the user Question: {question}
    and this is the SQL Result: {result}
    Answer:
    Provide a concise response directly answering the user's question in the language of the question, based solely on the SQL result.
    Respond in one line with no extra explanations or details, using the language of the user's question. Also PLEASE REMEMBER TO RESPOND EITHER IN FRENCH OR ENGLISH MAKE SURE TO CHECK THE LANGUAGE THE QUESTION USER USED.

    """
)
rephrase_answer = answer_prompt | llm | StrOutputParser()
chain_v1 = (
    RunnablePassthrough.assign(query=generate_query)
    .assign(validated_query=lambda x: verify_sql_query(extract_sql_query(x["query"])))
    .assign(result=lambda x: execute_query.invoke(x["validated_query"]) if "permission" not in x["validated_query"] else x["validated_query"])
    | rephrase_answer
)
response = chain_v1.invoke({"question": "Update the price of product Trek Conduit+ - 2016 with price 10000 euros"})
print(response)

 I'm sorry, but you don't have the necessary permissions to update the price of Trek Conduit+ - 2016 in this database.


In [18]:
response = chain_v1.invoke({"question": "Mettre à jour le prix du produit 'Trek Conduit+ - 2016' avec le prix 10000 euros"})
print(response)

 Nous n'avons pas les droits nécessaires pour mettre à jour le prix du produit 'Trek Conduit+ - 2016' à 10000 euros sur la base de données sélectionnée.


# Add few-shot examples

In [19]:
examples = [
    {
        "question": "Listez tous les produits dont le prix est supérieur à 100.",
        "query": "<SQL>SELECT product_name, listprice FROM products WHERE listprice > 100;</>"
    },
    {
        "question": "Affichez les commandes passées après le 1er janvier 2023.",
        "query": "<SQL>SELECT order_id, order_date FROM orders WHERE order_date > '2023-01-01';</>"
    },
    {
        "question": "Obtenez la liste des employés travaillant dans le département 'Baldwin Bikes'.",
        "query": "<SQL>SELECT s.first_name + ' ' + s.last_name AS staff_name FROM staffs s JOIN stores st ON s.store_id = st.store_id WHERE st.store_name = 'Baldwin Bikes';</>"
    },
    {
        "question": "Affichez le chiffre d'affaires total par client.",
        "query": "<SQL>SELECT o.customer_id, SUM(oi.quantity * oi.list_price * (1 - oi.discount)) AS total_revenue FROM order_items oi JOIN orders o ON oi.order_id = o.order_id GROUP BY o.customer_id ORDER BY total_revenue DESC;</>"
    },
    {
        "question": "Trouvez les clients qui n'ont jamais passé de commande.",
        "query": "<SQL>SELECT * FROM customers WHERE customer_id NOT IN (SELECT DISTINCT customer_id FROM orders);</>"
    },
    {
        "question": "Affichez les 5 produits les plus vendus.",
        "query": "<SQL>SELECT TOP 5 product_id, SUM(quantity) AS total_quantity FROM order_items GROUP BY product_id ORDER BY total_quantity DESC;</>"
    },
    {
        "question": "Obtenez la liste des departements basés dans la ville Rowlett.",
        "query": "<SQL>SELECT store_name FROM stores WHERE city = 'Rowlett';</>"
    },
    {
        "question": "Affichez le nombre total de commandes passées par mois.",
        "query": "<SQL>SELECT DATE_TRUNC('month', order_date) AS mois, COUNT(*) AS nombre_commandes FROM orders GROUP BY mois ORDER BY mois;</>"
    },
    {
        "question": "Trouvez le produit le plus cher.",
        "query": "<SQL>SELECT TOP 1 product_name, list_price  FROM products ORDER BY list_price DESC;</>"
    },
    {
        "question": "Affichez la liste des clients avec leur nombre de commandes.",
        "query": "<SQL>SELECT customers.customer_id, customers.name, COUNT(orders.order_id) AS nombre_commandes FROM customers LEFT JOIN orders ON customers.customer_id = orders.customer_id GROUP BY customers.customer_id, customers.name;</>"
    },
    {
        "question": "Affichez la liste des commandes avec les détails des produits achetés.",
        "query": "<SQL>SELECT orders.order_id, products.product_name, order_items.quantity, order_items.list_price FROM orders JOIN order_items ON orders.order_id = order_items.order_id JOIN products ON order_items.product_id = products.product_id;</>"
    },
    {
        "question": "Affichez la liste des clients ayant dépensé plus de 5000 en total.",
        "query": "<SQL>SELECT o.customer_id, SUM(oi.quantity * oi.list_price * (1 - oi.discount)) AS total_spent FROM order_items oi JOIN orders o ON oi.order_id = o.order_id GROUP BY o.customer_id HAVING SUM(oi.quantity * oi.list_price * (1 - oi.discount)) > 5000 ORDER BY total_spent DESC;</>"
    },
    {
        "question": "Trouvez les 3 clients ayant passé le plus grand nombre de commandes.",
        "query": "<SQL>SELECT TOP 3 customer_id, COUNT(order_id) AS total_commandes FROM orders GROUP BY customer_id ORDER BY total_commandes DESC;</>"
    },
    {
        "question": "Listez les produits dont la quantité en stock est inférieure à 10.",
        "query": "<SQL>SELECT product_id, quantity FROM stocks WHERE quantity < 10;</>"
    },
    {
        "question": "Affichez la moyenne des prix des produits par catégorie.",
        "query": "<SQL>SELECT  c.category_name, AVG(p.list_price) AS prix_moyen FROM products p JOIN categories c ON p.category_id = c.category_id GROUP BY c.category_id, c.category_name;</>"
    },
    {
        "question": "Trouvez le client ayant passé la commande la plus chère.",
        "query": "<SQL>SELECT TOP 1 o.customer_id, SUM(oi.quantity * oi.list_price * (1 - oi.discount)) AS total_order_price FROM order_items oi JOIN orders o ON oi.order_id = o.order_id GROUP BY o.customer_id ORDER BY total_order_price DESC;</>"
    },
    {
        "question": "Listez les commandes livrées en retard.",
        "query": "<SQL>SELECT order_date, shipped_date, required_date FROM orders WHERE shipped_date > required_date;</>"
    },
    {
        "question": "Affichez les employés avec leur nombre total de ventes.",
        "query": "<SQL>SELECT s.staff_id, s.first_name + ' ' + s.last_name AS staff_name, COUNT(o.order_id) AS total_ventes FROM staffs s JOIN orders o ON s.staff_id = o.staff_id GROUP BY s.staff_id, s.first_name, s.last_name;</>"
    }
]



In [20]:
examples = [
    {
        "question": "Listez tous les clients avec une limite de crédit supérieure à 50000.",
        "query": "<SQL>SELECT customername, creditlimit FROM customers WHERE creditlimit > 50000;</>"
    },
    {
        "question": "Affichez les employés qui rapportent à l'employé numéro 1002",
        "query": "<SQL>SELECT employeenumber, lastname, firstname FROM employees WHERE reportsto = 1002;</>"
    },
    {
        "question": "Listez tous les bureaux situés aux États-Unis.",
        "query": "<SQL>SELECT officecode, city, state FROM offices WHERE country = 'USA';</>"
    },
    {
        "question": "Affichez toutes les commandes dont la date de commande est après le 1er janvier 2023.",
        "query": "<SQL>SELECT ordernumber, orderdate FROM orders WHERE orderdate > '2023-01-01';</>"
    },
    {
        "question": "Affichez les paiements effectués pour un client spécifique (ex: client numéro 103).",
        "query": "<SQL>SELECT checknumber, paymentdate, amount FROM payments WHERE customernumber = 103;</>"
    },
    {
        "question": "Listez tous les produits appartenant à la ligne de produits 'Classic Cars'.",
        "query": "<SQL>SELECT productcode, productname FROM products WHERE productline = 'Classic Cars';</>"
    },
    {
        "question": "Affichez les détails de la commande numéro 10100.",
        "query": "<SQL>SELECT productcode, quantityordered, priceeach FROM orderdetails WHERE ordernumber = 10100;</>"
    },
    {
        "question": "Trouvez le nombre total de commandes passées par chaque client.",
        "query": "<SQL>SELECT customernumber, COUNT(ordernumber) AS total_orders FROM orders GROUP BY customernumber;</>"
    },
    {
        "question": "Affichez la liste des employés et leurs supérieurs hiérarchiques.",
        "query": "<SQL>SELECT e1.employeenumber, e1.firstname, e1.lastname, e2.firstname AS manager_firstname, e2.lastname AS manager_lastname FROM employees e1 LEFT JOIN employees e2 ON e1.reportsto = e2.employeenumber;</>"
    },
    {
        "question": "Listez les produits avec un stock inférieur à 10 unités.",
        "query": "<SQL>SELECT productcode, productname, quantityinstock FROM products WHERE quantityinstock < 10;</>"
    },
    {
    "question": "Affichez le produit le plus vendu.",
    "query": "<SQL>SELECT TOP 1 oi.product_id, p.product_name, SUM(oi.quantity) AS total_quantity FROM order_items oi JOIN products p ON oi.product_id = p.product_id GROUP BY oi.product_id, p.product_name ORDER BY total_quantity DESC;</>"
    }
]


In [21]:
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate

example_prompt = ChatPromptTemplate.from_template("{question}\n {query}")

few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt= example_prompt,
    examples= examples,
    input_variables= ["question"]
)
print(few_shot_prompt.format(question="Combien y a-t-il de commandes ?", top_k=3, table_info=""))

Human: Listez tous les clients avec une limite de crédit supérieure à 50000.
 <SQL>SELECT customername, creditlimit FROM customers WHERE creditlimit > 50000;</>
Human: Affichez les employés qui rapportent à l'employé numéro 1002
 <SQL>SELECT employeenumber, lastname, firstname FROM employees WHERE reportsto = 1002;</>
Human: Listez tous les bureaux situés aux États-Unis.
 <SQL>SELECT officecode, city, state FROM offices WHERE country = 'USA';</>
Human: Affichez toutes les commandes dont la date de commande est après le 1er janvier 2023.
 <SQL>SELECT ordernumber, orderdate FROM orders WHERE orderdate > '2023-01-01';</>
Human: Affichez les paiements effectués pour un client spécifique (ex: client numéro 103).
 <SQL>SELECT checknumber, paymentdate, amount FROM payments WHERE customernumber = 103;</>
Human: Listez tous les produits appartenant à la ligne de produits 'Classic Cars'.
 <SQL>SELECT productcode, productname FROM products WHERE productline = 'Classic Cars';</>
Human: Affichez le

# Dynamic few-shot example selection

In [22]:
embedding_model = OllamaEmbeddings(model="mistral")
vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    embedding_model,
    vectorstore,
    k=2,
    input_keys=["question"]

)
example_selector.select_examples({"question": "Afficher les top 5 produits"})

[{'query': '<SQL>SELECT employeenumber, lastname, firstname FROM employees WHERE reportsto = 1002;</>',
  'question': "Affichez les employés qui rapportent à l'employé numéro 1002"},
 {'query': '<SQL>SELECT checknumber, paymentdate, amount FROM payments WHERE customernumber = 103;</>',
  'question': 'Affichez les paiements effectués pour un client spécifique (ex: client numéro 103).'}]

In [23]:
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["question"]
)
print(few_shot_prompt.format(question="Afficher les top 5 produits", top_k=2, table_info=""))

Human: Listez tous les produits appartenant à la ligne de produits 'Classic Cars'.
 <SQL>SELECT productcode, productname FROM products WHERE productline = 'Classic Cars';</>
Human: Listez tous les bureaux situés aux États-Unis.
 <SQL>SELECT officecode, city, state FROM offices WHERE country = 'USA';</>


In [24]:
final_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a {db_type} expert. Given an input question, first create a syntactically correct {db_type} query.\n\n"
               "Pay attention to use only the column names you can see in the tables below. Also, pay attention to which column is in which table.\n\n"
               "Use the CURRENT_DATE function in the returned SQL query if the question involves 'today.'\n\n"
               "Note: The provided data is only a sample and does not represent the full database. Assume the database contains a complete dataset.\n"
               "Generate the SQL query without trying to run it or making assumptions about missing data.\n\n"
               "Only use the following tables:\n"
               "{tables_info}\n\n"
               "Important: Respond by only giving the generated SQL query with no extra explanations or details in this format <SQL>SQLQUERY</> and don't try to run it or expect results!\n\n"
               "Below are a number of examples of questions and their corresponding SQL queries:"
    ),
    few_shot_prompt,
    ("human", "{question}")
])

# Example usage
formatted_prompt = final_prompt.format(
    question="Quel est le prix du 'Trek Conduit+ - 2016' ?",
    tables_info=tables_info,
    db_type=db_type
)

print(formatted_prompt)


System: You are a postgresql expert. Given an input question, first create a syntactically correct postgresql query.

Pay attention to use only the column names you can see in the tables below. Also, pay attention to which column is in which table.

Use the CURRENT_DATE function in the returned SQL query if the question involves 'today.'

Note: The provided data is only a sample and does not represent the full database. Assume the database contains a complete dataset.
Generate the SQL query without trying to run it or making assumptions about missing data.

Only use the following tables:
{'customers': [{'name': 'customernumber', 'type': 'integer'}, {'name': 'customername', 'type': 'character varying'}, {'name': 'contactlastname', 'type': 'character varying'}, {'name': 'contactfirstname', 'type': 'character varying'}, {'name': 'phone', 'type': 'character varying'}, {'name': 'addressline1', 'type': 'character varying'}, {'name': 'addressline2', 'type': 'character varying'}, {'name': 'cit

In [22]:
generate_query = (
    RunnablePassthrough.assign(db_type=lambda x: db_type, tables_info=lambda x: tables_info)
    | final_prompt
    | llm
)
query_result = generate_query.invoke({"question": "Quel est le prix du 'Trek Conduit+ - 2016' ?"})
print(query_result)

 <SQL>SELECT list_price FROM products WHERE product_name = 'Trek Conduit+ - 2016';</>


In [23]:
chain_v2 = (
    RunnablePassthrough.assign(query=generate_query)
    .assign(validated_query=lambda x: verify_sql_query(extract_sql_query(x["query"])))
    .assign(result=lambda x: execute_query.invoke(x["validated_query"]) if "permission" not in x["validated_query"] else x["validated_query"])
    | rephrase_answer
)
response = chain_v2.invoke({"question": "Quel est le prix du 'Trek Conduit+ - 2016' ?"})
print(response)

 Le prix du 'Trek Conduit+ - 2016' est de 25000,00 €.


In [24]:
response = chain_v2.invoke({"question": "Supprimez le produit 'Trek Conduit+ - 2016' ?"})
print(response)

 Je suis désolé, mais je ne peux pas supprimer le produit 'Trek Conduit+ - 2016' pour le moment. (I'm sorry, but I can't delete the product 'Trek Conduit+ - 2016' for now.)


# Dynamic relevant table selection

In [25]:
select_table_prompt = PromptTemplate.from_template(
    """
    You are SQL expert. Given an input question, select the relevant SQL tables that may be relevant for the user question.

    Question: {question}
    Table names: {table_names}

    Respond by returning the selected tables names in this format:
    TABLE_NAME1,TABLE_NAME2..

    """
)

# Select tables chain
select_tables = (
    RunnablePassthrough.assign(table_names=lambda x: table_names)
    | select_table_prompt
    | llm
)

# Invoke query
tables_result = select_tables.invoke({"question": "Trouvez le client ayant passé la commande la plus chère."})
print(tables_result)

 customers, orders, order_items, products


In [26]:
def filter_table_info(selected_tables: str) -> dict:
    """
    Filter the tables_info dictionary to only include selected table names.

    :param selected_tables: A string containing table names separated by commas, e.g. "orders,customers"
    :return: A dictionary with filtered table info
    """
    # Convert the comma-separated table names string to a list
    selected_tables_list = [table.strip() for table in selected_tables.split(',')]

    # Filter the tables_info dictionary based on the selected table names
    filtered_info = {table: tables_info[table] for table in selected_tables_list if table in tables_info}

    return filtered_info

In [27]:
generate_query = (
    RunnablePassthrough.assign(table_names=select_tables)
    .assign(db_type=lambda x: db_type, tables_info=lambda x: filter_table_info(x["table_names"]))
    | final_prompt
    | llm
)
query_result = generate_query.invoke({"question": "Quel est le prix du 'Trek Conduit+ - 2016' ?"})
print(query_result)

 <SQL>SELECT p.price FROM products p WHERE p.name = 'Trek Conduit+ - 2016';</>


In [36]:
chain_v3 = (
    RunnablePassthrough.assign(query=generate_query)
    .assign(validated_query=lambda x: verify_sql_query(extract_sql_query(x["query"])))
    .assign(result=lambda x: execute_query.invoke(x["validated_query"]) if "permission" not in x["validated_query"] else x["validated_query"])
    | rephrase_answer
)
response = chain_v3.invoke({"question": "Quel est le prix du 'Trek Conduit+ - 2016' ?"})
print(response)

 Le prix du 'Trek Conduit+ - 2016' est de 25000,00.


# Add memory to the chatbot

In [29]:
from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()
question = "Combien y a-t-il des clients avec nombre des commandes supérieur à 5 ?"
response = chain_v3.invoke({"question": question , "messages": history.messages})
response

' Il y a 1441 clients avec plus de 5 commandes.'

In [30]:
history.add_user_message(question)
history.add_ai_message(response)
response = chain_v3.invoke({"question": "Pouvez vous lister leurs noms?" , "messages": history.messages})
response

' Les noms de personnes sont : Agatha, Alexandria, Angelika, Annett, Carman, Carlena, César, Cesar, Dorothea, Dewayne, Edris, Elana, Eliseo, Elmo, Ernie, Genoveva, Inger, Ivan, Jasper, Justina, Latashia, Lurlene, Mercedez, Miranda, Ophelia, Pandora, Renna, Rona, Sheryl, Sonja, Thad, Ulysse, Wynona, Yvonne.'

In [40]:
history = ChatMessageHistory()
question = "Combien y a-t-il des clients avec nombre des commandes supérieur à 5 ?"
response = chain_v2.invoke({"question": question , "messages": history.messages})
response

' Il y a 25 clients avec nombre des commandes supérieur à 5.'

In [41]:
history.add_user_message(question)
history.add_ai_message(response)
response = chain_v2.invoke({"question": "Pouvez vous lister leurs noms?" , "messages": history.messages})
response

' Voici les noms des personnes : Agatha, Alexandria, Annett, Angelika, Carlena, Carman, Dorothea, Destiny, Eleanor, Elana, Eliseo, Genoveva, Harper, Inger, Ivette, Jasper, Justina, Lamar, Lurlene, Mark, Miranda, Ophelia, Sonja, Tempie, Ulrike.'