This notebook runs Galois for both ChatGPT and InstructGPT.

In [None]:
cd ../..

In [None]:
import json
import openai
import pandas as pd

from dotenv import dotenv_values
from src.utils import augment_questions

key = dotenv_values()['key']
openai.api_key = key

In [None]:
final_sample_df = pd.read_csv('data/Final_Queries.csv')

As duckdb does not provide parsed query plans, we had to parse them. This was doable for non-join queries but not so for join queries. We ended up cretaing the query tree manually.

In [None]:
from src.QueryTree import Node

join_query_trees={}
# 40
# Which language is the most popular in Aruba?
q = 'SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Aruba" ORDER BY Percentage DESC LIMIT 1'
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           #[0.0]          │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           LIMIT           │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │          ORDER_BY         │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │          ORDERS:          │                             
# │           #[0.1]          │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │          Language         │                             
# │         Percentage        │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │      (Name = 'Aruba')     │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           INNER           ├──────────────┐              
# │    (Code = CountryCode)   │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││          SEQ_SCAN         │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │          country          ││      countrylanguage      │
# └───────────────────────────┘└───────────────────────────┘                             

# NOT PARSED
c1 = Node()
c1.text = ['SEQ_SCAN','country']
c1.op = 'SEQ_SCAN'
c1.args=['country']

c2 = Node()
c2.text = ['SEQ_SCAN','countrylanguage']
c2.op = 'SEQ_SCAN'
c2.args=['countrylanguage']

c3 = Node()
c3.text = ['JOIN','city']
c3.op = 'JOIN'
c3.key_left='What is the country code of !!x!!?. Answer briefly.'
c3.key_right = 'What is the country code of the country that speaks !!x!!? Answer briefly.'
c3.filter_key='left'

c4 = Node()
c4.text = ['FILTER',"(Name = 'Aruba')"]
c4.op = 'FILTER'
c4.args=["(Name = 'Aruba')"]

c5 = Node()
c5.text = ['PROJECTION','Language','Percentage']
c5.op = 'PROJECTION'
c5.args=['Language','Percentage']


c6 = Node()
c6.text = ['ORDER_BY','Percentage']
c6.op = 'ORDER_BY'
c6.args=['Percentage']


c7 = Node()
c7.text = ['PROJECTION','Language']
c7.op = 'PROJECTION'
c7.args=['Language']

c3.l=c1
c3.r = c2
c4.l=c3
c5.l=c4
c6.l=c5
c7.l=c6






# =======================================
# 41
# what is the capital of states that have cities named durham
q = 'SELECT t2.capital FROM state AS t2 JOIN city AS t1 ON t2.state_name = t1.state_name WHERE t1.city_name = "durham";'
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │          capital          │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │   (city_name = 'durham')  │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           INNER           ├──────────────┐              
# │ (state_name = state_name) │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││          SEQ_SCAN         │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │           state           ││            city           │
# └───────────────────────────┘└───────────────────────────┘                             

c1 = Node()
c1.text = ['SEQ_SCAN','state']
c1.op = 'SEQ_SCAN'
c1.args=['state']

c2 = Node()
c2.text = ['SEQ_SCAN','city']
c2.op = 'SEQ_SCAN'
c2.args=['city']

c3 = Node()
c3.text = ['JOIN','city']
c3.op = 'JOIN'
c3.key_left='What is the state name of !!x!!? Answer briefly.'
c3.key_right = 'What is the state name of !!x!!? Answer briefly.'
c3.filter_key='right'

c4 = Node()
c4.text = ['FILTER',"(city_name = 'durham')"]
c4.op = 'FILTER'
c4.args=["(city_name = 'durham')"]
c4.filled_question = 'Is !!x!! the same as Durham?'
c5 = Node()
c5.text = ['PROJECTION','capital']
c5.op = 'PROJECTION'
c5.args=['capital']
c5.filled_question = 'What is the capital of state of !!x!!?'


c3.l=c1
c3.r = c2
c4.l=c3
c5.l=c4

