In [1]:
import pandas as pd
import time
import openai
import os
import sys


In [2]:
schema_linking_prompt = '''
Table all_star, columns = [*, player_id, year, game_num, game_id, team_id, league_id, gp, starting_pos]
Table appearances, columns = [*, year, team_id, league_id, player_id, g_all, gs, g_batting, g_defense, g_p, g_c, g_1b, g_2b, g_3b, g_ss, g_lf, g_cf, g_rf, g_of, g_dh, g_ph, g_pr]
Table manager_award, columns = [*, player_id, award_id, year, league_id, tie, notes]
Table player_award, columns = [*, player_id, award_id, year, league_id, tie, notes]
Table manager_award_vote, columns = [*, award_id, year, league_id, player_id, points_won, points_max, votes_first]
Table player_award_vote, columns = [*, award_id, year, league_id, player_id, points_won, points_max, votes_first]
Table batting, columns = [*, player_id, year, stint, team_id, league_id, g, ab, r, h, double, triple, hr, rbi, sb, cs, bb, so, ibb, hbp, sh, sf, g_idp]
Table batting_postseason, columns = [*, year, round, player_id, team_id, league_id, g, ab, r, h, double, triple, hr, rbi, sb, cs, bb, so, ibb, hbp, sh, sf, g_idp]
Table player_college, columns = [*, player_id, college_id, year]
Table fielding, columns = [*, player_id, year, stint, team_id, league_id, pos, g, gs, inn_outs, po, a, e, dp, pb, wp, sb, cs, zr]
Table fielding_outfield, columns = [*, player_id, year, stint, glf, gcf, grf]
Table fielding_postseason, columns = [*, player_id, year, team_id, league_id, round, pos, g, gs, inn_outs, po, a, e, dp, tp, pb, sb, cs]
Table hall_of_fame, columns = [*, player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note]
Table home_game, columns = [*, year, league_id, team_id, park_id, span_first, span_last, games, openings, attendance]
Table manager, columns = [*, player_id, year, team_id, league_id, inseason, g, w, l, rank, plyr_mgr]
Table manager_half, columns = [*, player_id, year, team_id, league_id, inseason, half, g, w, l, rank]
Table player, columns = [*, player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight, height, bats, throws, debut, final_game, retro_id, bbref_id]
Table park, columns = [*, park_id, park_name, park_alias, city, state, country]
Table pitching, columns = [*, player_id, year, stint, team_id, league_id, w, l, g, gs, cg, sho, sv, ipouts, h, er, hr, bb, so, baopp, era, ibb, wp, hbp, bk, bfp, gf, r, sh, sf, g_idp]
Table pitching_postseason, columns = [*, player_id, year, round, team_id, league_id, w, l, g, gs, cg, sho, sv, ipouts, h, er, hr, bb, so, baopp, era, ibb, wp, hbp, bk, bfp, gf, r, sh, sf, g_idp]
Table salary, columns = [*, year, team_id, league_id, player_id, salary]
Table college, columns = [*, college_id, name_full, city, state, country]
Table postseason, columns = [*, year, round, team_id_winner, league_id_winner, team_id_loser, league_id_loser, wins, losses, ties]
Table team, columns = [*, year, league_id, team_id, franchise_id, div_id, rank, g, ghome, w, l, div_win, wc_win, lg_win, ws_win, r, ab, h, double, triple, hr, bb, so, sb, cs, hbp, sf, ra, er, era, cg, sho, sv, ipouts, ha, hra, bba, soa, e, dp, fp, name, park, attendance, bpf, ppf, team_id_br, team_id_lahman45, team_id_retro]
Table team_franchise, columns = [*, franchise_id, franchise_name, active, na_assoc]
Table team_half, columns = [*, year, league_id, team_id, half, div_id, div_win, rank, g, w, l]
Foreign_keys = [
  all_star.player_id = player.player_id,
  all_star.team_id = team.team_id,
  all_star.league_id = team.league_id,
  all_star.year = team.year,

  appearances.player_id = player.player_id,
  appearances.team_id = team.team_id,
  appearances.league_id = team.league_id,
  appearances.year = team.year,

  manager_award.player_id = manager.player_id,
  manager_award.league_id = league.league_id,
  manager_award.year = manager.year,

  player_award.player_id = player.player_id,
  player_award.league_id = league.league_id,
  player_award.year = player.year,

  manager_award_vote.player_id = manager.player_id,
  manager_award_vote.league_id = league.league_id,
  manager_award_vote.year = manager.year,

  player_award_vote.player_id = player.player_id,
  player_award_vote.league_id = league.league_id,
  player_award_vote.year = player.year,

  batting.player_id = player.player_id,
  batting.team_id = team.team_id,
  batting.league_id = team.league_id,
  batting.year = team.year,

  batting_postseason.player_id = player.player_id,
  batting_postseason.team_id = team.team_id,
  batting_postseason.league_id = team.league_id,
  batting_postseason.year = team.year,

  player_college.player_id = player.player_id,
  player_college.college_id = college.college_id,

  fielding.player_id = player.player_id,
  fielding.team_id = team.team_id,
  fielding.league_id = team.league_id,
  fielding.year = team.year,

  fielding_outfield.player_id = player.player_id,
  fielding_outfield.year = fielding.year,

  fielding_postseason.player_id = player.player_id,
  fielding_postseason.team_id = team.team_id,
  fielding_postseason.league_id = team.league_id,
  fielding_postseason.year = team.year,

  hall_of_fame.player_id = player.player_id,

  home_game.team_id = team.team_id,
  home_game.league_id = team.league_id,
  home_game.year = team.year,
  home_game.park_id = park.park_id,

  manager.player_id = player.player_id,
  manager.team_id = team.team_id,
  manager.league_id = team.league_id,
  manager.year = team.year,

  manager_half.player_id = player.player_id,
  manager_half.team_id = team.team_id,
  manager_half.league_id = team.league_id,
  manager_half.year = team.year,

  salary.player_id = player.player_id,
  salary.team_id = team.team_id,
  salary.league_id = team.league_id,
  salary.year = team.year,

  team.franchise_id = team_franchise.franchise_id,
  team.league_id = league.league_id,
  team.year = league.year,

  team_half.team_id = team.team_id,
  team_half.league_id = team.league_id,
  team_half.year = team.year
]
    {"Q": "What is the full name and id of the college with the largest number of baseball players?",
    "A": "Let’s think step by step. In the question 'What is the full name and id of the college with the largest number of baseball players?', 
    we are asked:\n'full name and id of the college' 
    so we need column = [college.name_full, college.college_id]\n
    'largest number of baseball players' so we need column = [player_college.college_id, player_college.player_id]\nBased on the columns and tables, we need these Foreign_keys = [player_college.college_id = college.college_id].\nBased on the tables, columns, and Foreign_keys, The set of possible cell values are = ['baseball']. So the Schema_links are:\nSchema_links: [college.name_full, college.college_id, player_college.college_id, player_college.player_id, 'baseball']"}

    {
    "Q": "Find the full name and id of the college that has the most baseball players.",
    "A": "Let’s think step by step. In the question 'Find the full name and id of the college that has the most baseball players.', we are asked:\n'full name and id of the college' so we need column = [college.name_full, college.college_id]\n'has the most baseball players' so we need column = [player_college.college_id, player_college.player_id]\nBased on the columns and tables, we need these Foreign_keys = [player_college.college_id = college.college_id].\nBased on the tables, columns, and Foreign_keys, The set of possible cell values are = ['baseball']. So the Schema_links are:\nSchema_links: [college.name_full, college.college_id, player_college.college_id, player_college.player_id, 'baseball']"
  },
  {
    "Q": "What is the average salary of the players in the team named 'Boston Red Stockings'?",
    "A": "Let’s think step by step. In the question 'What is the average salary of the players in the team named 'Boston Red Stockings'?', we are asked:\n'average salary' so we need column = [salary.salary]\n'players in the team named 'Boston Red Stockings'' so we need column = [team.name]\nBased on the columns and tables, we need these Foreign_keys = [salary.player_id = player.player_id, player.team_id = team.team_id].\nBased on the tables, columns, and Foreign_keys, The set of possible cell values are = ['Boston Red Stockings']. So the Schema_links are:\nSchema_links: [salary.salary, salary.player_id = player.player_id, player.team_id = team.team_id, team.name, 'Boston Red Stockings']"
  },
  {
    "Q": "Compute the average salary of the players in the team called 'Boston Red Stockings'.",
    "A": "Let’s think step by step. In the question 'Compute the average salary of the players in the team called 'Boston Red Stockings'.', we are asked:\n'average salary' so we need column = [salary.salary]\n'players in the team called 'Boston Red Stockings'' so we need column = [team.name]\nBased on the columns and tables, we need these Foreign_keys = [salary.player_id = player.player_id, player.team_id = team.team_id].\nBased on the tables, columns, and Foreign_keys, The set of possible cell values are = ['Boston Red Stockings']. So the Schema_links are:\nSchema_links: [salary.salary, salary.player_id = player.player_id, player.team_id = team.team_id, team.name, 'Boston Red Stockings']"
  },
  {
    "Q": "What are the first and last names of players participating in the all-star game in 1998?",
    "A": "Let’s think step by step. In the question 'What are the first and last names of players participating in the all-star game in 1998?', we are asked:\n'first and last names of players' so we need column = [player.name_first, player.name_last]\n'participating in the all-star game in 1998' so we need column = [all_star.year]\nBased on the columns and tables, we need these Foreign_keys = [all_star.player_id = player.player_id].\nBased on the tables, columns, and Foreign_keys, The set of possible cell values are = [1998]. So the Schema_links are:\nSchema_links: [player.name_first, player.name_last, all_star.player_id = player.player_id, all_star.year, 1998]"
  },
  {
    "Q": "List the first and last names for players who participated in the all-star game in 1998.",
    "A": "Let’s think step by step. In the question 'List the first and last names for players who participated in the all-star game in 1998.', we are asked:\n'first and last names for players' so we need column = [player.name_first, player.name_last]\n'participated in the all-star game in 1998' so we need column = [all_star.year]\nBased on the columns and tables, we need these Foreign_keys = [all_star.player_id = player.player_id].\nBased on the tables, columns, and Foreign_keys, The set of possible cell values are = [1998]. So the Schema_links are:\nSchema_links: [player.name_first, player.name_last, all_star.player_id = player.player_id, all_star.year, 1998]"
  }
'''

