# BPL Lab Meeting, 18 November 2024
## Environmental Rights Chatbot

First let's check out the project [README.md](https://github.com/Better-Planet-Laboratory/erv_chatbot) for information about the objective of the chatbot and data sources.

Next steps: 
1) Prepare data and load dataset(s) into SQL database
2) Define unit tests
3) Build LLM
4) Run LLM on possible queries




### 1) Load dataset(s) into SQL database

Resources: 
* https://python.langchain.com/docs/how_to/sql_csv/#sql

In [1]:
import os
from pathlib import Path
import pandas as pd
import geopandas as gpd
import rasterio as rio
from rasterstats import zonal_stats
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine

base_path = Path(os.getcwd()).parent
os.chdir(base_path)

In [2]:
# Functions

# Pretty print of dfs
def printdf(df: pd.DataFrame, num_rows: int=5, ignore_geometry: bool = True):
    if hasattr(df, 'geometry') and ignore_geometry:
        df = df[[col for col in df.columns if col!='geometry']]
    print(df.head(num_rows).to_string())

# Calculate zonal statistics in every polygon
def calculate_zonal_stats(list_of_stats: list, geodata_path: Path, raster_path: Path, save_csv_path: Path = None):

    # Load data
    polygons = gpd.read_file(geodata_path)
    with rio.open(raster_path) as src:
        if src.crs != polygons.crs:
            polygons = polygons.to_crs(src.crs)  # align CRS, if appplicable
        raster = src.read(1)
        affine = src.transform
        nodata_value = src.nodata

    # Calculate
    stats = zonal_stats(polygons, raster,
                        stats=list_of_stats,
                        affine=affine, nodata=nodata_value)

    # Reattach to original data
    polygon_x_raster = pd.concat([
        pd.DataFrame(polygons.drop(columns='geometry')),  # drop geometry to reduce size of dataset
        pd.DataFrame(stats)],
        axis=1)

    # Export, if requested
    if save_csv_path is not None:
        polygon_x_raster.to_csv(save_csv_path, index=False)

    return(polygon_x_raster)

# Load data to SQL with LangChain
def data_to_sql(df, name_of_database: str, if_exists_replace: bool = False):
    database_path = f'data/{name_of_database}.db'
    engine = create_engine(f'sqlite:///{database_path}')
    if not Path(database_path).exists() or if_exists_replace is True:
        df.to_sql(name_of_database, engine, if_exists='replace')
    db = SQLDatabase(engine=engine)

    return db

In [3]:
# Data
stats = ['mean', 'max', 'min', 'median', 'std']
df = calculate_zonal_stats(
    list_of_stats=stats,
    geodata_path=Path('data/input/WB_countries_Admin0_10m/'),
    raster_path=Path('data/input/airQuality.tif')
)
df = df[['WB_NAME', 'CONTINENT'] + stats]
db = data_to_sql(df = df, name_of_database = 'airxcntry', if_exists_replace = False)

In [4]:
printdf(df)

     WB_NAME      CONTINENT       mean        max   min  median       std
0  Indonesia           Asia  14.905112  36.920002  6.76   14.26  3.897446
1   Malaysia           Asia  12.393976  22.320000  7.04   12.22  2.588654
2      Chile  South America  16.120725  35.599998  4.16   15.10  5.632093
3    Bolivia  South America  28.121455  43.380001  6.92   29.52  8.072599
4       Peru  South America  27.543493  43.840000  4.66   28.26  8.422236


### 2) Define unit tests

In [5]:
# Potential queries
q_simple = "What is the average and standard deviation of air quality in Indonesia?"
q_relative = "Which five countries have the worst air quality?"
q_summarize = "What is the average air quality on each continent?"
q_filter = "In how many countries did maximum air quality surpass 100 micrograms per cubic meter?"

In [6]:
# Manually find answers from dataframe
answer_key = {}

# 1) "What is the average and standard deviation of air quality in Indonesia?"
answer_key = answer_key | {
    'q_simple': [df[df['WB_NAME'] == "Indonesia"][['mean', 'std']].round(2).to_dict('records')]
}

# 2) "Which five countries have the worst air quality?"
answer_key = answer_key | {
    'q_relative': df.sort_values('mean', ascending=False).head(5)['WB_NAME'].values.tolist()
}

# 3) "What is the average air quality on each continent?"
answer_key = answer_key | {
    'q_summarize': [df.groupby('CONTINENT')['mean'].mean().round(2).to_dict()]
}