join_query_trees[q] = c5



# NOT PARSED
# =======================================
# 42
# What are the regions that use English or Dutch?
q = 'SELECT DISTINCT T1.Region FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "English" OR T2.Language = "Dutch"'
# ┌───────────────────────────┐                             
# │          DISTINCT         │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │       "T1"."Region"       │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           Region          │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │((Language = 'English') OR │                             
# │   (Language = 'Dutch'))   │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           INNER           ├──────────────┐              
# │    (Code = CountryCode)   │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││          SEQ_SCAN         │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │          country          ││      countrylanguage      │
# └───────────────────────────┘└───────────────────────────┘                             

c1 = Node()
c1.text = ['SEQ_SCAN','country']
c1.op = 'SEQ_SCAN'
c1.args=['country']

c2 = Node()
c2.text = ['SEQ_SCAN','countrylanguage']
c2.op = 'SEQ_SCAN'
c2.args=['countrylanguage']

c3 = Node()
c3.text = ['JOIN','city']
c3.op = 'JOIN'
c3.key_left='What is the country code of !!x!!? Answer briefly.'
c3.key_right = 'What is the country code of the country that speaks !!x!!? Answer briefly.'
c3.filter_key='right'

c4 = Node()
c4.text = ['FILTER',"(Language = 'English')","(Language = 'Dutch')"]
c4.op = 'FILTER'
c4.args=["(Language = 'English')","(Language = 'Dutch')"]

c5 = Node()
c5.text = ['PROJECTION','Region']
c5.op = 'PROJECTION'
c5.args=['Region']
c5.filled_question = 'What is the region that speaks !!x!!?'


c3.l=c1
c3.r = c2
c4.l=c3
c5.l=c4

join_query_trees[q] = c5





# NOT PARSED
# =======================================
# 43
# What is the official language spoken in the country whose head of state is Beatrix?
q = 'SELECT T2.Language FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.HeadOfState = "Beatrix" AND T2.IsOfficial = "T"'
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │          Language         │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │ (HeadOfState = 'Beatrix') │                             
# │ (IsOfficial = CAST('T' AS │                             
# │          BOOLEAN))        │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           INNER           ├──────────────┐              
# │    (Code = CountryCode)   │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││          SEQ_SCAN         │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │          country          ││      countrylanguage      │
# └───────────────────────────┘└───────────────────────────┘                             

c1 = Node()
c1.text = ['SEQ_SCAN','country']
c1.op = 'SEQ_SCAN'
c1.args=['country']

c2 = Node()
c2.text = ['SEQ_SCAN','countrylanguage']
c2.op = 'SEQ_SCAN'
c2.args=['countrylanguage']

c3 = Node()
c3.text = ['JOIN','city']
c3.op = 'JOIN'
c3.key_left='What is the country code of !!x!!? Answer briefly.'
c3.key_right = 'What is the country code of the country that speaks !!x!!? Answer briefly.'
c3.filter_key='left'


c4 = Node()
c4.text = ['FILTER',"(HeadOfState = 'Beatrix')","(IsOfficial = CAST('T' AS BOOLEAN))"]
c4.op = 'FILTER'
c4.args=["(HeadOfState = 'Beatrix')","(IsOfficial = CAST('T' AS BOOLEAN))"]

c5 = Node()
c5.text = ['PROJECTION','Language']
c5.op = 'PROJECTION'
c5.args=['Language']
c5.filled_question = 'What is the language of !!x!!?'



c3.l=c1
c3.r = c2
c4.l=c3
c5.l=c4

join_query_trees[q]=c5





# NOT PARSED
# =======================================
# 44
# what state has no rivers
# SELECT state_name FROM state WHERE state_name NOT IN ( SELECT traverse FROM river );
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │         state_name        │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │       (NOT SUBQUERY)      │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │            MARK           ├──────────────┐              
# │   (state_name = #[8.0])   │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││         PROJECTION        │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │           state           ││          traverse         │
# └───────────────────────────┘└─────────────┬─────────────┘                             
#                              ┌─────────────┴─────────────┐
#                              │          SEQ_SCAN         │
#                              │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
#                              │           river           │
#                              └───────────────────────────┘                             






