In [1]:
# use svg graphics, display inline
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import glob
import re
import copy
import sys
import time
import ast
from typing import Any, Dict, Tuple

# basic scientific computing imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.formula.api as smf

# ai
from tqdm import tqdm
from Bronco import bronco

# hex colors for plotting
SOFT_PURPLE = '#8565C4'
SOFT_RED = '#C23F38'
SOFT_GREEN = '#56B000'
NEUTRAL_GREY = '#A9A9A9'

# display config
pd.set_option('display.float_format', lambda x: '%.3f' % x)
plt.rcParams['figure.figsize'] = 6, 4
plt.style.use('ggplot')
np.set_printoptions(suppress=True)
np.random.seed(42)

print(sys.version)

3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ]


Insight structure:

```python
{
    'question': question, 
    'query': query, 
    'query_result': result, 
    'answer': answer
}
```

# Load and clean the dataset

In [2]:
df = (
    pd.read_csv('data/spaceship_titanic.csv')
    .dropna(how='any')
    .rename(columns=str.lower)
    .rename(columns={
        'passengerid': 'passenger_id',
        'homeplanet': 'home_planet',
        'cryosleep': 'did_use_cryosleep',
        'vip': 'is_vip_passenger',
        'transported': 'did_survive_trip'
    })
    .astype({
        'did_survive_trip': float,
        'did_use_cryosleep': float,
        'is_vip_passenger': float
    })
    .assign(
        total_spent_dollars=lambda d: d.roomservice + d.foodcourt + d.shoppingmall + d.spa + d.vrdeck
    )
    .drop(columns=['roomservice', 'foodcourt', 'shoppingmall', 'spa', 'vrdeck'])

)
df.sample(3)

Unnamed: 0,passenger_id,home_planet,did_use_cryosleep,cabin,destination,age,is_vip_passenger,name,did_survive_trip,total_spent_dollars
8441,9014_01,Europa,1.0,B/293/P,TRAPPIST-1e,29.0,0.0,Nunkib Dishearly,1.0,0.0
8058,8615_02,Earth,1.0,G/1389/S,55 Cancri e,13.0,0.0,Jarena Buckentry,1.0,0.0
320,0358_01,Earth,1.0,G/50/S,TRAPPIST-1e,50.0,0.0,Ronna Connon,1.0,0.0


-----

In [3]:
question_to_query_prompt = '''
# Context
Your job is to write a sql query that answers the following question:
{question}

Below is a list of columns and sample values. Your query should only use the data contained in the table. The table name is `{table_name}`.

# Columns and sample values
{table_sample}

If the question is not a question or is answerable with the given columns, respond to the best of your ability.
Do not use columns that aren't in the table.
Ensure that the query runs and returns the correct output.

# Your query:
'''

def extract_first_sql_query(input_string):
    pattern = r'(?:SELECT|WITH)(?:.|\n)*?(?=;|$)'
    query = re.search(pattern, input_string, flags=re.IGNORECASE)
    
    output_query = query.group(0) if query else None
    
    if output_query is None:
        return None
    return output_query + ';' if not output_query.endswith(';') else output_query

def question_to_query(question: str, db_connection, table_name: str) -> str:
    '''
    Converts a natural language question into an SQL query based on a given table in a database.

    This function takes a natural language question and converts it into an SQL query. It uses the `bronco` module
    to interact with a machine learning model (GPT-4), which generates the SQL query based on the input question,
    the name of the table, and a sample of data from the table.

    Parameters:
    question (str): A natural language question that is to be converted into an SQL query.
    db_connection: A database connection object used to interact with the database.
    table_name (str): The name of the table in the database to be queried.

    Returns:
    str: An SQL query string generated from the input question.
    '''

    table_sample = bronco.data.get_table_sample(db_connection, table_name)
    query_generator = bronco.LLMFunction(
        prompt_template=question_to_query_prompt,
        model_name=bronco.GPT_4,
        parser=extract_first_sql_query
    )

    return query_generator.generate({
        'question': question,
        'table_name': table_name,
        'table_sample': table_sample
    })

