In [148]:
import re
def parse_response(response):
    pattern = r"```sql\s*(.*?)\s*```"
    
    sql_blocks = re.findall(pattern, response, re.DOTALL)

    if sql_blocks:
        # Extract the last SQL query in the response text and remove extra whitespace characters
        last_sql = sql_blocks[-1].strip()
        while last_sql.startswith('--'):
            last_sql = '\n'.join(last_sql.splitlines()[1:])
        return last_sql
    else:
        print("No SQL blocks found.")
        return ""

print(list_of_candidates[197*4+0])

SELECT 
    CAST(COUNT(CASE WHEN A.element = 'o' THEN 1 ELSE NULL END) AS REAL) / 
    COUNT(DISTINCT B.molecule_id) AS avg_oxygen_atoms
FROM 
    atom AS A
INNER JOIN 
    bond AS B
ON 
    A.molecule_id = B.molecule_id
WHERE 
    B.bond_type = '-';


In [2]:
#Here is the library functions and classes that we need to have to run the preprocessing and the ensemble method that is implemented in CodeS
#The ensemble method only execute the last beams one by one in the same order as their probability and the one that excutes will be returned
from transformers.trainer_utils import set_seed
from utils.classifier_model import SchemaItemClassifier
from utils.db_utils import check_sql_executability, detect_special_char
from schema_item_filter import SchemaItemClassifierInference, filter_schema
import pickle as pkl
from utils_SPIDER.post_process import process_duplication

def post_process_get_sql_from_gentext(gen_text):
    # remove \n and extra spaces
    # print(gen_text)
    sql = " ".join(gen_text.replace("\n", " ").split())
    sql = process_duplication(sql)
    # python version should >= 3.8
    if sql.startswith("SELECT"):
        sql = sql
    elif sql.startswith(" "):
        sql = "SELECT" + sql
    else:
        sql = "SELECT " + sql
    return sql


def post_process(sql, schema_items):
    sql = sql.replace("\n", " ")
    for table in schema_items:
        for column_name in table["column_names"]:
            if detect_special_char(column_name) and column_name in sql:
                sql = sql.replace(column_name, "`"+column_name+"`")

    while "``" in sql:
        sql = sql.replace("``", "`")

    sql = sql.split(";")[0].strip() + ";"

    return sql

with open('./eval_set.pkl' , 'rb')as f:
    eval_set = pkl.load(f)
def codeS_ensemble(generated_sqls , eval_data):
    #The function that returns the chosen candidate from the sql queries given
    # generated sqls shoould be sorted based on their probability. The algorithm starts from 0 index so it gets more priority
    generated_sqls = [ post_process_get_sql_from_gentext( generated_sql ) for generated_sql in generated_sqls ]
    generated_sqls = [ post_process( generated_sql, eval_data["schema"]["schema_items"] ) for generated_sql in generated_sqls ]
    final_generated_sql = None
    for generated_sql in generated_sqls:
        execution_error = check_sql_executability(generated_sql, eval_data["db_path"])
        if execution_error is None: # the generated sql has no execution errors, we will return it as the final generated sql
            final_generated_sql = generated_sql
            break

    if final_generated_sql is None:
        if generated_sqls[0].strip() != "":
            final_generated_sql = generated_sqls[0]
        else:
            final_generated_sql = "SQL placeholder"

    return final_generated_sql


def perform_codeS_ensemble_on_list_of_sqlintext( list_of_candidates , candidate_range, candidate_indice ):
    #Performing codeS ensemble on last beams of the beams search
    #The candidates are given in a list of text. list of length n where s/k=n and s is the number of samples and k is the range where a sample candidates can be found
    #Inputs:
        #list_of_candidates: A long list of all possible candidates. The candidates for the same sample should be adjacent
        #candidate_range: The maximum number of possible candidates for each sample in the given list_of_candidates
        #candidate_indice: indices of the candidates of the first sample starting from index zero upwards
    if len(list_of_candidates)%candidate_range:
        print( 'There is a problem and the given candidate_range is not correct. len(list_of_candidates)%candidate_range!=0' )
        return
    final_output_seqs = []
    for index, eval_data in zip(range(0, len(list_of_candidates), candidate_range ), eval_set):
        generated_sqls = []
        for sub_index in candidate_indice:
            generated_sqls.append( list_of_candidates[index+sub_index]  )
        final_generated_sql = codeS_ensemble(generated_sqls , eval_data)
        final_output_seqs.append(final_generated_sql)
        print('processed index ', int(index/candidate_range))
    return final_output_seqs
    
    

  from .autonotebook import tqdm as notebook_tqdm


