In [2]:
#Code for sequence-level-voting ensemble using my designed scoring mechanism with tree based similarity rather than jaccard similarity
#This ensemble is performed on the last beams gathered from the clause-level ensemble.
import torch
from nltk.tokenize import word_tokenize
from sqlglot import parse_one, exp , diff
from sqlglot.diff import Keep
from collections import defaultdict , Counter
import numpy as np
import random
import math

def jaccard_similarity(set_1 , set_2):
    
    intersection = len(set_1.intersection(set_2))
    union = len(set_1.union(set_2))

    if union==0:
        return 1
    
    return intersection / union

def generator_to_set(myGenerator):
    #It extracts the attribute .this for each element in the generator gotten from parse_one(query1).find_all(exp.Table/Column)
    mySet = set()
    myList = list(myGenerator)
    for i in myList:
        mySet.add(i.this.this)
    return mySet

def combine_similarities(column_similarity, table_similarity , tree_similarity):
    return (column_similarity+table_similarity+tree_similarity)/3
    
def query_similarity(q1_parsed , q2_parsed):
    #calculates the similarity between two queries.
    
    query1_columns = q1_parsed.find_all(exp.Column)
    query1_tables = q1_parsed.find_all(exp.Table)
    # 
    query2_columns = q2_parsed.find_all(exp.Column)
    query2_tables = q2_parsed.find_all(exp.Table)
    
    #putting the extracted columns in two sets
    q1_column_set = generator_to_set(query1_columns)
    q2_column_set = generator_to_set(query2_columns)
    column_similarity = jaccard_similarity( q1_column_set , q2_column_set )
        
    q1_table_set = generator_to_set(query1_tables)
    q2_table_set = generator_to_set(query2_tables)
    table_similarity = jaccard_similarity( q1_table_set , q2_table_set )

    diff_list = diff(q1_parsed , q2_parsed)
    number_of_keep = 0
    for i in diff_list:
        number_of_keep += int(isinstance( i , Keep ))
    tree_similarity = number_of_keep/len(diff_list)
    
    return combine_similarities( column_similarity , table_similarity , tree_similarity )


def similarity_meassure( hypothesis_query, reference_query, in_table_keywords):

    hypothesis_query = hypothesis_query.replace('`' , '"')
    reference_query = reference_query.replace('`' , '"')
    try:
        hypothesis_parsed = parse_one(hypothesis_query)
        reference_parsed = parse_one(reference_query)
    except:
        return 0
        
    else:
        return query_similarity( hypothesis_parsed , reference_parsed )

def ensemble( decoded_text_list , inputs_log_prob , batch_text ):
    #function for performing ensemble using the bleu metric between the candidate sequences.
    #input:
        #decoded_text_list: list of output strings for each candidate beam
        #inputs_log_prob: list of torch tensor representing the probability of each input_ids with the shape( 1 , num_beam)
        #batch_text: list of string with len()=number_components. Having the prompts for each components.
    #return:
        #ensembled decoded output: String
    
    NUMBER_of_components = len(inputs_log_prob)
    num_beams = inputs_log_prob[0].size(-1)
    
    table_creation_part_prompt = batch_text[0].split('Given the following database schema:')[-1].split('Answer the following')[0]
    # print('table_creation_part_prompt: ', table_creation_part_prompt)
    in_table_keywords = word_tokenize( table_creation_part_prompt )
    # print('in_table_keywords: ' , in_table_keywords)
        
    selection_score_list = [] #This scoring is used to select the best candidates. It uses the length penalty to calculate the scores
    
    #Calculating the similarity score for each candidate
    for j in range( len( decoded_text_list ) ): #[tok_component1_beam1, tok_component2_beam1, tok_component3_beam1, ..., tok_component1_beam2, tok_component2_beam2 , ...]
        temp_decoded_text_list = decoded_text_list.copy()
        decoded_text = temp_decoded_text_list.pop(j)
        selection_score = 0
        for index , other_response in enumerate(temp_decoded_text_list):
            if index>=j:
                index+=1
            # other_response_prob = torch.exp( inputs_log_prob[ int((index%self.number_of_components)+i) , 0,  int(index/self.number_of_components) ] )
            selection_score_with_other_response = inputs_log_prob[ int(index/num_beams) ][ 0,  int(index%num_beams) ] + similarity_meassure( decoded_text, other_response, in_table_keywords )
            
            if selection_score == 0:
                selection_score = selection_score_with_other_response
            else:
                # print(selection_score_with_other_response.dtype) torch.float32
                # print(selection_score.dtype) torch.float32
                alpha = max(selection_score_with_other_response , selection_score)
                beta = min(selection_score_with_other_response , selection_score)
                selection_score = alpha + torch.log1p(torch.exp(beta-alpha))
                
        selection_log_score = (selection_score + inputs_log_prob[ int(j/num_beams) ] [ 0,  int(j%num_beams) ] )#/2
        # print(f'toknes:{tokenized_response} point:{score}')
        selection_score_list.append(selection_log_score)

    #Finding the sequence with the highest bleu score.
    max_score_value = max( selection_score_list )
    max_index_score = selection_score_list.index(max_score_value)
    
    return decoded_text_list[max_index_score]

