In [14]:
import os 
from dotenv import load_dotenv

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

In [15]:
few_shots = [
    {
        'Question': "Which city has spent the highest amount over the years?",
        'SQL_Query': """
select City , sum(Amount) as Total_Amount
from cc
group by City
order by Total_Amount desc
limit 1;
"""
    },
    {
        'Question': "Which card type has the highest amount over the years",
        'SQL_Query': """
select Card_Type , sum(Amount) as Total_Amount
from cc
group by Card_Type
order by Total_Amount desc
limit 1 ;
"""
    },
    {
        'Question': "Which expense type has the highest amount over the years",
        'SQL_Query': """
select Exp_Type, sum(Amount) as Total_Amount
from cc
group by Exp_Type
order by Total_Amount desc
limit 1;
"""
    },
    {
        'Question': "what is the total amount spent between males and females in numbers and percentage?",
        'SQL_Query': """
select Gender , sum(Amount) as Total_Amount, sum(Amount)*100/
(select (sum(Amount)) as S
from cc) as percent_of_total
from cc
group by Gender
order by percent_of_total desc;
"""
    },
    {
        'Question': "Show the month wise spend across the years",
        'SQL_Query': """
select Month, sum(Amount) as Total_Amount
from cc
group by month
order by Amount desc;
"""
    }
]

In [16]:
from langchain_community.utilities import SQLDatabase

db_uri = "your_mysql_uri/credit_card"   #credit_card is the name of the database in which the data is stored
db = SQLDatabase.from_uri(db_uri)
db.table_info

'\nCREATE TABLE creditcards (\n\t`index` INTEGER, \n\t`City` TEXT, \n\t`Date` TEXT, \n\t`Card Type` TEXT, \n\t`Exp Type` TEXT, \n\t`Gender` TEXT, \n\t`Amount` INTEGER, \n\t`Year` INTEGER, \n\t`Month` TEXT\n)ENGINE=InnoDB COLLATE utf8mb4_0900_ai_ci DEFAULT CHARSET=utf8mb4\n\n/*\n3 rows from creditcards table:\nindex\tCity\tDate\tCard Type\tExp Type\tGender\tAmount\tYear\tMonth\n0\tDelhi, India\t29-10-2014\tGold\tBills\tF\t82475\t2014\tOctober\n1\tGreater Mumbai, India\t22-08-2014\tPlatinum\tBills\tF\t32555\t2014\tAugust\n2\tBengaluru, India\t27-08-2014\tSilver\tBills\tF\t101738\t2014\tAugust\n*/'

In [17]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(api_key=OPENAI_API_KEY)

In [18]:
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')



In [19]:
to_vectorize = [" ".join(example.values()) for example in few_shots]

In [20]:
from langchain_community.vectorstores.chroma import Chroma

vectorstore = Chroma.from_texts(to_vectorize, embedding=embeddings, metadatas=few_shots)

In [21]:
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=2
)
example_selector.select_examples({"Question":"What card type is used mostly?"})

[{'Question': 'Which card type has the highest amount over the years',
  'SQL_Query': '\nselect Card_Type , sum(Amount) as Total_Amount\nfrom cc\ngroup by Card_Type\norder by Total_Amount desc\nlimit 1 ;\n'},
 {'Question': 'Which card type has the highest amount over the years',
  'SQL_Query': '\nselect Card_Type , sum(Amount) as Total_Amount\nfrom cc\ngroup by Card_Type\norder by Total_Amount desc\nlimit 1 ;\n'}]

In [22]:
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX, _mysql_prompt

print(_mysql_prompt)

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of

In [23]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables=["Question", "SQL_Query"],
    template="\nQuestion: {Question}\nSQL_Query: {SQL_Query}"
)

In [24]:
from langchain.prompts import FewShotPromptTemplate

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=_mysql_prompt,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"] 
)

In [25]:
from langchain_experimental.sql import SQLDatabaseChain

new_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt)

In [32]:
new_chain("In which month most amount was spend by females using a signature card type?")



[1m> Entering new SQLDatabaseChain chain...[0m
In which month most amount was spend by females using a signature card type?
SQLQuery:[32;1m[1;3mSELECT Month, SUM(Amount) AS Total_Amount
FROM creditcards
WHERE `Card Type` = 'Signature' AND Gender = 'F'
GROUP BY Month
ORDER BY Total_Amount DESC
LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('October', Decimal('66991912'))][0m
Answer:[32;1m[1;3mIn the month of October, the most amount was spent by females using a signature card type.[0m
[1m> Finished chain.[0m


{'query': 'In which month most amount was spend by females using a signature card type?',
 'result': 'In the month of October, the most amount was spent by females using a signature card type.'}