# NOT PARSED
# =======================================
# 45
# how many rivers do not traverse the state with the capital albany
# SELECT COUNT ( river_name ) FROM river WHERE traverse NOT IN ( SELECT state_name FROM state WHERE capital = "albany" );
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │     count(river_name)     │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │         AGGREGATE         │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │     count(river_name)     │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │       (NOT SUBQUERY)      │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │            MARK           ├──────────────┐              
# │    (traverse = #[8.0])    │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││         PROJECTION        │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │           river           ││         state_name        │
# └───────────────────────────┘└─────────────┬─────────────┘                             
#                              ┌─────────────┴─────────────┐
#                              │           FILTER          │
#                              │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
#                              │    (capital = 'albany')   │
#                              └─────────────┬─────────────┘                             
#                              ┌─────────────┴─────────────┐
#                              │          SEQ_SCAN         │
#                              │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
#                              │           state           │
#                              └───────────────────────────┘                             








# NOT PARSED
# =======================================
# 46
# What is the number of distinct continents where Chinese is spoken?
q ='SELECT COUNT( DISTINCT Continent) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T2.Language = "Chinese"'
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │ count(DISTINCT Continent) │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │         AGGREGATE         │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │ count(DISTINCT Continent) │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │   (Language = 'Chinese')  │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           INNER           ├──────────────┐              
# │    (Code = CountryCode)   │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││          SEQ_SCAN         │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │          country          ││      countrylanguage      │
# └───────────────────────────┘└───────────────────────────┘                             

c1 = Node()
c1.text = ['SEQ_SCAN','country']
c1.op = 'SEQ_SCAN'
c1.args=['country']

c2 = Node()
c2.text = ['SEQ_SCAN','countrylanguage']
c2.op = 'SEQ_SCAN'
c2.args=['countrylanguage']

c3 = Node()
c3.text = ['JOIN','city']
c3.op = 'JOIN'
c3.key_left='What is the country code of !!x!!? Answer briefly.'
c3.key_right = 'What is the country code of the country that speaks !!x!!? Answer briefly.'
c3.filter_key='right'

c4 = Node()
c4.text = ['FILTER',"(Language = 'Chinese')"]
c4.op = 'FILTER'
c4.args=["(Language = 'Chinese')"]
c4.filled_question = 'Is !!x!! the same as Chinese?'

c5 = Node()
c5.text = ['AGGREGATE','count(Continent)']
c5.op = 'AGGREGATE'
c5.args=['count(Continent)']
c5.filled_question = 'What is the continent that speaks !!x!!?'

# c6 = Node()
# c6.text = ['PROJECTION','count(Continent)']
# c6.op = 'PROJECTION'
# c6.args=['count(Continent)']

c3.l=c1
c3.r = c2
c4.l=c3
c5.l=c4
#c6.l=c5

join_query_trees[q] = c5






# NOT PARSED
# =======================================
# 47
# Which region is the city Kabul located in?
q = 'SELECT Region FROM country AS T1 JOIN city AS T2 ON T1.Code = T2.CountryCode WHERE T2.Name = "Kabul"'
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           Region          │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │      (Name = 'Kabul')     │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           INNER           ├──────────────┐              
# │    (Code = CountryCode)   │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││          SEQ_SCAN         │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │          country          ││            city           │
# └───────────────────────────┘└───────────────────────────┘                             

c1 = Node()
c1.text = ['SEQ_SCAN','country']
c1.op = 'SEQ_SCAN'
c1.args=['country']

c2 = Node()
c2.text = ['SEQ_SCAN','city']
c2.op = 'SEQ_SCAN'
c2.args=['city']