In [3]:
classification_prompt  ='''
Q: "What is the full name and ID of the college with the largest number of baseball players?" schema_links: [college.name_full, college.college_id, player_college.player_id, player_college.college_id, baseball] A: Let’s think step by step. The SQL query for the question "What is the full name and ID of the college with the largest number of baseball players?" needs these tables = [college, player_college], so we need JOIN. Plus, it doesn't require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""]. So, we need JOIN and don't need nested queries, then the the SQL query can be classified as "NON-NESTED". Label: "NON-NESTED"

Q: "Find the full name and ID of the college that has the most baseball players." schema_links: [college.name_full, college.college_id, player_college.player_id, player_college.college_id, baseball] A: Let’s think step by step. The SQL query for the question "Find the full name and ID of the college that has the most baseball players." needs these tables = [college, player_college], so we need JOIN. Plus, it doesn't require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""]. So, we need JOIN and don't need nested queries, then the the SQL query can be classified as "NON-NESTED". Label: "NON-NESTED"

Q: "What is the average salary of the players in the team named 'Boston Red Stockings'?" schema_links: [salary.salary, team.name, team.team_id, salary.team_id, 'Boston Red Stockings'] A: Let’s think step by step. The SQL query for the question "What is the average salary of the players in the team named 'Boston Red Stockings'?" needs these tables = [salary, team], so we need JOIN. Plus, it doesn't require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""]. So, we need JOIN and don't need nested queries, then the the SQL query can be classified as "NON-NESTED". Label: "NON-NESTED"

Q: "Compute the average salary of the players in the team called 'Boston Red Stockings'." schema_links: [salary.salary, team.name, team.team_id, salary.team_id, 'Boston Red Stockings'] A: Let’s think step by step. The SQL query for the question "Compute the average salary of the players in the team called 'Boston Red Stockings'." needs these tables = [salary, team], so we need JOIN. Plus, it doesn't require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""]. So, we need JOIN and don't need nested queries, then the the SQL query can be classified as "NON-NESTED". Label: "NON-NESTED"

Q: "What are the first and last names of players participating in the all star game in 1998?" schema_links: [player.name_first, player.name_last, all_star.player_id, player.player_id, 1998] A: Let’s think step by step. The SQL query for the question "What are the first and last names of players participating in the all star game in 1998?" needs these tables = [player, all_star], so we need JOIN. Plus, it doesn't require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""]. So, we need JOIN and don't need nested queries, then the the SQL query can be classified as "NON-NESTED". Label: "NON-NESTED"

Q: "List the first and last name for players who participated in all star game in 1998." schema_links: [player.name_first, player.name_last, all_star.player_id, player.player_id, 1998] A: Let’s think step by step. The SQL query for the question "List the first and last name for players who participated in all star game in 1998." needs these tables = [player, all_star], so we need JOIN. Plus, it doesn't require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""]. So, we need JOIN and don't need nested queries, then the the SQL query can be classified as "NON-NESTED". Label: "NON-NESTED"
'''

