In [26]:
from sqlglot import parse_one, exp , diff
from sqlglot.errors import ErrorLevel
query = "SELECT DISTINCT T1.name , T1.release_year FROM singer AS T1 JOIN singer_in_concert AS T2 ON T1.name = T2.Singer_ID WHERE "
q_parsed = parse_one(query , error_level=ErrorLevel.IGNORE)
print(type(q_parsed))
tree_walk_dfs = q_parsed.dfs()
tree_walk_bfs = q_parsed.bfs()
for i in tree_walk_dfs:
    print(i)
print('===========================')
for i in tree_walk_bfs:
    print(i)

<class 'sqlglot.expressions.Select'>
SELECT DISTINCT T1.name, T1.release_year FROM singer AS T1 JOIN singer_in_concert AS T2 ON T1.name = T2.Singer_ID WHERE
DISTINCT
T1.name
name
T1
T1.release_year
release_year
T1
FROM singer AS T1
singer AS T1
singer
T1
T1
JOIN singer_in_concert AS T2 ON T1.name = T2.Singer_ID
singer_in_concert AS T2
singer_in_concert
T2
T2
T1.name = T2.Singer_ID
T1.name
name
T1
T2.Singer_ID
Singer_ID
T2
WHERE
SELECT DISTINCT T1.name, T1.release_year FROM singer AS T1 JOIN singer_in_concert AS T2 ON T1.name = T2.Singer_ID WHERE
DISTINCT
T1.name
T1.release_year
FROM singer AS T1
JOIN singer_in_concert AS T2 ON T1.name = T2.Singer_ID
WHERE
name
T1
release_year
T1
singer AS T1
singer_in_concert AS T2
T1.name = T2.Singer_ID
singer
T1
singer_in_concert
T2
T1.name
T2.Singer_ID
T1
T2
name
T1
Singer_ID
T2


In [4]:
#Code for sequence-level-voting ensemble using Query Skeleton Similarity and/or Schema-Linking Similarity
from sqlglot import parse_one, exp , diff
from sqlglot.diff import Keep
import random
import json

def jaccard_similarity(set_1 , set_2):
    
    intersection = len(set_1.intersection(set_2))
    union = len(set_1.union(set_2))

    if union==0:
        return 1
    
    return intersection / union

def generator_to_set(myGenerator):
    #It extracts the attribute .this for each element in the generator gotten from parse_one(query1).find_all(exp.Table/Column)
    mySet = set()
    myList = list(myGenerator)
    for i in myList:
        mySet.add(i.this.this)
    return mySet

def combine_similarities(column_similarity, table_similarity , tree_similarity):
    return (column_similarity+table_similarity+tree_similarity)/3

def get_most_similar_query(query_list):
    valid_query_list = []
    valid_q_parsed_list = []
    for i , query in enumerate(query_list):
        query = query.replace('`' , '"')
        try:
            q_parsed = parse_one(query)
            valid_q_parsed_list.append(q_parsed)
            valid_query_list.append(query)
        except:
            print('One of the queries are invalid: index = ' , i )
    if len(valid_query_list)==0:
        index = random.randint(0, len(query_list)-1)
        return query_list[index] , index
        
    elif len(valid_query_list)==1:
        return valid_query_list[0] , 0
        
    else:
        sim_sum_list = []
        for i in range(len(valid_q_parsed_list)):
            comparing_q_list = valid_q_parsed_list.copy()
            comparing_query = comparing_q_list.pop(i)
            sim_sum = 0
            for q in comparing_q_list:
                sim_sum += query_similarity( comparing_query , q )
            sim_sum_list.append(sim_sum)
        max_sim = max(sim_sum_list)
        max_sim_index = sim_sum_list.index(max_sim)
        return valid_query_list[max_sim_index] , max_sim_index
    
    
