In [68]:
import openai
import csv
import os
import openai
import pandas as pd
import collections
import ast
import time
import collections
import json
import sqlite3
import tiktoken
import xxhash
import numpy as np
from openai import OpenAI
from langchain.chains import LLMChain
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
from langchain_community.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings import HuggingFaceEmbeddings
import pickle
import seaborn as sns
import matplotlib.pyplot as plt

In [69]:
# Load environment variables from .env file
from dotenv import load_dotenv
import os
import openai
from openai import OpenAI

load_dotenv(override=True)

openai.api_key = os.getenv("OPENAI_API_KEY")
openai.base_url = os.getenv("OPENAI_BASE_URL")

client = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE_URL")
)

print(f"API Key set: {bool(os.getenv('OPENAI_API_KEY'))}")
print(f"Base URL set: {os.getenv('OPENAI_BASE_URL')}")


API Key set: True
Base URL set: https://api.agicto.cn/v1


In [70]:
def save(fname, d):
    with open(fname, 'wb') as f:
        pickle.dump(d, f)
def clean_query(sql_query):
    sql_query = sql_query.replace("```sql", '')
    sql_query = sql_query.replace("```", '')
    sql_query = sql_query.replace(';', '')
    sql_query = sql_query.replace('"""', '')
    if 'SELECT' not in sql_query.upper()[:10]:
        sql_query = 'SELECT ' + sql_query
    return sql_query
def num_tokens_from_string(string: str, encoding_name: str) -> int:
    encoding = tiktoken.encoding_for_model(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

In [71]:
def generate_db_schema(database):
    db_path = f'./databases/{database}/{database}.sqlite'
    conn = sqlite3.connect(db_path, uri=True)
    full_schema_prompt_list = []
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = cursor.fetchall()
    schemas = {}
    for table in tables:
        if table == 'sqlite_sequence':
            continue
        cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(table[0]))
        create_prompt = cursor.fetchone()[0]
        schemas[table[0]] = create_prompt

    for k, v in schemas.items():
        full_schema_prompt_list.append(v)

    schema_prompt = "\n\n".join(full_schema_prompt_list)

    return schema_prompt

In [72]:
import time as time
from multiprocessing import Process, Queue
import query_module

def evalfunc(sql_source, sql_target, database, source='kaggle'):
    db_path = f'./databases/{database}/{database}.sqlite'
    if not os.path.isfile(db_path):
        print("cannot find file", db_path)
        return False
    timeout = 120
    output = Queue()
    query_process = Process(target=query_module.execute_query, args=(db_path, sql_source, output))
    query_process.start()
    output_hash = ''
    
    try:
        # Connect to sqlite db
        # Execute both!
        source_results = None
        source_results = output.get(True, timeout+5)
        query_process.join(timeout)
        if query_process.is_alive():
            print("process terminated")
            query_process.terminate()  # Terminate the process
            query_process.join()  # Make sure it's cleaned up
            return False, [Exception('SQL query took too much time to execute.')]
        if isinstance(source_results, Exception):
            raise source_results
        output_hash = xxhash.xxh128_hexdigest(str(len(source_results)), seed=123)
        connection = sqlite3.connect(db_path)
        cursor = connection.cursor()
        target_results = cursor.execute(sql_target).fetchall()
        cursor.close()
        connection.close()
        # If the lengths don't match... there's no hope
        if len(source_results) != len(target_results):
            # (result matches or not, valid, hash)
            return False, []
        if 'ORDER BY' in sql_target:
            for a, b in zip(source_results, target_results):
                # NOTE: we are doing compares that are column-order independent
                # hence the sorting and the weird key (since we may have mixed
                # types in a row)
                lhs = tuple(sorted(list(a), key=lambda x: hash(x)))
                rhs = tuple(sorted(list(b), key=lambda x: hash(x)))
                output_hash = xxhash.xxh128_hexdigest(output_hash + str(lhs), seed=123)
                if lhs != rhs:
                    # Oh no, a row doesn't match!
                    return False, []
        else:
            lset, rset = set(), set()
            for a, b in zip(source_results, target_results):
                # NOTE: we are doing compares that are column-order independent
                # hence the sorting and the weird key (since we may have mixed
                # types in a row)
                lset.add(tuple(sorted(list(a), key=lambda x: hash(x))))
                rset.add(tuple(sorted(list(b), key=lambda x: hash(x))))
            output_hash = xxhash.xxh128_hexdigest(str(lset), seed=123)
            if lset != rset:
                # Oh no, rows don't match!
                return False, []
    # If we hit an error, that's not a match I guess...
    except Exception as ex:
        print(ex)
        return False, [ex]
    return True, []


def outputHash(sql_source, database):
    db_path = f'./databases/{database}/{database}.sqlite'
    output_hash = ''
    try:
        # Connect to sqlite db
        connection = sqlite3.connect(db_path)
        cursor = connection.cursor()
        source_results = cursor.execute(sql_source).fetchall()
        output_hash = xxhash.xxh128_hexdigest(str(len(source_results)), seed=123)
        if 'ORDER BY' in sql_source:
            for a in source_results:
                lhs = tuple(sorted(list(a), key=lambda x: hash(x)))
                output_hash = xxhash.xxh128_hexdigest(output_hash + str(lhs), seed=123)
        else:
            lset = set()
            for a in source_results:
                lset.add(tuple(sorted(list(a), key=lambda x: hash(x))))
            output_hash = xxhash.xxh128_hexdigest(str(lset), seed=123)
    except Exception as ex:
        return False
    finally:
        cursor.close()
        connection.close()
    return output_hash


def execute(sql, database, source='kaggle'):
    db_path = f'./databases/{database}/{database}.sqlite'
        
    if not os.path.isfile(db_path):
        print("cannot find file")
        return False
    results = ''
    try:
        # Connect to sqlite db
        connection = sqlite3.connect(db_path)
        cursor = connection.cursor()
        results = cursor.execute(sql).fetchall()
    # If we hit an error, that's not a match I guess...
    except KeyboardInterrupt:
        cursor.close()
        connection.close()
        print("KeyboardInterrupt")
        return False
    except Exception as ex:
        cursor.close()
        connection.close()
        print(ex)
        return False
    finally:
        cursor.close()
        connection.close()
    return results

In [73]:
def GPT4_turbo_generation(prompt, t = 0.0):
    response = client.chat.completions.create(
        model = 'gpt-4-turbo',
        messages=[{"role": "user", "content": prompt}],
        n = 1,
        stream = False,
        temperature=t,
        max_tokens=4096,
        logprobs=True,
    )
    logprobs = [token.logprob for token in response.choices[0].logprobs.content]
    perplexity_score = np.exp(-np.mean(logprobs))
    return response.choices[0].message.content.strip(), perplexity_score

def GPT4o_generation(prompt, t = 0.0):
    response = client.chat.completions.create(
        model = 'gpt-4o',
        messages=[{"role": "user", "content": prompt}],
        n = 1,
        stream = False,
        temperature=t,
        max_tokens=4096,
        logprobs=True,
    )
    logprobs = [token.logprob for token in response.choices[0].logprobs.content]
    perplexity_score = np.exp(-np.mean(logprobs))
    return response.choices[0].message.content.strip(), perplexity_score



def GPT35_generation(prompt, t = 0.0):
    response = client.chat.completions.create(
        model = 'gpt-3.5-turbo',
        messages=[{"role": "user", "content": prompt}],
        n = 1,
        stream = False,
        temperature=t,
        max_tokens=4096,
        logprobs=True,
    )
    logprobs = [token.logprob for token in response.choices[0].logprobs.content]
    perplexity_score = np.exp(-np.mean(logprobs))
    return response.choices[0].message.content.strip(), perplexity_score

In [74]:
df = pd.read_csv('kaggle_dataset.csv')
userstudy = []
try:
    with open('./user_study.pkl', 'rb') as f:
        userstudy = pickle.load(f)
except (EOFError, FileNotFoundError, pickle.UnpicklingError):
    print("Warning: Could not load user_study.pkl (file missing or empty). Proceeding with empty userstudy list.")
    userstudy = []

survey_questions = []
# add original gold query inside
for d in userstudy:
    if 'Question2Ask' in d:
        assert len(d['Question2Ask']) == len(d['Answer2Question']), print(d['nl'])
    q = d['nl']
    sql = df.loc[df['nl'] == q]['sql'].values
    d['gold'] = sql[0]
    d["target_schema"] = df.loc[df['nl'] == q]['target_schema'].values
    survey_questions.append(q)

print(len(df))
# drop the user study questions
df = df[~df['nl'].isin(survey_questions)]
print(len(df))


272
272


In [75]:
df = df.reset_index()

