In [1]:
import os
from dotenv import load_dotenv
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")


In [3]:
%pip install langchain_community

Collecting langchain_community
  Downloading langchain_community-0.2.12-py3-none-any.whl (2.3 MB)
     ---------------------------------------- 0.0/2.3 MB ? eta -:--:--
      --------------------------------------- 0.0/2.3 MB ? eta -:--:--
      --------------------------------------- 0.0/2.3 MB ? eta -:--:--
     - -------------------------------------- 0.1/2.3 MB 651.6 kB/s eta 0:00:04
     - -------------------------------------- 0.1/2.3 MB 595.3 kB/s eta 0:00:04
     --- ------------------------------------ 0.2/2.3 MB 1.0 MB/s eta 0:00:03
     --- ------------------------------------ 0.2/2.3 MB 1.0 MB/s eta 0:00:03
     ---- ----------------------------------- 0.2/2.3 MB 793.0 kB/s eta 0:00:03
     ---- ----------------------------------- 0.3/2.3 MB 711.1 kB/s eta 0:00:03
     ------ --------------------------------- 0.4/2.3 MB 908.0 kB/s eta 0:00:03
     -------- ------------------------------- 0.5/2.3 MB 1.0 MB/s eta 0:00:02
     -------- ------------------------------- 0.5/2.3 M


[notice] A new release of pip is available: 23.0.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
%pip install pymysql

Collecting pymysql
  Downloading PyMySQL-1.1.1-py3-none-any.whl (44 kB)
     ---------------------------------------- 0.0/45.0 kB ? eta -:--:--
     ------------------------------------ --- 41.0/45.0 kB 2.0 MB/s eta 0:00:01
     ------------------------------------ --- 41.0/45.0 kB 2.0 MB/s eta 0:00:01
     ------------------------------------ --- 41.0/45.0 kB 2.0 MB/s eta 0:00:01
     ------------------------------------ --- 41.0/45.0 kB 2.0 MB/s eta 0:00:01
     ------------------------------------ --- 41.0/45.0 kB 2.0 MB/s eta 0:00:01
     -------------------------------------- 45.0/45.0 kB 138.9 kB/s eta 0:00:00
Installing collected packages: pymysql
Successfully installed pymysql-1.1.1
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.0.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [6]:
from langchain_community.utilities.sql_database import SQLDatabase
db_user = "root"
db_password = "root"
db_host = "localhost"  
db_port = "3306"  
db_name = "classicmodels"
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

In [7]:
print(db.dialect)

mysql


In [8]:
print(db.get_usable_table_names())

['customers', 'employees', 'offices', 'orderdetails', 'orders', 'payments', 'productlines', 'products']


In [9]:
print(db.table_info)


CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	salesRepEmployeeNumber	creditLimit
103	Atelier graphique	Schmitt	Carine 	40.32.2555	54, rue Royale	None	Nantes	None	44000	France	1370	21000.00
112	Signal Gift Stores	King	Je

In [10]:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
generate_query = create_sql_query_chain(llm , db)
query = generate_query.invoke({"question": "what is price of `1968 Ford Mustang`"})
query

"SELECT `buyPrice`, `MSRP`\nFROM products\nWHERE `productName` = '1968 Ford Mustang'\nLIMIT 1;"

In [11]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query_tool = QuerySQLDataBaseTool(db = db)
execute_query_tool.invoke(query)

"[(Decimal('95.34'), Decimal('194.57'))]"

In [13]:
 from operator import itemgetter

 from langchain_core.output_parsers import StrOutputParser
 from langchain_core.prompts import PromptTemplate
 from langchain_core.runnables import RunnablePassthrough

 answer_prompt = PromptTemplate.from_template(
     """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

 Question: {question}
 SQL Query: {query}
 SQL Result: {result}
 Answer: """
 )

 rephrase_answer = answer_prompt | llm | StrOutputParser()

 chain = (
     RunnablePassthrough.assign(query=generate_query).assign(
         result=itemgetter("query") | execute_query_tool
     )
     | rephrase_answer
 )

 chain.invoke({"question": "what is price of `1968 Ford Mustang"})