In [4]:
easy_prompt = '''  
Q: "What is the full name and ID of the college with the largest number of baseball players?" Schema_links: [college.name_full, college.college_id, player_college.player_id, player_college.college_id, baseball] SQL: SELECT college.name_full, college.college_id FROM college JOIN player_college ON college.college_id = player_college.college_id WHERE player_college.team = 'baseball' GROUP BY college.college_id ORDER BY COUNT(player_college.player_id) DESC LIMIT 1

Q: "Find the full name and ID of the college that has the most baseball players." Schema_links: [college.name_full, college.college_id, player_college.player_id, player_college.college_id, baseball] SQL: SELECT college.name_full, college.college_id FROM college JOIN player_college ON college.college_id = player_college.college_id WHERE player_college.team = 'baseball' GROUP BY college.college_id ORDER BY COUNT(player_college.player_id) DESC LIMIT 1

Q: "What is the average salary of the players in the team named 'Boston Red Stockings'?" Schema_links: [salary.salary, team.name, team.team_id, salary.team_id, 'Boston Red Stockings'] SQL: SELECT AVG(salary.salary) FROM salary JOIN team ON salary.team_id = team.team_id WHERE team.name = 'Boston Red Stockings'

Q: "Compute the average salary of the players in the team called 'Boston Red Stockings'." Schema_links: [salary.salary, team.name, team.team_id, salary.team_id, 'Boston Red Stockings'] SQL: SELECT AVG(salary.salary) FROM salary JOIN team ON salary.team_id = team.team_id WHERE team.name = 'Boston Red Stockings'

Q: "What are the first and last names of players participating in the all star game in 1998?" Schema_links: [player.name_first, player.name_last, all_star.player_id, player.player_id, 1998] SQL: SELECT player.name_first, player.name_last FROM player JOIN all_star ON player.player_id = all_star.player_id WHERE all_star.year = 1998

Q: "List the first and last name for players who participated in all star game in 1998." Schema_links: [player.name_first, player.name_last, all_star.player_id, player.player_id, 1998] SQL: SELECT player.name_first, player.name_last FROM player JOIN all_star ON player.player_id = all_star.player_id WHERE all_star.year = 1998

'''