In [76]:
feedback_prefix_v1='''/* some examples are provided */
/* example question: */
How many acres burned in fires in California each year between 2000 and 2005?
/* example gold sql query*/
SELECT\n  SUM(FIRE_SIZE),\n  FIRE_YEAR\nFROM Fires\nWHERE\n  State = "CA" AND FIRE_YEAR BETWEEN 2000 AND 2005\nGROUP BY\n  FIRE_YEAR
/* example clarification question*/
What information should the output table contain? a) two columns: the total acres burned and the year, b) one column: the total acres burned for each year, c) one column: the total acres burned across all target years, d) other (please specify).
/* example reasoning */
Output table is determined by the SELECT clause in the gold sql query. The gold query uses ‘SELECT  SUM(FIRE_SIZE), FIRE_YEAR’. As a result, the output table has two columns, the total acres burned and the year. Hence, choice a is correct.
/* example  answer*/
answer_to_cq = "a) two columns: the total acres burned and the year"

/* example question: */
Which states had the largest number of fires in 2001?
/* example gold sql query*/
SELECT\n  State\nFROM Fires\nWHERE\n  FIRE_YEAR = 2001\nGROUP BY\n  State\nORDER BY\n  COUNT(*) DESC\nLIMIT 1;
/* example clarification question*/
Is the largest number of fires referring to? a) the total size of all fire incidents, b) the number of fire incidents, c) the largest size of all fire incidents, d) other (please specify).
/* example reasoning */
The clarification question is asking about how to represent the largest number of fires. The gold query uses ‘ORDER BY COUNT(*) DESC LIMIT 1’ to find the largest number of fires. As a result, choice a is correct.
/* example  answer*/
answer_to_cq = "b) the number of fire incidents"

/* example question: */
What was the most common cause of fire between 2000 and 2005?
/* example gold sql query*/
SELECT\n  STAT_CAUSE_DESCR\nFROM Fires\nWHERE\n  FIRE_YEAR BETWEEN 2000 AND 2005\nGROUP BY\n  STAT_CAUSE_DESCR\nORDER BY\n  COUNT(*) DESC\nLIMIT 1;
/* example clarification question*/
Which information should be used to represent the 'cause of fire'? a) the code that represents the cause, b) the description of the cause, c) both the code and the description of the cause, d) other (please specify).
/* example reasoning */
The clarification question is asking for which column should be used to represent the cause of fire. The gold query uses the STAT_CAUSE_DESCR to represent the cause. As a result, choice b is correct.
/* example  answer*/
answer_to_cq = "b) the description of the cause"

/* example question: */
Whose CDs sells best?
/* example gold sql query*/
SELECT\n  artist\nFROM torrents\nGROUP BY\n  artist\nORDER BY\n  SUM(totalSnatched) DESC\nLIMIT 1;
/* example clarification question*/
Which column should be used to identify music related to 'CD'? a) groupName, b) tag, c) releaseType, d) other (please specify)
/* example reasoning */
The gold query does not use a WHERE clause to filter the CDs. Hence, the CD information is not contained in the tag column or the release type column. As a result, choice a, b, and c are all wrong.
/* example  answer*/
answer_to_cq = “d) Consider all music; No filter on ‘CD’ ”

/* example question: */
How many people wrote comments for the question "Any additional notes or comments."? */
/* example gold sql query*/
SELECT COUNT(T1.UserID) FROM Answer AS T1 INNER JOIN Question AS T2 ON T1.QuestionID = T2.questionid WHERE T2.questiontext LIKE 'Any additional notes or comments' AND T1.AnswerText IS NOT NULL
/* example clarification question*/
How to determine if a user has provided comments? a) no check needed, b) see if `AnswerText` column has empty string, c) other (please specify).
/* example reasoning */
In the gold SQL query, it checks “T1.AnswerText IS NOT NULL”. Hence, choice a and b are both wrong.
/* example  answer*/
answer_to_cq = "c) ‘wrote comments’ imply `AnswerText` is not a NULL value".

/* example question: */
Calculate the difference between the number of customers and the number of subscribers who did the trip in June 2013. 
/* example gold sql query*/
SELECT SUM(IIF(subscription_type = 'Subscriber', 1, 0)) - SUM(IIF(subscription_type = 'Customer', 1, 0)) FROM trip WHERE start_date LIKE '6/%/2013%'
/* example clarification question*/
What predicate value should be used to determine a trip in June 2013? a) start_data > 06/2013, b) start_data = ‘June 2013’, c) other (please specify).
/* example reasoning */
The gold sql query uses start_date LIKE '6/%/2013%' to find trips in June 2013.
/* example  answer*/
answer_to_cq = "c) start_date LIKE '6/%/2013%'"


/* example question: */
Identify the players who weigh 120 kg.
/* example gold sql query*/
SELECT T2.PlayerName FROM weight_info AS T1 INNER JOIN PlayerInfo AS T2 ON T1.weight_id = T2.weight WHERE T1.weight_in_kg = 120
/* example clarification question*/
What fields should be contained in the output? a) one column of player name, b) one column of player id, c) two columns of player name and player ids, d) other (please specify).
/* example reasoning */
The gold query selects ‘SELECT T2.PlayerName’. Hence, a is correct.
/* example  answer*/
answer_to_cq = "a) one column of player name"

/* example question: */
How many reviews are created for the podcast "Scaling Global" under?
/* example gold sql query*/
SELECT COUNT(T2.content) FROM podcasts AS T1 INNER JOIN reviews AS T2 ON T2.podcast_id = T1.podcast_id WHERE T1.title = 'Scaling Global'
/* example clarification question*/
Which column represents the reviews? a) `podcast` column, b) `content` column, c) other (please specify).
/* example reasoning */
The gold query uses “COUNT(T2.content)” to determine the number of reviews. Hence, b is correct in which the `content` column represents the reviews.
/* example  answer*/
answer_to_cq = "b) `content` column"
\n\n
'''


feedback_v2 = """/* Given the following Natural Language Question: */
{nlq}
/* And the following Gold Query: */
{query}
/* Answer the following multiple choice clarification question truthfully based on the Gold Query: */
{question}

/* Follow these steps:
1. Identify which portion of the Gold Query answers the clarification question.
2. Evaluate the correctness of each multiple choice answer based only on the Gold Query.
3. If none of the choices are correct or you select "other (please specify)", provide a short answer for the clarification question.
4. Output the final answer in the format: answer_to_cq = "".

Let’s proceed step by step. */
"""

