In [1]:
import os
from langchain.llms import GooglePalm
from dotenv import load_dotenv
from langchain.utilities import (SQLDatabase)

In [2]:
%pwd

'c:\\Users\\Asus\\Machine_learning\\Practice\\Co2_emission\\notebook'

In [3]:
os.chdir("../")

In [4]:
load_dotenv()

palm_api_key    = os.getenv('PALM_API_KEY')
db_user         = os.getenv('db_user')
db_password     = os.getenv('db_password')
db_host         = os.getenv('db_host')
db_name         = os.getenv('db_name')

In [6]:
db_uri = f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}"
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",sample_rows_in_table_info=3)
llm = GooglePalm(google_api_key=palm_api_key,temperature=0.65)

In [7]:
type(db.table_info)

str

In [8]:
db.get_table_names()



['country_emissions', 'country_table']

In [9]:
from langchain_experimental.sql import SQLDatabaseChain


db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [11]:

qns1 = db_chain("which country has maximum co2 emission in the year of `2002`")



[1m> Entering new SQLDatabaseChain chain...[0m
which country has maximum co2 emission in the year of `2002`
SQLQuery:[32;1m[1;3mSELECT Country_name FROM country_emissions WHERE 2002 = (SELECT MAX(2002) FROM country_emissions)[0m
SQLResult: [33;1m[1;3m[('Africa Eastern and Southern',), ('Afghanistan',), ('Africa Western and Central',), ('Angola',), ('Albania',), ('Andorra',), ('Arab World',), ('United Arab Emirates',), ('Argentina',), ('Armenia',), ('Antigua and Barbuda',), ('Australia',), ('Austria',), ('Azerbaijan',), ('Burundi',), ('Belgium',), ('Benin',), ('Burkina Faso',), ('Bangladesh',), ('Bulgaria',), ('Bahrain',), ('Bahamas, The',), ('Bosnia and Herzegovina',), ('Belarus',), ('Belize',), ('Bolivia',), ('Brazil',), ('Barbados',), ('Brunei Darussalam',), ('Bhutan',), ('Botswana',), ('Central African Republic',), ('Canada',), ('Central Europe and the Baltics',), ('Switzerland',), ('Chile',), ('China',), ("Cote d'Ivoire",), ('Cameroon',), ('Congo, Dem. Rep.',), ('Congo, Re

In [12]:
db_chain("SELECT ce.Country_name, ct.IncomeGroup FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code ORDER BY ce.`2002` DESC LIMIT 1;")




[1m> Entering new SQLDatabaseChain chain...[0m
SELECT ce.Country_name, ct.IncomeGroup FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code ORDER BY ce.`2002` DESC LIMIT 1;
SQLQuery:[32;1m[1;3mSELECT ce.Country_name, ct.IncomeGroup FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code ORDER BY ce.`2002` DESC LIMIT 1[0m
SQLResult: [33;1m[1;3m[('Qatar', 'High income')][0m
Answer:[32;1m[1;3mQatar, High income[0m
[1m> Finished chain.[0m


{'query': 'SELECT ce.Country_name, ct.IncomeGroup FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code ORDER BY ce.`2002` DESC LIMIT 1;',
 'result': 'Qatar, High income'}

In [13]:
db_chain('SELECT Country_name FROM country_emissions WHERE `2002` = (SELECT MAX(`2002`) as max_value FROM country_emissions);')



[1m> Entering new SQLDatabaseChain chain...[0m
SELECT Country_name FROM country_emissions WHERE `2002` = (SELECT MAX(`2002`) as max_value FROM country_emissions);
SQLQuery:[32;1m[1;3mSELECT Country_name FROM country_emissions WHERE `2002` = (SELECT MAX(`2002`) as max_value FROM country_emissions);[0m
SQLResult: [33;1m[1;3m[('Qatar',)][0m
Answer:[32;1m[1;3mQatar[0m
[1m> Finished chain.[0m


{'query': 'SELECT Country_name FROM country_emissions WHERE `2002` = (SELECT MAX(`2002`) as max_value FROM country_emissions);',
 'result': 'Qatar'}

In [60]:
few_shots = [
    {
        'Question'  :'which country has maximum co2 emission in the year of 2002',
        'SQLQuery'  :'SELECT Country_name FROM country_emissions WHERE `2002` = (SELECT MAX(`2002`) as max_value FROM country_emissions);',
        'SQLResult' :'Result of the SQL query',
        'Answer'    :'Qatar'
    },
    {
        'Question'  :'Which country had the maximum CO2 emissions in the year 2002, and to which income group category do they belong?',
        'SQLQuery'  :'SELECT ce.Country_name, ct.IncomeGroup FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code ORDER BY ce.`2002` DESC LIMIT 1;',
        'SQLResult' :'Result of the SQL query',
        'Answer'    :'Qatar, High income',
    },
    {
        'Question'  :'how many countries emit co2 in the range of 4.25 to 20 in the year of 2010',
        'SQLQuery'  :'SELECT COUNT(ce.Country_code) FROM country_emissions AS ce WHERE ce.`2010` BETWEEN 4.25 AND 20;',
        'SQLResult' :'Result of the SQL query',
        'Answer'    :'84',
    },
    {
        'Question'  :'No of countries emit co2 in the range of 1.25 to 15 in the year of 2010 AND 1990',
        'SQLQuery'  :'SELECT COUNT(country_name) AS country_count FROM country_emissions WHERE (`2010` BETWEEN 1.25 AND 15) AND (`1990` BETWEEN 1.25 AND 15);',
        'SQLResult' :'Result of the SQL query',
        'Answer'    :'119',
    },
    {
        'Question'  : 'Identify the number of countries that exceeded 2 standard deviations in 2013',
        'SQLQuery'  : 'SELECT COUNT(Country_name) FROM country_emissions JOIN (SELECT AVG(`1990`) + 2 * STDDEV(`1990`) AS upper_limit, AVG(`1990`) - 2 * STDDEV(`1990`) AS lower_limit FROM country_emissions) AS std ON country_emissions.`1990` < std.lower_limit OR country_emissions.`1990` > std.upper_limit;',
        'SQLResult' : 'Result of the SQL query',
        'Answer'    : '11', 
    },
    {
        'Question'  : 'Identify the countries that exceeded 2 standard deviations in 2013',
        'SQLQuery'  : 'SELECT Country_name FROM country_emissions JOIN (SELECT AVG(`2013`) + 2 * STDDEV(`2013`) AS upper_limit_2013, AVG(`2013`) - 2 * STDDEV(`2013`) AS lower_limit_2013, AVG(`2015`) + 2 * STDDEV(`2015`) AS upper_limit_2015, AVG(`2015`) - 2 * STDDEV(`2015`) AS lower_limit_2015 FROM country_emissions) AS std ON (country_emissions.`2013` < std.lower_limit_2013 OR country_emissions.`2013` > std.upper_limit_2013) AND(country_emissions.`2015` < std.lower_limit_2015 OR country_emissions.`2015` > std.upper_limit_2015);',
        'SQLResult' : 'Result of the SQL query',
        'Answer'    : 'Australia, Bahrain, Brunei Darussalam, Canada, Estonia, Kazakhstan, Kuwait, Luxembourg, Oman, Qatar, Saudi Arabia, Trinidad and Tobago, United Arab Emirates, United States', 
    },
    {
        'Question'  : 'Identify the total countries that exceeded 2 standard deviations in 2013 and 2015',
        'SQLQuery'  : 'SELECT COUNT(Country_name) FROM country_emissions JOIN (SELECT AVG(`2013`) + 2 * STDDEV(`2013`) AS upper_limit_2013, AVG(`2013`) - 2 * STDDEV(`2013`) AS lower_limit_2013, AVG(`2015`) + 2 * STDDEV(`2015`) AS upper_limit_2015, AVG(`2015`) - 2 * STDDEV(`2015`) AS lower_limit_2015 FROM country_emissions) AS std ON (country_emissions.`2013` < std.lower_limit_2013 OR country_emissions.`2013` > std.upper_limit_2013) AND(country_emissions.`2015` < std.lower_limit_2015 OR country_emissions.`2015` > std.upper_limit_2015);',
        'SQLResult' : 'Result of the SQL query',
        'Answer'    : '13', 
    }


]


### Creating Semantic Similarity Based example selector

- create embedding on the few_shots
- Store the embeddings in Chroma DB
- Retrieve the the top most Semantically close example from the vector store

In [61]:
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

In [62]:
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

In [63]:
to_vectorize = [' '.join(sent.values()) for sent in few_shots]
vectorstore  = Chroma.from_texts(to_vectorize,embeddings,metadatas=few_shots)

In [64]:
example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=2)

In [65]:
example_selector.select_examples({'Question':'find the number of countries that were underperforming in the years 1990,'})

[{'Answer': '119',
  'Question': 'No of countries emit co2 in the range of 1.25 to 15 in the year of 2010 AND 1990',
  'SQLQuery': 'SELECT COUNT(country_name) AS country_count FROM country_emissions WHERE (`2010` BETWEEN 1.25 AND 15) AND (`1990` BETWEEN 1.25 AND 15);',
  'SQLResult': 'Result of the SQL query'},
 {'Answer': '119',
  'Question': 'No of countries emit co2 in the range of 1.25 to 15 in the year of 2010 AND 1990',
  'SQLQuery': 'SELECT COUNT(country_name) AS country_count FROM country_emissions WHERE (`2010` BETWEEN 1.25 AND 15) AND (`1990` BETWEEN 1.25 AND 15);',
  'SQLResult': 'Result of the SQL query'}]

In [66]:
from langchain.chains.sql_database.prompt import _mysql_prompt,PROMPT_SUFFIX,MYSQL_PROMPT

In [67]:
print(_mysql_prompt)
print(PROMPT_SUFFIX)

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of

## Setting up PromptTemplete using input variables

In [68]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables =['Question','SQLQuery','SQLResult','Answer'],
    template        ='\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}',
)

In [69]:
print(example_prompt.template)


Question: {Question}
SQLQuery: {SQLQuery}
SQLResult: {SQLResult}
Answer: {Answer}


### Few short learning

In [70]:
print(MYSQL_PROMPT)

input_variables=['input', 'table_info', 'top_k'] template='You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use CURDATE() function to get the current date, if the question involves "today".\n\nUse the following format:\n\nQues

In [71]:
from langchain.prompts import FewShotPromptTemplate


few_shot_prompt = FewShotPromptTemplate(
                    example_selector= example_selector,
                    example_prompt  = example_prompt,
                    prefix          = _mysql_prompt,
                    suffix          = PROMPT_SUFFIX,
                    input_variables=['input', 'table_info', 'top_k'])

In [72]:
new_chain = SQLDatabaseChain.from_llm(llm=llm,db=db,verbose=True,prompt=few_shot_prompt)

In [73]:
new_chain('which country has maximum co2 emission in the year of 2002')



[1m> Entering new SQLDatabaseChain chain...[0m
which country has maximum co2 emission in the year of 2002
SQLQuery:[32;1m[1;3mSELECT Country_name FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code ORDER BY ce.`2002` DESC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Qatar',)][0m
Answer:[32;1m[1;3mQatar[0m
[1m> Finished chain.[0m


{'query': 'which country has maximum co2 emission in the year of 2002',
 'result': 'Qatar'}

In [74]:
new_chain('how many countries emit co2 in the range of 4.25 to 20 in the year of 2010 ')



[1m> Entering new SQLDatabaseChain chain...[0m
how many countries emit co2 in the range of 4.25 to 20 in the year of 2010 
SQLQuery:[32;1m[1;3mSELECT COUNT(ce.Country_code) FROM country_emissions AS ce WHERE ce.`2010` BETWEEN 4.25 AND 20;[0m
SQLResult: [33;1m[1;3m[(84,)][0m
Answer:[32;1m[1;3m84[0m
[1m> Finished chain.[0m


{'query': 'how many countries emit co2 in the range of 4.25 to 20 in the year of 2010 ',
 'result': '84'}

In [75]:
new_chain.run('No of countries emit co2 in the range of 0.5 to 5 in the year of 1990 and 2015')



[1m> Entering new SQLDatabaseChain chain...[0m
No of countries emit co2 in the range of 0.5 to 5 in the year of 1990 and 2015
SQLQuery:[32;1m[1;3mSELECT COUNT(country_name) AS country_count FROM country_emissions WHERE (`2015` BETWEEN 0.5 AND 5) AND (`1990` BETWEEN 0.5 AND 5);[0m
SQLResult: [33;1m[1;3m[(82,)][0m
Answer:[32;1m[1;3m82[0m
[1m> Finished chain.[0m


'82'

In [76]:
new_chain('Identify the number of countries that exceeded 1 standard deviations in 2013')



[1m> Entering new SQLDatabaseChain chain...[0m
Identify the number of countries that exceeded 1 standard deviations in 2013
SQLQuery:[32;1m[1;3mSELECT COUNT(Country_name) FROM country_emissions JOIN (SELECT AVG(`2013`) + STDDEV(`2013`) AS upper_limit_2013, AVG(`2013`) - STDDEV(`2013`) AS lower_limit_2013 FROM country_emissions) AS std ON (country_emissions.`2013` < std.lower_limit_2013 OR country_emissions.`2013` > std.upper_limit_2013);[0m
SQLResult: [33;1m[1;3m[(27,)][0m
Answer:[32;1m[1;3m27[0m
[1m> Finished chain.[0m


{'query': 'Identify the number of countries that exceeded 1 standard deviations in 2013',
 'result': '27'}

In [77]:
few_shots

[{'Question': 'which country has maximum co2 emission in the year of 2002',
  'SQLQuery': 'SELECT Country_name FROM country_emissions WHERE `2002` = (SELECT MAX(`2002`) as max_value FROM country_emissions);',
  'SQLResult': 'Result of the SQL query',
  'Answer': 'Qatar'},
 {'Question': 'Which country had the maximum CO2 emissions in the year 2002, and to which income group category do they belong?',
  'SQLQuery': 'SELECT ce.Country_name, ct.IncomeGroup FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code ORDER BY ce.`2002` DESC LIMIT 1;',
  'SQLResult': 'Result of the SQL query',
  'Answer': 'Qatar, High income'},
 {'Question': 'how many countries emit co2 in the range of 4.25 to 20 in the year of 2010',
  'SQLQuery': 'SELECT COUNT(ce.Country_code) FROM country_emissions AS ce WHERE ce.`2010` BETWEEN 4.25 AND 20;',
  'SQLResult': 'Result of the SQL query',
  'Answer': '84'},
 {'Question': 'No of countries emit co2 in the range of 1.25 to 15 in the 

In [79]:
new_chain('In which incomegroup category do most of the highest emittig country belongs?')



[1m> Entering new SQLDatabaseChain chain...[0m
In which incomegroup category do most of the highest emittig country belongs?
SQLQuery:[32;1m[1;3mSELECT ct.IncomeGroup, COUNT(*) FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code GROUP BY ct.IncomeGroup ORDER BY COUNT(*) DESC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('High income', 58)][0m
Answer:[32;1m[1;3mHigh income[0m
[1m> Finished chain.[0m


{'query': 'In which incomegroup category do most of the highest emittig country belongs?',
 'result': 'High income'}

In [80]:
new_chain('In which incomegroup category do most of the lowest emittig country belongs?')



[1m> Entering new SQLDatabaseChain chain...[0m
In which incomegroup category do most of the lowest emittig country belongs?
SQLQuery:[32;1m[1;3mSELECT IncomeGroup FROM country_emissions AS ce JOIN country_table AS ct ON ce.Country_code = ct.Country_Code GROUP BY IncomeGroup ORDER BY AVG(ce.`2020`) LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Low income',)][0m
Answer:[32;1m[1;3mLow income[0m
[1m> Finished chain.[0m


{'query': 'In which incomegroup category do most of the lowest emittig country belongs?',
 'result': 'Low income'}

In [None]:
new_chain('which incomegroup category is belongs to who has ')