# 4) "In how many countries did maximum air quality surpass 100 micrograms per cubic meter?"
answer_key = answer_key | {
    'q_filter': df[df['max'] > 100]['WB_NAME'].count()
}


In [7]:
answer_key

{'q_simple': [[{'mean': 14.91, 'std': 3.9}]],
 'q_relative': ['Qatar',
  'Bangladesh',
  'Nigeria',
  'Saudi Arabia',
  'United Arab Emirates'],
 'q_summarize': [{'Africa': 29.7,
   'Asia': 31.31,
   'Europe': 12.78,
   'North America': 11.88,
   'Oceania': 5.72,
   'Seven seas (open ocean)': 7.44,
   'South America': 17.68}],
 'q_filter': 10}

### 3) Build LLM

Resources:
* https://python.langchain.com/docs/tutorials/sql_qa

We will eventually chain together agents that a) turn question into SQL query, b) run that SQL query, and c) return SQL result using a natural language response; but first let's explore each agent in turn.

In [8]:
from langchain_ollama import ChatOllama
from langchain.chains import create_sql_query_chain
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

In [9]:
# Load model
model = ChatOllama(model="llama3.1", seed=123)

#### a) Generate SQL query

In [10]:
# a) Generate SQL query (default)
sql_agent_default = create_sql_query_chain(model, db)
sql_code_default = sql_agent_default.invoke({"question": q_simple})
print(sql_code_default)

Question: What is the average and standard deviation of air quality in Indonesia?
SQLQuery: SELECT "mean", "std" FROM airxcntry WHERE "WB_NAME" = 'Indonesia


##### Can anyone identify whether this SQL query run successfully? Let's find out...

In [11]:
db.run(sql_code_default)