In [77]:
cq_prefix_v1 = '''/* some examples are provided */
/* example question: */
Which artist/group is most productive?
/* example previous clarification questions and user replies: */
clarification questions: How to rank artist/group productivity? a) rank by the number of records produced, b) rank by the total number of downloads, c) other (please specify).
user: b) rank by the total number of downloads```
/* example reasoning and remaining ambiguity type*/
It is clear that the SQL answer should use ORDER BY and LIMIT 1 based on the sum of total downloads. However, it is unclear what columns should be used to represent the 'artist/group'.  Both the `artist` and the `groupName` columns contain information about 'artist/group'. ’‘AmbTableColumn’ remains.
/* example clarification question */
mul_choice_cq = "Which columns represent the 'artist/group' information? a) the artist column only, b) the groupName column only, c) both the artist column and the groupName column, d) other (please specify).”```

/* example question: */
Which Premier League matches ended in a draw in 2016?
/* example previous clarification questions and user replies: */
clarification questions: Is the year '2016' referring to? a) season is 2016, b) season is either 2015/2016 or 2016/2017, c) the date time is at year 2016, d) other (specify).
user: a) season is 2016,
clarification questions: How to find the 'Premier league'? a) consider all leagues, b) consider only the league with name 'Premier League', c) other (specify).
user: b) consider only the league with name 'Premier League'
/* example reasoning and remaining ambiguity type*/
It is clear that the SQL answer to this question needs to contain a WHERE clause for three conditions: 'Premier League', 'draw', and 'in 2016'. However, the question did not specify what fields should be contained in the output table. 'AmbOutput' remains.
/* example clarification question */
mul_choice_cq = “What fields represent the target 'matches'? a) all fields from football data table, b) the `league` column, c) other (specify).”

/* example question: */
Which type of crime has the highest rate of ‘Investigation complete’?
/* example previous clarification questions and user replies: */
No previous clarification questions.
/* example reasoning and remaining ambiguity type*/
It is clear that the SQL answer to this question needs to contain a WHERE clause to find crimes that have 'Investigation complete' outcomes, uses ORDER BY and LIMIT 1 to find the type of crime with the highest rate, and the output table has only one row. However, it needs to be clarified i) what predicate value should be used for 'Investigation complete', and ii) how to represent the 'rate', and iii) if the output table contains only the crime type column or the crime type column with the highest rate aggregate. Hence, this question is ambiguous because of 'AmbVal', 'AmbQuestion', and ‘AmbOutput’.
/* example clarification question */
mul_choice_cq = “What information should be used to find 'Investigation complete'? a) see if outcome contains the phrase 'Investigation complete', b)  see if outcome is 'Investigation complete; no suspect identified', c) other (please specify).”

/* example question: */
For award winners, which position has the most hall of fame players?
/* example previous clarification questions and user replies: */
clarification questions: How should the 'position' for players be identified? a) by the `award_id` column, b) by the `category` column, c) other (please specify).
user: c)  by the `note` column
/* example reasoning and remaining ambiguity type*/
It is clear that the answer should use the `note` column for player ‘positions’. However, it is unclear what fields should contain in the output table. Hence ‘AmbOutput’ remains.
/* example clarification question */
mul_choice_cq = “What fields should be contained in the output table? a) one field: the position, b) two fields: the position and the number of hall-of-fame players, c) other (please specify).”

/* example question: */
How many Wisconsin school districts receive federal funding?
/* example previous clarification questions and user replies: */
clarification question: How to determine if a district has received federal funding? a) based on the t_fed_rev is larger than 0, b) the answer does not need to consider this aspect, c) other (please specify).
user: c) every school in `FINREV_FED_17` table has received federal funding.
/* example reasoning and remaining ambiguity type*/
It is clear that every school in `FINREV_FED_17` table have received federal funding. However, it is unclear if the word ‘Wisconsin’ refers to the state or the school district. Hence, ‘AmbQuestion’ remains.
/* example clarification question */
mul_choice_cq = “Is 'Wisconsin school districts' referring to? a) all school districts in the state Wisconsin, b) school districts with names that contain Wisconsin, c) other (please specify).”

/* example question: */
How many 2-year public schools are there in "California"?
/* example previous clarification questions and user replies: */
clarification question: Which column(s) should be used to find ‘2-year public schools’? a) `level` column, b) `control` column, c) other (please specify).
user: c) use both `level` and `control` columns to find ‘2-year public schools’ information.
/* example reasoning and remaining ambiguity type*/
It is clear that the correct SQL answer should have a WHERE clause with filters based on the `level` and `control` columns. However, it is unclear what predicate values should be used for these two columns. Hence, ‘AmbValue’ remains.
/* example clarification question */
mul_choice_cq = “What predicate values should be used for the `level` and `control` columns to find  ‘2-year public schools’? a) ‘2-year’ and ‘public’ b) ‘2’ and ‘public, c) other (please specify)’.” 

/* example question: */
Calculate the total beat of the crimes reported in a community area in the central side with a population of 50,000 and above.
/* example previous clarification questions and user replies: */
clarification question: What column and predicate value should be used to determine ‘central side’? a) Column `side` in table `Community_Area` with value ‘central’, b) Column `side` in table `Community_Area` with value ‘Central’, c) other (please specify).
user: b) Column `side` in table `Community_Area` with value ‘Central’
/* example reasoning and remaining ambiguity type*/
It is clear that the output table should contain a single number and use the predicate ‘Central’ in `Community_Area`.`side`; However, it is not clear which column of statistics is ‘total beat’ referring to. Hence, AmbTableColumn remains.
/* example clarification question */
mul_choice_cq = “Which column is related to ‘total beat’? a) `Crime`.`beat`, b) `Crime`.`report_no`, c) other (please specify).”

/* example question: */
Of all the nonessential genes that are not of the motorprotein class and whose phenotype is cell cycle defects, how many do not have a physical type of interaction?
/* example previous clarification questions and user replies: */
No previous clarification questions.
/* example reasoning and remaining ambiguity type*/
It is clear that ‘phenotype’ is referring to the `Phenotype` column, ‘motorprotein class’ is referring to the `class` column, ‘nonessential genes’ is referring to the `essential` column, and `physical type` is referring to the `type` column. However, it is unclear what fields should be contained in the output table, and hence ‘AmbOutput’ remains.
/* example clarification question */
mul_choice_cq = “What fields should be included in the output table? a) One column for the number of genes b) Two columns for GeneID and physical type c) Other (please specify).”
\n\n
'''


SRA = """/* Ask the user a new multiple choice clarification question to help you find the correct SQL answer for the following question: */
{question}
/* Given the following database schema: */
{schema}
/* And the following incorrect sql answers: */
{sqls}
/* And the following previous clarification questions and user replies: */
{cqs}

/* Consider the following ambiguity categories:
    - AmbQuestion: Is the question itself ambiguous?
    - AmbTableColumn: Is there ambiguity in mapping the entities from the QUESTION to tables and columns in the DATABASE SCHEMA?
    - AmbOutput: What fields and how many fields should be included in the output table?
    - AmbValue: What predicate value should be used to filter results?
*/

/* The clarification question should be easy to understand for people with no coding experience. */

/* Let's think step by step to generate the helpful multiple choice clarification question.
1. Summarize the clear information based on previous clarification questions and incorrect queries.
2. Evaluate whether AmbQuestion, AmbTableColumn, AmbOutput, and AmbValue remain in formulating an SQL query, considering each category individually.
3. Ask a new multiple-choice question to address the remaining ambiguities and assist in identifying the correct SQL query. Use format: mul_choice_cq = "".
*/

"""


SRA_ES = """/* Ask the user a new multiple choice clarification question to help you find the correct SQL answer for the following question: */
{question}
/* Given the following database schema: */
{schema}
/* And the following incorrect sql answers: */
{sqls}
/* And the following previous clarification questions and user replies: */
{cqs}

/* Consider the following ambiguity categories:
    - AmbQuestion: Is the question itself ambiguous?
    - AmbTableColumn: Is there ambiguity in mapping the entities from the QUESTION to tables and columns in the DATABASE SCHEMA?
    - AmbOutput: What fields and how many fields should be included in the output table?
    - AmbValue: What predicate value should be used to filter results?
*/

/* The clarification question should be easy to understand for people with no coding experience. */

/* Let's think step by step to generate the helpful multiple choice clarification question.
1. Summarize the clear information based on previous clarification questions and incorrect queries.
2. Evaluate whether AmbQuestion, AmbTableColumn, AmbOutput, and AmbValue remain in formulating an SQL query, considering each category individually.
3. If no remaining ambiguities are identified, then output "NO AMBIGUITY".
   Else, ask a new multiple-choice question to address the remaining ambiguities and assist in identifying the correct SQL query. Use format: mul_choice_cq = "".
*/

"""

In [78]:
# DAIL SQLNoRule
sql_generation = '''/* Given the following database schema: */
{schema}

{metadata}
/* Answer the following with no explanation: {question} */
SELECT '''

# update to code representation
sql_generation_v2 = '''/* Given the following database schema: */
{schema}
/* And the following incorrect sql answers: */
{sqls}
/* And the following user replies to help you write the correct sql query: */
{cqas}

{metadata}
/* Answer the following with no explanation: {question} */
SELECT '''

# update to code representation
fix_invalid_v1 = """/* Given the following database schema: */
{schema}`
/* And the following inexecutable sql query */
{invalidSQL}
/* And the following exception message */
{ex}

/* Fix the exception and write a new executable SQL query with no explanation */
SELECT """

# update to code representation
sql_generation_selfdebug = '''/* Given the following database schema: */
{schema}
/* And the following incorrect sql answers: */
{sqls}

{metadata}
/* Answer the following with no explanation: {question} */
SELECT '''

In [79]:
selfdebug_examples_prefix = '''/* Given the following incorrect sql asnwers: */
SELECT creation, COUNT(*) FROM department GROUP BY creation ORDER BY
COUNT(*) DESC LIMIT 1
/* Answer the following with no explanation: In which year were most departments established? */
SELECT creation FROM department GROUP BY creation ORDER BY COUNT(*) DESC LIMIT 1
-------
/* Given the following incorrect sql asnwers: */
SELECT customers.customer_name FROM customers JOIN orders ON customers.customer_id = orders.customer_id WHERE orders.order_status = "On Road" AND orders.order_status = "Shipped"
/* Answer the following with no explanation: Which customers have both "On Road" and "Shipped" as order status? List the customer names. */
SELECT customers.customer_name FROM customers JOIN orders ON customers.customer_id = orders.customer_id WHERE orders.order_status = "On Road" INTERSECT SELECT customers.customer_name FROM customers JOIN orders ON customers.customer_id = orders.customer_id WHERE orders.order_status = "Shipped"
-------
/* Given the following incorrect sql asnwers: */
SELECT origin FROM flight WHERE destination = "HONO"
/* Answer the following with no explanation: Show origins of all flights with destination Honolulu. */
SELECT origin FROM flight WHERE destination = "Honolulu"
-------
/* Given the following incorrect sql asnwers: */
SELECT AVG(long) FROM station WHERE id IN (SELECT station_id FROM status WHERE bikes_available <= 10)
/* Answer the following with no explanation: What is the average longitude of stations that never had bike availability more than 10? */
SELECT origin FROM flight WHERE destination = "Honolulu"
SELECT AVG(long) FROM station WHERE id NOT IN (SELECT station_id FROM status WHERE bikes_available > 10)
-------
/* Given the following incorrect sql asnwers: */
SELECT name, nationality FROM host WHERE age = (SELECT MIN(age) FROM host)
/* Answer the following with no explanation: Show the name and the nationality of the oldest host. */
SELECT name, nationality FROM host ORDER BY age DESC LIMIT 1
-------
/* Given the following incorrect sql asnwers: */
SELECT COUNT(status) FROM city
/* How many different statuses do cities have? */
SELECT COUNT(DISTINCT status) FROM city
-------'''
selfdebug_examples = selfdebug_examples_prefix.split('-------')