In [132]:
#Performing codeS ensemble on last beams of the beams search for those approaches that the start of the answer is the same for each candidate
#and the output is given in a list of torch tensro of shape: (input_ids_len, num_beams)
from transformers import AutoTokenizer
import json

model_name = 'seeklhy/codes-1b'
dataset_path = './data/sft_bird_dev_text2sql.json'
sic_path = './sic_ckpts/sic_bird'
table_num = 5
column_num = 6
output_filename = '../../final_output_CodeS_preprocess_ensemble.pkl'

tokenizer = AutoTokenizer.from_pretrained(model_name)

final_output_seqs = []

# eval_set = json.load(open(dataset_path))
# sic = SchemaItemClassifierInference(sic_path)
# eval_set = filter_schema(eval_set, "eval", sic, table_num, column_num)
with open('./eval_set.pkl' , 'rb')as f:
    eval_set = pkl.load(f)
counter = 0
for ids , starting, eval_data in zip(input_ids, starting_batch, eval_set):
    generated_sqls = tokenizer.batch_decode( ids[ starting : , : ].transpose( 0,1 ) ,
                                                          skip_special_tokens=True )
    print(counter)
    counter+=1
    final_generated_sql = codeS_ensemble(generated_sqls, eval_data)
            
    final_output_seqs.append(final_generated_sql)
    
with open(output_filename , 'wb')as f:
    pkl.dump(final_output_seqs , f)

In [151]:
#For when the output sql queries are given in a list of text. This is for clause level
import pickle as pkl
# input_ids_filename = '../../output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_structured/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_structured.pkl' #list of sql candidates in text
input_ids_filename = '../../vulcan_output/output_OmniORG_greedy_bird_wEvidence_0shot_OmniSQL-7b_chatTemplate.pkl' #list of sql candidates in text
# input_ids_filename = '../../output_1shot_diffDivpen_8in0-2shotComps_beam4_seq_mine_scaled_pureSeqlevelTreebased_SFTCodeS-7b_a1b1jointMargin6compProb.pkl' #list of sql candidates in text
# input_ids_filename = '../output_CodeS_beam4_bird_wEvidence_0-5shot_SFTCodeS-7b.pkl'
with open(input_ids_filename , 'rb') as f:
    list_of_candidates = pkl.load(f)
print(len(list_of_candidates))

#for omniSQL:
list_of_candidates = [  parse_response(i) for i in list_of_candidates]

output_filename = '../../final_output_CodeS_preprocess_ensemble.pkl'


# for x in range( 0 , len(list_of_candidates) , 20):
#     for i in range (20):
#         print( list_of_candidates[i+x]==list_of_candidates[x] )
#     print('-----------------------------')
# candidate_range = 40
# candidate_range = 32
# candidate_range = 25
# candidate_range = 20
# candidate_range = 10
# candidate_range = 5
# candidate_range = 4
# candidate_range = 2
candidate_range = 1
# candidate_indice = [0,4,8,12,16,1,5,9,13,17,2,6,10,14,18,3,7,11,15,19]
# candidate_indice = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19]
# candidate_indice = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]
# candidate_indice = [4,5,6,7]
# candidate_indice = [8,9,10,11]
# candidate_indice = [12,13,14,15]
# candidate_indice = [16,17,18,19]
# candidate_indice = [3]
candidate_indice = list(range(candidate_range))
# candidate_indice = [0,1,2,3,4,5,6,7,8,9]
# candidate_indice = [0,1,2,3,4]
# candidate_indice = [0,1,2,3]
# candidate_indice = [0,1]
final_output_seqs = perform_codeS_ensemble_on_list_of_sqlintext( list_of_candidates , candidate_range, candidate_indice )
with open(output_filename , 'wb')as f:
    pkl.dump(final_output_seqs , f)
    