def query_similarity(q1_parsed , q2_parsed):
    #calculates the similarity between two queries.
    
    query1_columns = q1_parsed.find_all(exp.Column)
    query1_tables = q1_parsed.find_all(exp.Table)
    
    query2_columns = q2_parsed.find_all(exp.Column)
    query2_tables = q2_parsed.find_all(exp.Table)
    
    #putting the extracted columns in two sets
    q1_column_set = generator_to_set(query1_columns)
    q2_column_set = generator_to_set(query2_columns)
    column_similarity = jaccard_similarity( q1_column_set , q2_column_set )
        
    q1_table_set = generator_to_set(query1_tables)
    q2_table_set = generator_to_set(query2_tables)
    table_similarity = jaccard_similarity( q1_table_set , q2_table_set )

    diff_list = diff(q1_parsed , q2_parsed)
    number_of_keep = 0
    for i in diff_list:
        number_of_keep += int(isinstance( i , Keep ))
    tree_similarity = number_of_keep/len(diff_list)
    
    return combine_similarities( column_similarity , table_similarity , tree_similarity )

def seq_level_ensemble_sql_similarity(pred_file_list , output_file_path , reference_file_index=0 ):
    
    #Getting the data from all pred_files
    generated_prompt_response_list = []
    for file in pred_file_list:
        with open(file , 'r') as f:
            generated_prompt_response_file_byte = f.read()
            generated_prompt_response = json.loads(generated_prompt_response_file_byte)
            generated_prompt_response_list.append(generated_prompt_response)
            
    #Getting the template output file
    with open(pred_file_list[ reference_file_index ] , 'r') as f:
        refrence_gen_prompt_response_file_byte = f.read()
        refrence_gen_prompt_response = json.loads(refrence_gen_prompt_response_file_byte)
        
    #Iterating over each question in a dataset
    for i in range ( len(generated_prompt_response_list[0]['questions']) ):
        print('processing question: ' , i)
        q_list = []
        #Iterating over the number of ensemble components
        for j in range( len( generated_prompt_response_list ) ):
            q_list.append(generated_prompt_response_list[j]['questions'][i]['response'])
        ensembled_query , query_index = get_most_similar_query(q_list)
        refrence_gen_prompt_response['questions'][i]['response'] = ensembled_query
        refrence_gen_prompt_response['questions'][i]['prompt'] = generated_prompt_response_list[ query_index ]['questions'][i]['prompt']
    with open(output_file_path , 'w' )as f:
        json.dump(refrence_gen_prompt_response , f)


In [5]:

# pred_file_1 = './llama_pred/SPIDER-TEST_SQL_0-SHOT_CTX-200_ANS-2048_Llama_7b.json'
# pred_file_2 = './codeS_pred/codes-1b_BIRD_table_num_5_column_num_6_1-shot_max_tokens_8192_max_new_tokens_256.json'
# pred_file_3 = './codeS_pred/codes-1b_BIRD_table_num_5_column_num_6_3-shot_max_tokens_8192_max_new_tokens_256.json'
# pred_file_4 = './codeS_pred/codes-1b_BIRD_table_num_5_column_num_6_5-shot_max_tokens_8192_max_new_tokens_256.json'

# pred_file_1 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_0-3_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'
# pred_file_2 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_3-6_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'
# pred_file_3 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_6-9_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'
# pred_file_4 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_9-12_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'
# pred_file_5 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_12-15_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'

pred_file_1 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_0_5_max_tokens_8192_max_new_tokens_256.json'
pred_file_2 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_5_10_max_tokens_8192_max_new_tokens_256.json'
pred_file_3 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_10_15_max_tokens_8192_max_new_tokens_256.json'
pred_file_4 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_15_20_max_tokens_8192_max_new_tokens_256.json'
pred_file_5 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_20_25_max_tokens_8192_max_new_tokens_256.json'

# pred_file_1 = './llama_pred/BIRD-TEST_SQL_0-SHOT_CTX-200_ANS-2048_evidence_Llama_7b.json'
# pred_file_2 = './llama_pred/BIRD-TEST_SQL_1-SHOT_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'
# pred_file_3 = './llama_pred/BIRD-TEST_SQL_3-SHOT_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'
# pred_file_4 = './llama_pred/BIRD-TEST_SQL_5-SHOT_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'

pred_files = [ pred_file_1 , pred_file_2 , pred_file_3 , pred_file_4, pred_file_5]
# pred_files = [ pred_file_2 , pred_file_3 , pred_file_4]
output_file_path = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_MBRdbScore_table_num_5_column_num_6_5-shot-5Components_max_tokens_8192_max_new_tokens_256.json'