selfdebug_few_shot = []
for i in range(1,7):
    prefix = []
    for j in range(i):
        prefix.append(selfdebug_examples[j])
    selfdebug_few_shot.append('\n'.join(prefix))


In [80]:
fewshot_prefix = "/* some examples are provided */\n"

In [81]:
# embeddings = OpenAIEmbeddings()
# # generate vectorstore for userstudy
# examples = []
# for study in userstudy:
#     t = {'nl':study['nl'], 'gold':study['gold']}
#     feedback = ""
#     if 'Question2Ask' in study:
#         for q, a in zip(study['Question2Ask'], study['Answer2Question']):
#             feedback += "multiple choice clarification question: "+q+'\nuser: '+a+'\n'
#     t['feedback'] = feedback
#     examples.append(t)
# to_vectorize = [example['nl'] for example in examples]
# userstudy_vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples,  persist_directory="./userstudy_chroma")


In [82]:
embeddings = OpenAIEmbeddings()
userstudy_vectorstore = Chroma(persist_directory="./userstudy_chroma", embedding_function=embeddings)

In [83]:
def baselineFewShot(data_frame, history_log, log_name, rounds, num_of_tests, model_name, vectorstore, num_examples, data_source, with_metadata):
    assert model_name in ['gpt35turbo', 'gpt4turbo']
    assert num_examples < len(selfdebug_few_shot)
    generation = None
    if model_name == 'gpt35turbo':
        generation = GPT35_generation
    else:
        generation = GPT4_turbo_generation
    
    for index in range(num_of_tests):
        if index in history_log and "num_cq_asked" in history_log[index]:
            # skip tests already seen
            continue
        
        assert index in history_log
        assert len(history_log[index]['sql_log']) == 1
        
        d = data_frame.iloc[[index]] 
        cqs_and_answers = []
        evidence = ''
        query = set()
        if data_source == 'kaggle':
            gold = d['sql'].values[0]
            dbname = d['target_db'].values[0]
            nlq = d['nl'].values[0]
            dbschema = d['target_schema'].values[0]
        elif data_source == 'bird':
            gold = d['SQL'].values[0]
            nlq = d['question'].values[0]
            dbname = d['db_id'].values[0]
            dbschema = generate_db_schema(dbname)
            if with_metadata:
                evidence = d['evidence'].values[0]
        print(nlq, index, dbname, evidence)
        
        order, sql_prompt, sql_query, pscore = history_log[index]['sql_log'][0]
        order += 1
        sql_query = clean_query(sql_query)
        query.add(sql_query)
        execution, exception = evalfunc(sql_query, gold, dbname, data_source)
        
        if exception:
            most_recent_sql = clean_query(history_log[index]['sql_log'][-1][2])
            query.remove(most_recent_sql)
            invalid_prompt = fix_invalid_v1.format(schema=dbschema, question=nlq,\
                                                 invalidSQL=most_recent_sql, ex=exception[0])
            sql, pscore= generation(invalid_prompt)
            sql = clean_query(sql)
            history_log[index]['sql_log'].append((order, invalid_prompt, sql, pscore))
            order += 1
            query.add(sql)
            execution, _ = evalfunc(sql, gold, dbname, data_source)
        if execution:
            history_log[index]['num_cq_asked'] = 0
#             print()
#             print("-----execution match-----")
#             print()
            continue
            
        for turn in range(rounds):
            sql_prompt = sql_generation_selfdebug.format(schema=dbschema, question=nlq,\
                                              sqls=";\n".join(query), metadata=evidence)
            sql_prompt = fewshot_prefix + selfdebug_few_shot[num_examples-1] + sql_prompt
            sql_query, pscore= generation(sql_prompt)
            history_log[index]['sql_log'].append((order, sql_prompt, sql_query, pscore))
            order += 1
            sql_query = clean_query(sql_query)
            query.add(sql_query)
            execution, exception = evalfunc(sql_query, gold, dbname, data_source)
            if exception:
                most_recent_sql = clean_query(history_log[index]['sql_log'][-1][2])
                query.remove(most_recent_sql)
                invalid_prompt = fix_invalid_v1.format(schema=dbschema, question=nlq,\
                                                     invalidSQL=most_recent_sql, ex=exception[0])
                sql, pscore= generation(invalid_prompt)
                sql = clean_query(sql)
                query.add(sql)
                history_log[index]['sql_log'].append((order, invalid_prompt, sql, pscore))
                order += 1
                execution, _ = evalfunc(sql, gold, dbname, data_source)
            if execution:
                history_log[index]['num_cq_asked'] = turn + 1
#                 print()
#                 print("********execution match*********")
#                 print()
                break
        if 'num_cq_asked' not in history_log[index]:
            history_log[index]['num_cq_asked'] = "Failed"
#         print('')
#         print("------next question------")
#         print('')
    save(log_name, history_log)

In [84]:
def askClarificationQuestions(data_frame, history_log, log_name, rounds, num_of_tests, model_name, vectorstore_feedback, num_examples, data_source, with_metadata):    
    assert model_name in ['gpt35turbo', 'gpt4turbo']
    generation = None
    if model_name == 'gpt35turbo':
        generation = GPT35_generation
    else:
        generation = GPT4_turbo_generation
    
    feedback_example_selector = SemanticSimilarityExampleSelector(
        vectorstore=vectorstore_feedback,
        k=num_examples,
    )
    
    feedback_example_prompt = PromptTemplate(
        input_variables=['nl', 'gold', 'feedback'],
        template="/* Given the following user feedback on clarification questions */\n{feedback}\n/* Answer the following with no explanation: {nl} */\n{gold}",
    )
    sql_generation_feedback_few_shot_prompt = FewShotPromptTemplate(
        example_selector=feedback_example_selector,
        example_prompt=feedback_example_prompt ,
        suffix=sql_generation_v2,
        input_variables=["question", "schema", "sqls", "cqas", "metadata"],
    )
    
    for index in range(num_of_tests):
        if index in history_log and "num_cq_asked" in history_log[index]:
            # skip tests already seen
            print('skip')
            continue
        
        assert index in history_log
        assert len(history_log[index]['sql_log']) == 1
        
        d = data_frame.iloc[[index]] 
        cqs_and_answers = []
        evidence = ''
        query = set()
        if data_source == 'kaggle':
            gold = d['sql'].values[0]
            dbname = d['target_db'].values[0]
            nlq = d['nl'].values[0]
            dbschema = d['target_schema'].values[0]
        elif data_source == 'bird':
            gold = d['SQL'].values[0]
            nlq = d['question'].values[0]
            dbname = d['db_id'].values[0]
            dbschema = generate_db_schema(dbname)
            if with_metadata:
                evidence = d['evidence'].values[0]
        print("nl: ", nlq, index, dbname, evidence)
        
        order, sql_prompt, sql_query, pscore = history_log[index]['sql_log'][0]
        order += 1
        sql_query = clean_query(sql_query)
#         print("sql: ", sql_query, pscore)
        query.add(sql_query)
        execution, exception = evalfunc(sql_query, gold, dbname, data_source)
        
        if exception:
            most_recent_sql = clean_query(history_log[index]['sql_log'][-1][2])
            query.remove(most_recent_sql)
            invalid_prompt = fix_invalid_v1.format(schema=dbschema, question=nlq,\
                                                 invalidSQL=most_recent_sql, ex=exception[0])
            sql, pscore= generation(invalid_prompt)
            sql = clean_query(sql)
            history_log[index]['sql_log'].append((order, invalid_prompt, sql, pscore))
            order += 1
            query.add(sql)
            execution, _ = evalfunc(sql, gold, dbname, data_source)
            
            
        if execution:
            history_log[index]['num_cq_asked'] = 0