c3 = Node()
c3.text = ['JOIN','city']
c3.op = 'JOIN'
c3.key_left='What is the country code of !!x!!? Answer briefly.'
c3.key_right = 'What is the country code of the country of !!x!!? Answer briefly.'
c3.filter_key='right'

c4 = Node()
c4.text = ['FILTER',"(Name = 'Kabul')"]
c4.op = 'FILTER'
c4.args=["(Name = 'Kabul')"]
c4.filled_question = 'Is !!x!! the same as Kabul?'
c5 = Node()
c5.text = ['PROJECTION','Region']
c5.op = 'PROJECTION'
c5.args=['Region']
c5.filled_question = 'What is the region of !!x!!?'

c3.l=c1
c3.r = c2
c4.l=c3
c5.l=c4


join_query_trees[q] = c5








# NOT PARSED
# =======================================
# 48
# How many official languages does Afghanistan have?
q = 'SELECT COUNT(*) FROM country AS T1 JOIN countrylanguage AS T2 ON T1.Code = T2.CountryCode WHERE T1.Name = "Afghanistan" AND IsOfficial = "T"'
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │        count_star()       │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │         AGGREGATE         │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │        count_star()       │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │   (Name = 'Afghanistan')  │                             
# │ (IsOfficial = CAST('T' AS │                             
# │          BOOLEAN))        │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           INNER           ├──────────────┐              
# │    (Code = CountryCode)   │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││          SEQ_SCAN         │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │          country          ││      countrylanguage      │
# └───────────────────────────┘└───────────────────────────┘                             

c1 = Node()
c1.text = ['SEQ_SCAN','country']
c1.op = 'SEQ_SCAN'
c1.args=['country']

c2 = Node()
c2.text = ['SEQ_SCAN','countrylanguage']
c2.op = 'SEQ_SCAN'
c2.args=['countrylanguage']

c3 = Node()
c3.text = ['JOIN','city']
c3.op = 'JOIN'
c3.key_left='What is the country code of !!x!!? Answer briefly.'
c3.key_right = 'What is the country code of the country that speaks !!x!!? Answer briefly.'
c3.filter_key='left'

c4 = Node()
c4.text = ['FILTER',"(Name = 'Afghanistan')","(IsOfficial = CAST('T' AS BOOLEAN))"]
c4.op = 'FILTER'
c4.args=["(Name = 'Afghanistan')","(IsOfficial = CAST('T' AS BOOLEAN))"]

c5 = Node()
c5.text = ['AGGREGATE','count_star()']
c5.op = 'AGGREGATE'
c5.args=['count_star()']

# c6 = Node()
# c6.text = ['PROJECTION','count_star()']
# c6.op = 'PROJECTION'
# c6.args=['count_star()']

c3.l=c1
c3.r = c2
c4.l=c3
c5.l=c4
#c6.l=c5

join_query_trees[q] = c5




# NOT PARSED
# =======================================
# 49
# which capitals are not major cities
q = 'SELECT t2.capital FROM state AS t2 JOIN city AS t1 ON t2.capital = t1.city_name WHERE t1.population <= 150000;'
# ┌───────────────────────────┐                             
# │         PROJECTION        │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │          capital          │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │           FILTER          │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │ (population <= CAST(150000│                             
# │         AS BIGINT))       │                             
# └─────────────┬─────────────┘                                                          
# ┌─────────────┴─────────────┐                             
# │      COMPARISON_JOIN      │                             
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │                             
# │           INNER           ├──────────────┐              
# │   (capital = city_name)   │              │              
# └─────────────┬─────────────┘              │                                           
# ┌─────────────┴─────────────┐┌─────────────┴─────────────┐
# │          SEQ_SCAN         ││          SEQ_SCAN         │
# │   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   ││   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
# │           state           ││            city           │
# └───────────────────────────┘└───────────────────────────┘                             

c1 = Node()
c1.text = ['SEQ_SCAN','state']
c1.op = 'SEQ_SCAN'
c1.args=['state']

c2 = Node()
c2.text = ['SEQ_SCAN','city']
c2.op = 'SEQ_SCAN'
c2.args=['city']