#On invoking the chain , question -> generate_query[(llm , db) chain] generates query and stores it in query variable -> itemgetter extracts the query item and passes it to execute_query_tool which executes the query , generates the result and stores it in result variable
#Now there is (question , query , result) inside the chain -> rephrase_answer
#Inside rephrase_answer chain , (question ,query ,answer) -> prompt template -> llm -> output parser -> output

'The price of the 1968 Ford Mustang is $95.34.'

In [14]:
chain.invoke({"question": "How many customers have an order count greater than 5"})

'There are 2 customers who have an order count greater than 5.'

In [15]:
chain.invoke({"question": "Find the top 5 products with the highest total sales revenue"})

'The top 5 products with the highest total sales revenue are:\n1. 1992 Ferrari 360 Spider red - $276,839.98\n2. 2001 Ferrari Enzo - $190,755.86\n3. 1952 Alpine Renault 1300 - $190,017.96\n4. 2003 Harley-Davidson Eagle Drag Bike - $170,686.00\n5. 1968 Ford Mustang - $161,531.48'

In [10]:
chain.invoke({"question": "List the top 3 countries with the highest number of orders"})

'The top 3 countries with the highest number of orders are:\n1. USA with 112 orders\n2. France with 37 orders\n3. Spain with 36 orders'

#Few Shot Prompting

In [43]:
examples = [
    {
  "input": "Find the total number of orders placed by customers in the USA.",
  "query": "SELECT COUNT(*) FROM orders o JOIN customers c ON o.customerNumber = c.customerNumber WHERE c.country = 'USA';"
},
{
  "input": "Retrieve the names and cities of all employees who work in the 'EMEA' territory.",
  "query": "SELECT firstName, lastName, city FROM employees e JOIN offices o ON e.officeCode = o.officeCode WHERE o.territory = 'EMEA';"
},
{
  "input": "List all products that have more than 1000 units in stock.",
  "query": "SELECT productName FROM products WHERE quantityInStock > 1000;"
},
{
  "input": "Get the details of the order with the highest total amount.",
  "query": "SELECT o.orderNumber, SUM(od.quantityOrdered * od.priceEach) AS totalAmount FROM orders o JOIN orderdetails od ON o.orderNumber = od.orderNumber GROUP BY o.orderNumber ORDER BY totalAmount DESC LIMIT 1;"
},
{
  "input": "Find the average credit limit of customers in Germany.",
  "query": "SELECT AVG(creditLimit) FROM customers WHERE country = 'Germany';"
},
{
  "input": "List all orders that were shipped after the required date.",
  "query": "SELECT * FROM orders WHERE shippedDate > requiredDate;"
},
{
  "input": "Show the product names and their respective vendors for all products in the 'Planes' product line.",
  "query": "SELECT productName, productVendor FROM products WHERE productLine = 'Planes';"
},
{
  "input": "Get the number of employees who report to the employee with employee number 1002.",
  "query": "SELECT COUNT(*) FROM employees WHERE reportsTo = 1002;"
},
{
  "input": "Retrieve the details of all payments made before January 1, 2023.",
  "query": "SELECT * FROM payments WHERE paymentDate < '2023-01-01';"
}

]

In [20]:
%pip install faiss-cpu


Collecting faiss-cpu
  Downloading faiss_cpu-1.8.0.post1-cp310-cp310-win_amd64.whl (14.6 MB)
     ---------------------------------------- 0.0/14.6 MB ? eta -:--:--
     ---------------------------------------- 0.0/14.6 MB ? eta -:--:--
     ---------------------------------------- 0.0/14.6 MB ? eta -:--:--
     --------------------------------------- 0.1/14.6 MB 653.6 kB/s eta 0:00:23
     --------------------------------------- 0.1/14.6 MB 653.6 kB/s eta 0:00:23
     --------------------------------------- 0.1/14.6 MB 653.6 kB/s eta 0:00:23
     --------------------------------------- 0.1/14.6 MB 653.6 kB/s eta 0:00:23
     --------------------------------------- 0.1/14.6 MB 653.6 kB/s eta 0:00:23
     --------------------------------------- 0.1/14.6 MB 653.6 kB/s eta 0:00:23
     --------------------------------------- 0.1/14.6 MB 209.5 kB/s eta 0:01:10
     --------------------------------------- 0.1/14.6 MB 209.5 kB/s eta 0:01:10
     --------------------------------------- 0.1/14