seq_level_ensemble_sql_similarity(pred_files , output_file_path , reference_file_index=0 )


processing question:  0
processing question:  1
processing question:  2
processing question:  3
processing question:  4
processing question:  5
processing question:  6
processing question:  7
processing question:  8
processing question:  9
processing question:  10
processing question:  11
processing question:  12
processing question:  13
processing question:  14
processing question:  15
processing question:  16
processing question:  17
processing question:  18
processing question:  19
processing question:  20
processing question:  21
processing question:  22
processing question:  23
processing question:  24
processing question:  25
processing question:  26
processing question:  27
processing question:  28
One of the queries are invalid: index =  2
processing question:  29
processing question:  30
processing question:  31
processing question:  32
One of the queries are invalid: index =  3
processing question:  33
processing question:  34
processing question:  35
processing question:  36

In [6]:
#Code for sequence-level-voting ensemble using blue metrics.
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
import json


# pred_file_1 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_0-3_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'
# pred_file_2 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_3-6_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'
# pred_file_3 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_6-9_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'
# pred_file_4 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_9-12_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'
# pred_file_5 = './llama_pred/SPIDER_beam_4_lenpen0-TEST_SQL_3-SHOT_12-15_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_llama_7b.json'

pred_file_1 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_0_5_max_tokens_8192_max_new_tokens_256.json'
pred_file_2 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_5_10_max_tokens_8192_max_new_tokens_256.json'
pred_file_3 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_10_15_max_tokens_8192_max_new_tokens_256.json'
pred_file_4 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_15_20_max_tokens_8192_max_new_tokens_256.json'
pred_file_5 = './codeS_pred/codes-1b_beam4_lenpen05_BIRD_table_num_5_column_num_6_5-shot_20_25_max_tokens_8192_max_new_tokens_256.json'

# pred_file_1 = './llama_pred/SPIDER-TEST_SQL_0-SHOT_CTX-200_ANS-2048_Llama_7b.json'
# pred_file_2 = './llama_pred/SPIDER-TEST_SQL_1-SHOT_EUCDISQUESTIONMASK_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'
# pred_file_3 = './llama_pred/SPIDER-TEST_SQL_3-SHOT_EUCDISQUESTIONMASK_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b_try2.json'
# pred_file_4 = './llama_pred/SPIDER-TEST_SQL_5-SHOT_EUCDISQUESTIONMASK_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'

# pred_file_1 = './llama_pred/BIRD-TEST_SQL_0-SHOT_CTX-200_ANS-2048_evidence_Llama_7b.json'
# pred_file_2 = './llama_pred/BIRD-TEST_SQL_1-SHOT_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'
# pred_file_3 = './llama_pred/BIRD-TEST_SQL_3-SHOT_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'
# pred_file_4 = './llama_pred/BIRD-TEST_SQL_5-SHOT_EUCDISMASKPRESKLSIMTHR_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'

with open(pred_file_1 , 'r') as f:
    refrence_gen_prompt_response_file_byte = f.read()
    refrence_gen_prompt_response = json.loads(refrence_gen_prompt_response_file_byte)

pred_files = [ pred_file_1 , pred_file_2 , pred_file_3 , pred_file_4, pred_file_5]
generated_prompt_response_list = []

for file in pred_files:
    with open(file , 'r') as f:
        generated_prompt_response_file_byte = f.read()
        generated_prompt_response = json.loads(generated_prompt_response_file_byte)
        generated_prompt_response_list.append(generated_prompt_response)