#             print()
#             print("-----execution match-----")
#             print()
            continue
            
        for turn in range(rounds):
            cqas = ""
            if with_metadata:
                cqas = 'user: ' + evidence + '\n'
            for i in range(len(cqs_and_answers)):
                if i%2 == 0:
                    cqas += "multiple choice clarification question: "+cqs_and_answers[i]+'\n'
                else:
                    cqas += "user: "+cqs_and_answers[i]+'\n'
            if cqas == "":
                cqas = "no previous clarification question.\n"
            cq_prompt = SRA.format(schema=dbschema, question=nlq,\
                                            sqls=";\n".join(query), cqs=cqas)
            cq_prompt = cq_prefix_v1 + cq_prompt
            cq, pscore= generation(cq_prompt)
            history_log[index]['cq_log'].append((order, cq_prompt, cq, pscore))
            order += 1
            if "mul_choice_cq = " in cq:
                cq = cq.split("mul_choice_cq = ")[-1]
#             print("cq: ", cq)
            feedback_prompt = feedback_v2.format(query = gold, question = cq, nlq=nlq)
            feedback_prompt = feedback_prefix_v1 + feedback_prompt
            feedback, pscore= GPT4o_generation(feedback_prompt)

            history_log[index]['feedback_log'].append((order, feedback_prompt, feedback, pscore))
            order += 1
            if "answer_to_cq =" in feedback:
                feedback = feedback.split("answer_to_cq =")[-1]
#             print()
#             print("feedback, ", feedback)
            cqs_and_answers.append(cq)
            cqs_and_answers.append(feedback)
        
            # fix incorrect sql based on user feedback
            cqas = ""
            for i in range(len(cqs_and_answers)):
                if i%2 == 0:
                    cqas += "multiple choice clarification question: "+cqs_and_answers[i]+'\n'
                else:
                    cqas += "user: "+cqs_and_answers[i]+'\n'
            if cqas == '':
                cqas = 'no previous clarification questions are asked.\n'
            
            # use examples from the user study
            sql_prompt = sql_generation_feedback_few_shot_prompt.format(schema=dbschema, question=nlq,\
                                              sqls=";\n".join(query), cqas=cqas, metadata=evidence)
            sql_prompt = "/* some examples are provided */\n" + sql_prompt
            sql_query, pscore= generation(sql_prompt)
            sql_query = clean_query(sql_query)
#             print("sql: ", sql_query, pscore)
            history_log[index]['sql_log'].append((order, sql_prompt, sql_query, pscore))
            order += 1
            sql_query = clean_query(sql_query)
            query.add(sql_query)
            execution, exception = evalfunc(sql_query, gold, dbname, data_source)
            if exception:
                most_recent_sql = clean_query(history_log[index]['sql_log'][-1][2])
                query.remove(most_recent_sql)
                invalid_prompt = fix_invalid_v1.format(schema=dbschema, question=nlq,\
                                                     invalidSQL=most_recent_sql, ex=exception[0])
                sql, pscore= generation(invalid_prompt)
                sql = clean_query(sql)
                query.add(sql)
                history_log[index]['sql_log'].append((order, invalid_prompt, sql, pscore))
                order += 1
                execution, _ = evalfunc(sql, gold, dbname, data_source)
            if execution:
                history_log[index]['num_cq_asked'] = turn + 1
#                 print()
#                 print("********execution match*********")
#                 print()
                break
                
        if 'num_cq_asked' not in history_log[index]:
            history_log[index]['num_cq_asked'] = "Failed"
#         print('')
#         print("------next question------")
#         print('')
    save(log_name, history_log)

In [85]:
def askCQsBreakNoAmb(data_frame, history_log, log_name, rounds, num_of_tests, model_name, vectorstore_feedback, num_examples, data_source, with_metadata):    
    assert model_name in ['gpt35turbo', 'gpt4turbo']
    generation = None
    if model_name == 'gpt35turbo':
        generation = GPT35_generation
    else:
        generation = GPT4_turbo_generation
    

    feedback_example_selector = SemanticSimilarityExampleSelector(
        vectorstore=vectorstore_feedback,
        k=num_examples,
    )
    feedback_example_prompt = PromptTemplate(
        input_variables=['nl', 'gold', 'feedback'],
        template="\nExample Question: {nl}\nExample Feedback:{feedback}\nExample Answer: {gold}",
    )

    sql_generation_feedback_few_shot_prompt = FewShotPromptTemplate(
        example_selector=feedback_example_selector,
        example_prompt=feedback_example_prompt ,
        suffix=sql_generation_v2,
        input_variables=["question", "schema", "sqls", "cqas", "metadata"],
    )
    
    for index in range(num_of_tests):
        if index in history_log and "num_cq_asked" in history_log[index]:
            # skip tests already seen
            continue
        assert index in history_log
        assert len(history_log[index]['sql_log']) == 1
        
        d = data_frame.iloc[[index]] 
        cqs_and_answers = []
        evidence = ''
        query = set()
        if data_source == 'kaggle':
            gold = d['sql'].values[0]
            dbname = d['target_db'].values[0]
            nlq = d['nl'].values[0]
            dbschema = d['target_schema'].values[0]
        elif data_source == 'bird':
            gold = d['SQL'].values[0]
            nlq = d['question'].values[0]
            dbname = d['db_id'].values[0]
            dbschema = generate_db_schema(dbname)
            if with_metadata:
                evidence = d['evidence'].values[0]
        print("nl: ", nlq, index, dbname, evidence)
        
        order, sql_prompt, sql_query, pscore = history_log[index]['sql_log'][0]
        order += 1
        sql_query = clean_query(sql_query)
        #print("sql: ", sql_query, pscore)
        query.add(sql_query)
        execution, exception = evalfunc(sql_query, gold, dbname, data_source)
        
        if exception:
            most_recent_sql = clean_query(history_log[index]['sql_log'][-1][2])
            query.remove(most_recent_sql)
            invalid_prompt = fix_invalid_v1.format(schema=dbschema, question=nlq,\
                                                 invalidSQL=most_recent_sql, ex=exception[0])
            sql, pscore= generation(invalid_prompt)
            sql = clean_query(sql)
            history_log[index]['sql_log'].append((order, invalid_prompt, sql, pscore))
            order += 1
            query.add(sql)
            execution, _ = evalfunc(sql, gold, dbname, data_source)
            
        if execution:
            history_log[index]['num_cq_asked'] = 0
            #print()
            #print("-----execution match-----")
            #print()
            continue
            
            
        for turn in range(rounds):
            cqas = ""
            if with_metadata:
                cqas = 'user: ' + evidence + '\n'
            for i in range(len(cqs_and_answers)):
                if i%2 == 0:
                    cqas += "multiple choice clarification question: "+cqs_and_answers[i]+'\n'
                else:
                    cqas += "user: "+cqs_and_answers[i]+'\n'
            if cqas == "":
                cqas = "no previous clarification question.\n"
            cq_prompt = SRA_ES.format(schema=dbschema, question=nlq,\
                                            sqls=";\n".join(query), cqs=cqas)
            cq_prompt = cq_prefix_v1 + cq_prompt
            cq, pscore= generation(cq_prompt)
            history_log[index]['cq_log'].append((order, cq_prompt, cq, pscore))
            order += 1
            if "NO AMBIGUITY" in cq:
#                 print()
#                 print("-----NO AMBGUITY-----")
#                 print()
                break
            if "mul_choice_cq = " in cq:
                cq = cq.split("mul_choice_cq = ")[-1]
#             print("cq: ", cq)
            feedback_prompt = feedback_v2.format(query = gold, question = cq, nlq=nlq)
            feedback_prompt = feedback_prefix_v1 + feedback_prompt
            feedback, pscore= GPT4o_generation(feedback_prompt)

            history_log[index]['feedback_log'].append((order, feedback_prompt, feedback, pscore))
            order += 1
            if "answer_to_cq =" in feedback:
                feedback = feedback.split("answer_to_cq =")[-1]