[notice] A new release of pip is available: 23.0.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [44]:
from langchain_community.vectorstores import FAISS
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    FAISS,
    k=5,
    input_keys=["input"],
)



In [45]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate
example_prompt = ChatPromptTemplate.from_messages(
     [
         ("human", "{input}\nSQLQuery:"),
         ("ai", "{query}"),
     ]
 )
few_shot_prompt = FewShotChatMessagePromptTemplate(
     example_prompt=example_prompt,
     example_selector=example_selector,
     input_variables=["input","top_k"],
 )
print(few_shot_prompt.format(input="How many customers with credit limit more than 50000?"))

Human: Find the average credit limit of customers in Germany.
SQLQuery:
AI: SELECT AVG(creditLimit) FROM customers WHERE country = 'Germany';
Human: Find the total number of orders placed by customers in the USA.
SQLQuery:
AI: SELECT COUNT(*) FROM orders o JOIN customers c ON o.customerNumber = c.customerNumber WHERE c.country = 'USA';
Human: List all products that have more than 1000 units in stock.
SQLQuery:
AI: SELECT productName FROM products WHERE quantityInStock > 1000;
Human: Retrieve the details of all payments made before January 1, 2023.
SQLQuery:
AI: SELECT * FROM payments WHERE paymentDate < '2023-01-01';
Human: Get the number of employees who report to the employee with employee number 1002.
SQLQuery:
AI: SELECT COUNT(*) FROM employees WHERE reportsTo = 1002;


In [28]:
def few_shot_prompt_generate_template(question):
        examples = [
        {
    "input": "Find the total number of orders placed by customers in the USA.",
    "query": "SELECT COUNT(*) FROM orders o JOIN customers c ON o.customerNumber = c.customerNumber WHERE c.country = 'USA';"
    },
    {
    "input": "Retrieve the names and cities of all employees who work in the 'EMEA' territory.",
    "query": "SELECT firstName, lastName, city FROM employees e JOIN offices o ON e.officeCode = o.officeCode WHERE o.territory = 'EMEA';"
    },
    {
    "input": "List all products that have more than 1000 units in stock.",
    "query": "SELECT productName FROM products WHERE quantityInStock > 1000;"
    },
    {
    "input": "Get the details of the order with the highest total amount.",
    "query": "SELECT o.orderNumber, SUM(od.quantityOrdered * od.priceEach) AS totalAmount FROM orders o JOIN orderdetails od ON o.orderNumber = od.orderNumber GROUP BY o.orderNumber ORDER BY totalAmount DESC LIMIT 1;"
    },
    {
    "input": "Find the average credit limit of customers in Germany.",
    "query": "SELECT AVG(creditLimit) FROM customers WHERE country = 'Germany';"
    },
    {
    "input": "List all orders that were shipped after the required date.",
    "query": "SELECT * FROM orders WHERE shippedDate > requiredDate;"
    },
    {
    "input": "Show the product names and their respective vendors for all products in the 'Planes' product line.",
    "query": "SELECT productName, productVendor FROM products WHERE productLine = 'Planes';"
    },
    {
    "input": "Get the number of employees who report to the employee with employee number 1002.",
    "query": "SELECT COUNT(*) FROM employees WHERE reportsTo = 1002;"
    },
    {
    "input": "Retrieve the details of all payments made before January 1, 2023.",
    "query": "SELECT * FROM payments WHERE paymentDate < '2023-01-01';"
    }

    ]

        from langchain_community.vectorstores import FAISS
        from langchain_core.example_selectors import SemanticSimilarityExampleSelector
        from langchain_openai import OpenAIEmbeddings

        example_selector = SemanticSimilarityExampleSelector.from_examples(
            examples,
            OpenAIEmbeddings(),
            FAISS,
            k=5,
            input_keys=["input"],
        )

        example_selector.select_examples({"input": question}) 

        from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate
        example_prompt = ChatPromptTemplate.from_messages(
                [
                    ("human", "{input}\nSQLQuery:"),
                    ("ai", "{query}"),
                ]
            )
        few_shot_prompt = FewShotChatMessagePromptTemplate(
                example_prompt=example_prompt,
                example_selector=example_selector,
                input_variables=["input","top_k"],
            )
        few_shot_prompt.format(input=question)
        # print(few_shot_prompt.format(input=question))
        # print(type(few_shot_prompt))
    
        
    

