In [None]:
import re
import zhon
import json
import math
import time
import torch
import string
import datetime
import numpy as np
import pandas as pd
import collections
import Levenshtein
import scipy.stats as ss
import torch.nn.functional as F
import matplotlib.pyplot as plt

from LAC import LAC
from matplotlib.pyplot import figure
from zhon.hanzi import punctuation as CHN_punctuation

TEST_DATA_PATH = "./data/sim_interpretation_B.txt"
LIME_SCORE_PATH = './lime_scores/output.npy'
SHAP_SCORE_PATH = './shap_scores/output.npy'
IG_SCORE_PATH = './ig_scores/output.npy'

MACBERT_LOGITS_PATH = "./macbert_large_without_EF/predictionB/logits_B_test.npy"
MACBERT_EF_LOGITS_PATH = "./macbert_large_with_EF/predictionB/logits_B_test.npy"

ROBERTA_LOGITS_PATH = "./roberta_large_without_EF/predictionB/logits_B_test.npy"
ROBERTA_EF_LOGITS_PATH = "./roberta_large_with_EF/predictionB/logits_B_test.npy"

BERT_CHN_LOGITS_PATH = "./bert_base_chinese_without_EF/predictionB/logits_B_test.npy"
BERT_CHN_EF_LOGITS_PATH = "./bert_base_chinese_with_EF/predictionB/logits_B_test.npy"

BERT_MULLING_LOGITS_PATH = "./bert_base_multilingual_without_EF/predictionB/logits_B_test.npy"
BERT_MULLING_EF_LOGITS_PATH = "./bert_base_multilingual_with_EF/predictionB/logits_B_test.npy"

In [None]:
ENG_punctuation =  string.punctuation
P_LIST = list(ENG_punctuation) + list(CHN_punctuation) 
stopwordsFile = open("./baidu_stopwords.txt", "r")
baidu_stopwords = stopwordsFile.read()  
STOPWORDS= baidu_stopwords.split('\n')
RM_TOKENS = STOPWORDS + P_LIST
lac = LAC(mode='rank')

In [None]:
if torch.cuda.is_available():    
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
#function for calculating accuracy
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

#function for timing 
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    elapsed_rounded = int(round((elapsed)))
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def Levenshtein_similarity(string1, string2):
    '''
    output: scalar, Levenshtein similarity of string1 & string2
    '''
    Levenshtein_ratio = Levenshtein.ratio(string1, string2)
    return(Levenshtein_ratio)

### test data

In [None]:
LCQMC_testB_dic = {'id':[], 'query':[], 'title':[], 'text_q_seg':[],'text_t_seg':[] } 
with open(TEST_DATA_PATH, 'r') as f:
    for line in f:
        line_dic = json.loads(line)
        for k in line_dic.keys():
            LCQMC_testB_dic[k].append(line_dic[k])

LCQMC_testB = pd.DataFrame.from_dict(LCQMC_testB_dic)
LCQMC_testB['sentence'] = LCQMC_testB['query'] +"[SEP]" + LCQMC_testB['title']
LCQMC_testB

### Bayes Iteration Prediction

In [None]:
GOLD_DATA_PATH  = './GOLD_DATA_PATH/Data.txt'
with open(GOLD_DATA_PATH,'r') as file:
    lines = file.readlines()
data_B = []
for line in lines:
    data_B.append(json.loads(line))
data_B = pd.DataFrame(data_B)

In [None]:
def softmax(x):
     return np.exp(x)/sum(np.exp(x))

def refine_sen(sen):
    refine_sen = ''.join([i for i in list(sen) if i not in RM_TOKENS])
    return refine_sen

def update_prob(p, pie):
    p = p*pie/(p*pie+(1-p)*(1-pie))
    return p 

In [None]:
def Bayes_iteration_prediction(logits, infomation, num_iter, print_log=True):
    ''' 
    Use Bayes iteration prediction algorithm to obtain the predicted labels
    
    logits: binary classification logits
    infomation: data for providing infomation, dataset that has [['id','text_q','text_t','sent_label']], in which 'label' is only for accuracy calculating 
    num_iter: number of iteration
    print_log: whether to print the log 
    '''
    prob = softmax(logits.transpose()).transpose()
    pred_l = np.argmax(logits ,axis=1)
    if print_log:
        print('pred_iter0', sum(np.array(data_B['sent_label']) == pred_l)/len(pred_l))
    BayesDataB = infomation.copy()[['sent_id']]

    fined_sq,fined_st = [],[]
    for i in range(len(infomation)):
        fined_sq.append(infomation['text_q'][i])
        fined_st.append(infomation['text_t'][i]) 
    BayesDataB['pred_iter0'] = pred_l
    BayesDataB['prob_iter0'] = prob[:,1]
    BayesDataB['fined_sq'] = [refine_sen(s) for s in fined_sq]
    BayesDataB['fined_st'] = [refine_sen(s) for s in fined_st]

    BayesQ = {}
    for i in range(len(BayesDataB)):
        if BayesDataB['fined_sq'][i] not in BayesQ:
            BayesQ[BayesDataB['fined_sq'][i]] = [1, BayesDataB['prob_iter0'][i], BayesDataB['prob_iter0'][i]/1]
        else: 
            BayesQ[BayesDataB['fined_sq'][i]][0]+=1
            BayesQ[BayesDataB['fined_sq'][i]][1]+=BayesDataB['prob_iter0'][i]
            BayesQ[BayesDataB['fined_sq'][i]][2]= BayesQ[BayesDataB['fined_sq'][i]][1]/BayesQ[BayesDataB['fined_sq'][i]][0]
        if BayesDataB['fined_st'][i] not in BayesQ:
            BayesQ[BayesDataB['fined_st'][i]] = [1, BayesDataB['prob_iter0'][i], BayesDataB['prob_iter0'][i]/1]
        else: 
            BayesQ[BayesDataB['fined_st'][i]][0]+=1
            BayesQ[BayesDataB['fined_st'][i]][1]+=BayesDataB['prob_iter0'][i]
            BayesQ[BayesDataB['fined_st'][i]][2]= BayesQ[BayesDataB['fined_st'][i]][1]/BayesQ[BayesDataB['fined_st'][i]][0]

    acc_bayes_records = [sum(np.array(data_B['sent_label']) == pred_l)/len(pred_l)] 
    for b in range(num_iter):
        pred_next = []
        prob_next = []
        for i in range(len(BayesDataB)):
            pie1 = max(BayesQ[BayesDataB['fined_sq'][i]][2], BayesQ[BayesDataB['fined_st'][i]][2])
            p1 = BayesDataB['prob_iter'+str(b)][i]
            p1_update = update_prob(p1, pie1)
            pred_next.append(int((p1_update>0.5)==True))
            prob_next.append(p1_update)
        BayesDataB['pred_iter'+str(b+1)] = pred_next
        BayesDataB['prob_iter'+str(b+1)] = prob_next
        acc_bayes_records.append(sum(np.array(data_B['sent_label']) == pred_next)/len(pred_next))
        if print_log:
            print('pred_iter'+str(b+1),sum(np.array(data_B['sent_label']) == pred_next)/len(pred_next))
        
    return {'final_prediction': pred_next,  'acc_records': acc_bayes_records}