#             print()
#             print("feedback, ", feedback)
            cqs_and_answers.append(cq)
            cqs_and_answers.append(feedback)
        
            # fix incorrect sql based on user feedback
            cqas = ""
            for i in range(len(cqs_and_answers)):
                if i%2 == 0:
                    cqas += "multiple choice clarification question: "+cqs_and_answers[i]+'\n'
                else:
                    cqas += "user: "+cqs_and_answers[i]+'\n'
            if cqas == '':
                cqas = 'no previous clarification questions are asked.\n'
            
            # use examples from the user study
            sql_prompt = sql_generation_feedback_few_shot_prompt.format(schema=dbschema, question=nlq,\
                                              sqls=";\n".join(query), cqas=cqas, metadata=evidence)
            sql_prompt = "/* some examples are provided */\n" + sql_prompt
            sql_query, pscore= generation(sql_prompt)
            sql_query = clean_query(sql_query)
            #print("sql: ", sql_query, pscore)
            history_log[index]['sql_log'].append((order, sql_prompt, sql_query, pscore))
            order += 1
            sql_query = clean_query(sql_query)
            query.add(sql_query)
            execution, exception = evalfunc(sql_query, gold, dbname, data_source)
            if exception:
                most_recent_sql = clean_query(history_log[index]['sql_log'][-1][2])
                query.remove(most_recent_sql)
                invalid_prompt = fix_invalid_v1.format(schema=dbschema, question=nlq,\
                                                     invalidSQL=most_recent_sql, ex=exception[0])
                sql, pscore= generation(invalid_prompt)
                sql = clean_query(sql)
                query.add(sql)
                history_log[index]['sql_log'].append((order, invalid_prompt, sql, pscore))
                order += 1
                execution, _ = evalfunc(sql, gold, dbname, data_source)
            if execution:
                history_log[index]['num_cq_asked'] = turn + 1
                break
                
        if 'num_cq_asked' not in history_log[index]:
            history_log[index]['num_cq_asked'] = "Failed"
#         print('')
#         print("------next question------")
#         print('')
    save(log_name, history_log)

