In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
from langchain_community.utilities import SQLDatabase



sqlite_uri = 'sqlite:///./HRDataset.db' 
db = SQLDatabase.from_uri(sqlite_uri)

In [3]:
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
generate_query = create_sql_query_chain(llm, db)
# query = generate_query.invoke({"question": "which employee has the most absences and how many?"})
# # "what is price of `1968 Ford Mustang`"
# print(query)

In [4]:

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
execute_query = QuerySQLDataBaseTool(db=db)
# execute_query.invoke(query)

In [5]:
chain = generate_query | execute_query
# chain.invoke({"question": "Name of the third employe?"})

In [6]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user questiony.Your answer should contain thank you at the end of messages.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

rephrase_answer = answer_prompt | llm | StrOutputParser()

chain = (
    RunnablePassthrough.assign(query=generate_query).assign(
        result=itemgetter("query") | execute_query
    )
    | rephrase_answer
)

chain.invoke({"question": "which employee has the most absences and how many?"})

'The employee with the most absences is Bates, Norman with a total of 20 absences. Thank you.'

In [7]:
examples = [
    {
        "input": "List all customers in France with a credit limit over 20,000.",
        "query": "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"
    },
    {
        "input": "Get the highest payment amount made by any customer.",
        "query": "SELECT MAX(amount) FROM payments;"
    },
   
]

In [8]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder,FewShotChatMessagePromptTemplate,PromptTemplate

example_prompt = ChatPromptTemplate.from_messages(
    [
        ("human", "{input}\nSQLQuery:"),
        ("ai", "{query}"),
    ]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    examples=examples,
    # input_variables=["input","top_k"],
    input_variables=["input"],
)
print(few_shot_prompt.format(input1="How many employees are there?"))

Human: List all customers in France with a credit limit over 20,000.
SQLQuery:
AI: SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;
Human: Get the highest payment amount made by any customer.
SQLQuery:
AI: SELECT MAX(amount) FROM payments;


In [9]:
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings

vectorstore = Chroma()
vectorstore.delete_collection()
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(),
    vectorstore,
    k=2,
    input_keys=["input"],
)
example_selector.select_examples({"input": "how many employees we have?"})

[{'input': 'List all customers in France with a credit limit over 20,000.',
  'query': "SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;"},
 {'input': 'Get the highest payment amount made by any customer.',
  'query': 'SELECT MAX(amount) FROM payments;'}]

In [10]:
few_shot_prompt = FewShotChatMessagePromptTemplate(
    example_prompt=example_prompt,
    example_selector=example_selector,
    input_variables=["input","top_k"],
)

In [11]:

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries."),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)

In [12]:
generate_query = create_sql_query_chain(llm, db,final_prompt)
chain = (
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)

In [13]:
# from operator import itemgetter
# from langchain.chains.openai_tools import create_extraction_chain_pydantic
# from langchain_core.pydantic_v1 import BaseModel, Field
# from typing import List
# import pandas as pd

# def get_table_details():
#     # Read the CSV file into a DataFrame
#     table_description = pd.read_csv("database_table_descriptions.csv")
#     table_docs = []

#     # Iterate over the DataFrame rows to create Document objects
#     table_details = ""
#     for index, row in table_description.iterrows():
#         table_details = table_details + "Table Name:" + row['Table'] + "\n" + "Table Description:" + row['Description'] + "\n\n"

#     return table_details


# class Table(BaseModel):
#     """Table in SQL database."""

#     name: str = Field(description="Name of table in SQL database.")

# # table_names = "\n".join(db.get_usable_table_names())
# table_details = get_table_details()
# print(table_details)

In [14]:

# table_details_prompt = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \
# The tables are:

# {table_details}

# Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed."""

# table_chain = create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt)
# tables = table_chain.invoke({"input": "give me details of customer and their order count"})
# tables
     

In [15]:
# def get_tables(tables: List[Table]) -> List[str]:
#     tables  = [table.name for table in tables]
#     return tables

# select_table = {"input": itemgetter("question")} | create_extraction_chain_pydantic(Table, llm, system_message=table_details_prompt) | get_tables
# select_table.invoke({"question": "give me details of customer and their order count"})

In [16]:
# chain = (
# RunnablePassthrough.assign(table_names_to_use=select_table) |
# RunnablePassthrough.assign(query=generate_query).assign(
#     result=itemgetter("query") | execute_query
# )
# | rephrase_answer
# )
# chain.invoke({"question": "How many cutomers with order count more than 5"})

In [17]:

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.\n\nHere is the relevant table info: {table_info}\n\nBelow are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions."),
        few_shot_prompt,
        MessagesPlaceholder(variable_name="messages"),
        ("human", "{input}"),
    ]
)
print(final_prompt.format(input="How many products are there?",table_info="some table info",messages=[]))

System: You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. Unless otherwise specificed.

Here is the relevant table info: some table info

Below are a number of examples of questions and their corresponding SQL queries. Those examples are just for referecne and hsould be considered while answering follow up questions.
Human: List all customers in France with a credit limit over 20,000.
SQLQuery:
AI: SELECT * FROM customers WHERE country = 'France' AND creditLimit > 20000;
Human: Get the highest payment amount made by any customer.
SQLQuery:
AI: SELECT MAX(amount) FROM payments;
Human: How many products are there?


In [18]:
from langchain.memory import ChatMessageHistory
history = ChatMessageHistory()

generate_query = create_sql_query_chain(llm, db,final_prompt)

chain = (
# RunnablePassthrough.assign(table_names_to_use=select_table) |
RunnablePassthrough.assign(query=generate_query).assign(
    result=itemgetter("query") | execute_query
)
| rephrase_answer
)

In [19]:
question = "is norman bates male or female?"
response = chain.invoke({"question": question,"messages":history.messages})
print(response)

# history.add_user_message(question)
# history.add_ai_message(response)

# response = chain.invoke({"question": "now tell me its absences","messages":history.messages})
# print(response)

Norman Bates is male. Thank you.


In [25]:
history.messages

[HumanMessage(content=' give me the data of the first record'),
 AIMessage(content="The data of the first record is: \n('Adinolfi, Wilson  K', 10026, 0, 0, 1, 1, 5, 4, 0, 62506, 0, 19, 'Production Technician I', 'MA', 1960, '07/10/83', 'M ', 'Single', 'US Citizen', 'No', 'White', '7/5/2011', None, 'N/A-StillEmployed', 'Active', 'Production       ', 'Michael Albert', 22.0, 'LinkedIn', 'Exceeds', 4.6, 5, 0, '1/17/2019', 0, 1)\nThank you.")]

In [26]:
history.add_user_message(question)
history.add_ai_message(response)

In [27]:
history.messages

[HumanMessage(content=' give me the data of the first record'),
 AIMessage(content="The data of the first record is: \n('Adinolfi, Wilson  K', 10026, 0, 0, 1, 1, 5, 4, 0, 62506, 0, 19, 'Production Technician I', 'MA', 1960, '07/10/83', 'M ', 'Single', 'US Citizen', 'No', 'White', '7/5/2011', None, 'N/A-StillEmployed', 'Active', 'Production       ', 'Michael Albert', 22.0, 'LinkedIn', 'Exceeds', 4.6, 5, 0, '1/17/2019', 0, 1)\nThank you."),
 HumanMessage(content='is norman bates male or female?'),
 AIMessage(content='Norman Bates is male. Thank you.')]

In [28]:
response = chain.invoke({"question": "ok how many absences does he have?","messages":history.messages})
print(response)

Norman Bates has 20 absences. Thank you.