1534
processed index  0
SQL execution runtime error: no such column: schools.Educational Option Type.
processed index  1
processed index  2
processed index  3
processed index  4
processed index  5
processed index  6
processed index  7
processed index  8
SQL execution runtime error: no such column: T1.Charter Funding Type.
processed index  9
processed index  10
processed index  11
processed index  12
processed index  13
processed index  14
processed index  15
processed index  16
processed index  17
processed index  18
processed index  19
processed index  20
processed index  21
processed index  22
processed index  23
processed index  24
processed index  25
SQL execution runtime error: no such column: f.FRPM Count (Ages 15-17).
processed index  26
processed index  27
processed index  28
SQL execution runtime error: no such column: s.School Type.
processed index  29
processed index  30
SQL execution runtime error: near "AS": syntax error.
processed index  31
SQL execution runtime error: no

In [4]:
import pickle as pkl
# file1 = '../../output_sequences_lenpen05_clause_mine_JaccSimNormalizedNoDel_lastForAllTreeBaseNoDel_extCls_0_383.pkl'
# file2 = '../../output_sequences_lenpen05_clause_mine_JaccSimNormalizedNoDel_lastForAllTreeBaseNoDel_extCls_383_767.pkl'
# file3 = '../../output_sequences_lenpen05_clause_mine_JaccSimNormalizedNoDel_lastForAllTreeBaseNoDel_extCls_767_1150.pkl'
# file4 = '../../output_sequences_lenpen05_clause_mine_JaccSimNormalizedNoDel_lastForAllTreeBaseNoDel_extCls_1150_end.pkl'

file1 = '../../output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_structured_0_767.pkl'
file2 = '../../output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_structured_767_end.pkl'

# file1 = '../../output_sequences_spider_lenpen01_seq_mine_treeSim_Qwen_NoDel_0_517.pkl'
# file2 = '../../output_sequences_spider_lenpen01_seq_mine_treeSim_Qwen_NoDel_517_end.pkl'


# file1 = '../../output_sequences_lenpen05_clause_mine_JaccSimNormalized_tokTresh70_lastForAllTreeBaseNoDel_extCls_0_787.pkl'
# file2 = '../../output_sequences_lenpen05_clause_mine_JaccSimNormalized_tokTresh70_lastForAllTreeBaseNoDel_extCls_787_end.pkl'

target_file = '../../output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_structured/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_structured.pkl'
# file_list = [file1, file2, file3 , file4]
file_list = [file1, file2]

def unify_outputs( file_list , target_file ):
    target_output_list = []
    target_batch_text_list = []
    target_inputs_log_prob_list = []
    for file_name in file_list:
        
        # batch_text_file = file_name[0:-4] + '_batch_text.pkl'
        inputs_log_prob_file = file_name[0:-4] + '_inputs_log_prob.pkl'
        
        with open( file_name , 'rb' ) as f:
            target_output_list.extend( pkl.load(f) )
            
        # with open( batch_text_file , 'rb' ) as f:
        #     target_batch_text_list.extend( pkl.load(f) )
            
        with open( inputs_log_prob_file , 'rb' ) as f:
            target_inputs_log_prob_list.extend( pkl.load(f) )
        
        with open(target_file , 'wb')as f:
            pkl.dump( target_output_list , f )
            
        # target_batch_text_file = target_file[0:-4] + '_batch_text.pkl'
        # with open(target_batch_text_file , 'wb')as f:
        #     pkl.dump( target_batch_text_list , f )
            
        target_inputs_log_prob_file = target_file[0:-4] + '_inputs_log_prob.pkl'
        with open(target_inputs_log_prob_file , 'wb')as f:
            pkl.dump( target_inputs_log_prob_list , f )


