# Data privacy

According to [Open AI's official statement](https://help.openai.com/en/articles/5722486-how-your-data-is-used-to-improve-model-performance):

>OpenAI does not use data submitted by customers via our API to train OpenAI models or improve OpenAI’s service offering. In order to support the continuous improvement of our models, you can fill out this form to opt-in to share your data with us.

According to [Open AI's API data usage policies](https://openai.com/policies/api-data-usage-policies):

>1. OpenAI will not use data submitted by customers via our API to train or improve our models, unless you explicitly decide to share your data with us for this purpose. You can opt-in to share data.
>2. Any data sent through the API will be retained for abuse and misuse monitoring purposes for a maximum of 30 days, after which it will be deleted (unless otherwise required by law).

# Library import and client setup

In [None]:
import openai
import traceback
import pandas as pd
import tiktoken
import re

from datetime import datetime
from google.cloud import bigquery
from google.cloud.exceptions import BadRequest
from typing import Dict, List, Union

In [None]:
client = bigquery.Client()

__NOTE:__ The variable `MAX_ALLOWED_TOKENS` should be set in such a way that there's a safe buffer (at least 100) between its value and the actual maximum token allowed by the API

In [None]:
# Set up OpenAI API credentials
openai.api_key = "YOUR_API_KEY"

# Define the GPT-4 model prompt
DEFAULT_GPT_MODEL = "gpt-3.5-turbo-0301"

# Max token
MAX_ALLOWED_TOKENS = 2000

# model_name = "gpt-4"
# GPT-4 is not available to even subscribed accounts. Need to join the waitlist for this: https://openai.com/waitlist/gpt-4-api

# Table info feed-in

In [None]:
def export_table_schema(table_name: str) -> str:
    """Get table schema from given table name

    Args:
        table_name (str): Name of table to extract schema.
        Note that the dataset should already have been initialized with the BigQuery client.
        This is to avoid feeding into the table project name, which could contain sensitive info

    Returns:
        str: Schema of the table in format `column_1 type, column_2 type, ...`
    """
    table = client.get_table(table_name)
    schema = table.schema
    # Extract table schemas, except columns whose name includes "aeris"
    schema_fields = [f"{field.name} {field.field_type}" for field in schema if "aeris" not in field.name.lower()]
    schema_text = ", ".join(schema_fields)
    return schema_text

In [None]:
# Initialize table(s) to feed into the model

def init_tables(tables: Union[str, List[str]]) -> List[Dict]:
    """Get table names and schema to feed to the GPT model

    Args:
        tables (Union[str, List[str]]): Table name in BigQuery. Must be in format `dataset.table_name`

    Returns:
        List[Dict]: List of dictionaries, each of the item contain the table name and its schema
    """
    if not isinstance(tables, list):
        tables = [tables]

    tables_with_schema = [
        {
            "name": table,
            "schema": export_table_schema(table)
        } for table in tables
    ]

    return tables_with_schema


# GPT interation class

In [None]:
class GPTQueryGenerator:
    def __init__(self, tables: List, model: str = "gpt-3.5-turbo-0301", max_tokens: int = 2000) -> None:
        self._model = model
        self._max_tokens = max_tokens
        self._messages = [
            {"role": "system", "content": f"TABLE `{table['name']}` - SCHEMA `{table['schema']}`"} for table in tables
        ]
        # Initialize the list of messages with some instructions for better query
        self._messages += [
            {
                "role": "system",
                "content": "Use the reference tables - schemas to compose SQL query in BigQuery dialect",
            },
            {"role": "system", "content": "Answer user's questions with SQL queries only"},
            # {"role": "system", "content": "Only join tables in reference if needed. Do not join unnecessarily."},
            {"role": "system", "content": "Only use the columns existing in the corresponding schema for each table"},
            {
                "role": "system",
                "content": "Columns of types DATE and TIMESTAMP are incompatible. Convert before do operations related to both",
            },
            # {"role": "assistant", "content": "If the user ask for something by X, then it's a good idea to have a GROUP BY X clause in the SQL"},
            # {"role": "assistant", "content": "New column names in AS clause should not have any white spaces and do not put them inside quotes"},
            # {"role": "assistant", "content": "Remember to make sure all the columns are grouped nor aggregated when you are using aggragation in SQL"},
            # {"role": "assistant", "content": "If the user do not mention any time period, the query should take into account all the dates available in the table"},
            # {"role": "assistant", "content": "When the user do not mention any specific requirements for the numeric values, default to average"},
            # {"role": "system", "content": "Only provide the SQL query. Do not provide anything else outside the query."},
        ]

    def generate_sql_query(self, question: str, temperature: float = 0.2) -> str:
        """Get the SQL query the model generated to answer user's question. The context will be specific to the tables initialed with the object

        Args:
            question (str): Question for the query to answer
            temperature (float, optional): temperature parameter for the GPT model, range from `0-2`. Higher means more likely to generate more 'random' answer, lower means the answers will be more 'focused' to the topic given as input. Defaults to 0.2.

        Returns:
            str: Executable query, specific to the user's question and the schemas provided
        """
        print(f"Question:\n{question}")
        # Guarantee that the number tokens from our messages will not exceed the maximum allowed
        self._prune_messages()
        # Extract content without timestamp field because extra fields are not allowed for the API
        msg = [{key: value for key, value in d.items() if key != "timestamp"} for d in self._messages] + [
            {"role": "user", "content": question}
        ]

        response = openai.ChatCompletion.create(
            model=self._model,
            messages=msg,
            max_tokens=MAX_ALLOWED_TOKENS,
            temperature=temperature,
            n=1,
            stop=None,
        )
        answer = response.choices[0].message["content"]
        print(f"Answer:\n{answer}")
        sql_query = self._parse_query(answer)
        # Adding User's previous question as context for conversation.
        self._messages += [
            {
                "role": "system",
                "name": "user_previous_question",
                "content": question,
                "timestamp": datetime.utcnow().timestamp(),
            }
        ]
        return sql_query

    def execute_sql_query(self, sql_query: str) -> pd.DataFrame:
        """Run a SQL queries on predefined project.
        Note that the `client` variable for BigQuery client must be initialized before runnign this method

        Args:
            sql_query (str): Query to run

        Returns:
            pd.DataFrame: Pandas dataframe containing the results of the query
        """
        result = client.query(sql_query).to_dataframe()
        return result

    def get_query_response(self, question: str = None) -> str:
        """Get the SQL response (with explanations and other filler words the GPT model will likely to give)

        Args:
            question (str, optional): What we want to ask the model. If not provided beforehand then it will be prompted as an input box

        Returns:
            str: Model's query response to the question
        """
        user_input = input("Input question:") if not question else question

        sql_query = self.generate_sql_query(user_input)
        print("Generated SQL Query:\n", sql_query)
        return sql_query

    def ask_question_and_validate(
        self, question: str = None, temperature: float = 0.2, debug: bool = False
    ) -> pd.DataFrame:
        """Ask GPT some questions on predefined tables. The query generated will be automatically executed

        Args:
            question (str, optional): What we want to ask the model. If not provided beforehand then it will be prompted as an input box.
            temperature (float, optional): temperature parameter for the GPT model, range from `0-2`. Higher means more likely to generate more 'random' answer, lower means the answers will be more 'focused' to the topic given as input. Defaults to 0.2.
            debug (bool, optional): Debug mode. Enable to get the full stacktrace when query execution encounter errors. Defaults to False.

        Returns:
            pd.DataFrame: Pandas dataframe object as the result of running the query
        """
        user_input = input("Input question:") if not question else question

        sql_query = self.generate_sql_query(user_input, temperature)

        # Execute the generated SQL query
        try:
            result = self.execute_sql_query(sql_query)
        except BadRequest as err:
            # Notify the user about the bad query generated
            print(f"Running the generated query caused the error: {err.errors}")
            if debug:
                print(traceback.format_exc())
            print(f"Please try to ask the question again, rephrase or provide additional guides if necessary")
            result = None
        # print("Result:", result)
        return result

    def _parse_query(self, msg: str) -> str:
        """Extract the SQL from an  answer from GPT

        Args:
            msg (str): Raw string answer from GPT

        Returns:
            str: The extracted SQL from the answer. If there's no SQL then return `None`
        """
        pattern = r"```[a-zA-Z]*\n([\s\S]+?)\n```"
        matching = re.search(pattern, msg)
        if matching:
            matching = matching[0]
            # The SQL in the answer will always be in wrapped in side two ```, with a line break
            idx1 = matching.find("\n")
            idx2 = matching.rfind("\n")
            return matching[idx1 + 1 : idx2]
        else:
            return None

    def _num_tokens_from_message(self) -> int:
        """Returns the number of tokens used by a list of messages."""
        # Only get the valid messages
        actual_content = [{key: value for key, value in d.items() if key != "timestamp"} for d in self._messages]
        try:
            encoding = tiktoken.encoding_for_model(self._model)
        except KeyError:
            encoding = tiktoken.get_encoding("cl100k_base")

        if self._model == "gpt-3.5-turbo-0301":  # note: future models may deviate from this
            num_tokens = 0
            for message in actual_content:
                num_tokens += 4  # every message follows <im_start>{role/name}\n{content}<im_end>\n
                for key, value in message.items():
                    num_tokens += len(encoding.encode(value))
                    if key == "name":  # if there's a name, the role is omitted
                        num_tokens += -1  # role is always required and always 1 token
            num_tokens += 2  # every reply is primed with <im_start>assistant
            return num_tokens
        else:
            raise NotImplementedError(
                f"""num_tokens_from_messages() is not presently implemented for model {self._model}.
        See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
            )

    def _prune_messages(self) -> None:
        """Adjust the promts sent to model so that the total number of tokens
        do not exceed the max number of tokens allowed by the model
        """
        # Recursively remove the ealiest messages until the number of tokens is within limit.
        while self._num_tokens_from_message() > MAX_ALLOWED_TOKENS:
            # Filter dictionaries with the "timestamp" field
            filtered_data = [d for d in self._messages if "timestamp" in d]
            # Find the dictionary with the earliest timestamp
            if filtered_data:
                earliest_entry = min(filtered_data, key=lambda x: x["timestamp"])
                self._messages.remove(earliest_entry)
            self._prune_messages()

# Interation testing

In [None]:
table_names = [
    ## Test tables containing data from open data source
    "test_financial_open_data.trans",
    "test_financial_open_data.account",
    "test_financial_open_data.client",
    "test_financial_open_data.order",

    ## Aeris tables
    # "cmp_dw_prov.aersys_accounts",
    # "cmp_dw_prov.device",
    # "cmp_dw_prov.product",
    # "ais_acp_int_ds.packet_data ",
    # "location.all_cell_locations",
]

In [None]:
tables = init_tables(table_names)
tables

[{'name': 'test_financial_open_data.trans',
  'schema': 'trans_id INTEGER, account_id INTEGER, date DATE, type STRING, operation STRING, amount INTEGER, balance INTEGER, k_symbol STRING, bank STRING, account INTEGER'},
 {'name': 'test_financial_open_data.account',
  'schema': 'account_id INTEGER, district_id INTEGER, frequency STRING, date DATE'},
 {'name': 'test_financial_open_data.client',
  'schema': 'client_id INTEGER, gender STRING, birth_date DATE, district_id INTEGER'},
 {'name': 'test_financial_open_data.order',
  'schema': 'order_id INTEGER, account_id INTEGER, bank_to STRING, account_to INTEGER, amount FLOAT, k_symbol STRING'}]

In [None]:
gpt = GPTQueryGenerator(tables=tables)

In [None]:
df = gpt.ask_question_and_validate("top 10 accounts with most transactions in all time period")
df

Question:
top 10 accounts with most transactions in all time period
Answer:
Here's the query to get the top 10 accounts with the most transactions in all time period:

```
SELECT account_id, COUNT(*) AS transaction_count
FROM test_financial_open_data.trans
GROUP BY account_id
ORDER BY transaction_count DESC
LIMIT 10
```


Unnamed: 0,account_id,transaction_count
0,8261,675
1,3834,665
2,96,661
3,2932,655
4,9307,649
5,9265,643
6,5215,637
7,2762,634
8,1801,633
9,5952,628


In [None]:
def extract_sql(s:str):
    pattern = r"```[a-zA-Z]*\n([\s\S]+?)\n```"
    m = re.search(pattern, s)[0]
    ind1 = m.find('\n')
    ind2 = m.rfind('\n')
    return m[ind1+1:ind2]


In [None]:
s = """Here's the query to get the top 10 accounts with the most transactions in all time period:

```sql
SELECT account_id, COUNT(*) AS transaction_count
FROM test_financial_open_data.trans
GROUP BY account_id
ORDER BY transaction_count DESC
LIMIT 10
```"""

In [None]:
print(extract_sql(s))

SELECT account_id, COUNT(*) AS transaction_count
FROM test_financial_open_data.trans
GROUP BY account_id
ORDER BY transaction_count DESC
LIMIT 10


In [None]:
df = gpt.ask_question_and_validate("number of male and female customers")
df

Question:
number of male and female customers
Answer:
To get the number of male and female customers, we can use the `client` table and group by the `gender` column. Here's the query:

```
SELECT gender, COUNT(*) as count
FROM test_financial_open_data.client
GROUP BY gender
```

This will give us a result with two rows, one for each gender, and the count of customers for each gender.


Unnamed: 0,gender,count
0,F,2645
1,M,2724


In [None]:
df = gpt.ask_question_and_validate("client's birth year distribution")
df

Question:
client's birth year distribution
Answer:
Here's the SQL query to get the client's birth year distribution:

```
SELECT EXTRACT(YEAR FROM birth_date) AS birth_year, COUNT(*) AS count
FROM test_financial_open_data.client
GROUP BY birth_year
ORDER BY birth_year ASC
```

This query extracts the birth year from the `birth_date` column, groups the clients by birth year, and counts the number of clients in each group. The result is ordered by birth year in ascending order.


Unnamed: 0,birth_year,count
0,1911,2
1,1912,1
2,1913,3
3,1914,4
4,1915,4
...,...,...
72,1983,3
73,1984,1
74,1985,3
75,1986,1


In [None]:
df = gpt.ask_question_and_validate("show client's total spending by birth year.")
df

Question:
show client's total spending by birth year.
Answer:
```
SELECT 
  EXTRACT(YEAR FROM c.birth_date) AS birth_year, 
  SUM(t.amount) AS total_spending
FROM 
  test_financial_open_data.client c
  JOIN test_financial_open_data.account a ON c.client_id = a.account_id
  JOIN test_financial_open_data.trans t ON a.account_id = t.account_id
GROUP BY 
  birth_year
ORDER BY 
  birth_year
``` 

This query will join the `client`, `account`, and `trans` tables to get the total spending of each client by their birth year. We first extract the birth year from the `birth_date` column using the `EXTRACT` function. Then we join the tables and group the results by birth year. Finally, we sort the results by birth year in ascending order.


Unnamed: 0,birth_year,total_spending
0,1911,600994
1,1912,233099
2,1913,2605796
3,1914,2020703
4,1915,2331146
...,...,...
71,1982,35151313
72,1983,5052739
73,1985,715306
74,1986,1632106


In [None]:
df = gpt.ask_question_and_validate("show the bank receiving the most money for each customer. Specify where you get the column client_id in every subquery if necessary. Table account does not have the column client_id", debug=True)
df

Question:
show the bank receiving the most money for each customer. Specify where you get the column client_id in every subquery if necessary. Table account does not have the column client_id
Answer:
Here is the query that shows the bank receiving the most money for each customer. I have specified where I get the column client_id in every subquery if necessary. Since the table account does not have the column client_id, I have used the table trans to get the client_id.

```
SELECT 
    t.client_id, 
    o.bank_to, 
    SUM(o.amount) AS total_amount
FROM 
    test_financial_open_data.order o
    JOIN (
        SELECT DISTINCT 
            t1.account_id, 
            t2.client_id
        FROM 
            test_financial_open_data.trans t1
            JOIN test_financial_open_data.account t2 ON t1.account_id = t2.account_id
    ) t ON o.account_id = t.account_id
GROUP BY 
    t.client_id, 
    o.bank_to
HAVING 
    SUM(o.amount) = (
        SELECT 
            MAX(total_amount) 
        F

# Using ChatGPT to generate queries for Aeris data warehouse

In [None]:
client = bigquery.Client(project="aeriscom-acpcmn-pre-202006")

In [None]:
table_names = [
    ## Test tables containing data from open data source
    # "test_financial_open_data.trans",
    # "test_financial_open_data.account",
    # "test_financial_open_data.client",
    # "test_financial_open_data.order",

    ## Aeris tables
    "cmp_dw_prov.aersys_accounts",
    "cmp_dw_prov.device",
    "cmp_dw_prov.product",
    "ais_acp_int_ds.packet_data",
    "device_location.all_cell_locations",
]
tables = init_tables(table_names)
tables

[{'name': 'cmp_dw_prov.aersys_accounts',
  'schema': 'ACCOUNT_ID INTEGER, ACCOUNT_TYPE INTEGER, NAME STRING, DIVISION STRING, WWW_ADDR STRING, SERVICE_NAME STRING, TMIN_MAXREQUEST INTEGER, TMIN_MAXTOTAL STRING, DATAINFO_TYPE STRING, DATAREPORT_TYPE STRING, MIN_MAXREQUEST INTEGER, PREDEVELOPER_ID INTEGER, NOT_USED INTEGER, MARKETING_TEAM STRING, REMARKS STRING, REGION_CODE INTEGER, PEAKTIME_BEGIN INTEGER, PEAKTIME_END INTEGER, REP_TIME INTEGER'},
 {'name': 'cmp_dw_prov.device',
  'schema': 'REP_TIMESTAMP INTEGER, DEVICE_ID INTEGER, ACCOUNT_ID INTEGER, DEVICE_STATUS INTEGER, IMSI STRING, ICCID STRING, PRIMARY_MIN STRING, ESN STRING, IMEI STRING, MEID STRING, MSISDN STRING, RADIO_MAKER STRING, FW_VERSION STRING, HW_VERSION STRING, DEVICE_TYPE INTEGER, APPLICATION_TYPE STRING, ACTIVE_DATE INTEGER, DEACTIVATE_DATE INTEGER, DEVICE_DESCRIPTION STRING, TERM_DATE INTEGER, CARRIER_PROVIDER STRING, LAST_MODIFIED_BY STRING, LAST_MODIFIED_DATE INTEGER, APP_FW_VERSION STRING, APP_HW_VERSION STRING, 

In [None]:
gpt = GPTQueryGenerator(tables=tables, model=DEFAULT_GPT_MODEL, max_tokens=MAX_ALLOWED_TOKENS)

In [None]:
df = gpt.ask_question_and_validate(
    "show top 5 accounts with the most devices in United States. Note that table all_cell_locations does not have column mccMnc and table all_cell_locations does not have the column country"
)
df

Question:
show top 5 accounts with the most devices in United States. Note that table all_cell_locations does not have column mccMnc and table all_cell_locations does not have the column country
Answer:
To solve this question, we need to join the following tables: `cmp_dw_prov.device`, `cmp_dw_prov.aersys_accounts` and `device_location.all_cell_locations`. 

We can join the tables `cmp_dw_prov.device` and `cmp_dw_prov.aersys_accounts` using the `ACCOUNT_ID` column. Then, we can join the resulting table with `device_location.all_cell_locations` using the `msisdn` column from `cmp_dw_prov.device` and the `mcc` and `mnc` columns from `device_location.all_cell_locations`.

After that, we can filter the resulting table to only include rows where the `country` column from `device_location.all_cell_locations` is equal to "United States". Finally, we can group the resulting table by the `ACCOUNT_ID` column and count the number of devices for each account. We can then sort the resulting table b

In [None]:
gpt._num_tokens_from_message()

717

In [None]:
df = gpt.ask_question_and_validate("Show top 10 accounts with the most devices")
df

Question:
Show top 10 accounts with the most devices
Answer:
Sure! Here's the SQL query to show the top 10 accounts with the most devices:

```sql
SELECT ACCOUNT_ID, COUNT(*) AS DEVICE_COUNT
FROM cmp_dw_prov.device
GROUP BY ACCOUNT_ID
ORDER BY DEVICE_COUNT DESC
LIMIT 10
```

This query selects the `ACCOUNT_ID` column and counts the number of devices for each account using `COUNT(*)`. It then groups the results by `ACCOUNT_ID` and sorts the results in descending order by `DEVICE_COUNT`. Finally, it limits the results to the top 10 accounts with the most devices.


Unnamed: 0,ACCOUNT_ID,DEVICE_COUNT
0,40028,4267183
1,10931,3225152
2,10115,1715295
3,41071,1705211
4,10578,1629963
5,21561,1511710
6,16544,1500704
7,13647,1361136
8,11034,1077637
9,24012,781503


In [None]:
gpt._num_tokens_from_message()

733

In [None]:
df = gpt.ask_question_and_validate("Show top 10 accounts with the least devices")
df

Question:
Show top 10 accounts with the least devices
Answer:
To show the top 10 accounts with the least devices, we can use the `device` table and group by `account_id`, then order by the count of devices per account in ascending order and limit the results to 10. Here's the SQL query:

```sql
SELECT account_id, COUNT(*) AS device_count
FROM cmp_dw_prov.device
GROUP BY account_id
ORDER BY device_count ASC
LIMIT 10
```


Unnamed: 0,account_id,device_count
0,22078,1
1,26540,1
2,10986,1
3,18038,1
4,30205,1
5,38367,1
6,12561,1
7,20910,1
8,28902,1
9,12986,1


In [None]:
gpt._num_tokens_from_message()

749

In [None]:
df = gpt.ask_question_and_validate("Show top 10 accounts with most data transferred (in GB) yesterday")
df

Question:
Show top 10 accounts with most data transferred (in GB) yesterday
Answer:
To show the top 10 accounts with the most data transferred (in GB) yesterday, we need to join the `ais_acp_int_ds.packet_data` table with the `cmp_dw_prov.device` table to get the account ID and then aggregate the data transferred by account. We can use the following query:

```sql
SELECT
  a.ACCOUNT_ID,
  SUM(p.BIN + p.BOUT) / 1024 / 1024 / 1024 AS DATA_TRANSFERRED_GB
FROM
  `ais_acp_int_ds.packet_data` p
JOIN
  `cmp_dw_prov.device` d
ON
  p.deviceId = d.DEVICE_ID
JOIN
  `cmp_dw_prov.aersys_accounts` a
ON
  d.ACCOUNT_ID = a.ACCOUNT_ID
WHERE
  DATE(p.eventTime) = DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)
GROUP BY
  a.ACCOUNT_ID
ORDER BY
  DATA_TRANSFERRED_GB DESC
LIMIT
  10
```

This query joins the `ais_acp_int_ds.packet_data` table with the `cmp_dw_prov.device` table on the `deviceId` column and then joins the `cmp_dw_prov.device` table with the `cmp_dw_prov.aersys_accounts` table on the `ACCOUNT_ID` c

Unnamed: 0,ACCOUNT_ID,DATA_TRANSFERRED_GB
0,10115,21873.192659
1,34105,10970.388692
2,13670,10896.835574
3,23567,6013.739582
4,19008,3826.402539
5,18930,2319.145089
6,11008,2286.52081
7,22958,2066.98323
8,18914,2047.222646
9,90308,1773.686794


In [None]:
gpt._num_tokens_from_message()

770

In [None]:
df = gpt.ask_question_and_validate("Show top 5 countries with the most devices")
df

Question:
Show top 5 countries with the most devices
Answer:
To show the top 5 countries with the most devices, we can use the `all_cell_locations` and `ais_acp_int_ds.packet_data` tables. We will join these tables on the `location_info` and `imsi` columns, respectively, to get the country for each device. Then we will group by country and count the number of devices in each country. Finally, we will order the results by the count of devices in descending order and limit the output to the top 5 countries.

Here's the SQL query:

```
SELECT
  SUBSTR(location_info, 1, 2) AS country_code,
  COUNT(DISTINCT deviceId) AS device_count
FROM
  device_location.all_cell_locations AS loc
JOIN
  ais_acp_int_ds.packet_data AS pd
ON
  loc.cid = CAST(pd.locationInfo AS INT64)
WHERE
  pd.imsi IS NOT NULL
GROUP BY
  country_code
ORDER BY
  device_count DESC
LIMIT
  5
```

This query will return the top 5 countries with the most devices, along with the number of devices in each country. Note that we are 

In [None]:
gpt._num_tokens_from_message()

742

# Remarks

## How we phrase the question is very important

It will hugely affect the resulting query


### First example

This is done without the instruction

```JSON
{
    "role": "assistant",
    "content": "If the user do not mention any time period, the query should take into account all the dates available in the table"
}
```


Take this question for example:

    "show percentage of death over population by country, order from highest to lowest"

The model will return the following SQL

```SQL
SELECT location, (total_deaths/population)*100 AS death_percentage
FROM test.owid_covid_data
WHERE population IS NOT NULL
GROUP BY location, population
ORDER BY death_percentage DESC
```

This query will throw an error, because of the missing aggregation

`BadRequest: 400 SELECT list expression references column total_deaths which is neither grouped nor aggregated`

But if we rephrase the question as follow:

    "show average percentage of death over population by country, order from highest to lowest"

This is the SQL we received:

```SQL
 SELECT location, AVG(total_deaths_per_million/population)*100 AS avg_death_percentage
FROM test.owid_covid_data
GROUP BY location
ORDER BY avg_death_percentage DESC
```

Which is more in line with what we want. The difference is only the word __"average"__ that we put in the question

### Second example

This is done with the full set of instructions seen above in function `generate_sql_query`

Question: __"show top 3 contintents with the highest death count per population"__

SQL:
```SQL
 SELECT continent, SUM(total_deaths)/SUM(population) AS death_per_population
FROM test.owid_covid_data
GROUP BY continent
ORDER BY death_per_population DESC
LIMIT 3;
```

Question: __"show top 3 contintents with the highest average death count per population"__

SQL:
```SQL
 SELECT continent, AVG(total_deaths/population) AS avg_death_per_pop
FROM test.owid_covid_data
GROUP BY continent
ORDER BY avg_death_per_pop DESC
LIMIT 3;
```

## The temperature parameter

According to the [API documentation](https://platform.openai.com/docs/api-reference/chat/create):

```
temperature number  Optional    Defaults to 1

What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
```

Reducing this parameter will only make the answers we receive less random, but will not guarantee the same answer everytime. So It's best to just keep it low.

## Fine-tuning

[More to explore]

### Detailed Instructions

A way to fine-tune the model is to give it some "rules" under the `messages` parameter. For example

```python
messages = [
    ...
    {"role": "assistant", "content": "If the user ask for something by X, then the answer SQL should have GROUP BY X clause"},
    ...
]
```

With this message automatically embedded into the question every time we ask the model, the situation where the model "forgets" to aggregate the column reduced drastically

So to apply the "GPT approach", we can summarize some "common sense" rules to feed to the model input as guidance

### Further fine tuning (TODO)

Refer to:

- https://platform.openai.com/docs/guides/fine-tuning
- https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
- https://docs.wandb.ai/guides/integrations/openai?utm_source=wandb_docs&utm_medium=code&utm_campaign=OpenAI+API