### Multiple models

In [None]:
macbert_large_logits = np.load(MACBERT_LOGITS_PATH, allow_pickle=True)
result = Bayes_iteration_prediction(macbert_large_logits, data_B, 14, print_log=True)
final_prediction_macbert, acc_records = result['final_prediction'], np.array(result['acc_records'])*100
np.save('./macbert_large_without_EF/predictionB/bayes_prediction_label.npy', np.array(final_prediction_macbert))

macbert_large_with_EF_logits = np.load(MACBERT_EF_LOGITS_PATH, allow_pickle=True)
result_EF = Bayes_iteration_prediction(macbert_large_with_EF_logits, data_B, 14, print_log=True)
final_prediction_macbert_EF, acc_records_EF = result_EF['final_prediction'], np.array(result_EF['acc_records'])*100
np.save('./macbert_large_with_EF/predictionB/bayes_prediction_label.npy', np.array(final_prediction_macbert_EF))

# plot
plt.figure(figsize=(5, 3.5),dpi=100)
for i in range(len(acc_records)):
    plt.axhline(y=acc_records[i], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records[0], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records[i], xmin=0, xmax=i, color='grey', linestyle='--')
plt.plot(acc_records, '-^', color = 'orange', label = 'MacBERT-BIP')

for i in range(len(acc_records_EF)):
    plt.axhline(y=acc_records_EF[i], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records_EF[0], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records_EF[i], xmin=0, xmax=i, color='grey', linestyle='--')
plt.plot(acc_records_EF, '-^', color = 'firebrick', label = 'MacBERT-EF-BIP')
plt.xlim((-1.5,15))
plt.xlabel('Bayes iteration',fontsize=15)
plt.ylabel('Accuracy',fontsize=15)
plt.legend(loc = 'lower center')
plt.show()

In [None]:
roberta_large_logits = np.load(ROBERTA_LOGITS_PATH, allow_pickle=True)
result = Bayes_iteration_prediction(roberta_large_logits, data_B, 14)
final_prediction_roberta, acc_records = result['final_prediction'], np.array(result['acc_records'])*100
np.save('./roberta_large_without_EF/predictionB/bayes_prediction_label.npy', np.array(final_prediction_roberta))

roberta_large_EF_logits = np.load(ROBERTA_EF_LOGITS_PATH, allow_pickle=True)
result_EF = Bayes_iteration_prediction(roberta_large_EF_logits, data_B, 14)
final_prediction_roberta_EF, acc_records_EF = result_EF['final_prediction'], np.array(result_EF['acc_records'])*100
np.save('./roberta_large_with_EF/predictionB/bayes_prediction_label.npy', np.array(final_prediction_roberta_EF))

# plot
plt.figure(figsize=(5, 3.5),dpi=100)
for i in range(len(acc_records)):
    plt.axhline(y=acc_records[i], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records[0], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records[i], xmin=0, xmax=i, color='grey', linestyle='--')
plt.plot(acc_records, '-^', color = 'orange', label = 'RoBERTa-BIP')

for i in range(len(acc_records_EF)):
    plt.axhline(y=acc_records_EF[i], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records_EF[0], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records_EF[i], xmin=0, xmax=i, color='grey', linestyle='--')
plt.plot(acc_records_EF, '-^', color = 'firebrick', label = 'RoBERTa-EF-BIP')

plt.xlim((-1.5,15))
plt.ylim((86, 89))
plt.xlabel('Bayes iteration',fontsize=15)
plt.ylabel('Accuracy',fontsize=15)
plt.legend(loc = 'lower center')
plt.show()

In [None]:
bert_base_chinese_logits = np.load(BERT_CHN_LOGITS_PATH, allow_pickle=True)
result = Bayes_iteration_prediction(bert_base_chinese_logits, data_B, 14)
final_prediction_bert_base_chinese, acc_records = result['final_prediction'], np.array(result['acc_records'])*100
np.save('./bert_base_chinese_without_EF/predictionB/bayes_prediction_label.npy', np.array(final_prediction_bert_base_chinese))