unify_outputs( file_list , target_file )
# target_list = []
# with open( file1 , 'rb' ) as f:
#     target_list = pkl.load(f)
# with open( file2 , 'rb' ) as f:
#     target_list.extend( pkl.load(f) )
# # with open( file3 , 'rb' ) as f:
# #     target_list.extend( pkl.load(f) )

# with open(target_file , 'wb')as f:
#     pkl.dump( target_list , f )

In [64]:
### Code for comparing the result of two experiments.
def sort_numbers_in_file(filename):
    try:
        # Read numbers from the file
        with open(filename, 'r') as file:
            numbers = [int(line.strip()) for line in file]
        
        # Sort the numbers
        numbers.sort()
        
        # Write the sorted numbers back to the file
        with open(filename, 'w') as file:
            for number in numbers:
                file.write(f"{number}\n")
    except Exception as e:
        print(f"Error: {e}")

def difference_between_files(file1, file2):
    #Prints the numbers that are in the first file and not in the second file
    try:
        # Read numbers from both files
        with open(file1, 'r') as f1, open(file2, 'r') as f2:
            numbers1 = set(int(line.strip()) for line in f1)
            numbers2 = set(int(line.strip()) for line in f2)
        
        # Compute the difference
        difference = sorted(numbers1 - numbers2)
        print(f'{file1} set length: {len(numbers1)}')
        print(f'{file2} set length: {len(numbers2)}')
        # Write the result to the output file
        print(difference)
        print('length of the list: ' , len(difference))
    except Exception as e:
        print(f"Error: {e}")

# Example usage
# sort_numbers_in_file("exec_result.txt")
# difference_between_files("file1.txt", "file2.txt", "output.txt")
first = 'exec_result_divpen4_2.txt'
second = 'exec_result_divpen2_2.txt'
# Example usage
difference_between_files( first ,  second  ) #What does first has more than second
print('__________________________')
difference_between_files( second ,  first  ) #What does second has more than first

exec_result_divpen4_2.txt set length: 953
exec_result_divpen2_2.txt set length: 1004
[8, 59, 251, 326, 592, 708, 724, 879, 883, 1107, 1241, 1366, 1437, 1499, 1528]
length of the list:  15
__________________________
exec_result_divpen2_2.txt set length: 1004
exec_result_divpen4_2.txt set length: 953
[41, 54, 85, 90, 101, 122, 131, 146, 171, 188, 223, 232, 250, 286, 288, 319, 370, 384, 415, 429, 430, 445, 449, 462, 476, 478, 495, 506, 507, 522, 525, 556, 584, 590, 596, 617, 677, 778, 782, 851, 899, 900, 930, 936, 954, 995, 1001, 1037, 1093, 1125, 1150, 1165, 1194, 1201, 1202, 1243, 1259, 1260, 1282, 1285, 1289, 1301, 1387, 1390, 1490, 1530]
length of the list:  66


In [15]:
import json
with open('./data/sft_data_collections/bird/dev/dev.json', 'r') as j:
    contents = json.loads(j.read())
my_list = []
for index, content in enumerate(contents):
    # if content['difficulty'] == 'simple':

    if content['difficulty'] == 'moderate':

    # if content['difficulty'] == 'challenging':
        my_list.append(index)
print(my_list)