In [4]:
results_to_answer_prompt = '''
# Task
Based on the results of a SQL query, provide a brief summary of the key findings and explicitly answer the following question:
{question}

The query results from the table `{table_name}` are as follows:

# Query Results Table
{query_results_table}

In 2-5 sentences, summarize the main insights from the query results and give a clear and direct answer to the original question.

# Summary and Answer:
'''

def interpret_query_results(question, query_result):

    result_interpreter = bronco.LLMFunction(
        prompt_template=results_to_answer_prompt,
        model_name=bronco.GPT_4
    )

    return result_interpreter.generate({
        'question': question,
        'table_name': table_name,
        'query_results_table': query_result
    })

In [5]:
def question_to_insight(question: str, df: pd.DataFrame, table_name: str) -> Dict[str, Any]:
    '''
    Generates insights from a given question by querying a DataFrame.

    This function takes a question as input, uploads a DataFrame to a sqlite3 in-memory
    database under a specified table name, then generates a SQL query from the question.
    It executes this query against the database, interprets the results, and returns an
    insight in the form of a dictionary.

    Parameters:
    question (str): The question to be answered.
    df (pd.DataFrame): The DataFrame to be queried.
    table_name (str): The name of the table in the database. This should be meaningful
                      as it helps the model understand the context of the data.

    Returns:
    Dict[str, Any]: A dictionary containing the original question, generated SQL query,
                    query results, and the interpreted answer.
    '''
    
    db_connection = bronco.data.upload_df_to_sql(df, table_name)

    query = question_to_query(
        question=question,
        db_connection=db_connection,
        table_name=table_name
    )

    current_insight = {
        'question': question,
        'query': query
    }
    current_insight = maybe_fix_insight(current_insight)
    
    current_insight['query_result'] = bronco.data.query_database(db_connection, query)
    db_connection.close()
    
    current_insight['answer'] = interpret_query_results(question, current_insight['query_result'])
    insight = maybe_fix_insight(current_insight)
    
    return insight

def pretty_print_insight(insight):
    for key in insight:
        print(f'{key}:\n{insight[key]}\n')

In [11]:
fixer_prompt = '''
# Context
Your job is to find the error and fix it in a python dictionary that represents a data insight. It contains up to 4 elements: (question, query, query_result, answer). You will be given a brief explanation of what went wrong. Note that the `'query_result'` is never wrong, only the `'query'` or the `'answer'`

# Task
Return the fixed version of the dictionary. Return it in python dictionary form. Do not add extra key-value pairs.

# Example output
{{
    'question': 'What are the top 3 countries in terms of user count?',
    'query': 'SELECT country, COUNT(DISTINCT userid) AS total_users from users GROUP BY 1 ORDER BY 2 DESC LIMIT 3',
    'query_result': "     country  total_users\n0        US       120000\n1     Spain        95000\n2    France        87000",
    'answer': 'The top 3 countries are US, Spain, France'
}}

# Incorrect dictionary
{insight_dict}

# Mistake explanation
{explanation}

# Your fixed insight dict
'''

def extract_dict(input_string):
    pattern = r'\{[^{}]*\}'

    match = re.search(pattern, input_string)

    if not match:
        return {}

    # Safely evaluate the matched string to a dictionary
    try:
        return ast.literal_eval(match.group(0))
    except ValueError:
        return {}


insight_fixer = bronco.LLMFunction(
    prompt_template=fixer_prompt,
    model_name=bronco.GPT_3_5_TURBO,
    parser=extract_dict
)

fix = insight_fixer.generate({
    'insight_dict': bad_insight,
    'explanation': validation['reasoning']
})

print(fix)

NameError: name 'bad_insight' is not defined