bert_base_chinese_EF_logits = np.load(BERT_CHN_EF_LOGITS_PATH, allow_pickle=True)
result_EF = Bayes_iteration_prediction(bert_base_chinese_EF_logits, data_B, 14)
final_prediction_bert_base_chinese_EF, acc_records_EF = result_EF['final_prediction'], np.array(result_EF['acc_records'])*100
np.save('./bert_base_chinese_with_EF/predictionB/bayes_prediction_label.npy', np.array(final_prediction_bert_base_chinese_EF))

# plot
plt.figure(figsize=(5, 3.5),dpi=100)
for i in range(len(acc_records)):
    plt.axhline(y=acc_records[i], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records[0], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records[i], xmin=0, xmax=i, color='grey', linestyle='--')
plt.plot(acc_records, '-^', color = 'orange', label = 'bert-chinese-BIP')

for i in range(len(acc_records_EF)):
    plt.axhline(y=acc_records_EF[i], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records_EF[0], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records_EF[i], xmin=0, xmax=i, color='grey', linestyle='--')
plt.plot(acc_records_EF, '-^', color = 'firebrick', label = 'bert-chinese-EF-BIP')

plt.xlim((-1.5,15))
plt.xlabel('Bayes iteration',fontsize=15)
plt.ylabel('Accuracy',fontsize=15)
plt.legend(loc = 'lower center')
plt.show()

In [None]:
bert_base_multilang_logits = np.load(BERT_MULLING_LOGITS_PATH, allow_pickle=True)
result = Bayes_iteration_prediction(bert_base_multilang_logits, data_B, 14)
final_prediction_bert_base_multilang, acc_records = result['final_prediction'], np.array(result['acc_records'])*100
np.save('./bert_base_multilingual_without_EF/predictionB/bayes_prediction_label.npy', np.array(final_prediction_bert_base_multilang))

bert_base_multilang_EF_logits = np.load(BERT_MULLING_EF_LOGITS_PATH, allow_pickle=True)
result_EF = Bayes_iteration_prediction(bert_base_multilang_EF_logits, data_B, 14)
final_prediction_bert_base_multilang_EF, acc_records_EF = result_EF['final_prediction'], np.array(result_EF['acc_records'])*100
np.save('./bert_base_multilingual_with_EF/predictionB/bayes_prediction_label.npy', np.array(final_prediction_bert_base_multilang_EF))

# plot
plt.figure(figsize=(5, 3.5),dpi=100)
for i in range(len(acc_records)):
    plt.axhline(y=acc_records[i], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records[0], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records[i], xmin=0, xmax=i, color='grey', linestyle='--')
plt.plot(acc_records, '-^', color = 'orange', label = 'bert-multilingual-BIP')

for i in range(len(acc_records_EF)):
    plt.axhline(y=acc_records_EF[i], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records_EF[0], xmin=0, xmax=i, color='lightgrey', linestyle='--')
plt.axhline(y=acc_records_EF[i], xmin=0, xmax=i, color='grey', linestyle='--')
plt.plot(acc_records_EF, '-^', color = 'firebrick', label = 'bert-multilingual-EF-BIP')

plt.xlim((-1.5,15))
plt.xlabel('Bayes iteration',fontsize=15)
plt.ylabel('Accuracy',fontsize=15)
plt.legend(loc = 'lower center')
plt.show()

### Rationale

In [None]:
def reorder(score, token_index, mode='abs'):
    '''
    input
        score: list
        token_index: list
        mode: ['abs' or 'sequence' or 'token_seq']
    output
        reordered score and token lists based on corresponding score
    '''
    assert len(score) == len(token_index)
    output = {}
    if mode =='abs':
        abs_score = [abs(i) for i in score]
        abs_dic = dict(zip(token_index, abs_score))
        sorted_abs_dic = dict(sorted(abs_dic.items(), key=lambda item: item[1], reverse=True))
        output['sorted_token_index'] = list(sorted_abs_dic.keys())
        output['sorted_score'] = list(sorted_abs_dic.values())
    
    if mode =='sequence':
        score_dic = dict(zip(token_index, score))
        sorted_score_dic = dict(sorted(score_dic.items(), key=lambda item: item[1], reverse=True))
        output['sorted_token_index'] = list(sorted_score_dic.keys())
        output['sorted_score'] = list(sorted_score_dic.values())

    if mode =='token_seq':
        abs_dic = dict(zip(token_index, score))
        sorted_abs_dic = dict(sorted(abs_dic.items(), key=lambda item: item[0], reverse=False))
        output['sorted_token_index'] = list(sorted_abs_dic.keys())
        output['sorted_score'] = list(sorted_abs_dic.values())
        
    return output

#### LAC scores and Lime scores

In [None]:
def single_LAC_expand(LAC_result):
    '''
    input: single LAC_result: list of list: [LAC_token, LAC_ner, LAC_imp]
    output: expanded LAC_result: list of list: [LAC_token, LAC_ner, LAC_imp, num]
    '''
    number = [len(list(i)) for i in LAC_result[0]]
    expand_LAC_token = sum([list(i) for i in LAC_result[0]],[])
    expand_LAC_enr = sum([[LAC_result[1][i]]*number[i] for i in range(len(LAC_result[1]))],[])
    expand_LAC_imp =  sum([[LAC_result[2][i]]*number[i] for i in range(len(LAC_result[2]))],[])
    
    return [expand_LAC_token, expand_LAC_enr, expand_LAC_imp, number]