OperationalError: (sqlite3.OperationalError) near "Question": syntax error
[SQL: Question: What is the average and standard deviation of air quality in Indonesia?
SQLQuery: SELECT "mean", "std" FROM airxcntry WHERE "WB_NAME" = 'Indonesia]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

##### Oops! Let's see if we can do some prompt engineering to improve model.


In [12]:
# Default prompt
sql_prompt_default = sql_agent_default.get_prompts()[0]
sql_prompt_default.pretty_print()

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result

##### Prompt engineering

Remove the following instructions, which may be confusing the model: 
* Pay attention to use date('now') function to get the current date, if the question involves "today".
* Wrap each column name in double quotes (") to denote them as delimited identifiers.

And add the following instructions based on how we see the model interpreting our request. 
* Return ONLY the SQLQuery with no other context given. Do not include the question in the output. Do not include the title 'SQLQuery' in the output.


In [13]:
# Improved prompt
sql_prompt_text = '''
    You are a SQLite expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer to the input question.
    Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
    Never query for all columns from a table. You must query only the columns that are needed to answer the question. 
    Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
    Return ONLY the SQLQuery with no other context given. Do not include the question in the output. Do not include the title 'SQLQuery' in the output.
    
    Use the following format:
    Question: Question here
    SQLQuery: SQL Query to run
    SQLResult: Result of the SQLQuery
    Answer: Final answer here
    
    Only use the following tables:
    {table_info}
    
    Question: {input}
'''

sql_prompt = PromptTemplate(
    input_variables = ["input", "table_info", "top_k", "dialect"],
    template=sql_prompt_text
)

In [14]:
# New agent based on improved prompt
sql_agent = create_sql_query_chain(model, db, prompt=sql_prompt)

##### Now let's see what our new model returns when asked to generate SQL code.


In [15]:
sql_code = sql_agent.invoke({"question": q_simple})
print(sql_code)
print(db.run(sql_code))

print("\nAnswer key:")
answer_key['q_simple']

SELECT mean, std FROM airxcntry WHERE WB_NAME = 'Indonesia'
[(14.905112004829341, 3.897445770948378)]

Answer key:


[[{'mean': 14.91, 'std': 3.9}]]

##### Yay! We haven't actually built another agent to execute the SQL query yet, so let's do that now; and chain it to the agent that writes SQL code.


#### b) Execute SQL query

In [16]:
# b) Execute SQL query 
execute_sql_query = QuerySQLDataBaseTool(db=db)
write_sql_query = create_sql_query_chain(model, db, prompt=sql_prompt)
sql_chain = write_sql_query | execute_sql_query
sql_chain.invoke({"question": q_simple})

'[(14.905112004829341, 3.897445770948378)]'

#### c) Natural language response

In [17]:
# Prompt for agent that will interact with user
answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the user question.

    Question: {question}
    SQL Query: {query}
    SQL Result: {result}
    Answer:
    
    """
)

In [18]:
# Chain all agents together
chain = (
    RunnablePassthrough.assign(query=write_sql_query).assign(
        result=itemgetter("query") | execute_sql_query
    )
    | answer_prompt | model | StrOutputParser()
)


In [19]:
# Test chain
final_response = chain.invoke({"question": q_simple})
print(final_response)

The average air quality in Indonesia is approximately 14.91 and the standard deviation of air quality in Indonesia is approximately 3.90. 

This information can be used to assess the overall air quality in Indonesia and its variations, which may be useful for policymakers or researchers interested in environmental health issues.


### 4) Answers to other queries

In [20]:
# q_relative 
# q_relative = "Which five countries have the worst average air quality?"

print(q_relative, "\n")

r_relative = chain.invoke({"question": q_relative})
print(r_relative)

print("\nAnswer key:", answer_key['q_relative'])

print("\nSQL code written:", sql_agent.invoke({"question": q_relative}))


Which five countries have the worst air quality? 

Based on the SQL query and result, it appears that the air quality is being measured by a "min" value (likely the minimum concentration of some pollutant) and the countries with the worst air quality are those with the highest "min" values.

However, since the actual data being used to calculate the "min" value is not provided in the query or result, we can only make an educated guess about what kind of air quality metric is being measured. 

Assuming that the "min" value represents a minimum concentration of some pollutant (such as particulate matter), it seems that Qatar, Bahrain, Kuwait, Bangladesh, and the United Arab Emirates have the worst air quality based on this particular metric.

Therefore, the answer to the user's question could be:

"The five countries with the worst air quality are Qatar, Bahrain, Kuwait, Bangladesh, and the United Arab Emirates."

Answer key: ['Qatar', 'Bangladesh', 'Nigeria', 'Saudi Arabia', 'United Ara

##### Human assessment: A little verbose and chose to rely on the min instead of mean columns which could be defensible, but should be standardized. Greatly improved by user adding more context.

In [21]:
# q_summarize

print(q_summarize, "\n")

r_summarize = chain.invoke({"question": q_summarize})
print(r_summarize)

print("\nAnswer key:", answer_key['q_summarize'])

print("\nSQL code written:", sql_agent.invoke({"question": q_summarize}))


What is the average air quality on each continent? 

Based on the SQL result, it appears that the average air quality on each continent is as follows:

* Africa: 22.63
* Asia: 14.91
* Europe: 9.92
* North America: 16.97
* Oceania: 12.76
* Seven seas (open ocean): 4.88
* South America: 16.12

Therefore, the answer to the user question is:

"The average air quality on each continent is as follows:
Africa: 22.63
Asia: 14.91
Europe: 9.92
North America: 16.97
Oceania: 12.76
Seven seas (open ocean): 4.88
South America: 16.12"

Answer key: [{'Africa': 29.7, 'Asia': 31.31, 'Europe': 12.78, 'North America': 11.88, 'Oceania': 5.72, 'Seven seas (open ocean)': 7.44, 'South America': 17.68}]

SQL code written: SELECT CONTINENT, mean FROM airxcntry GROUP BY CONTINENT


##### Human assessment: Incorrect! Looks like it is just taking the mean of the first entry in each continent. 

In [22]:
# q_filter

print(q_filter, "\n")

r_filter = chain.invoke({"question": q_filter})
print(r_filter)

print("\nAnswer key:", answer_key['q_filter'])

print("\nSQL code written:", sql_agent.invoke({"question": q_filter}))


In how many countries did maximum air quality surpass 100 micrograms per cubic meter? 

The SQL query is trying to count the number of rows in the `airxcntry` table where the maximum air quality index (`max`) is greater than 100. However, there's a syntax error.

Looking at the SQL result, it seems that SQLite (the database system being used) doesn't like the word "index" as part of the SELECT clause. This is because "index" is a reserved keyword in SQLite.

To fix this issue, we can use backticks to escape the word "index", or simply replace it with the actual column name if it's not using the default index column name (which I assume is not the case here).

However, based on the SQL query and result provided, it seems that there might be a different error. The `SELECT COUNT(index) FROM airxcntry WHERE max > 100` query should actually be `SELECT COUNT(*) FROM airxcntry WHERE max > 100`. The column name is not "index" but rather an auto-incrementing integer value.

So, the correct SQL 

##### Human assessment: Again verbose and never got around to actually answering the question. Not sure its suggestion is the best / most efficient SQL query either.