for i in range ( len(generated_prompt_response_list[0]['questions']) ): #number of questions we have in a dataset
    tokenized_responses = []
    for j in range( len( generated_prompt_response_list ) ): #number of ensemble files we have for each question
        tokenized_responses.append( word_tokenize( generated_prompt_response_list[j]['questions'][i]['response'] ) )
    blue_scores = []
    for j in range( len( generated_prompt_response_list ) ):
        temp_tokenized_responses = tokenized_responses.copy()
        tokenized_response = temp_tokenized_responses.pop(j)
        blue_score = sentence_bleu( temp_tokenized_responses , tokenized_response )
        blue_scores.append(blue_score)
    max_bleu_score_value = max(blue_scores)
    max_index_bleu_score = blue_scores.index(max_bleu_score_value)
    refrence_gen_prompt_response['questions'][i]['response'] = generated_prompt_response_list[max_index_bleu_score]['questions'][i]['response']
    refrence_gen_prompt_response['questions'][i]['prompt'] = generated_prompt_response_list[max_index_bleu_score]['questions'][i]['prompt']
with open('./codeS_pred/codes-1b_beam4_lenpen05_BIRD_MBRbleu_table_num_5_column_num_6_5-shot-5Components_max_tokens_8192_max_new_tokens_256.json' , 'w' )as f:
    json.dump(refrence_gen_prompt_response , f)

    

In [72]:
from utils.post_process import get_exec_output
db_dir = './DAIL-SQL/dataset/spider/database'
# db_dir = './DAIL-SQL/dataset/bird/database'
# import asyncio
def query_to_db(query , db_id):
    db_path = f"{db_dir}/{db_id}/{db_id}"
    flag, denotation = get_exec_output(
            db_path,
            query)
    return flag, denotation
    

generated_prompts_file = './DAIL-SQL/dataset/process/SPIDER-TEST_SQL_0-SHOT_CTX-200_ANS-2048/questions.json'

with open(generated_prompts_file , 'r') as f:
    generated_prompts_file_byte = f.read()
    generated_prompts = json.loads(generated_prompts_file_byte)

generated_response_file = './llama_pred/ENSEMBLE_seqLevelVote_SQLscore_SPIDER-TEST_SQL_0_1_3_5-SHOT_EUCDISQUESTIONMASK_QA-EXAMPLE_CTX-200_ANS-2048_Llama_7b.json'
with open(generated_response_file , 'r') as f:
    generated_response_file_byte = f.read()
    generated_response = json.loads(generated_response_file_byte)

from utils.post_process import result_eq
execution_accuracy = 0
counter = 0
for q_unit_gen , q_unit_truth in zip(generated_response['questions'] , generated_prompts['questions']):
    pred_response = q_unit_gen['response']
    start_of_answer = q_unit_truth['prompt'].splitlines()[-1]
    ground_truth = start_of_answer + ' ' + q_unit_truth['response']
    db_id = q_unit_truth['db_id']
    flag1, denotation1 = query_to_db(pred_response , db_id) #flag has ('result' , [data in columns])
    flag2, denotation2 = query_to_db(ground_truth , db_id)
    if flag1[0] != 'result' or flag2[0]!='result':
        is_equal = False
        # print(  counter , '-' , flag1[0] , ' --> ' , 'pred_response: ' , pred_response )
    elif 'ORDER BY' in ground_truth or 'order by' in ground_truth:
        is_equal = result_eq(flag1[1] , flag2[1] , order_matters=True)
    else:
        is_equal = result_eq(flag1[1] , flag2[1] , order_matters=False)
    execution_accuracy += is_equal
    counter += 1
print( execution_accuracy/len(generated_response['questions']) )


0.31237911025145065


In [84]:
import json

mystr = ['./sic_ckpts/sic_bird\nalskdjflk lskdjf\nksdijf' , 'sldkjf\n;lkjsdf;lkjafd']
mystrdict = dict()
mystrdict[0] = './sic_ckpts/sic_bird\nalskdjflk lskdjf\nksdijf'
mystrdict[1] = 'sldkjf\n;lkjsdf;lkjafd'

with open('./test.json' , 'w') as f:
    f.write(json.dumps(mystr , indent = 2, ensure_ascii = False))

with open('./test1.json' , 'w') as f:
    f.write(json.dumps(mystrdict , indent = 2, ensure_ascii = False))

In [13]:
#Playground
import torch
a = torch.tensor([0,1,2,3])
print(a.size())
print(a)
b = a.reshape( (2,2) )
print(b.size())
print(b)

torch.Size([4])
tensor([0, 1, 2, 3])
torch.Size([2, 2])
tensor([[0, 1],
        [2, 3]])