[1, 4, 12, 23, 24, 25, 26, 27, 31, 32, 33, 34, 35, 37, 40, 43, 45, 47, 48, 49, 55, 65, 68, 72, 74, 76, 77, 79, 81, 85, 89, 93, 95, 98, 99, 100, 117, 118, 119, 120, 128, 129, 130, 131, 135, 136, 137, 138, 144, 145, 148, 150, 151, 152, 163, 168, 171, 175, 180, 182, 185, 186, 188, 189, 192, 193, 194, 197, 201, 208, 213, 222, 226, 228, 232, 234, 236, 237, 242, 243, 244, 245, 246, 250, 251, 254, 255, 258, 260, 267, 270, 272, 273, 280, 283, 284, 287, 298, 303, 310, 317, 320, 327, 329, 338, 344, 345, 346, 347, 349, 352, 360, 391, 397, 401, 402, 405, 407, 408, 409, 412, 417, 427, 432, 433, 434, 446, 450, 459, 462, 465, 466, 468, 469, 472, 473, 474, 479, 480, 483, 484, 486, 498, 499, 500, 501, 505, 508, 511, 515, 516, 517, 518, 520, 522, 527, 529, 530, 544, 557, 563, 565, 571, 572, 578, 581, 584, 587, 595, 604, 615, 633, 635, 637, 640, 652, 654, 657, 665, 672, 682, 683, 685, 692, 694, 707, 708, 716, 719, 723, 726, 728, 732, 733, 736, 739, 740, 751, 753, 758, 761, 766, 782, 786, 790, 794, 796, 7

In [11]:
#The code to provide the oracle prediction on Component probability:
prediction_file = 'exec_result_divpen2_2.txt'
NUMBER_of_beams = int(4) #Don't Forget to set this peoperly!!!!

sort_numbers_in_file( prediction_file )

question_prob = []
with open( prediction_file , 'r' ) as f:
    correct_q_list = [ int(line.strip()) for line in f ]
last_q_idx = 0
for i in range(1534):
    count = correct_q_list.count(i)
    prob = (count + 0.1) / NUMBER_of_beams
    question_prob.append(prob)

with open('../../oracle_comp_prob_per_sample_divpen2_2.pkl' , 'wb') as f:
    pkl.dump(question_prob, f)
    
    

In [48]:
import math
from scipy.stats import kendalltau, pearsonr
question_prob = []
# with open('../../oracle_comp_prob_per_sample_1.pkl' , 'rb') as f:
#     question_prob.append( pkl.load( f) )
with open('../../oracle_comp_prob_per_sample_2.pkl' , 'rb') as f:
    question_prob.append( pkl.load( f) )
with open('../../oracle_comp_prob_per_sample_3.pkl' , 'rb') as f:
    question_prob.append( pkl.load( f) )
# with open('../../oracle_comp_prob_per_sample_4.pkl' , 'rb') as f:
#     question_prob.append( pkl.load( f) )
# with open('../../oracle_comp_prob_per_sample_5.pkl' , 'rb') as f:
#     question_prob.append( pkl.load( f) )
with open('../../oracle_comp_prob_per_sample_divpen05_1.pkl' , 'rb') as f:
    question_prob.append( pkl.load( f) )
with open('../../oracle_comp_prob_per_sample_divpen05_2.pkl' , 'rb') as f:
    question_prob.append( pkl.load( f) )
with open('../../oracle_comp_prob_per_sample_divpen1_1.pkl' , 'rb') as f:
    question_prob.append( pkl.load( f) )
with open('../../oracle_comp_prob_per_sample_divpen1_2.pkl' , 'rb') as f:
    question_prob.append( pkl.load( f) )
with open('../../oracle_comp_prob_per_sample_divpen2_1.pkl' , 'rb') as f:
    question_prob.append( pkl.load( f) )
with open('../../oracle_comp_prob_per_sample_divpen2_2.pkl' , 'rb') as f:
    question_prob.append( pkl.load( f) )

Qquestion_prob = []
# with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_lastOf0-5shot_question_log_prob.pkl' , 'rb') as f:
#     Qquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_2ndOf0-5shot_question_log_prob.pkl' , 'rb') as f:
    Qquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_3rdOf0-5shot_question_log_prob.pkl' , 'rb') as f:
    Qquestion_prob.append( pkl.load( f) )
# with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_4thOf0-5shot_question_log_prob.pkl' , 'rb') as f:
#     Qquestion_prob.append( pkl.load( f) )
# with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_5thOf0-5shot_question_log_prob.pkl' , 'rb') as f:
#     Qquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen05_bird_wEvidence_1shot_SFTCodeS-7b_1stOf0-5shot_question_log_prob.pkl' , 'rb') as f:
    Qquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen05_bird_wEvidence_1shot_SFTCodeS-7b_2ndOf0-5shot_question_log_prob.pkl' , 'rb') as f:
    Qquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen1_bird_wEvidence_1shot_SFTCodeS-7b_1stOf0-5shot_question_log_prob.pkl' , 'rb') as f:
    Qquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen1_bird_wEvidence_1shot_SFTCodeS-7b_2ndOf0-5shot_question_log_prob.pkl' , 'rb') as f:
    Qquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen2_bird_wEvidence_1shot_SFTCodeS-7b_1stOf0-5shot_question_log_prob.pkl' , 'rb') as f:
    Qquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen2_bird_wEvidence_1shot_SFTCodeS-7b_2ndOf0-5shot_question_log_prob.pkl' , 'rb') as f:
    Qquestion_prob.append( pkl.load( f) )

Cquestion_prob = []            
# with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_lastOf0-5shot_context_log_prob.pkl' , 'rb') as f:
#     Cquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_2ndOf0-5shot_context_log_prob.pkl' , 'rb') as f:
    Cquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_3rdOf0-5shot_context_log_prob.pkl' , 'rb') as f:
    Cquestion_prob.append( pkl.load( f) )
# with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_4thOf0-5shot_context_log_prob.pkl' , 'rb') as f:
#     Cquestion_prob.append( pkl.load( f) )
# with open('../../vulcan_output/output_CodeS_beam4_bird_wEvidence_1shot_SFTCodeS-7b_5thOf0-5shot_context_log_prob.pkl' , 'rb') as f:
#     Cquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen05_bird_wEvidence_1shot_SFTCodeS-7b_1stOf0-5shot_context_log_prob.pkl' , 'rb') as f:
    Cquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen05_bird_wEvidence_1shot_SFTCodeS-7b_2ndOf0-5shot_context_log_prob.pkl' , 'rb') as f:
    Cquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen1_bird_wEvidence_1shot_SFTCodeS-7b_1stOf0-5shot_context_log_prob.pkl' , 'rb') as f:
    Cquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen1_bird_wEvidence_1shot_SFTCodeS-7b_2ndOf0-5shot_context_log_prob.pkl' , 'rb') as f:
    Cquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen2_bird_wEvidence_1shot_SFTCodeS-7b_1stOf0-5shot_context_log_prob.pkl' , 'rb') as f:
    Cquestion_prob.append( pkl.load( f) )
with open('../../vulcan_output/output_CodeS_beam4_divpen2_bird_wEvidence_1shot_SFTCodeS-7b_2ndOf0-5shot_context_log_prob.pkl' , 'rb') as f:
    Cquestion_prob.append( pkl.load( f) )

def sort_by_other_list(main_list, ref_list):
    return [x for _, x in sorted( zip(ref_list, main_list), reverse=True )]

total = 0
evaluating = 0
for index in range(1534):
    
    main_items = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']#, 'i' , 'j']

    #Ground truth
    gt_ref_list = [ math.log(q[index]) for q in question_prob ]
    gt_order = sort_by_other_list(main_items , gt_ref_list)
    # print('GT  list: ' , gt_ref_list)

    #a * contex + b * question prob
    a = -1
    b = 7
    new_ref_list = [ a* p[index]+ b * q[index] for p,q in zip(Cquestion_prob , Qquestion_prob) ]

    
    # print(f'new list: {new_ref_list}\n')
    new_order = sort_by_other_list(main_items , new_ref_list)

    # tau, _ = kendalltau(gt_order, new_order) #Kendall test
    tau, _ = pearsonr([sum(gt_ref_list[::2]),sum(gt_ref_list[1::2])], [sum(new_ref_list[::2]),sum(new_ref_list[1::2])]) #pearson test
    
    
    if len(set(gt_ref_list))>1 and math.isnan(tau)==False:
        
        # print(f'tau: {tau}\n')
        evaluating+=1
        total += tau

print('evaluating: ' , evaluating)
print("Kendall's Tau:", total/evaluating)


    

evaluating:  559
Kendall's Tau: 0.0626118067978533


  tau, _ = pearsonr([sum(gt_ref_list[::2]),sum(gt_ref_list[1::2])], [sum(new_ref_list[::2]),sum(new_ref_list[1::2])]) #pearson test
