# Data privacy

According to [Open AI's official statement](https://help.openai.com/en/articles/5722486-how-your-data-is-used-to-improve-model-performance): 

>OpenAI does not use data submitted by customers via our API to train OpenAI models or improve OpenAI’s service offering. In order to support the continuous improvement of our models, you can fill out this form to opt-in to share your data with us. 

According to [Open AI's API data usage policies](https://openai.com/policies/api-data-usage-policies):

>1. OpenAI will not use data submitted by customers via our API to train or improve our models, unless you explicitly decide to share your data with us for this purpose. You can opt-in to share data.
>2. Any data sent through the API will be retained for abuse and misuse monitoring purposes for a maximum of 30 days, after which it will be deleted (unless otherwise required by law).

# Library import and client setup

In [19]:
import openai
import traceback
import pandas as pd

from google.cloud import bigquery
from google.cloud.exceptions import BadRequest
from typing import Dict, List, Union

In [20]:
PROJECT_ID = "setalab"
API_KEY = "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"

In [21]:
client = bigquery.Client(project=PROJECT_ID)

In [22]:
# Set up OpenAI API credentials
openai.api_key = API_KEY

# Define the GPT-4 model prompt
model_name = "gpt-3.5-turbo-0301"

# model_name = "gpt-4"
# GPT-4 is not available to even subscribed accounts. Need to join the waitlist for this: https://openai.com/waitlist/gpt-4-api

# Table info feed-in

In [23]:
def export_table_schema(table_name: str) -> str:
    """Get table schema from given table name

    Args:
        table_name (str): Name of table to extract schema.
        Note that the dataset should already have been initialized with the BigQuery client.
        This is to avoid feeding into the table project name, which could contain sensitive info

    Returns:
        str: Schema of the table in format `column_1 type, column_2 type, ...`
    """
    table = client.get_table(table_name)
    schema = table.schema
    # Extract table schemas, except columns whose name includes "aeris"
    schema_fields = [f"{field.name} {field.field_type}" for field in schema if "aeris" not in field.name.lower()]
    schema_text = ", ".join(schema_fields)
    return schema_text

In [24]:
# Initialize table(s) to feed into the model

def init_tables(tables: Union[str, List[str]]) -> List[Dict]:
    """Get table names and schema to feed to the GPT model

    Args:
        tables (Union[str, List[str]]): Table name in BigQuery. Must be in format `dataset.table_name`

    Returns:
        List[Dict]: List of dictionaries, each of the item contain the table name and its schema
    """
    if not isinstance(tables, list):
        tables = [tables]
    
    tables_with_schema = [
        {
            "name": table,
            "schema": export_table_schema(table)
        } for table in tables
    ]

    return tables_with_schema


# GPT interation class

In [25]:
class GPTQueryGenerator:
    def __init__(self, tables: List) -> None:
        self._messages = [
            {"role": "system", "content": f"Refernece table name `{table['name']}`, schema `{table['schema']}`"} for table in tables
        ]
        # Initialize the list of messages with some instructions for better query
        self._messages += [
            {"role": "system", "content": "You are a SQL programmer who provides efficient queries."},
            {"role": "system", "content": "Use the reference tables with the corresponding schemas to compose SQL query in Google BigQuery dialect"},
            {"role": "system", "content": "Answer user's questions with SQL queries so that they can use the queries against the tables to answer their questions"},
            {"role": "system", "content": "Only join tables in reference if needed. Do not join unnecessarily."},
            {"role": "system", "content": "Only use the columns existing in the corresponding schema for each table"},
            # {"role": "assistant", "content": "If the user ask for something by X, then it's a good idea to have a GROUP BY X clause in the SQL"},
            # # {"role": "assistant", "content": "New column names in AS clause should not have any white spaces and do not put them inside quotes"},
            # {"role": "assistant", "content": "Remember to make sure all the columns are grouped nor aggregated when you are using aggragation in SQL"},
            # {"role": "assistant", "content": "If the user do not mention any time period, the query should take into account all the dates available in the table"},
            # {"role": "assistant", "content": "When the user do not mention any specific requirements for the numeric values, default to average"},
            # {"role": "system", "content": "Only provide the SQL query. Do not provide anything else outside the query."},
        ]

    def generate_sql_query(self, question: str, temperature: float = 0.2) -> str:
        """Get the SQL query the model generated to answer user's question. The context will be specific to the tables initialed with the object

        Args:
            question (str): Question for the query to answer
            temperature (float, optional): temperature parameter for the GPT model, range from `0-2`. Higher means more likely to generate more 'random' answer, lower means the answers will be more 'focused' to the topic given as input. Defaults to 0.2.

        Returns:
            str: Executable query, specific to the user's question and the schemas provided
        """
        print(f"Question:\n{question}")

        msg = self._messages + [{"role": "user", "content": question},]

        response = openai.ChatCompletion.create(
            model=model_name,
            messages=msg,
            max_tokens=1024,
            temperature=temperature,
            n=1,
            stop=None,
        )
        sql_query = response.choices[0].message['content']
        print(f"Answer:\n{sql_query}")
        if "```sql" in sql_query:
            sql_query = sql_query.split("```sql")[1].split("```")[0]
        elif "```" in sql_query:
            sql_query = sql_query.strip().split("```")[1].strip()
        # Adding User's previous question as context for conversation.
        self._messages += [{"role": "system", "name": "user_previous_question", "content": question}]
        return sql_query

    def execute_sql_query(self, sql_query: str) -> pd.DataFrame:
        """Run a SQL queries on predefined project. 
        Note that the `client` variable for BigQuery client must be initialized before runnign this method

        Args:
            sql_query (str): Query to run

        Returns:
            pd.DataFrame: Pandas dataframe containing the results of the query
        """
        result = client.query(sql_query).to_dataframe()
        return result
    
    def get_query_response(self, question: str = None) -> str:
        """Get the SQL response (with explanations and other filler words the GPT model will likely to give)

        Args:
            question (str, optional): What we want to ask the model. If not provided beforehand then it will be prompted as an input box

        Returns:
            str: Model's query response to the question
        """
        user_input = input("Input question:") if not question else question

        sql_query = self.generate_sql_query(user_input)
        print("Generated SQL Query:\n", sql_query)
        return sql_query

    def ask_question_and_validate(self, question: str = None, temperature: float=0.2, debug: bool = False) -> pd.DataFrame:
        """Ask GPT some questions on predefined tables. The query generated will be automatically executed

        Args:
            question (str, optional): What we want to ask the model. If not provided beforehand then it will be prompted as an input box.
            temperature (float, optional): temperature parameter for the GPT model, range from `0-2`. Higher means more likely to generate more 'random' answer, lower means the answers will be more 'focused' to the topic given as input. Defaults to 0.2.
            debug (bool, optional): Debug mode. Enable to get the full stacktrace when query execution encounter errors. Defaults to False.

        Returns:
            pd.DataFrame: Pandas dataframe object as the result of running the query
        """
        user_input = input("Input question:") if not question else question

        sql_query = self.generate_sql_query(user_input, temperature)
        # print("Generated SQL Query:\n", sql_query)

        # Execute the generated SQL query
        try:
            result = self.execute_sql_query(sql_query)
        except BadRequest as err:
            # Notify the user about the bad query generated
            print(f"Running the generated query caused the error: {err.errors}")
            # Let the model the context where it generated an acceptable answer (the query did ran)
            self._messages += [
                {"role": "system", "name": "acceptable_answer", "content": f"Question: '{user_input}'. Acceptable query generated: '{sql_query}'"},
            ]
            if debug:
                print(traceback.format_exc())
            print(f"Please try to ask the question again, rephrase or provide additional guides if necessary")
            # Let the model know the bad job it did
            self._messages += [
                {"role": "system", "name": "bad_answer", "content": f"Question: '{user_input}'. Bad query generated: '{sql_query}'"},
            ]
            result = None
        # print("Result:", result)
        return result


# Interation testing

In [26]:
table_names = [
    ## Test tables containing data from open data source
    "setalab.financial.trans",
    "setalab.financial.account",
    "setalab.financial.client",
    "setalab.financial.order",
]

In [27]:
tables = init_tables(table_names)
tables

[{'name': 'setalab.financial.trans',
  'schema': 'trans_id INTEGER, account_id INTEGER, date DATE, type STRING, operation STRING, amount INTEGER, balance INTEGER, k_symbol STRING, bank STRING, account INTEGER'},
 {'name': 'setalab.financial.account',
  'schema': 'account_id INTEGER, district_id INTEGER, frequency STRING, date DATE'},
 {'name': 'setalab.financial.client',
  'schema': 'client_id INTEGER, gender STRING, birth_date DATE, district_id INTEGER'},
 {'name': 'setalab.financial.order',
  'schema': 'order_id INTEGER, account_id INTEGER, bank_to STRING, account_to INTEGER, amount FLOAT, k_symbol STRING'}]

In [28]:
gpt = GPTQueryGenerator(tables=tables)

In [29]:
df = gpt.ask_question_and_validate("top 10 accounts with most transactions in all time period")
df

Question:
top 10 accounts with most transactions in all time period
Answer:
```
SELECT account_id, COUNT(*) AS num_transactions
FROM setalab.financial.trans
GROUP BY account_id
ORDER BY num_transactions DESC
LIMIT 10
```


Unnamed: 0,account_id,num_transactions
0,8261,675
1,3834,665
2,96,661
3,2932,655
4,9307,649
5,9265,643
6,5215,637
7,2762,634
8,1801,633
9,5952,628


In [30]:
df = gpt.ask_question_and_validate("number of male and female customers")
df

Question:
number of male and female customers
Answer:
To get the number of male and female customers, we can use the `setalab.financial.client` table and count the number of rows for each gender. Here's the SQL query:

```
SELECT gender, COUNT(*) as count
FROM setalab.financial.client
GROUP BY gender
```

This will return a table with two columns: `gender` and `count`. The `gender` column will have the values "M" or "F" for male and female respectively, and the `count` column will have the number of customers for each gender.


Unnamed: 0,gender,count
0,F,2645
1,M,2724


In [31]:
df = gpt.ask_question_and_validate("client's birth year distribution")
df

Question:
client's birth year distribution
Answer:
To get the birth year distribution of clients, we can use the `EXTRACT` function to extract the year from the `birth_date` column and then use the `GROUP BY` clause to group the results by birth year. Here's the query:

```
SELECT EXTRACT(YEAR FROM birth_date) AS birth_year, COUNT(*) AS num_clients
FROM setalab.financial.client
GROUP BY birth_year
ORDER BY birth_year
```

This will give us a table with two columns: `birth_year` and `num_clients`, where `birth_year` is the birth year of the clients and `num_clients` is the number of clients born in that year. The results will be ordered by birth year in ascending order.


Unnamed: 0,birth_year,num_clients
0,1911,2
1,1912,1
2,1913,3
3,1914,4
4,1915,4
...,...,...
72,1983,3
73,1984,1
74,1985,3
75,1986,1


In [32]:
df = gpt.ask_question_and_validate("show client's total spending by birth year.")
df

Question:
show client's total spending by birth year.
Answer:
To show client's total spending by birth year, we need to join the `setalab.financial.client` table with the `setalab.financial.account` and `setalab.financial.order` tables. Here is the SQL query to get the total spending by birth year:

```
SELECT EXTRACT(YEAR FROM c.birth_date) AS birth_year, SUM(o.amount) AS total_spending
FROM setalab.financial.client c
JOIN setalab.financial.account a ON c.client_id = a.account_id
JOIN setalab.financial.order o ON a.account_id = o.account_id
GROUP BY birth_year
ORDER BY birth_year ASC
```

This query will extract the birth year from the `birth_date` column of the `setalab.financial.client` table, join it with the `setalab.financial.account` and `setalab.financial.order` tables, and group the results by birth year. The `SUM` function is used to calculate the total spending for each birth year. The results will be ordered by birth year in ascending order.


Unnamed: 0,birth_year,total_spending
0,1911,2498.0
1,1912,1578.0
2,1913,12447.0
3,1914,11637.0
4,1915,6581.0
...,...,...
70,1981,105718.2
71,1982,132319.9
72,1983,8083.0
73,1985,2228.0


In [33]:
df = gpt.ask_question_and_validate("show the bank receiving the most money for each customer. Specify where you get the column client_id in every subquery if necessary. Table account does not have the column client_id", debug=True)
df

Question:
show the bank receiving the most money for each customer. Specify where you get the column client_id in every subquery if necessary. Table account does not have the column client_id
Answer:
To get the bank receiving the most money for each customer, we need to join the `setalab.financial.order` table with the `setalab.financial.account` table on the `account_id` column. Then, we can join the result with the `setalab.financial.client` table on the `district_id` column to get the `client_id`. Finally, we group the result by `client_id` and `bank_to` and select the `client_id`, `bank_to`, and the sum of `amount`. We can then use a subquery to get the bank receiving the most money for each customer.

Here's the SQL query:

```
SELECT
  client_id,
  bank_to,
  total_amount
FROM (
  SELECT
    c.client_id,
    o.bank_to,
    SUM(o.amount) AS total_amount,
    ROW_NUMBER() OVER (PARTITION BY c.client_id ORDER BY SUM(o.amount) DESC) AS rn
  FROM
    `setalab.financial.order` o
    JO

Unnamed: 0,client_id,bank_to,total_amount
0,1,AB,29465.0
1,2,QR,283767.0
2,3,QR,283767.0
3,4,QR,55405.5
4,5,QR,55405.5
...,...,...,...
5364,13955,QR,283767.0
5365,13956,QR,283767.0
5366,13968,OP,31740.8
5367,13971,OP,26257.0


# Remarks

## How we phrase the question is very important

It will hugely affect the resulting query


### First example

This is done without the instruction

```JSON
{
    "role": "assistant",
    "content": "If the user do not mention any time period, the query should take into account all the dates available in the table"
}
```


Take this question for example:

    "show percentage of death over population by country, order from highest to lowest"

The model will return the following SQL

```SQL
SELECT location, (total_deaths/population)*100 AS death_percentage
FROM test.owid_covid_data
WHERE population IS NOT NULL
GROUP BY location, population
ORDER BY death_percentage DESC
```

This query will throw an error, because of the missing aggregation

`BadRequest: 400 SELECT list expression references column total_deaths which is neither grouped nor aggregated`

But if we rephrase the question as follow:

    "show average percentage of death over population by country, order from highest to lowest"

This is the SQL we received:

```SQL
 SELECT location, AVG(total_deaths_per_million/population)*100 AS avg_death_percentage
FROM test.owid_covid_data
GROUP BY location
ORDER BY avg_death_percentage DESC
```

Which is more in line with what we want. The difference is only the word __"average"__ that we put in the question

## Second example

This is done with the full set of instructions seen above in function `generate_sql_query`

Question: __"show top 3 contintents with the highest death count per population"__

SQL:
```SQL
 SELECT continent, SUM(total_deaths)/SUM(population) AS death_per_population
FROM test.owid_covid_data
GROUP BY continent
ORDER BY death_per_population DESC
LIMIT 3;
```

Question: __"show top 3 contintents with the highest average death count per population"__

SQL:
```SQL
 SELECT continent, AVG(total_deaths/population) AS avg_death_per_pop
FROM test.owid_covid_data
GROUP BY continent
ORDER BY avg_death_per_pop DESC
LIMIT 3;
```

## The temperature parameter

According to the [API documentation](https://platform.openai.com/docs/api-reference/chat/create):

```
temperature number  Optional    Defaults to 1

What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
```

Reducing this parameter will only make the answers we receive less random, but will not guarantee the same answer everytime. So It's best to just keep it low.

## Fine-tuning

[More to explore]

### Detailed Instructions

A way to fine-tune the model is to give it some "rules" under the `messages` parameter. For example

```python
messages = [
    ...
    {"role": "assistant", "content": "If the user ask for something by X, then the answer SQL should have GROUP BY X clause"},
    ...
]
```

With this message automatically embedded into the question every time we ask the model, the situation where the model "forgets" to aggregate the column reduced drastically

So to apply the "GPT approach", we can summarize some "common sense" rules to feed to the model input as guidance

### Further fine tuning (TODO)

Refer to:

- https://platform.openai.com/docs/guides/fine-tuning
- https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
- https://docs.wandb.ai/guides/integrations/openai?utm_source=wandb_docs&utm_medium=code&utm_campaign=OpenAI+API