c3 = Node()
c3.text = ['JOIN','city']
c3.op = 'JOIN'
c3.key_left='What is the capital of !!x!!? Answer briefly.'
c3.key_right = 'What is the name of !!x!!? Answer briefly.'
c3.filter_key='right'

c4 = Node()
c4.text = ['FILTER',"(population <= CAST(150000 AS BIGINT))"]
c4.op = 'FILTER'
c4.args=["(population <= CAST(150000 AS BIGINT))"]

c5 = Node()
c5.text = ['PROJECTION','capital']
c5.op = 'PROJECTION'
c5.args=['capital']
c5.filled_question  = 'What is the capital of the state of !!x!!?'


c3.l=c1
c3.r = c2
c4.l=c3
c5.l=c4

join_query_trees[q] = c5


# NOT PARSED
# =======================================

In [None]:
# here we translate the operations given by duckdb to question with placeholders for values.
# we then augment some questions to improve performance.
question_maps = json.load(open('data/question_maps.json','r'))
augmented_question_maps = augment_questions(question_maps)

## ChatGPT

In [None]:
# instruction to use
inst_chatgpt = "You are a highly intelligent question answering bot. If I ask you a question that is rooted in truth, you will give you the answer. If I ask you a question that is nonsense, trickery, or has no clear answer, you will respond with 'Unknown'. You will answer concisely."

In [None]:
# fewshot examples
fewshot_chatgpt = [['What is human life expectancy in the United States?', '78.'],
 ['Who was president of the United States in 1955?', 'Dwight D. Eisenhower.'],
 ['Which party was founded by Gramsci?', 'Comunista.'],
 ['What is the capital of France?', 'Paris.'],
 ['What is a continent starting with letter O?', 'Oceania.'],
 ['Where were the 1992 Olympics held?', 'Barcelona.'],
 ['How many squigs are in a bonk?', 'Unknown']]

### Galois

In [None]:
from src.chatgpt_galois import GPT_SPWJ_seq

In [None]:
# few shot chat gpt with instruction
GPT_SPWJ_seq(model_arch='gpt-3.5-turbo',
            df=final_sample_df,
            instr=inst_chatgpt,
            few_shots=fewshot_chatgpt,
            inst_funct=1,
            label='Chat-GPT3-FS-new',
            augmented_question_maps=augmented_question_maps,
            query_plan_dict=join_query_trees,
            verbose=True)


In [None]:
# zero shot no instruction
# i=1
# GPT_SPW_seq(model_arch='gpt-3.5-turbo',
#             df=final_sample_df.iloc[i:i+1],
#             instr="",
#             few_shots=[],
#             inst_funct=1,
#             label='Chat-GPT3-ZS-new',
#             augmented_question_maps=augmented_question_maps,
#             query_plan_dict=join_query_trees,
#             verbose=True)

### Single Question

In [None]:
from src.chatgpt_galois import run_question

In [None]:
single_question_answers = run_question(final_sample_df,inst_chatgpt,fewshot_chatgpt)
chatgpt_final_result_df =pd.DataFrame(single_question_answers,columns=['Single Question Answer'])
chatgpt_final_result_df['Question']=final_sample_df['Question']
chatgpt_final_result_df['Query Answer'] = final_sample_df['Answer']
chatgpt_final_result_df['Database'] = final_sample_df['Database']
chatgpt_final_result_df['Query'] = final_sample_df['Query']
chatgpt_final_result_df = chatgpt_final_result_df[['Question', 'Query Answer', 'Database','Query','Single Question Answer']]


In [None]:
# print Galois results and add them to CSV file
results = json.load(open('data/results/chat_gpt/Chat-GPT3-FS.json','r'))
for r in results:
    print(r['Gold Question'])
    print(r['LP Answers'][-1])
    print("#######################################")


We end up with an empty file similar to `data/results/chat_gpt/Chat_GPT_label.csv`.

In [None]:
from src.utils import get_cardinality