#To use the above algorithm we need the following things: decoded_text_list , inputs_log_prob , batch_text 


In [3]:
NUMBER_OF_SAMPLES = int(len(outputs_seq)/20)
final_output_seqs = []
for i in range(NUMBER_OF_SAMPLES):
    output_seq_strt_idx = i*20
    prompts_strt_idx = i*5
    prob_logs_strt_idx = i*5
    decoded_text_list = outputs_seq[ output_seq_strt_idx : output_seq_strt_idx + 1 ]
    decoded_text_list.extend(outputs_seq[ output_seq_strt_idx+4 : output_seq_strt_idx + 5 ])
    decoded_text_list.extend(outputs_seq[ output_seq_strt_idx+8 : output_seq_strt_idx + 9 ])
    decoded_text_list.extend(outputs_seq[ output_seq_strt_idx+12 : output_seq_strt_idx + 13 ])
    decoded_text_list.extend(outputs_seq[ output_seq_strt_idx+16 : output_seq_strt_idx + 17 ])
    inputs_log_prob = [ i[ : , 0:1 ] for i in outputs_log_prob[ prob_logs_strt_idx : prob_logs_strt_idx + 5 ] ]
    batch_text = prompts[ prompts_strt_idx : prompts_strt_idx + 5 ]
    final_output = ensemble( decoded_text_list , inputs_log_prob , batch_text )
    final_output_seqs.append( final_output )
with open('./final_output_weighted_treebased_scoring.pkl' , 'wb')as f:
    pkl.dump(final_output_seqs , f)

In [1]:
#grouping the pieces of output_sequences
import pickle as pkl

def gather_threads_outputs(dir_name , post_fix):
    #dir_name = 'outputs1'
    #post_fix = '_batch_text' or '_inputs_log_prob' or ''
    #grouping the pieces of output_sequences
    prefix = './'+ dir_name + '/output_sequences_'
    output_sequences = []
    for i in range(0 , 1050 , 50):
        if i <1000:
            if i == 0: 
                file_name = prefix + '0_'  + str(i+50).lstrip('0') + '-4'+ post_fix +'.pkl'
            else:
                file_name = prefix + str(i).lstrip('0') + '_' + str(i+50).lstrip('0') + '-4'+ post_fix +'.pkl'
        else:
            file_name = prefix + str(i).lstrip('0') + '_end' + '-4'+ post_fix +'.pkl'
        # print('Processing file: ' , file_name )
        with open(file_name , 'rb') as f:
            part_of_output = pkl.load(f)
            # print('Length of the processed output file: ' , len(part_of_output))
        output_sequences.extend(part_of_output)

    # print( 'Output length is: ' , len(output_sequences) )
    return output_sequences

output_dir = 'outputs1'
outputs_seq = gather_threads_outputs(output_dir , '') # 20 answers for each question. 0-19 question 1, etc.
prompts = gather_threads_outputs(output_dir , '_batch_text') # 5 prompts for each question all in one list: 0-4 for question 1, 5-9 for question 2, and etc.
outputs_log_prob = gather_threads_outputs(output_dir , '_inputs_log_prob') # 5 prompts for each question all in one list: 0-4 for question 1, 5-9 for question 2, and etc.


  return torch.load(io.BytesIO(b))