In [5]:
medium_prompt = '''
Q: "What is the full name and ID of the college with the largest number of baseball players?" Schema_links: [college.name_full, college.college_id, player_college.player_id, player_college.college_id, baseball] A: Let’s think step by step. For creating the SQL for the given question, we need to join these tables = [college, player_college]. First, create an intermediate representation, then use it to construct the SQL query.
Intermediate_representation: SELECT college.name_full, college.college_id FROM college JOIN player_college ON college.college_id = player_college.college_id WHERE player_college.team = 'baseball' GROUP BY college.college_id ORDER BY COUNT(player_college.player_id) DESC LIMIT 1
SQL: SELECT college.name_full, college.college_id FROM college JOIN player_college ON college.college_id = player_college.college_id WHERE player_college.team = 'baseball' GROUP BY college.college_id ORDER BY COUNT(player_college.player_id) DESC LIMIT 1

Q: "Find the full name and ID of the college that has the most baseball players." Schema_links: [college.name_full, college.college_id, player_college.player_id, player_college.college_id, baseball] A: Let’s think step by step. For creating the SQL for the given question, we need to join these tables = [college, player_college]. First, create an intermediate representation, then use it to construct the SQL query.
Intermediate_representation: SELECT college.name_full, college.college_id FROM college JOIN player_college ON college.college_id = player_college.college_id WHERE player_college.team = 'baseball' GROUP BY college.college_id ORDER BY COUNT(player_college.player_id) DESC LIMIT 1
SQL: SELECT college.name_full, college.college_id FROM college JOIN player_college ON college.college_id = player_college.college_id WHERE player_college.team = 'baseball' GROUP BY college.college_id ORDER BY COUNT(player_college.player_id) DESC LIMIT 1

Q: "What is the average salary of the players in the team named 'Boston Red Stockings'?" Schema_links: [salary.salary, team.name, team.team_id, salary.team_id, 'Boston Red Stockings'] A: Let’s think step by step. For creating the SQL for the given question, we need to join these tables = [salary, team]. First, create an intermediate representation, then use it to construct the SQL query.
Intermediate_representation: SELECT AVG(salary.salary) FROM salary JOIN team ON salary.team_id = team.team_id WHERE team.name = 'Boston Red Stockings'
SQL: SELECT AVG(salary.salary) FROM salary JOIN team ON salary.team_id = team.team_id WHERE team.name = 'Boston Red Stockings'

Q: "Compute the average salary of the players in the team called 'Boston Red Stockings'." Schema_links: [salary.salary, team.name, team.team_id, salary.team_id, 'Boston Red Stockings'] A: Let’s think step by step. For creating the SQL for the given question, we need to join these tables = [salary, team]. First, create an intermediate representation, then use it to construct the SQL query.
Intermediate_representation: SELECT AVG(salary.salary) FROM salary JOIN team ON salary.team_id = team.team_id WHERE team.name = 'Boston Red Stockings'
SQL: SELECT AVG(salary.salary) FROM salary JOIN team ON salary.team_id = team.team_id WHERE team.name = 'Boston Red Stockings'

Q: "What are the first and last names of players participating in the all star game in 1998?" Schema_links: [player.name_first, player.name_last, all_star.player_id, player.player_id, 1998] A: Let’s think step by step. For creating the SQL for the given question, we need to join these tables = [player, all_star]. First, create an intermediate representation, then use it to construct the SQL query.
Intermediate_representation: SELECT player.name_first, player.name_last FROM player JOIN all_star ON player.player_id = all_star.player_id WHERE all_star.year = 1998
SQL: SELECT player.name_first, player.name_last FROM player JOIN all_star ON player.player_id = all_star.player_id WHERE all_star.year = 1998

Q: "List the first and last name for players who participated in all star game in 1998." Schema_links: [player.name_first, player.name_last, all_star.player_id, player.player_id, 1998] A: Let’s think step by step. For creating the SQL for the given question, we need to join these tables = [player, all_star]. First, create an intermediate representation, then use it to construct the SQL query.
Intermediate_representation: SELECT player.name_first, player.name_last FROM player JOIN all_star ON player.player_id = all_star.player_id WHERE all_star.year = 1998
SQL: SELECT player.name_first, player.name_last FROM player JOIN all_star ON player.player_id = all_star.player_id WHERE all_star.year = 1998
'''