# print Galois results and check cardinality
get_cardinality(results)

### CoT

In [None]:
# CoT
cot_ex = """Let us think step by step:
First, get a list of all continents: Asia, Africa, North America, South America, Antarctica, Europe, and Australia.
Second, for each continent, get the surface area: Asia: 31,033,131, Africa: 29,648,481	,	North America: 21,330,000		,
	South America: 17,461,112		,	Antarctica: 13,720,000, Europe: 22,134,710, 	Australia: 8,486,460
Third, Get the maximum value: 31,033,131.
So the answer is Asia."""

from src.chatgpt_galois import run_CoT

CoT_answers = run_CoT(final_sample_df,inst_chatgpt,cot_ex)
chatgpt_cot_result_df =pd.DataFrame(CoT_answers,columns=['CoT Answer'])
chatgpt_cot_result_df['Question']=final_sample_df['Question']
chatgpt_cot_result_df['Query Answer'] = final_sample_df['Answer']
chatgpt_cot_result_df['Database'] = final_sample_df['Database']
chatgpt_cot_result_df['Query'] = final_sample_df['Query']
chatgpt_final_result_df = chatgpt_cot_result_df[['Question', 'Query Answer', 'Database','Query','CoT Answer']]
#chatgpt_final_result_df.to_csv('data.results/chat_gpt/Chat_GPT_CoT_Label.csv',index=False)

## InstructGPT


In [None]:
inst_gpt3 = 'I am a highly intelligent question answering bot. If you ask me a question that is rooted in truth, I will give you the answer. If you ask me a question that is nonsense, trickery, or has no clear answer, I will respond with "Unknown".\n\n'

gpt3_few_shot = """Q: What is human life expectancy in the United States?
A: 78.
 
Q: Who was president of the United States in 1955?
A: Dwight D. Eisenhower.
 
Q: Which party was founded by Gramsci?
A: Comunista.
 
Q: What is the capital of France?
A: Paris.

Q: What is a continent starting with letter O?
A: Oceania.
 
Q: Where were the 1992 Olympics held?
A: Barcelona.
 
Q: How many squigs are in a bonk?
A: Unknown

"""

In [None]:
def gpt3_form(q:str) -> str: 
    """Format question and answer"""
    return 'Q: '+q+'\nA:'

### Galois

In [None]:
from src.instructgpt_galois import GPT_SPWJ_seq

In [None]:
# Instruct-gpt with instruct-gpt and fewshot
GPT_SPWJ_seq(model_arch='text-davinci-003',
            df=final_sample_df,
            instr=inst_gpt3,
            few_shots=gpt3_few_shot,
            inst_funct=gpt3_form,
            label='Inst-GPT3-FS',
            augmented_question_maps=augmented_question_maps,
            query_plan_dict=join_query_trees,
            verbose=True)


### Single Question

In [None]:
from src.chatgpt_galois import run_question

In [None]:
single_question_answers = run_question(final_sample_df,inst_gpt3,gpt3_few_shot,gpt3_form)
instruct_final_result_df =pd.DataFrame(single_question_answers,columns=['Single Question Answer'])
instruct_final_result_df['Question']=final_sample_df['Question']
instruct_final_result_df['Query Answer'] = final_sample_df['Answer']
instruct_final_result_df['Database'] = final_sample_df['Database']
instruct_final_result_df['Query'] = final_sample_df['Query']
instruct_final_result_df = instruct_final_result_df[['Question', 'Query Answer', 'Database','Query','Single Question Answer']]


In [None]:
# print Galois results and add them to CSV file
results = json.load(open('data/results/instruct_gpt/Inst-GPT3-FS.json','r'))
for r in results:
    print(r['Gold Question'])
    print(r['LP Answers'][-1])
    print("#######################################")


We end up with an empty file similar to `data/results/instruct_gpt/Inst_GPT_label.csv`.

In [None]:
from src.utils import get_cardinality

# print Galois results and check cardinality
get_cardinality(results)