In [None]:
lac = LAC(mode='rank')
dual_ranking_rationale = np.load(LIME_SCORE_PATH, allow_pickle=True).item()
shap_scores = np.load(SHAP_SCORE_PATH, allow_pickle=True).item()
ig_scores = np.load(IG_SCORE_PATH, allow_pickle=True).item()

dual_ranking_rationale['token'] = []
dual_ranking_rationale['query'] = []
dual_ranking_rationale['title'] = []
dual_ranking_rationale['ner'] = []
dual_ranking_rationale['lac'] = []
dual_ranking_rationale['num'] = []
dual_ranking_rationale['piece'] = []
dual_ranking_rationale['laclime'] = [] #LAC-wise mean of token lime score 
dual_ranking_rationale['laclime_rank'] = [] 
dual_ranking_rationale['ori_laclime_rank'] = []
dual_ranking_rationale['shap_score'] = shap_scores['rationale_score']
dual_ranking_rationale['ig_score'] = ig_scores['rationale_score']

for i in range(len(dual_ranking_rationale['id'])): 
    
    q = reorder(dual_ranking_rationale['rationale_score'][i][0], dual_ranking_rationale['rationale'][i][0], mode='token_seq')
    t = reorder(dual_ranking_rationale['rationale_score'][i][1], dual_ranking_rationale['rationale'][i][1], mode='token_seq')
    
    dual_ranking_rationale['rationale'][i] = [q['sorted_token_index'], t['sorted_token_index']]
    dual_ranking_rationale['rationale_score'][i] = [q['sorted_score'], t['sorted_score']]    

    dual_ranking_rationale['token'].append([LCQMC_testB['text_q_seg'][i], LCQMC_testB['text_t_seg'][i]])
    dual_ranking_rationale['query'].append(LCQMC_testB['query'][i])
    dual_ranking_rationale['title'].append(LCQMC_testB['title'][i])
    LAC_result_q, LAC_result_t = lac.run(LCQMC_testB['query'][i]), lac.run(LCQMC_testB['title'][i])
    dual_ranking_rationale['piece'].append([LAC_result_q[0],LAC_result_t[0]])
    LAC_result_q, LAC_result_t = single_LAC_expand(LAC_result_q) ,single_LAC_expand(LAC_result_t) 

    LAC_token_q, LAC_token_t = LAC_result_q[0], LAC_result_t[0]
    LAC_ner_q, LAC_ner_t = LAC_result_q[1], LAC_result_t[1]
    LAC_imp_q, LAC_imp_t = LAC_result_q[2], LAC_result_t[2]
    LAC_num_q, LAC_num_t = LAC_result_q[3], LAC_result_t[3]
    
    
    if LCQMC_testB['text_q_seg'][i] == LAC_token_q and LCQMC_testB['text_t_seg'][i] == LAC_token_t:
        dual_ranking_rationale['ner'].append([LAC_ner_q, LAC_ner_t])
        dual_ranking_rationale['lac'].append([LAC_imp_q, LAC_imp_t])
        dual_ranking_rationale['num'].append([LAC_num_q, LAC_num_t])
        #add LIME scores based on LAC segmentation (mean)
        c_sum_q, c_sum_t = [0] + list(np.cumsum(LAC_num_q)), [0] + list(np.cumsum(LAC_num_t)) 
        laclime_q = sum([[np.mean(q['sorted_score'][c_sum_q[k]:c_sum_q[k+1]])]*LAC_num_q[k] for k in range(len(LAC_num_q))],[])
        laclime_t = sum([[np.mean(t['sorted_score'][c_sum_t[k]:c_sum_t[k+1]])]*LAC_num_t[k] for k in range(len(LAC_num_t))],[])
        dual_ranking_rationale['laclime'].append([laclime_q, laclime_t])
        dual_ranking_rationale['laclime_rank'].append([ list(ss.rankdata([abs(x) for x in laclime_q])), 
                                                  list(ss.rankdata([abs(x) for x in laclime_t])) ])
        dual_ranking_rationale['ori_laclime_rank'].append([ list(ss.rankdata([x for x in laclime_q])), 
                                                  list(ss.rankdata([x for x in laclime_t])) ])
        
    else:
        j = 0
        while j < len(LCQMC_testB['text_q_seg'][i]):
            if LCQMC_testB['text_q_seg'][i][j] == LAC_result_q[0][j]:
                j+=1
            else:
                LAC_result_q[0][j] = LAC_result_q[0][j]+LAC_result_q[0][j+1]
                del LAC_result_q[0][j+1]
                LAC_result_q[1][j] = LAC_result_q[1][j]  #use the first ner
                del LAC_result_q[1][j+1]
                LAC_result_q[2][j] = int((LAC_result_q[2][j]+LAC_result_q[2][j+1])/2)
                del LAC_result_q[2][j+1]
                need_reduce_idx = [j - k < 0 for k in np.cumsum(LAC_result_q[3])].index(True)
                LAC_result_q[3][need_reduce_idx] = LAC_result_q[3][need_reduce_idx]-1
            
        j = 0
        while j < len(LCQMC_testB['text_t_seg'][i]):
            if LCQMC_testB['text_t_seg'][i][j] == LAC_result_t[0][j]:
                j+=1
            else:
                LAC_result_t[0][j] = LAC_result_t[0][j]+LAC_result_t[0][j+1]
                del LAC_result_t[0][j+1]
                LAC_result_t[1][j] = LAC_result_t[1][j] #use the first ner
                del LAC_result_t[1][j+1]
                LAC_result_t[2][j] = int((LAC_result_t[2][j]+LAC_result_t[2][j+1])/2)
                del LAC_result_t[2][j+1]
                need_reduce_idx = [j - k < 0 for k in np.cumsum(LAC_result_t[3])].index(True)
                LAC_result_t[3][need_reduce_idx] = LAC_result_t[3][need_reduce_idx]-1
                
        LAC_token_q, LAC_token_t = LAC_result_q[0], LAC_result_t[0]
        LAC_ner_q, LAC_ner_t = LAC_result_q[1], LAC_result_t[1]
        LAC_imp_q, LAC_imp_t = LAC_result_q[2], LAC_result_t[2]
        LAC_num_q, LAC_num_t = LAC_result_q[3], LAC_result_t[3]
        dual_ranking_rationale['ner'].append([LAC_ner_q, LAC_ner_t])
        dual_ranking_rationale['lac'].append([LAC_imp_q, LAC_imp_t])
        dual_ranking_rationale['num'].append([LAC_num_q, LAC_num_t])

        #add LIME scores based on LAC segmentation (mean)
        c_sum_q, c_sum_t = [0] + list(np.cumsum(LAC_num_q)), [0] + list(np.cumsum(LAC_num_t)) 
        laclime_q = sum([[np.mean(q['sorted_score'][c_sum_q[k]:c_sum_q[k+1]])]*LAC_num_q[k] for k in range(len(LAC_num_q))],[])
        laclime_t = sum([[np.mean(t['sorted_score'][c_sum_t[k]:c_sum_t[k+1]])]*LAC_num_t[k] for k in range(len(LAC_num_t))],[])
        dual_ranking_rationale['laclime'].append([laclime_q, laclime_t])
        dual_ranking_rationale['laclime_rank'].append([ list(ss.rankdata([abs(x) for x in laclime_q])), 
                                                  list(ss.rankdata([abs(x) for x in laclime_t])) ])
        dual_ranking_rationale['ori_laclime_rank'].append([ list(ss.rankdata([x for x in laclime_q])), 
                                                  list(ss.rankdata([x for x in laclime_t])) ])