In [13]:
few_shot_prompt_generate_template("How many customers with credit limit more than 50000")


In [25]:
context = db.get_context()
print(list(context))
print(context["table_info"])

['table_info', 'table_names']

CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	salesRepEmployeeNumber	creditLimit
103	Atelier graphique	Schmitt	Carine 	40.32.2555	54, rue Royale	None	Nantes	None	44000	France	1370	21000.00


In [27]:
final_prompt = ChatPromptTemplate.from_messages(
     [
         ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n Here is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
         few_shot_prompt,
         ("human", "{input}"),
     ]
 )
print(final_prompt.format(input="How many products are there?",table_info=context["table_info"]))


System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.
 Here is the relevant table info: 
CREATE TABLE customers (
	`customerNumber` INTEGER NOT NULL, 
	`customerName` VARCHAR(50) NOT NULL, 
	`contactLastName` VARCHAR(50) NOT NULL, 
	`contactFirstName` VARCHAR(50) NOT NULL, 
	phone VARCHAR(50) NOT NULL, 
	`addressLine1` VARCHAR(50) NOT NULL, 
	`addressLine2` VARCHAR(50), 
	city VARCHAR(50) NOT NULL, 
	state VARCHAR(50), 
	`postalCode` VARCHAR(15), 
	country VARCHAR(50) NOT NULL, 
	`salesRepEmployeeNumber` INTEGER, 
	`creditLimit` DECIMAL(10, 2), 
	PRIMARY KEY (`customerNumber`), 
	CONSTRAINT customers_ibfk_1 FOREIGN KEY(`salesRepEmployeeNumber`) REFERENCES employees (`employeeNumber`)
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci

/*
3 rows from customers table:
customerNumber	customerName	contactLastName	contactFirstName	phone	addressLine1	addressLine2	city	state	postalCode	country	sa

In [16]:
generate_query = create_sql_query_chain(llm, db,final_prompt)
chain = (
 RunnablePassthrough.assign(query=generate_query).assign(
     result=itemgetter("query") | execute_query_tool
 )
 | rephrase_answer
 )
chain.invoke({"question": "How many customers with credit limit more than 50000"})


'There are 85 customers with a credit limit greater than 50000.'

In [15]:
from langchain.chains.sql_database.prompt import SQL_PROMPTS

list(SQL_PROMPTS)

['crate',
 'duckdb',
 'googlesql',
 'mssql',
 'mysql',
 'mariadb',
 'oracle',
 'postgresql',
 'sqlite',
 'clickhouse',
 'prestodb']

#Exploring SQL Prompt Templates

In [23]:
from langchain.chains import create_sql_query_chain

chain = create_sql_query_chain(llm, db)
chain.get_prompts()[0].pretty_print()

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [24]:
prompt_with_context = chain.get_prompts()[0].partial(table_info=context["table_info"])
print(prompt_with_context.pretty_repr())


You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the S

In [30]:
%pip install pandas

Collecting pandas
  Using cached pandas-2.2.2-cp310-cp310-win_amd64.whl (11.6 MB)
Collecting tzdata>=2022.7
  Using cached tzdata-2024.1-py2.py3-none-any.whl (345 kB)
Collecting pytz>=2020.1
  Using cached pytz-2024.1-py2.py3-none-any.whl (505 kB)
Installing collected packages: pytz, tzdata, pandas
Successfully installed pandas-2.2.2 pytz-2024.1 tzdata-2024.1
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.0.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


#Dynamic Relevant Table Selection

In [31]:
from operator import itemgetter
from langchain.chains.openai_tools import create_extraction_chain_pydantic
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
import pandas as pd

def get_table_details():
    # Read the CSV file into a DataFrame
    table_description = pd.read_csv("classicmodels_table_descriptions.csv")
    table_docs = []

    # Iterate over the DataFrame rows to create Document objects
    table_details = ""
    for index, row in table_description.iterrows():
        table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n"

    return table_details


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")

# table_names = "\n".join(db.get_usable_table_names())
table_details = get_table_details()
print(table_details)


Table Name:productlines
Table Description:Contains information about various product lines, including text descriptions, HTML formatted descriptions, and images. Each product line categorizes a group of products and is identified by a unique product line code.

Table Name:products
Table Description:Stores details of individual products offered by the company. Each product is associated with a product line, has a unique product code, and includes attributes such as name, scale, vendor, description, quantity in stock, buying price, and Manufacturer's Suggested Retail Price (MSRP).

Table Name:offices
Table Description:Contains information about the company's offices, including office codes, location details (city, state, country), contact numbers, and address details. Each office is uniquely identified by an office code.

Table Name:employees
Table Description:Stores data about employees, including their employee numbers, names, job titles, contact information, and office codes. It also 

In [90]:
#Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.
#Using create_extraction_chain_pydantic
# Construct the prompt
table_details_prompt = f"""Given the following SQL tables:

{table_details}

Return the names of ALL the SQL tables that MIGHT be relevant to the user's question. 
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.
User's question: "{input}"

"""

# Create the table extraction chain
table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt)

# Invoke the chain with the user input
tables = table_chain.invoke({"input": "Top 5 products with highest sales revenue"})
print(tables)


[Table(name='productlines'), Table(name='products'), Table(name='orders'), Table(name='orderdetails'), Table(name='customers'), Table(name='payments')]


In [32]:
#Using Pydantic Tool parser
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field


class Table(BaseModel):
    """Table in SQL database."""

    name: str = Field(description="Name of table in SQL database.")


system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
The tables details are:

{table_details}

Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "{input}"),
    ]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])

table_chain = prompt | llm_with_tools | output_parser

table_chain.invoke({"input": "List the top 3 countries with the highest number of orders"})

[Table(name='orders'), Table(name='customers'), Table(name='employees')]

In [33]:

def get_tables(tables: List[Table]) -> List[str]:
    tables  = [table.name for table in tables]
    table_names = {}
    table_names["tables"] = tables
    return table_names

table_chain = {"input": itemgetter("question")} | prompt | llm_with_tools | output_parser| get_tables
select_table = table_chain.invoke({"question": "List the top 3 countries with the highest number of orders"})
select_table

{'tables': ['orders', 'customers', 'employees']}

In [96]:
def get_tables(tables: List[Table]) -> List[str]:
    tables  = [table.name for table in tables]
    table_names = {}
    table_names["tables"] = tables
    return table_names

select_table_chain = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables
select_table = select_table_chain.invoke({"question": "List the top 3 countries with the highest number of orders"})
select_table

{'tables': ['productlines',
  'products',
  'offices',
  'employees',
  'customers',
  'payments',
  'orders',
  'orderdetails']}

#Enhancing Chatbots with Memory for Follow-up Database Queries

In [34]:
from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()


In [35]:
final_prompt = ChatPromptTemplate.from_messages(
     [
         ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\nHere are the tables info : {table_info}\nHere are the tables that you can refer to : {selected_tables}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for reference and should be considered while answering follow up questions"),
         few_shot_prompt,
         MessagesPlaceholder(variable_name="messages"),
         ("human", "{input}"),
     ]
      
 )


In [36]:
generate_query = create_sql_query_chain(llm, db,final_prompt)
chain =(RunnablePassthrough.assign(selected_tables=table_chain)|
       RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query_tool)) | rephrase_answer
 


In [37]:
history.messages

[]

In [38]:
def get_response(question):
    few_shot_prompt_generate_template(question)
    response = chain.invoke({"question": question,"messages":history.messages})
    history.add_user_message(question)
    history.add_ai_message(response)
    return response




In [39]:
response = get_response("How many customers with creditlimit > 50000")
response


'There are 85 customers with a credit limit greater than 50000.'

In [40]:
response = get_response("What are their names")
response

"The names of the customers with a credit limit greater than $50,000 are as follows:\n- Signal Gift Stores\n- Australian Collectors, Co.\n- La Rochelle Gifts\n- Baane Mini Imports\n- Mini Gifts Distributors Ltd.\n- Blauer See Auto, Co.\n- Mini Wheels Co.\n- Land of Toys Inc.\n- Euro+ Shopping Channel\n- Volvo Model Replicas, Co\n- Danish Wholesale Imports\n- Saveley & Henriot, Co.\n- Dragon Souveniers, Ltd.\n- Muscle Machine Inc\n- Diecast Classics Inc.\n- Technics Stores Inc.\n- Handji Gifts& Co\n- Herkku Gifts\n- Daedalus Designs Imports\n- La Corne D'abondance, Co.\n- Gift Depot Inc.\n- Osaka Souveniers Co.\n- Vitachrome Inc.\n- Toys of Finland, Co.\n- AV Stores, Co.\n- Clover Collections, Co.\n- UK Collectables, Ltd.\n- Canadian Gift Exchange Network\n- Online Mini Collectables\n- Toys4GrownUps.com\n- Mini Caravy\n- King Kong Collectables, Co.\n- Enaco Distributors\n- Heintze Collectables\n- Collectable Mini Designs Co.\n- giftsbymail.co.uk\n- Alpha Cognac\n- Amica Models & Co.\n- 

In [124]:
response = get_response("What are their creditlimits")
response

"The credit limits of customers with credit limits greater than $50,000 are as follows:\n- Signal Gift Stores: $71,800.00\n- Australian Collectors, Co.: $117,300.00\n- La Rochelle Gifts: $118,200.00\n- Baane Mini Imports: $81,700.00\n- Mini Gifts Distributors Ltd.: $210,500.00\n- Blauer See Auto, Co.: $59,700.00\n- Mini Wheels Co.: $64,600.00\n- Land of Toys Inc.: $114,900.00\n- Euro+ Shopping Channel: $227,600.00\n- Volvo Model Replicas, Co: $53,100.00\n- Danish Wholesale Imports: $83,400.00\n- Saveley & Henriot, Co.: $123,900.00\n- Dragon Souveniers, Ltd.: $103,800.00\n- Muscle Machine Inc: $138,500.00\n- Diecast Classics Inc.: $100,600.00\n- Technics Stores Inc.: $84,600.00\n- Handji Gifts& Co: $97,900.00\n- Herkku Gifts: $96,800.00\n- Daedalus Designs Imports: $82,900.00\n- La Corne D'abondance, Co.: $84,300.00\n- Gift Depot Inc.: $84,300.00\n- Osaka Souveniers Co.: $81,200.00\n- Vitachrome Inc.: $76,400.00\n- Toys of Finland, Co.: $96,500.00\n- AV Stores, Co.: $136,800.00\n- Clove

In [108]:
history.messages

[HumanMessage(content='How many customers with creditlimit > 50000'),
 AIMessage(content='There are 85 customers with a credit limit greater than 50000.'),
 HumanMessage(content='What are their names'),
 AIMessage(content="The names of the customers with a credit limit greater than $50,000 are as follows:\n- Signal Gift Stores\n- Australian Collectors, Co.\n- La Rochelle Gifts\n- Baane Mini Imports\n- Mini Gifts Distributors Ltd.\n- Blauer See Auto, Co.\n- Mini Wheels Co.\n- Land of Toys Inc.\n- Euro+ Shopping Channel\n- Volvo Model Replicas, Co\n- Danish Wholesale Imports\n- Saveley & Henriot, Co.\n- Dragon Souveniers, Ltd.\n- Muscle Machine Inc\n- Diecast Classics Inc.\n- Technics Stores Inc.\n- Handji Gifts& Co\n- Herkku Gifts\n- Daedalus Designs Imports\n- La Corne D'abondance, Co.\n- Gift Depot Inc.\n- Osaka Souveniers Co.\n- Vitachrome Inc.\n- Toys of Finland, Co.\n- AV Stores, Co.\n- Clover Collections, Co.\n- UK Collectables, Ltd.\n- Canadian Gift Exchange Network\n- Online Min