In [6]:
hard_prompt = '''
Q: "What is the average salary of the players in the team named 'Boston Red Stockings'?" Schema_links: [salary.salary, team.name, team.team_id, salary.team_id, 'Boston Red Stockings'] 
A: Let’s think step by step. "What is the average salary of the players in the team named 'Boston Red Stockings'?" can be solved by knowing the answer to the following sub-question: "What is the average salary for players in the team 'Boston Red Stockings'?" The SQL query for the sub-question "What is the average salary for players in the team 'Boston Red Stockings'?" is SELECT AVG(salary.salary) FROM salary JOIN team ON salary.team_id = team.team_id WHERE team.name = 'Boston Red Stockings' So, the answer to the question "What is the average salary of the players in the team named 'Boston Red Stockings'?" is = Intermediate_representation: select avg(salary.salary) from salary join team on salary.team_id = team.team_id where team.name = 'Boston Red Stockings' SQL: SELECT AVG(salary.salary) FROM salary JOIN team ON salary.team_id = team.team_id WHERE team.name = 'Boston Red Stockings'

Q: "What are the first and last names of players participating in the all star game in 1998?" Schema_links: [player.name_first, player.name_last, all_star.player_id, player.player_id, 1998] A: Let’s think step by step. "What are the first and last names of players participating in the all star game in 1998?" can be solved by knowing the answer to the following sub-question: "What are the first and last names of players who participated in the all star game in 1998?" The SQL query for the sub-question "What are the first and last names of players who participated in the all star game in 1998?" is SELECT player.name_first, player.name_last FROM player JOIN all_star ON player.player_id = all_star.player_id WHERE all_star.year = 1998 So, the answer to the question "What are the first and last names of players participating in the all star game in 1998?" is = Intermediate_representation: select player.name_first, player.name_last from player join all_star on player.player_id = all_star.player_id where all_star.year = 1998 SQL: SELECT player.name_first, player.name_last FROM player JOIN all_star ON player.player_id = all_star.player_id WHERE all_star.year = 1998

Q: "Find the room number of the rooms which can sit 50 to 100 students and their buildings." Schema_links: [classroom.building, classroom.room_number, classroom.capacity, 50, 100] 
A: Let’s think step by step. "Find the room number of the rooms which can sit 50 to 100 students and their buildings." can be solved by knowing the answer to the following sub-question: "What are the room numbers and buildings for rooms with capacity between 50 and 100 students?" The SQL query for the sub-question "What are the room numbers and buildings for rooms with capacity between 50 and 100 students?" is SELECT building, room_number FROM classroom WHERE capacity BETWEEN 50 AND 100 So, the answer to the question "Find the room number of the rooms which can sit 50 to 100 students and their buildings." is = Intermediate_representation: select building, room_number from classroom where capacity between 50 and 100 SQL: SELECT building, room_number FROM classroom WHERE capacity BETWEEN 50 AND 100

Q: "Find the buildings which have rooms with capacity more than 50." Schema_links: [classroom.building, classroom.capacity, 50] A: Let’s think step by step. "Find the buildings which have rooms with capacity more than 50." can be solved by knowing the answer to the following sub-question: "Which buildings have rooms with a capacity greater than 50?" The SQL query for the sub-question "Which buildings have rooms with a capacity greater than 50?" is SELECT DISTINCT building FROM classroom WHERE capacity > 50 So, the answer to the question "Find the buildings which have rooms with capacity more than 50." is = Intermediate_representation: select distinct building from classroom where capacity > 50 SQL: SELECT DISTINCT building FROM classroom WHERE capacity > 50

Q: "Find the department name of the instructor whose name contains 'Soisalon'." Schema_links: [instructor.dept_name, instructor.name, Soisalon] A: Let’s think step by step. "Find the department name of the instructor whose name contains 'Soisalon'." can be solved by knowing the answer to the following sub-question: "What is the department name of the instructor whose name contains 'Soisalon'?" The SQL query for the sub-question "What is the department name of the instructor whose name contains 'Soisalon'?" is SELECT dept_name FROM instructor WHERE name LIKE '%Soisalon%' So, the answer to the question "Find the department name of the instructor whose name contains 'Soisalon'." is = Intermediate_representation: select dept_name from instructor where name like '%Soisalon%' SQL: SELECT dept_name FROM instructor WHERE name LIKE '%Soisalon%'

'''

In [7]:
DATASET_SCHEMA =  "dataset/tables.json"
DATASET = "dataset/dev.json"
OUTPUT_FILE = 'dataset/predicted_sql.txt'

In [9]:
def load_data(DATASET):
    return pd.read_json(DATASET)

In [10]:
def hard_prompt_maker(test_sample_text,database,schema_links,sub_questions):
  instruction = "# Use the intermediate representation and the schema links to generate the SQL queries for each of the questions.\n"
  fields = find_fields_MYSQL_like("baseball_1")
  fields += "Foreign_keys = " + find_foreign_keys_MYSQL_like("baseball_1") + '\n'
  fields += find_fields_MYSQL_like(database)
  fields += "Foreign_keys = " + find_foreign_keys_MYSQL_like(database) + '\n'
  stepping = f'''\nA: Let's think step by step. "{test_sample_text}" can be solved by knowing the answer to the following sub-question "{sub_questions}".'''
  fields += "\n"
  prompt = instruction +fields + hard_prompt + 'Q: "' + test_sample_text + '"' + '\nschema_links: ' + schema_links + stepping +'\nThe SQL query for the sub-question"'
  return prompt

In [11]:
def medium_prompt_maker(test_sample_text,database,schema_links):
  instruction = "# Use the the schema links and Intermediate_representation to generate the SQL queries for each of the questions.\n"
  fields = find_fields_MYSQL_like("baseball_1")
  fields += "Foreign_keys = " + find_foreign_keys_MYSQL_like("baseball_1") + '\n'
  fields += find_fields_MYSQL_like(database)
  fields += "Foreign_keys = " + find_foreign_keys_MYSQL_like(database) + '\n'
  fields += "\n"
  prompt = instruction +fields + medium_prompt + 'Q: "' + test_sample_text + '\nSchema_links: ' + schema_links + '\nA: Let’s think step by step.'
  return prompt
def easy_prompt_maker(test_sample_text,database,schema_links):
  instruction = "# Use the the schema links to generate the SQL queries for each of the questions.\n"
  fields = find_fields_MYSQL_like("baseball_1")
  fields += find_fields_MYSQL_like(database)
  fields += "\n"
  prompt = instruction +fields + easy_prompt + 'Q: "' + test_sample_text + '\nSchema_links: ' + schema_links + '\nSQL:'
  return prompt