###  lexical category scores

In [None]:
ner_a,ner_b,ner_c,ner_d = [],[],[],[]
for i in range(len(dual_ranking_rationale['lac'] )):
    for j in range(len(dual_ranking_rationale['lac'][i][0])):
        if dual_ranking_rationale['lac'][i][0][j] == 3:
            ner_a.append(dual_ranking_rationale['ner'][i][0][j])
        if dual_ranking_rationale['lac'][i][0][j] == 2:
            ner_b.append(dual_ranking_rationale['ner'][i][0][j])
        if dual_ranking_rationale['lac'][i][0][j] == 1:
            ner_c.append(dual_ranking_rationale['ner'][i][0][j])
        if dual_ranking_rationale['lac'][i][0][j] == 0:
            ner_d.append(dual_ranking_rationale['ner'][i][0][j])
    for j in range(len(dual_ranking_rationale['lac'][i][1])):
        if dual_ranking_rationale['lac'][i][1][j] == 3:
            ner_a.append(dual_ranking_rationale['ner'][i][1][j])
        if dual_ranking_rationale['lac'][i][1][j] == 2:
            ner_b.append(dual_ranking_rationale['ner'][i][1][j])
        if dual_ranking_rationale['lac'][i][1][j] == 1:
            ner_c.append(dual_ranking_rationale['ner'][i][1][j])
        if dual_ranking_rationale['lac'][i][1][j] == 0:
            ner_d.append(dual_ranking_rationale['ner'][i][1][j])

aug_frequency = collections.Counter(ner_a + ner_a + ner_a + ner_a+ ner_b + ner_b +ner_b + ner_c + ner_c+ ner_d)
ori_frequency = collections.Counter(ner_a + ner_b + ner_c + ner_d)
fine_imp_dic = {k: round(2*aug_frequency[k]/ori_frequency[k])/2 for k in ori_frequency.keys()}   

print(dict(sorted(fine_imp_dic.items(), key=lambda item: -item[1])))
dual_ranking_rationale['lexicality'] = []
for i in range(len(dual_ranking_rationale['id'])):
    fine_imp_q = [fine_imp_dic[k] for k in dual_ranking_rationale['ner'][i][0]]
    fine_imp_t = [fine_imp_dic[k] for k in dual_ranking_rationale['ner'][i][1]]
    dual_ranking_rationale['lexicality'].append([fine_imp_q, fine_imp_t])

### Rationale （Dual ranking + Bi-criteria Denoising)

In [None]:
def refine_sen(sen):
    refine_sen = ''.join([i for i in list(sen)])
    return refine_sen

def refine_sen2(sen):
    refine_sen = ''.join([i for i in list(sen) if i not in RM_TOKENS])
    return refine_sen