In [16]:
validator_prompt = '''
# Context
You will be given a python dictionary that represents a data insight. It contains up to 4 elements: (question, query, query_result, answer). Your job is to validate these elements and ensure that there are no major errors. These errors can include SQL errors, incorrect interpretations of the data, SQL code that doesn't achieve the correct objective, etc. Note that the `'query_result'` is never wrong, only the `'query'` or the `'answer'`

# Task
Validate the following insight dictionary and return a decision of [VALID] or [NOT_VALID], followed by a brief explanation of why.

# Example outputs
[NOT_VALID] The SQL query performs integer division instead of float division for a proportion
[VALID] There are no major issues
[NOT_VALID] The answer incorrectly interprets the data and draws the wrong conclusion

# Insight dictionary
{insight_dict}

# Your validation output
'''

def parse_label_with_regex(input_string):
    # Using regex to extract both the label and the reasoning
    match = re.search(r'\[(.*?)\]\s*(.*)', input_string)

    if not match:
        return {}
    return {
        'label': match.group(1), 
        'reasoning': match.group(2)
    }

def maybe_fix_insight(insight):
    
    insight_validator = bronco.LLMFunction(
        prompt_template=validator_prompt,
        model_name=bronco.GPT_4,
        parser=parse_label_with_regex
    )
    validation_result = insight_validator.generate({'insight_dict': insight})
    
    if validation_result['label'] == 'VALID':
        return insight

    print('Fixing insight...')

    insight_fixer = bronco.LLMFunction(
        prompt_template=fixer_prompt,
        model_name=bronco.GPT_4,
        parser=extract_dict
    )

    fixed_insight = insight_fixer.generate({
        'insight_dict': insight,
        'explanation': validation_result['reasoning']
    })

    original_keys = list(insight.keys())

    # removes hallucinated keys
    fixed_insight = {key: fixed_insight[key] for key in original_keys if key in fixed_insight}

    return fixed_insight

In [17]:
question = 'What is the probability of surviving the trip by home planet?'
table_name = 'spaceship_titanic'

insight = question_to_insight(question, df, table_name)

pretty_print_insight(insight)

Fixing insight...
question:
What is the probability of surviving the trip by home planet?

query:
SELECT 
    home_planet, 
    SUM(did_survive_trip) AS survived, 
    COUNT(*) AS total, 
    (CAST(SUM(did_survive_trip) AS FLOAT) / CAST(COUNT(*) AS FLOAT)) AS survival_probability 
FROM 
    spaceship_titanic 
GROUP BY 
    home_planet;

query_result:
  home_planet  survived  total  survival_probability
0       Earth  1518.000   3566                 0.426
1      Europa  1104.000   1673                 0.660
2        Mars   705.000   1367                 0.516

answer:
The probability of surviving the trip varies depending on the home planet of the passengers. Passengers from Europa have the highest survival probability at 66%, followed by Mars at 51.6%, and Earth at 42.6%. Therefore, the survival probability is highest for passengers from Europa and lowest for those from Earth.



# TODO:
Write 2 functions:
- An insight validator that ensures that there are no major flaws in the insight
- An insight fixer. This function should pass the entire insight to an LLM and identify which step caused the error. Example: The SQL query divides two integers and returns 0 instead of a percentage. Have it return a dictionary with the fixed step. From there, run the subsequent steps with the messed up part fixed.

In [32]:
# bad_insight = copy.copy(insight)
# good_insight = copy.copy(insight)

In [72]:
validation = insight_validator.generate({'insight_dict': bad_insight})
validation

{'label': 'NOT_VALID',
 'reasoning': 'The SQL query performs integer division instead of float division for a probability calculation.'}

{'question': 'What is the probability of surviving the trip by home planet?', 'query': 'SELECT \n    COUNT(*) AS total_passengers,\n    SUM(did_survive_trip) AS total_survived,\n    CAST(SUM(did_survive_trip) AS FLOAT) / COUNT(*) AS probability_survival\nFROM \n    spaceship_titanic;', 'query_result': 'total_passengers  total_survived  probability_survival\n0               1000             700               0.7', 'answer': 'The probability of surviving the trip by home planet is 0.7'}