def classification_prompt_maker(test_sample_text,database,schema_links):
  instruction = "# For the given question, classify it as EASY, NON-NESTED, or NESTED based on nested queries and JOIN.\n"
  instruction += "\nif need nested queries: predict NESTED\n"
  instruction += "elif need JOIN and don't need nested queries: predict NON-NESTED\n"
  instruction += "elif don't need JOIN and don't need nested queries: predict EASY\n\n"
  fields = find_fields_MYSQL_like("baseball_1")
  fields += "Foreign_keys = " + find_foreign_keys_MYSQL_like("baseball_1") + '\n'
  fields += find_fields_MYSQL_like(database)
  fields += "Foreign_keys = " + find_foreign_keys_MYSQL_like(database) + '\n'
  fields += "\n"
  prompt = instruction + fields + classification_prompt + 'Q: "' + test_sample_text + '\nschema_links: ' + schema_links + '\nA: Let’s think step by step.'
  return prompt

In [12]:
def schema_linking_prompt_maker(test_sample_text,database):
  instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys.\n"
  fields = find_fields_MYSQL_like(database)
  foreign_keys = "Foreign_keys = " + find_foreign_keys_MYSQL_like(database) + '\n'
  prompt = instruction + schema_linking_prompt + fields +foreign_keys+ 'Q: "' + test_sample_text + """"\nA: Let’s think step by step."""
 
  return prompt


def find_foreign_keys_MYSQL_like(db_name):
  df = spider_foreign[spider_foreign['Database name'] == db_name]
  output = "["
  for index, row in df.iterrows():
    output += row['First Table Name'] + '.' + row['First Table Foreign Key'] + " = " + row['Second Table Name'] + '.' + row['Second Table Foreign Key'] + ','
  output= output[:-1] + "]"
  return output
def find_fields_MYSQL_like(db_name):
  df = spider_schema[spider_schema['Database name'] == db_name]
  df = df.groupby(' Table Name')
  output = ""
  for name, group in df:
    output += "Table " +name+ ', columns = ['
    for index, row in group.iterrows():
      output += row[" Field Name"]+','
    output = output[:-1]
    output += "]\n"
  return output
def find_primary_keys_MYSQL_like(db_name):
  df = spider_primary[spider_primary['Database name'] == db_name]
  output = "["
  for index, row in df.iterrows():
    output += row['Table Name'] + '.' + row['Primary Key'] +','
  output = output[:-1]
  output += "]\n"
  return output
def creatiing_schema(DATASET_JSON):
    schema_df = pd.read_json(DATASET_JSON)
    schema_df = schema_df.drop(['column_names','table_names'], axis=1)
    schema = []
    f_keys = []
    p_keys = []
    for index, row in schema_df.iterrows():
        tables = row['table_names_original']
        col_names = row['column_names_original']
        col_types = row['column_types']
        foreign_keys = row['foreign_keys']
        primary_keys = row['primary_keys']
        for col, col_type in zip(col_names, col_types):
            index, col_name = col
            if index == -1:
                for table in tables:
                    schema.append([row['db_id'], table, '*', 'text'])
            else:
                schema.append([row['db_id'], tables[index], col_name, col_type])
        for primary_key in primary_keys:
            index, column = col_names[primary_key]
            p_keys.append([row['db_id'], tables[index], column])
        for foreign_key in foreign_keys:
            first, second = foreign_key
            first_index, first_column = col_names[first]
            second_index, second_column = col_names[second]
            f_keys.append([row['db_id'], tables[first_index], tables[second_index], first_column, second_column])
    spider_schema = pd.DataFrame(schema, columns=['Database name', ' Table Name', ' Field Name', ' Type'])
    spider_primary = pd.DataFrame(p_keys, columns=['Database name', 'Table Name', 'Primary Key'])
    spider_foreign = pd.DataFrame(f_keys,
                        columns=['Database name', 'First Table Name', 'Second Table Name', 'First Table Foreign Key',
                                 'Second Table Foreign Key'])
    return spider_schema,spider_primary,spider_foreign
def debuger(test_sample_text,database,sql):
  instruction = """#### For the given question, use the provided tables, columns, foreign keys, and primary keys to fix the given SQLite SQL QUERY for any issues. If there are any problems, fix them. If there are no issues, return the SQLite SQL QUERY as is.
#### Use the following instructions for fixing the SQL QUERY:
1) Use the database values that are explicitly mentioned in the question.
2) Pay attention to the columns that are used for the JOIN by using the Foreign_keys.
3) Use DESC and DISTINCT when needed.
4) Pay attention to the columns that are used for the GROUP BY statement.
5) Pay attention to the columns that are used for the SELECT statement.
6) Only change the GROUP BY clause when necessary (Avoid redundant columns in GROUP BY).
7) Use GROUP BY on one column only.

"""
  fields = find_fields_MYSQL_like(database)
  fields += "Foreign_keys = " + find_foreign_keys_MYSQL_like(database) + '\n'
  fields += "Primary_keys = " + find_primary_keys_MYSQL_like(database)
  prompt = instruction + fields+ '#### Question: ' + test_sample_text + '\n#### SQLite SQL QUERY\n' + sql +'\n#### SQLite FIXED SQL QUERY\nSELECT'
  return prompt

In [13]:
from openai import OpenAI

client = OpenAI()

In [14]:
# prompts=debuger('what is the full name and id of the college with the largest number of baseball players?', 'baseball_1', 'SELECT college.name_full, college.college_id FROM college JOIN player_college ON college.college_id = player_college.college_id GROUP BY college.college_id ORDER BY COUNT(player_college.player_id) DESC LIMIT 1')

# GPT4_debug(prompts)