def dual_ranking_bi_criteria_denoising(dual_ranking_rationale, rationale_base_pred, rationale_method = 'dual_ranking', denoising = True, denoising_k =3):
    ''' 
    using dual ranking and Bi-criteria denoising algorithm
    
    input: 
        dual_ranking_rationale: pd.DataFrame that needs denoise, containing columns 
            ['id','query','title','piece','rationale', 'lac', 'lexicality','laclime_rank', 'rationale_score'(lime score),'shap_score','ig_score']
        rationale_base_pred:  model prediction, list of 0 and 1
        rationale_method: string, method of rationale ranking. 
            Options: {'dual_ranking', 'lexicality_only', 'lac_only', 'positive_lime', 'proportional_lime', 'positive_shap','proportional_shap', 'positive_ig','proportional_ig'}
        denoising: bool, whether to use denoising algorithm
        denoising_k: int, denoising parameter in Bi-criteria denoising algorithm
    output: 
        final_output: pd.DataFrame with columns ['id', 'label', 'rationale']
    '''
    dual_ranking_rationale['label'] =  list(rationale_base_pred)
    final_output = np.load(LIME_SCORE_PATH, allow_pickle=True).item()

    refined_sen_Q = {}
    for i in range(len(dual_ranking_rationale['id'])):
        refined_sen = refine_sen(dual_ranking_rationale['query'][i])
        if refined_sen not in refined_sen_Q:
            refined_sen_Q[refined_sen] = 1
        else:
            refined_sen_Q[refined_sen] += 1
        refined_sen = refine_sen(dual_ranking_rationale['title'][i])
        if refined_sen not in refined_sen_Q:
            refined_sen_Q[refined_sen] = 1
        else:
            refined_sen_Q[refined_sen] += 1

    refined_sen_Q2 = {}
    for i in range(len(dual_ranking_rationale['id'])):
        refined_sen = refine_sen2(dual_ranking_rationale['query'][i])
        if refined_sen not in refined_sen_Q2:
            refined_sen_Q2[refined_sen] = 1
        else:
            refined_sen_Q2[refined_sen] += 1
            
        refined_sen = refine_sen2(dual_ranking_rationale['title'][i])
        if refined_sen not in refined_sen_Q2:
            refined_sen_Q2[refined_sen] = 1
        else:
            refined_sen_Q2[refined_sen] += 1

    for i in range(len(dual_ranking_rationale['rationale'])): 
        text_q_seg = LCQMC_testB[LCQMC_testB['id'] == dual_ranking_rationale['id'][i]]['text_q_seg'].item()
        text_t_seg = LCQMC_testB[LCQMC_testB['id'] == dual_ranking_rationale['id'][i]]['text_t_seg'].item()
        common_seg = set(text_q_seg).intersection(set(text_t_seg))

        text_q  = LCQMC_testB[LCQMC_testB['id'] == dual_ranking_rationale['id'][i]]['query'].item()
        text_t  = LCQMC_testB[LCQMC_testB['id'] == dual_ranking_rationale['id'][i]]['title'].item()
        if refined_sen_Q[refine_sen(text_q)] > refined_sen_Q[refine_sen(text_t)]:
            criteria  = 'ODB_q'
        elif refined_sen_Q[refine_sen(text_q)] < refined_sen_Q[refine_sen(text_t)]:
            criteria  = 'ODB_t'
        else:
            if refined_sen_Q2[refine_sen2(text_q)] > refined_sen_Q2[refine_sen2(text_t)]:
                criteria = 'RDB_q'
            elif refined_sen_Q2[refine_sen2(text_q)] < refined_sen_Q2[refine_sen2(text_t)]:
                criteria = 'RDB_t'
            else:
                criteria = 'NONE'
        
        
        q_seg_from_piece = list(''.join([p for p in dual_ranking_rationale['piece'][i][0] if p not in STOPWORDS+P_LIST]))
        remove_q_seg_from_piece = list(''.join([p for p in dual_ranking_rationale['piece'][i][0] if p in STOPWORDS+P_LIST]))
        t_seg_from_piece = list(''.join([p for p in dual_ranking_rationale['piece'][i][1] if p not in STOPWORDS+P_LIST]))
        remove_t_seg_from_piece = list(''.join([p for p in dual_ranking_rationale['piece'][i][1] if p in STOPWORDS+P_LIST]))
        critical_uncommon_seg = set(q_seg_from_piece).union(set(t_seg_from_piece)) - set(q_seg_from_piece).intersection(set(t_seg_from_piece)) - set(remove_q_seg_from_piece).union(set(remove_t_seg_from_piece))-set(STOPWORDS)
        
        
        if dual_ranking_rationale['label'][i] == 1:
            if rationale_method == 'positive_lime':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['rationale_score'][i][0][k]>0]
                LAC_imp = [dual_ranking_rationale['rationale_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['rationale_score'][i][1][k]>0]
                LAC_imp = [dual_ranking_rationale['rationale_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
            
            elif rationale_method == 'proportional_lime':
                rationale = [k for k in range(len(text_q_seg))]
                LAC_imp = [dual_ranking_rationale['rationale_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_q_seg)*0.705)]
                
                rationale = [k for k in range(len(text_t_seg))]
                LAC_imp = [dual_ranking_rationale['rationale_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]

            elif rationale_method ==  'positive_shap':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['shap_score'][i][0][k]>0]
                LAC_imp = [dual_ranking_rationale['shap_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['shap_score'][i][1][k]>0]
                LAC_imp = [dual_ranking_rationale['shap_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
            
            elif rationale_method ==  'proportional_shap':
                rationale = [k for k in range(len(text_q_seg))]
                LAC_imp = [dual_ranking_rationale['shap_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]
                
                rationale = [k for k in range(len(text_t_seg))]
                LAC_imp = [dual_ranking_rationale['shap_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]


            elif rationale_method == 'positive_ig':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['ig_score'][i][0][k]>0]
                LAC_imp = [dual_ranking_rationale['ig_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['ig_score'][i][1][k]>0]
                LAC_imp = [dual_ranking_rationale['ig_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

            elif rationale_method == 'proportional_ig':
                rationale = [k for k in range(len(text_q_seg))]
                LAC_imp = [dual_ranking_rationale['ig_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]
                
                rationale = [k for k in range(len(text_t_seg))]
                LAC_imp = [dual_ranking_rationale['ig_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]


            elif rationale_method == 'lexicality_only':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['lexicality'][i][0][k]>1]
                LAC_imp = [dual_ranking_rationale['lexicality'][i][0][k] for k in rationale]
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['lexicality'][i][1][k]>1]
                LAC_imp = [dual_ranking_rationale['lexicality'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

            elif rationale_method == 'lac_only':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['lac'][i][0][k]>0]
                LAC_imp = [dual_ranking_rationale['lexicality'][i][0][k] for k in rationale]
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['lac'][i][1][k]>0]
                LAC_imp = [dual_ranking_rationale['lexicality'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

            elif rationale_method == 'dual_ranking':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['lac'][i][0][k]>0 and text_q_seg[k] not in P_LIST  and (text_q_seg[k] in common_seg or dual_ranking_rationale['lac'][i][0][k]>=2)]
                max_LAC_imp = max([dual_ranking_rationale['lexicality'][i][0][k] for k in rationale])
                LAC_imp = [dual_ranking_rationale['lexicality'][i][0][k]+ min(max_LAC_imp-dual_ranking_rationale['lexicality'][i][0][k], dual_ranking_rationale['lac'][i][0][k]/100) for k in rationale]
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['lac'][i][1][k] >0 and text_t_seg[k] not in P_LIST and (text_t_seg[k] in common_seg or dual_ranking_rationale['lac'][i][1][k]>=2)]
                max_LAC_imp = max([dual_ranking_rationale['lexicality'][i][1][k] for k in rationale])
                LAC_imp = [dual_ranking_rationale['lexicality'][i][1][k]+ min(max_LAC_imp-dual_ranking_rationale['lexicality'][i][1][k], dual_ranking_rationale['lac'][i][1][k]/100) for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                            
            else:
                pass

            #denoising for alignment（label 1)
            if denoising and (criteria == 'ODB_q' or criteria == 'RDB_q'):
                new_rationale = []
                for qk in final_output['rationale'][i][0]:
                    for tk in final_output['rationale'][i][1]:
                        if text_t_seg[tk] == text_q_seg[qk]:
                            new_rationale.append(tk)
                            final_output['rationale'][i][1].remove(tk)
                            break
                final_output['rationale'][i][1] = new_rationale
            
            if denoising and (criteria == 'ODB_t' or criteria == 'RDB_t'):
                new_rationale = []
                for tk in final_output['rationale'][i][1]:
                    for qk in final_output['rationale'][i][0]:
                        if text_q_seg[qk] == text_t_seg[tk]:
                            new_rationale.append(qk)
                            final_output['rationale'][i][0].remove(qk)
                            break
                final_output['rationale'][i][0] = new_rationale


        if dual_ranking_rationale['label'][i] == 0:
            if rationale_method == 'positive_lime':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['rationale_score'][i][0][k]>0]
                LAC_imp = [dual_ranking_rationale['rationale_score'][i][0][k] for k in rationale]
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['rationale_score'][i][1][k]>0]
                LAC_imp = [dual_ranking_rationale['rationale_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

            elif rationale_method == 'proportional_lime':
                rationale = [k for k in range(len(text_q_seg))]
                LAC_imp = [dual_ranking_rationale['rationale_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_q_seg)*0.705)]
                
                rationale = [k for k in range(len(text_t_seg))]
                LAC_imp = [dual_ranking_rationale['rationale_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]

            elif rationale_method == 'positive_shap':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['shap_score'][i][0][k]>0]
                LAC_imp = [dual_ranking_rationale['shap_score'][i][0][k] for k in rationale]
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['shap_score'][i][1][k]>0]
                LAC_imp = [dual_ranking_rationale['shap_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

            elif rationale_method ==  'proportional_shap':
                rationale = [k for k in range(len(text_q_seg))]
                LAC_imp = [dual_ranking_rationale['shap_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]
                
                rationale = [k for k in range(len(text_t_seg))]
                LAC_imp = [dual_ranking_rationale['shap_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]

            elif rationale_method == 'positive_ig':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['ig_score'][i][0][k]>0]
                LAC_imp = [dual_ranking_rationale['ig_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['ig_score'][i][1][k]>0]
                LAC_imp = [dual_ranking_rationale['ig_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

            elif rationale_method == 'proportional_ig':
                rationale = [k for k in range(len(text_q_seg))]
                LAC_imp = [dual_ranking_rationale['ig_score'][i][0][k] for k in rationale]                
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]
                
                rationale = [k for k in range(len(text_t_seg))]
                LAC_imp = [dual_ranking_rationale['ig_score'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index'][:math.ceil(len(text_t_seg)*0.705)]

            elif rationale_method == 'lexicality_only':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['lexicality'][i][0][k]>1]
                LAC_imp = [dual_ranking_rationale['lexicality'][i][0][k] for k in rationale]
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['lexicality'][i][1][k]>1]
                LAC_imp = [dual_ranking_rationale['lexicality'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
           
            elif rationale_method == 'lac_only':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['lac'][i][0][k]>0]
                LAC_imp = [dual_ranking_rationale['lexicality'][i][0][k] for k in rationale]
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['lac'][i][1][k]>0]
                LAC_imp = [dual_ranking_rationale['lexicality'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']

            elif rationale_method == 'dual_ranking':
                rationale = [k for k in range(len(text_q_seg)) if dual_ranking_rationale['lac'][i][0][k]>0 and text_q_seg[k] not in P_LIST]
                max_LAC_imp = max([dual_ranking_rationale['lexicality'][i][0][k] for k in rationale])
                LAC_imp = [dual_ranking_rationale['lexicality'][i][0][k]+min(max_LAC_imp-dual_ranking_rationale['lexicality'][i][0][k], dual_ranking_rationale['laclime_rank'][i][0][k]/100 ) if text_q_seg[k] in critical_uncommon_seg else dual_ranking_rationale['lexicality'][i][0][k] for k in rationale]
                final_output['rationale'][i][0]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
                
                rationale = [k for k in range(len(text_t_seg)) if dual_ranking_rationale['lac'][i][1][k] >0 and text_t_seg[k] not in P_LIST]
                max_LAC_imp = max([dual_ranking_rationale['lexicality'][i][1][k] for k in rationale])
                LAC_imp = [dual_ranking_rationale['lexicality'][i][1][k]+min(max_LAC_imp-dual_ranking_rationale['lexicality'][i][1][k], dual_ranking_rationale['laclime_rank'][i][1][k]/100) if text_t_seg[k] in critical_uncommon_seg else dual_ranking_rationale['lexicality'][i][1][k] for k in rationale]
                final_output['rationale'][i][1]  = reorder(LAC_imp, rationale, mode='sequence')['sorted_token_index']
           
    
            else:
                pass

            #denoising for alignment (label 0)
            if denoising and (criteria == 'ODB_q' or criteria == 'RDB_q'):
                new_rationale = []
                uncommon_flag = 0 
                for qk in final_output['rationale'][i][0]:
                    if len(final_output['rationale'][i][1])==0:
                        break                
                    first_t = final_output['rationale'][i][1][0]
                    
                    if text_t_seg[first_t]== text_q_seg[qk]:
                        new_rationale.append(first_t)
                        final_output['rationale'][i][1].remove(first_t)
                        if len(final_output['rationale'][i][1])>0:
                            continue
                        else:
                            break
                            
                    else:                    
                        if text_q_seg[qk] not in text_t_seg:
                            if uncommon_flag < denoising_k:
                                uncommon_flag +=1 
                                uncommon_t_1st = final_output['rationale'][i][1][0] 
                                new_rationale.append(uncommon_t_1st)
                                final_output['rationale'][i][1].remove(uncommon_t_1st)
                                continue
                            else:
                                break
                        else:
                            for tk in final_output['rationale'][i][1]:
                                if text_t_seg[tk] == text_q_seg[qk]:
                                    new_rationale.append(tk)
                                    final_output['rationale'][i][1].remove(tk)
                                    break
                new_rationale += final_output['rationale'][i][1]
                final_output['rationale'][i][1] = new_rationale
                
                
            if denoising and (criteria == 'ODB_t' or criteria == 'RDB_t'):
                new_rationale = []
                uncommon_flag = 0 
                for tk in final_output['rationale'][i][1]:
                    if len(final_output['rationale'][i][0])==0:
                        break
                        
                    first_q = final_output['rationale'][i][0][0]
                    
                    if text_q_seg[first_q]== text_t_seg[tk]:
                        new_rationale.append(first_q)
                        final_output['rationale'][i][0].remove(first_q)
                        if len(final_output['rationale'][i][0])>0:
                            continue 
                        else:
                            break
                            
                    else:                    
                        if text_t_seg[tk] not in text_q_seg:
                            if uncommon_flag < denoising_k:
                                uncommon_flag +=1 
                                uncommon_q_1st = final_output['rationale'][i][0][0] 
                                new_rationale.append(uncommon_q_1st)
                                final_output['rationale'][i][0].remove(uncommon_q_1st)
                                continue
                            else:
                                break
                        else:
                            for qk in final_output['rationale'][i][0]:
                                if text_q_seg[qk] == text_t_seg[tk]:
                                    new_rationale.append(qk)
                                    final_output['rationale'][i][0].remove(qk)
                                    break
                new_rationale += final_output['rationale'][i][0]
                final_output['rationale'][i][0] = new_rationale

    return final_output

### Write outputs

In [None]:
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method = 'dual_ranking', denoising = True)
out_file = open('./rationale_results/sim_rationale_dual_ranking.txt', 'w')
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()

### Some baselines

In [None]:
# No denoising 
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method = 'dual_ranking', denoising = False)
out_file = open('./rationale_results/sim_rationale_wo_denoising.txt', 'w')
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()

In [None]:
# lexicality ranking only 
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method= 'lexicality_only',  denoising = False)
out_file = open('./rationale_results/sim_rationale_lexicality.txt', 'w')
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()

In [None]:
# lime positive
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method = 'positive_lime', denoising = False)
out_file = open('./rationale_results/sim_rationale_lime.txt', 'w')
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()

In [None]:
# proportional lime
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method = 'proportional_lime', denoising = False)
out_file = open('./rationale_results/sim_rationale_proportional_lime.txt', 'w')  
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()

In [None]:
# shap positive 
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method = 'positive_shap', denoising = False)
out_file = open('./rationale_results/sim_rationale_shap.txt', 'w')
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()

In [None]:
# proportional shap  
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method = 'proportional_shap', denoising = False)
out_file = open('./rationale_results/sim_rationale_proportional_shap.txt', 'w')
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()

In [None]:
# ig positive
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method = 'positive_ig', denoising = False)
out_file = open('./rationale_results/sim_rationale_ig.txt', 'w')
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()


In [None]:
# proportional ig
final_output = dual_ranking_bi_criteria_denoising(dual_ranking_rationale, final_prediction_macbert_EF, rationale_method = 'proportional_ig', denoising = False)
out_file = open('./rationale_results/sim_rationale_proportional_ig.txt', 'w')
for i in range(len(final_output['id'])):
    out_file.write(str(final_output['id'][i]) + '\t'+ str(final_output['label'][i]) + '\t' +
                   ','.join([str(i) for i in final_output['rationale'][i][0]]) +'\t'+
                   ','.join([str(i) for i in final_output['rationale'][i][1]]) +'\n')
out_file.close()