In [None]:
def visualize_results(res_m1_zero, res_m1_few, res_m2_zero, res_m2_few, res_m3_zero, res_m3_few, test_subset=None):
    print("\n--- Visualization of Results ---")
    
    # 1. Construct unified data structure & Save to JSON
    data_map = {
        'M1_Zero': res_m1_zero, 'M1_Few': res_m1_few,
        'M2_Zero': res_m2_zero, 'M2_Few': res_m2_few,
        'M3_Zero': res_m3_zero, 'M3_Few': res_m3_few
    }
    
    json_data = []
    for key, df in data_map.items():
        if df.empty: continue
        method, mode = key.split('_')
        records = df.to_dict(orient='records')
        for r in records:
            r['Method'] = method
            r['Mode'] = mode
            # Calculate status for stacked bar
            if r['is_correct']:
                if r.get('rounds', 0) == 0:
                    r['Status'] = 'Initial Correct'
                else:
                    r['Status'] = 'Fixed Correct'
            else:
                r['Status'] = 'Incorrect'
        json_data.extend(records)
        
    json_path = 'experiment_results.json'
    with open(json_path, 'w') as f:
        json.dump(json_data, f, indent=2, default=str)
    print(f"Results saved to {json_path}")
    
    # 2. Read from JSON for plotting (as requested)
    print(f"Reading results from {json_path} for plotting...")
    with open(json_path, 'r') as f:
        loaded_data = json.load(f)
    
    full_df = pd.DataFrame(loaded_data)
    
    if full_df.empty:
        print("No data to visualize.")
        return

    # Set Academic Style
    sns.set_theme(style="white")
    sns.set_context("paper", font_scale=1.2)
    
    # Define Palettes: M1=Blue, M2=Orange, M3=Green
    palette_dict = {'M1': '#4E79A7', 'M2': '#F28E2B', 'M3': '#59A14F'}
    
    # --- Figure 1: 2x2 Grid (Accuracy & Rounds) ---
    fig1, axes1 = plt.subplots(2, 2, figsize=(14, 10))
    fig1.suptitle('Performance Overview', fontsize=16, fontweight='bold')
    
    modes = ['Zero', 'Few']
    agg_df = full_df.groupby(['Method', 'Mode']).agg(
        Accuracy=('is_correct', 'mean'),
        Avg_Rounds=('rounds', 'mean')
    ).reset_index()
    
    for i, mode in enumerate(modes):
        # Top Row: Accuracy
        ax_acc = axes1[0, i]
        subset = agg_df[agg_df['Mode'] == mode]
        if not subset.empty:
            sns.barplot(data=subset, x='Method', y='Accuracy', palette=palette_dict, ax=ax_acc)
            ax_acc.set_title(f'Accuracy ({mode}-shot)')
            ax_acc.set_ylim(0, 1.1)
            ax_acc.set_ylabel('Success Rate' if i==0 else '')
            for p in ax_acc.patches:
                ax_acc.annotate(f'{p.get_height():.2f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                                ha='center', va='center', xytext=(0, 5), textcoords='offset points', fontweight='bold')
    
        # Bottom Row: Rounds
        ax_rds = axes1[1, i]
        subset = agg_df[agg_df['Mode'] == mode]
        if not subset.empty:
            sns.barplot(data=subset, x='Method', y='Avg_Rounds', palette=palette_dict, ax=ax_rds)
            ax_rds.set_title(f'Avg Interaction Rounds ({mode}-shot)')
            ax_rds.set_ylabel('Rounds' if i==0 else '')
            # Add values on top
            max_h = 0
            for p in ax_rds.patches:
                h = p.get_height()
                if h > max_h: max_h = h
                ax_rds.annotate(f'{h:.2f}', (p.get_x() + p.get_width() / 2., h),
                                ha='center', va='bottom', fontsize=10, fontweight='bold')
            ax_rds.set_ylim(0, max_h * 1.2 if max_h > 0 else 1) 

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    
    # --- Figure 2: Stacked Bar Chart (Correctness Breakdown) ---
    breakdown_list = []
    for (method, mode), group in full_df.groupby(['Method', 'Mode']):
        total = len(group)
        initial = len(group[group['Status'] == 'Initial Correct'])
        fixed = len(group[group['Status'] == 'Fixed Correct'])
        incorrect = len(group[group['Status'] == 'Incorrect'])
        
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Initial Correct', 'Prop': initial/total})
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Fixed Correct', 'Prop': fixed/total})
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Incorrect', 'Prop': incorrect/total})
        
    breakdown_df = pd.DataFrame(breakdown_list)
    status_palette = {'Initial Correct': '#1f77b4', 'Fixed Correct': '#2ca02c', 'Incorrect': '#d62728'}
    
    if not breakdown_df.empty:
        fig2, axes2 = plt.subplots(1, 2, figsize=(14, 6))
        fig2.suptitle('Correctness Breakdown', fontsize=16, fontweight='bold')
        
        for i, mode in enumerate(modes):
            ax = axes2[i]
            subset = breakdown_df[breakdown_df['Mode'] == mode]
            if not subset.empty:
                pivot_df = subset.pivot(index='Method', columns='Type', values='Prop').fillna(0)
                # Ensure all columns exist
                for col in ['Initial Correct', 'Fixed Correct', 'Incorrect']:
                    if col not in pivot_df.columns: pivot_df[col] = 0
                pivot_df = pivot_df[['Initial Correct', 'Fixed Correct', 'Incorrect']]
                
                pivot_df.plot(kind='bar', stacked=True, color=[status_palette[c] for c in pivot_df.columns], ax=ax)
                ax.set_title(f'Breakdown ({mode}-shot)')
                ax.set_ylabel('Proportion' if i==0 else '')
                ax.set_ylim(0, 1.0)
                if i == 1:
                    ax.legend(title='Status', bbox_to_anchor=(1.05, 1), loc='upper left')
                else:
                    ax.get_legend().remove()

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

    # --- Figure 3: Difficulty Distribution ---
    if test_subset is not None and 'difficulty' in test_subset.columns:
        plt.figure(figsize=(6, 6))
        diff_counts = test_subset['difficulty'].value_counts()
        colors_map = {'Hard': '#ff9999', 'Medium': '#ffff99', 'Simple': '#66b3ff'}
        pie_colors = [colors_map.get(l, '#cccccc') for l in diff_counts.index]
        
        plt.pie(diff_counts, labels=diff_counts.index, autopct='%1.1f%%', startangle=140, colors=pie_colors)
        plt.title('Test Sample Difficulty Distribution', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()


# --- Function Definition (Inserted) ---
source = 'kaggle'
modelname = 'gpt4turbo' 
k_shot = 3
with_metadata = True

# Define generation function
if modelname == 'gpt35turbo':
    generation = GPT35_generation
elif modelname == 'gpt4turbo':
    generation = GPT4_turbo_generation
else:
    print('error')

import copy
import pandas as pd

def log_to_df(history_log):
    results = []
    for idx, data in history_log.items():
        rounds = data.get('num_cq_asked', 'Failed')
        is_correct = rounds != 'Failed'
        if not is_correct:
            rounds = 4 # max rounds assumption
        results.append({'id': idx, 'rounds': rounds, 'is_correct': is_correct})
    return pd.DataFrame(results)

# Helper to generate initial logs (replacing batch file reading)
def generate_initial_logs(dataframe, limit, generation_func, k_shot, with_metadata):
    logs = {}
    print(f"Generating initial SQLs for {limit} samples...")
    for index in range(min(limit, len(dataframe))):
        d = dataframe.iloc[index]
        nlq = d['nl']
        dbname = d['target_db']
        dbschema = d['target_schema']
        evidence = d.get('evidence', '') if with_metadata else ''
        
        # Construct prompt (Simple Few-Shot or Zero-Shot)
        # Using a basic prompt structure similar to what might be expected
        prompt = f"/* Given the following database schema: */\n{dbschema}\n"
        if with_metadata and evidence:
            prompt += f"{evidence}\n"
        prompt += f"/* Answer the following: {nlq} */\nSELECT "
        
        # Generate
        try:
            sql_query, pscore = generation_func(prompt)
            sql_query = clean_query(sql_query)
        except Exception as e:
            print(f"Error generating for {index}: {e}")
            sql_query = "SELECT * FROM table" # Fallback
            pscore = 0.0
            
        logs[index] = {}
        logs[index]['sql_log'] = [[0, 'initial generation', sql_query, pscore]]
        logs[index]['cq_log'] = []
        logs[index]['feedback_log'] = []
    return logs

# Main Execution Parameters
limit = 5 # Set to a small number for testing/demo. Increase to len(df) for full run.

# Generate Base Logs once
base_history_log = generate_initial_logs(df, limit, generation, k_shot, with_metadata)

# Method 1: Baseline
print("Running Method 1 (Baseline)...")
log_m1 = copy.deepcopy(base_history_log)
baselineFewShot(df, log_m1, f'{with_metadata}_{source}_{modelname}_baseline.pkl', 4, limit, modelname, userstudy_vectorstore, k_shot, source, with_metadata)
res_m1_few = log_to_df(log_m1)

# Method 2: Clarification Questions
print("Running Method 2 (Clarification)...")
log_m2 = copy.deepcopy(base_history_log)
askClarificationQuestions(df, log_m2, f'{with_metadata}_{source}_{modelname}_cq.pkl', 4, limit, modelname, userstudy_vectorstore, k_shot, source, with_metadata)
res_m2_few = log_to_df(log_m2)

# Method 3: Break No Ambiguity
print("Running Method 3 (Break No Amb)...")
log_m3 = copy.deepcopy(base_history_log)
askCQsBreakNoAmb(df, log_m3, f'{with_metadata}_{source}_{modelname}_break_cq.pkl', 4, limit, modelname, userstudy_vectorstore, k_shot, source, with_metadata)
res_m3_few = log_to_df(log_m3)

# Visualization
visualize_results(pd.DataFrame(), res_m1_few, pd.DataFrame(), res_m2_few, pd.DataFrame(), res_m3_few)


Generating initial SQLs for 5 samples...
Running Method 1 (Baseline)...
Name the most popular release on houston. 0 WhatCDHipHop 
Which albums have been downloaded more than 100 times? 1 WhatCDHipHop 
Find me top 5 most popular releases after 2000? 2 WhatCDHipHop 
what release types are captured in this data set? 3 WhatCDHipHop 
which tags exist? 4 WhatCDHipHop 
Running Method 2 (Clarification)...
nl:  Name the most popular release on houston. 0 WhatCDHipHop 
nl:  Which albums have been downloaded more than 100 times? 1 WhatCDHipHop 
nl:  Find me top 5 most popular releases after 2000? 2 WhatCDHipHop 
nl:  what release types are captured in this data set? 3 WhatCDHipHop 
nl:  which tags exist? 4 WhatCDHipHop 
Running Method 3 (Break No Amb)...
nl:  Name the most popular release on houston. 0 WhatCDHipHop 
nl:  Which albums have been downloaded more than 100 times? 1 WhatCDHipHop 
nl:  Find me top 5 most popular releases after 2000? 2 WhatCDHipHop 
nl:  what release types are captured i

NameError: name 'visualize_results' is not defined

In [None]:

import matplotlib.pyplot as plt
import seaborn as sns
import json
import pandas as pd

def visualize_results(res_m1_zero, res_m1_few, res_m2_zero, res_m2_few, res_m3_zero, res_m3_few, test_subset=None):
    print("\n--- Visualization of Results ---")
    
    # 1. Construct unified data structure & Save to JSON
    data_map = {
        'M1_Zero': res_m1_zero, 'M1_Few': res_m1_few,
        'M2_Zero': res_m2_zero, 'M2_Few': res_m2_few,
        'M3_Zero': res_m3_zero, 'M3_Few': res_m3_few
    }
    
    json_data = []
    for key, df in data_map.items():
        if df.empty: continue
        method, mode = key.split('_')
        records = df.to_dict(orient='records')
        for r in records:
            r['Method'] = method
            r['Mode'] = mode
            # Calculate status for stacked bar
            if r['is_correct']:
                if r.get('rounds', 0) == 0:
                    r['Status'] = 'Initial Correct'
                else:
                    r['Status'] = 'Fixed Correct'
            else:
                r['Status'] = 'Incorrect'
        json_data.extend(records)
        
    json_path = 'experiment_results.json'
    with open(json_path, 'w') as f:
        json.dump(json_data, f, indent=2, default=str)
    print(f"Results saved to {json_path}")
    
    # 2. Read from JSON for plotting (as requested)
    print(f"Reading results from {json_path} for plotting...")
    with open(json_path, 'r') as f:
        loaded_data = json.load(f)
    
    full_df = pd.DataFrame(loaded_data)
    
    if full_df.empty:
        print("No data to visualize.")
        return

    # Set Academic Style
    sns.set_theme(style="white")
    sns.set_context("paper", font_scale=1.2)
    
    # Define Palettes: M1=Blue, M2=Orange, M3=Green
    palette_dict = {'M1': '#4E79A7', 'M2': '#F28E2B', 'M3': '#59A14F'}
    
    # --- Figure 1: 2x2 Grid (Accuracy & Rounds) ---
    fig1, axes1 = plt.subplots(2, 2, figsize=(14, 10))
    fig1.suptitle('Performance Overview', fontsize=16, fontweight='bold')
    
    modes = ['Zero', 'Few']
    agg_df = full_df.groupby(['Method', 'Mode']).agg(
        Accuracy=('is_correct', 'mean'),
        Avg_Rounds=('rounds', 'mean')
    ).reset_index()
    
    for i, mode in enumerate(modes):
        # Top Row: Accuracy
        ax_acc = axes1[0, i]
        subset = agg_df[agg_df['Mode'] == mode]
        if not subset.empty:
            sns.barplot(data=subset, x='Method', y='Accuracy', palette=palette_dict, ax=ax_acc)
            ax_acc.set_title(f'Accuracy ({mode}-shot)')
            ax_acc.set_ylim(0, 1.1)
            ax_acc.set_ylabel('Success Rate' if i==0 else '')
            for p in ax_acc.patches:
                ax_acc.annotate(f'{p.get_height():.2f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                                ha='center', va='center', xytext=(0, 5), textcoords='offset points', fontweight='bold')

        # Bottom Row: Rounds
        ax_rds = axes1[1, i]
        subset = agg_df[agg_df['Mode'] == mode]
        if not subset.empty:
            sns.barplot(data=subset, x='Method', y='Avg_Rounds', palette=palette_dict, ax=ax_rds)
            ax_rds.set_title(f'Avg Interaction Rounds ({mode}-shot)')
            ax_rds.set_ylabel('Rounds' if i==0 else '')
            # Add values on top
            max_h = 0
            for p in ax_rds.patches:
                h = p.get_height()
                if h > max_h: max_h = h
                ax_rds.annotate(f'{h:.2f}', (p.get_x() + p.get_width() / 2., h),
                                ha='center', va='bottom', fontsize=10, fontweight='bold')
            ax_rds.set_ylim(0, max_h * 1.2 if max_h > 0 else 1) 

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    
    # --- Figure 2: Stacked Bar Chart (Correctness Breakdown) ---
    breakdown_list = []
    for (method, mode), group in full_df.groupby(['Method', 'Mode']):
        total = len(group)
        initial = len(group[group['Status'] == 'Initial Correct'])
        fixed = len(group[group['Status'] == 'Fixed Correct'])
        incorrect = len(group[group['Status'] == 'Incorrect'])
        
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Initial Correct', 'Prop': initial/total})
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Fixed Correct', 'Prop': fixed/total})
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Incorrect', 'Prop': incorrect/total})
        
    breakdown_df = pd.DataFrame(breakdown_list)
    status_palette = {'Initial Correct': '#2ca02c', 'Fixed Correct': '#1f77b4', 'Incorrect': '#d62728'}
    
    if not breakdown_df.empty:
        fig2, axes2 = plt.subplots(1, 2, figsize=(14, 6))
        fig2.suptitle('Correctness Breakdown', fontsize=16, fontweight='bold')
        
        for i, mode in enumerate(modes):
            ax = axes2[i]
            subset = breakdown_df[breakdown_df['Mode'] == mode]
            if not subset.empty:
                pivot_df = subset.pivot(index='Method', columns='Type', values='Prop').fillna(0)
                # Ensure all columns exist
                for col in ['Initial Correct', 'Fixed Correct', 'Incorrect']:
                    if col not in pivot_df.columns: pivot_df[col] = 0
                pivot_df = pivot_df[['Initial Correct', 'Fixed Correct', 'Incorrect']]
                
                pivot_df.plot(kind='bar', stacked=True, color=[status_palette[c] for c in pivot_df.columns], ax=ax)
                ax.set_title(f'Breakdown ({mode}-shot)')
                ax.set_ylabel('Proportion' if i==0 else '')
                ax.set_ylim(0, 1.0)
                if i == 1:
                    ax.legend(title='Status', bbox_to_anchor=(1.05, 1), loc='upper left')
                else:
                    ax.get_legend().remove()

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

    # --- Figure 3: Difficulty Distribution ---
    if test_subset is not None and 'difficulty' in test_subset.columns:
        plt.figure(figsize=(6, 6))
        diff_counts = test_subset['difficulty'].value_counts()
        colors_map = {'Hard': '#ff9999', 'Medium': '#ffff99', 'Simple': '#66b3ff'}
        pie_colors = [colors_map.get(l, '#cccccc') for l in diff_counts.index]
        
        plt.pie(diff_counts, labels=diff_counts.index, autopct='%1.1f%%', startangle=140, colors=pie_colors)
        plt.title('Test Sample Difficulty Distribution', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()


In [None]:
# Example Usage:
# visualize_results(res_m1_zero, res_m1_few, res_m2_zero, res_m2_few, res_m3_zero, res_m3_few, test_subset=test_subset)


In [None]:

import matplotlib.pyplot as plt
import seaborn as sns
import json
import pandas as pd

def visualize_results(res_m1_zero, res_m1_few, res_m2_zero, res_m2_few, res_m3_zero, res_m3_few, test_subset=None):
    print("\n--- Visualization of Results ---")
    
    # 1. Construct unified data structure & Save to JSON
    data_map = {
        'M1_Zero': res_m1_zero, 'M1_Few': res_m1_few,
        'M2_Zero': res_m2_zero, 'M2_Few': res_m2_few,
        'M3_Zero': res_m3_zero, 'M3_Few': res_m3_few
    }
    
    json_data = []
    for key, df in data_map.items():
        if df.empty: continue
        method, mode = key.split('_')
        records = df.to_dict(orient='records')
        for r in records:
            r['Method'] = method
            r['Mode'] = mode
            # Calculate status for stacked bar
            if r['is_correct']:
                if r.get('rounds', 0) == 0:
                    r['Status'] = 'Initial Correct'
                else:
                    r['Status'] = 'Fixed Correct'
            else:
                r['Status'] = 'Incorrect'
        json_data.extend(records)
        
    json_path = 'experiment_results.json'
    with open(json_path, 'w') as f:
        json.dump(json_data, f, indent=2, default=str)
    print(f"Results saved to {json_path}")
    
    # 2. Read from JSON for plotting (as requested)
    print(f"Reading results from {json_path} for plotting...")
    with open(json_path, 'r') as f:
        loaded_data = json.load(f)
    
    full_df = pd.DataFrame(loaded_data)
    
    if full_df.empty:
        print("No data to visualize.")
        return

    # Set Academic Style
    sns.set_theme(style="white")
    sns.set_context("paper", font_scale=1.2)
    
    # Define Palettes: M1=Blue, M2=Orange, M3=Green
    palette_dict = {'M1': '#4E79A7', 'M2': '#F28E2B', 'M3': '#59A14F'}
    
    # --- Figure 1: 2x2 Grid (Accuracy & Rounds) ---
    fig1, axes1 = plt.subplots(2, 2, figsize=(14, 10))
    fig1.suptitle('Performance Overview', fontsize=16, fontweight='bold')
    
    modes = ['Zero', 'Few']
    agg_df = full_df.groupby(['Method', 'Mode']).agg(
        Accuracy=('is_correct', 'mean'),
        Avg_Rounds=('rounds', 'mean')
    ).reset_index()
    
    for i, mode in enumerate(modes):
        # Top Row: Accuracy
        ax_acc = axes1[0, i]
        subset = agg_df[agg_df['Mode'] == mode]
        if not subset.empty:
            sns.barplot(data=subset, x='Method', y='Accuracy', palette=palette_dict, ax=ax_acc)
            ax_acc.set_title(f'Accuracy ({mode}-shot)')
            ax_acc.set_ylim(0, 1.1)
            ax_acc.set_ylabel('Success Rate' if i==0 else '')
            for p in ax_acc.patches:
                ax_acc.annotate(f'{p.get_height():.2f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                                ha='center', va='center', xytext=(0, 5), textcoords='offset points', fontweight='bold')

        # Bottom Row: Rounds
        ax_rds = axes1[1, i]
        subset = agg_df[agg_df['Mode'] == mode]
        if not subset.empty:
            sns.barplot(data=subset, x='Method', y='Avg_Rounds', palette=palette_dict, ax=ax_rds)
            ax_rds.set_title(f'Avg Interaction Rounds ({mode}-shot)')
            ax_rds.set_ylabel('Rounds' if i==0 else '')
            # Add values on top
            max_h = 0
            for p in ax_rds.patches:
                h = p.get_height()
                if h > max_h: max_h = h
                ax_rds.annotate(f'{h:.2f}', (p.get_x() + p.get_width() / 2., h),
                                ha='center', va='bottom', fontsize=10, fontweight='bold')
            ax_rds.set_ylim(0, max_h * 1.2 if max_h > 0 else 1) 

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()
    
    # --- Figure 2: Stacked Bar Chart (Correctness Breakdown) ---
    breakdown_list = []
    for (method, mode), group in full_df.groupby(['Method', 'Mode']):
        total = len(group)
        initial = len(group[group['Status'] == 'Initial Correct'])
        fixed = len(group[group['Status'] == 'Fixed Correct'])
        incorrect = len(group[group['Status'] == 'Incorrect'])
        
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Initial Correct', 'Prop': initial/total})
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Fixed Correct', 'Prop': fixed/total})
        breakdown_list.append({'Method': method, 'Mode': mode, 'Type': 'Incorrect', 'Prop': incorrect/total})
        
    breakdown_df = pd.DataFrame(breakdown_list)
    status_palette = {'Initial Correct': '#1f77b4', 'Fixed Correct': '#2ca02c', 'Incorrect': '#d62728'}
    
    if not breakdown_df.empty:
        fig2, axes2 = plt.subplots(1, 2, figsize=(14, 6))
        fig2.suptitle('Correctness Breakdown', fontsize=16, fontweight='bold')
        
        for i, mode in enumerate(modes):
            ax = axes2[i]
            subset = breakdown_df[breakdown_df['Mode'] == mode]
            if not subset.empty:
                pivot_df = subset.pivot(index='Method', columns='Type', values='Prop').fillna(0)
                # Ensure all columns exist
                for col in ['Initial Correct', 'Fixed Correct', 'Incorrect']:
                    if col not in pivot_df.columns: pivot_df[col] = 0
                pivot_df = pivot_df[['Initial Correct', 'Fixed Correct', 'Incorrect']]
                
                pivot_df.plot(kind='bar', stacked=True, color=[status_palette[c] for c in pivot_df.columns], ax=ax)
                ax.set_title(f'Breakdown ({mode}-shot)')
                ax.set_ylabel('Proportion' if i==0 else '')
                ax.set_ylim(0, 1.0)
                if i == 1:
                    ax.legend(title='Status', bbox_to_anchor=(1.05, 1), loc='upper left')
                else:
                    ax.get_legend().remove()

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

    # --- Figure 3: Difficulty Distribution ---
    if test_subset is not None and 'difficulty' in test_subset.columns:
        plt.figure(figsize=(6, 6))
        diff_counts = test_subset['difficulty'].value_counts()
        colors_map = {'Hard': '#ff9999', 'Medium': '#ffff99', 'Simple': '#66b3ff'}
        pie_colors = [colors_map.get(l, '#cccccc') for l in diff_counts.index]
        
        plt.pie(diff_counts, labels=diff_counts.index, autopct='%1.1f%%', startangle=140, colors=pie_colors)
        plt.title('Test Sample Difficulty Distribution', fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()


In [None]:
# Example Usage:
visualize_results(res_m1_zero, res_m1_few, res_m2_zero, res_m2_few, res_m3_zero, res_m3_few, test_subset=test_subset)