In [16]:
import re
def GPT4_generation(prompt):
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{"role": "user", "content": prompt}],
        n=1,
        stream=False,
        temperature=0.0,
        max_tokens=600,
        top_p=1.0,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        stop=["Q:"]
    )
    return response.choices[0].message.content






def GPT4_debug(prompt):
  response = client.chat.completions.create(
    model="gpt-4o-mini",
    messages=[{"role": "user", "content": prompt}],
    n = 1,
    stream = False,
    temperature=0.0,
    max_tokens=350,
    top_p = 1.0,
    frequency_penalty=0.0,
    presence_penalty=0.0,
    stop = ["#", ";","\n\n"]
  )
   
  return response.choices[0].message.content

if __name__ == '__main__':
    spider_schema,spider_primary,spider_foreign = creatiing_schema(DATASET_SCHEMA)
    
    val_df = load_data(DATASET)
    print(f"Number of data samples {val_df.shape[0]}")
    CODEX = []
    for index, row in val_df[:10].iterrows():
        #if index < 405: continue #for testing
        print(f"index is {index}")
        print(row['query'])
        print(row['question'])

 
        schema_links = None
        while schema_links is None:
            try:
           
                schema_links = GPT4_generation(schema_linking_prompt_maker(row['question'], row['db_id']))
       
            except:
                time.sleep(3)
                pass
        try:
            schema_links = schema_links.split("Schema_links: ")[1]
           
        except:
            print("Slicing error for the schema_linking module")
            schema_links = "[]"
        #print(schema_links)
        classification = None
        while classification is None:
            try:
                classification = GPT4_generation(
                    classification_prompt_maker(row['question'], row['db_id'], schema_links[1:]))
          
            except:
                time.sleep(3)
                pass
        try:
            predicted_class = classification.split("Label: ")[1]
          
        except:
            print("Slicing error for the classification module")
            # predicted_class = '"NESTED"'
        #print(classification)
        if '"EASY"' in predicted_class:
            print("EASY")
            SQL = None
            while SQL is None:
                try:
                    SQL = GPT4_generation(easy_prompt_maker(row['question'], row['db_id'], schema_links))
                   
                except:
                    time.sleep(3)
                    pass
            try: 
                match = re.search(r"\s*```sql\n(.*?)```", SQL, re.DOTALL)
                
                if match:
                    SQL= match.group(1).strip().replace('```sql','')
                   
                else:
                    print("No final SQL query found.")
                     
            except Exception as e: 
                print("Error in extracting SQL:", str(e)) 
                SQL = "SELECT"
        elif '"NON-NESTED"' in predicted_class:
            print("NON-NESTED")
            SQL = None
            while SQL is None:
                try:
                    SQL = GPT4_generation(medium_prompt_maker(row['question'], row['db_id'], schema_links))
                   
                except:
                    time.sleep(3)
                    pass
            try: 
                match = re.search(r"SQL Query:\s*```sql\n(.*?)```", SQL, re.DOTALL)
              
                if match:
                    SQL= match.group(1).strip().replace('```sql','')
                   
                else:
                    print("No final SQL query found.")
                     
            except Exception as e: 
                print("Error in extracting SQL:", str(e)) 
                SQL = "SELECT"
            
        else:
            print('222222222222222222',classification)
            
            sub_questions = classification.split('questions = ["')[1].split('"]')[0]
            print("NESTED")
            SQL = None
            while SQL is None:
                try:
                    SQL = GPT4_generation(
                        hard_prompt_maker(row['question'], row['db_id'], schema_links, sub_questions))
                except:
                    time.sleep(3)
                    pass
            try:
                match = re.search(r"SQL Query:\s*```sql\n(.*?)```", SQL, re.DOTALL)
              
                if match:
                    SQL= match.group(1).strip().replace('```sql','')
                else:
                    print("No final SQL query found.")
                   
            except:
                print("SQL slicing error")
                SQL = "SELECT"
         
        debugged_SQL = None
        while debugged_SQL is None:
            try:
                 
                debugged_SQL = GPT4_debug(debuger(row['question'], row['db_id'], SQL)).replace("'''", " ")
                 
            except:
                time.sleep(3)
                pass
        SQL =  debugged_SQL
        print(SQL)
        CODEX.append([row['question'], SQL, row['query'], row['db_id']])
        #break
     
    df = pd.DataFrame(CODEX, columns=['NLQ', 'PREDICTED SQL', 'GOLD SQL', 'DATABASE'])
    df['PREDICTED SQL'] = (
    df['PREDICTED SQL']
    .str.replace(r"```sql", "", regex=True)
    .str.replace(r"```", "", regex=True)
    .str.strip()
)

    
    results = df['PREDICTED SQL'].tolist()
    df.to_csv('dataset/final.csv', index=False)

    with open(OUTPUT_FILE, 'w') as f:
        for line in results:
            f.write(f"{line}\n")

Number of data samples 82
index is 0
SELECT T1.name_full ,  T1.college_id FROM college AS T1 JOIN player_college AS T2 ON T1.college_id  =  T2.college_id GROUP BY T1.college_id ORDER BY count(*) DESC LIMIT 1;
what is the full name and id of the college with the largest number of baseball players?
NON-NESTED
```sql
SELECT college.name_full, college.college_id 
FROM college 
JOIN player_college ON college.college_id = player_college.college_id 
GROUP BY college.college_id 
ORDER BY COUNT(player_college.player_id) DESC 
LIMIT 1
```
index is 1
SELECT T1.name_full ,  T1.college_id FROM college AS T1 JOIN player_college AS T2 ON T1.college_id  =  T2.college_id GROUP BY T1.college_id ORDER BY count(*) DESC LIMIT 1;
Find the full name and id of the college that has the most baseball players.
NON-NESTED
```sql
SELECT college.name_full, college.college_id 
FROM college 
JOIN player_college ON college.college_id = player_college.college_id 
GROUP BY college.college_id 
ORDER BY COUNT(player_colle

In [17]:
def evaluation_prompt(predicted_sql, gold_sql):
    return f"""
Evaluate the following two SQL queries based on these metrics:
- Accuracy (Is it doing what the gold query does?)
- Efficiency (Is it optimized?)
- Hallucination (Does it contain made-up columns or tables? )
- Completeness (Does it return everything it should?)
- Structure Similarity (Does the structure match the gold?)
- Readability (Is it easy to read?)
- Overall Score (Average of all above)

### Predicted SQL:
{predicted_sql}

### Gold SQL:
{gold_sql}

Give the response in the following format:

**Scores**:
- Accuracy: X
- Efficiency: X
- Hallucination: X
- Completeness: X
- Structure Similarity: X
- Readability: X
- Overall Score: X
"""


In [18]:
def evaluate_row(predicted_sql, gold_sql):
    prompt = evaluation_prompt(predicted_sql, gold_sql)
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.0,
        max_tokens=500
    )
    return response.choices[0].message.content

In [19]:
df['Evaluation'] = df.apply(lambda row: evaluate_row(row['PREDICTED SQL'], row['GOLD SQL']), axis=1)


In [20]:
import re

# Define a function to extract scores from text
def extract_scores(text):
    pattern = r"- (\w[\w\s]*): (\d+\.?\d*)"
    matches = re.findall(pattern, text)

    score_dict = {
        'Accuracy': None,
        'Efficiency': None,
        'Hallucination': None,
        'Completeness': None,
        'Structure Similarity': None,
        'Readability': None,
        'Overall Score': None
    }

    for key, value in matches:
        key = key.strip()
        if key in score_dict:
            score_dict[key] = float(value)

    return pd.Series(score_dict)
score_df = df['Evaluation'].apply(extract_scores)
df = pd.concat([df, score_df], axis=1)
df.to_csv('few_results.csv')

In [21]:
df.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
NLQ,what is the full name and id of the college wi...,Find the full name and id of the college that ...,What is average salary of the players in the t...,Compute the average salary of the players in t...,What are first and last names of players parti...,List the first and last name for players who p...,"What are the first name, last name and id of t...",Which player has the most all star game experi...,How many players enter hall of fame each year?,Count the number of players who enter hall of ...
PREDICTED SQL,"SELECT college.name_full, college.college_id \...","SELECT college.name_full, college.college_id \...",SELECT AVG(salary.salary) \nFROM salary \nJOIN...,SELECT AVG(salary.salary) \nFROM salary \nJOIN...,"SELECT player.name_first, player.name_last \nF...","SELECT DISTINCT player.name_first, player.name...","SELECT player.name_first, player.name_last, pl...","SELECT player.name_first, player.name_last, pl...","SELECT hall_of_fame.yearid, COUNT(DISTINCT hal...","SELECT hall_of_fame.yearid, COUNT(DISTINCT hal..."
GOLD SQL,"SELECT T1.name_full , T1.college_id FROM coll...","SELECT T1.name_full , T1.college_id FROM coll...",SELECT avg(T1.salary) FROM salary AS T1 JOIN t...,SELECT avg(T1.salary) FROM salary AS T1 JOIN t...,"SELECT name_first , name_last FROM player AS ...","SELECT name_first , name_last FROM player AS ...","SELECT T1.name_first , T1.name_last , T1.play...","SELECT T1.name_first , T1.name_last , T1.play...","SELECT yearid , count(*) FROM hall_of_fame GR...","SELECT yearid , count(*) FROM hall_of_fame GR..."
DATABASE,baseball_1,baseball_1,baseball_1,baseball_1,baseball_1,baseball_1,baseball_1,baseball_1,baseball_1,baseball_1
Evaluation,**Scores**:\n- Accuracy: 10 (Both queries achi...,**Scores**:\n- Accuracy: 10 (The predicted SQL...,**Scores**:\n- Accuracy: 0.5 (The predicted SQ...,**Scores**:\n- Accuracy: 7\n- Efficiency: 8\n-...,**Scores**:\n- Accuracy: 5 (The predicted SQL ...,**Scores**:\n- Accuracy: 1 (Both queries retri...,**Scores**:\n- Accuracy: 10 (Both queries achi...,**Scores**:\n- Accuracy: 10 (The predicted SQL...,**Scores**:\n- Accuracy: 8 (The predicted SQL ...,**Scores**:\n- Accuracy: 8 \n - The predicte...
Accuracy,10.0,10.0,0.5,7.0,5.0,1.0,10.0,10.0,8.0,8.0
Efficiency,10.0,10.0,1.0,8.0,5.0,0.8,10.0,10.0,7.0,7.0
Hallucination,10.0,10.0,0.5,9.0,5.0,1.0,10.0,10.0,9.0,9.0
Completeness,10.0,10.0,1.0,7.0,5.0,1.0,10.0,10.0,8.0,8.0
Structure Similarity,9.0,9.0,0.8,6.0,4.0,0.9,9.0,9.0,8.0,8.0
