In [22]:
import json
from tqdm import tqdm
import pandas as pd

def find_lcsubstr(str_a, str_b):
#  longest common subsequence of str_a and str_b, with O(n) space complexity
    if len(str_a) == 0 or len(str_b) == 0:
        return 0
    dp = [0 for _ in range(len(str_b) + 1)]
    for i in range(1, len(str_a) + 1):
        left_up = 0
        dp[0] = 0
        for j in range(1, len(str_b) + 1):
            left = dp[j-1]
            up = dp[j]
            if str_a[i-1] == str_b[j-1]:
                dp[j] = left_up + 1
            else:
                dp[j] = max([left, up])
            left_up = up
    return dp[len(str_b)]/min(len(str_b),len(str_a))

def data_loading(path):
    all_keys = ['question_id','question','qas_id','doc_tokens','answer_text','start_position','end_position','is_impossible']
    all_values = [[] for i in range(len(all_keys))]
    final_dic = dict(zip(all_keys,all_values))
    with open(path,encoding='utf-8') as json_file:
        for line in tqdm(json_file):
            data = json.loads(line)
            current_keys = data.keys()
            for key in current_keys:
                final_dic[key].append(data[key])
    final_pd = pd.DataFrame(final_dic)
    final_pd.start_position = final_pd.start_position.replace({-1:0})
    final_pd.end_position = final_pd.end_position.replace({-1:0})
    return final_pd

test_df_location = 'cleaned_data_test.json'
evaluate_file = '1130_roberta_3.0data.json'

test_df = data_loading(test_df_location)
all_qas_id = list(test_df['qas_id'])
all_question = list(test_df['question'])
all_doc_tokens = list(test_df['doc_tokens'])
answer_all = list(test_df['answer_text'])
final_qas_map = dict(zip(all_qas_id,answer_all))
final_que_map = dict(zip(all_qas_id,all_question))
final_doc_map = dict(zip(all_qas_id,all_doc_tokens))
shot_ans = ['是什么','在哪里','哪位','哪国','国家','国籍','哪个朝代','本名','原名','哪个','哪一个','哪','最','几次','是谁','学名','第一个','第一','最大','最小','最长','最深','最高','最多','第一','作者','叫什么名字','名字','叫什么']

1500it [00:00, 93762.38it/s]


In [23]:
import jieba

def check_relevance(question, answer, do_filter=False, threshould = 0.2):  # do_filter表示是否返回合格答案 True则返回 False只返回判断
    check_list = ['建议','当地','需要','确定','咨询','询问','客服','相关','部门','有关部门','私信','同求','您好','办理','手续','及时']
    right_list = ['这种情况','情况','病史','描述','叙述','根治','缓解','改善','症状','表现','食用','口服','服用','导致','表现','形容',
    '危险','根据你的描述','引起的','根据','正常','可能是','你说的','导致的','造成的','检查一下','试试','尝试','进行治疗','检查','治疗',
    '药物治疗','治疗中','相关治疗','治疗方式','治疗方案','治疗效果','常规治疗','其他治疗','治疗后','手术治疗','治疗方法','综合治疗',
    '系统治疗','治疗手段','治疗方面','中医治疗','治疗措施','治疗过程中','疾病的治疗','对症治疗','持续治疗','临床治疗','其它治疗',
    '西医治疗','保守治疗','医治','长期治疗','治疗效果不佳','一般治疗','治疗法']
    flag = 0 # flag = 1 代表这个数据中存在合格答案 （长度大于20且无关键词）
    mark = 0
    valid_answer = ''
    sentence = answer
    current_sentence = [i for i in jieba.cut(sentence, cut_all=False)]
    current_question = [i for i in jieba.cut(question, cut_all=False)]
#         current_sentence = [i.term for i in wordseg.segment(sentence).basic_words]
#         current_question = [i.term for i in wordseg.segment(question).basic_words]
    for item in right_list:
        if item in sentence:
            flag = 1
            valid_answer = sentence if do_filter else 0
            break
    if flag == 1 and do_filter:
        return valid_answer, flag
    elif flag == 1 and not do_filter:
        return flag
    
    else:
        if len(current_sentence) <= 20: # 长度小于20的情况
            for item in check_list:
                if item in sentence:
                    mark = 1
                    break
            if mark == 1 and do_filter: # 包含关键词不考虑
                return valid_answer, flag
            elif mark == 1 and not do_filter:
                return flag
            else: # 小于20单不包含关键词 则判断问题答案重叠率
                final_q = set(current_question)
                candidate_all = final_q.intersection(set(current_sentence))
                if len(candidate_all)/len(final_q) >= threshould:
                    flag = 1
                    valid_answer = sentence

        elif len(current_sentence) > 20:  # 长度大于20的情况
            final_q = set(current_question)
            candidate_all = final_q.intersection(set(current_sentence))
            if len(candidate_all)/len(final_q) >=threshould: #重叠率判断
                flag = 1
                valid_answer = sentence
                
    if do_filter:
        return valid_answer, flag
    else:
        return flag

In [33]:
def load_prediction_RoBERTA(path, final_qas_map,final_que_map, final_doc_map):
    final_content = []
    final_question = []
    final_answer = []
    final_prediction = []
    final_flag_prediction = []
    final_flag_answer = []
    
    prediction_result_100 = []
    prediction_result_60 = []
    prediction_not_null_60 = []
    prediction_not_null_100 = []
    count = 0
    with open(path, "r",encoding='utf-8') as f:
        all_data = f.readlines()
        for line in tqdm(all_data):
            current_object = json.loads(line)
            qas_id = current_object['question_ID']
            actual_answer = ''.join(current_object['answer'].split())
            actual_answer_final = final_qas_map[qas_id]
            prediction = ''.join(current_object['prediction'].split())
            
            if actual_answer_final!='':
                final_content.append('Paragraph:'+ final_doc_map[qas_id])
                final_question.append('Question' + final_que_map[qas_id])
                final_answer.append('Answer:' +actual_answer_final)
                final_prediction.append('Prediction:' + prediction)
                final_flag_prediction.append('Prediction_relevancy:'+str(check_relevance(final_que_map[qas_id], prediction)))      
                final_flag_answer.append('Answer_relevancy:'+str(check_relevance(final_que_map[qas_id], actual_answer_final)))   
            
            if actual_answer == prediction and actual_answer == '':
                prediction_result_100.append(1)
                prediction_result_60.append(1)
            else:
                overlap_rate = find_lcsubstr(actual_answer, prediction)
                if overlap_rate == 1: # 考虑包含情况
                    prediction_result_100.append(1)
                else:
                    prediction_result_100.append(0)
                if overlap_rate >=0.6: #考虑包含情况
                    prediction_result_60.append(1)
                else:
                    prediction_result_60.append(0)
#             (prediction in answer and len(prediction)/len(answer)>=0.5)
            if actual_answer_final!='':
                overlap_rate = find_lcsubstr(actual_answer, prediction)
                if overlap_rate == 1:
                    prediction_not_null_100.append(1)
                else:
                    prediction_not_null_100.append(0)
                if overlap_rate >= 0.6:
                    prediction_not_null_60.append(1)
                else:
                    prediction_not_null_60.append(0)
                    
            if actual_answer_final!='':
                    count +=1
        
        
        print('RoBERTa all prediction, total length:{}'.format(len(prediction_result_100)))
        print('100% 相同情况下 准确率：', prediction_result_100.count(1)/len(prediction_result_100))
        print('60% 相同情况下 准确率：', prediction_result_60.count(1)/len(prediction_result_60))
        print('非空总数：{}'.format(count))
        print('100% 相同情况下 答案非空 准确率：', prediction_not_null_100.count(1)/len(prediction_not_null_100))
        print('60% 相同情况下 答案非空 准确率：', prediction_not_null_60.count(1)/len(prediction_not_null_60))
        
    final_dic = list(zip(final_content,final_question,final_answer,final_prediction,final_flag_answer, final_flag_prediction))
    
    required_dic = list(zip(final_question,final_answer,final_flag_answer, final_prediction, final_flag_prediction))
    return final_dic,required_dic

In [34]:
final_dic_ROBERTA,required_dic_ROBERTA  = load_prediction_RoBERTA(evaluate_file, final_qas_map,final_que_map, final_doc_map)

100%|████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:00<00:00, 1838.85it/s]

RoBERTa all prediction, total length:1500
100% 相同情况下 准确率： 0.9533333333333334
60% 相同情况下 准确率： 0.9586666666666667
非空总数：1297
100% 相同情况下 答案非空 准确率： 0.9491133384734002
60% 相同情况下 答案非空 准确率： 0.9552814186584425





In [35]:
with open('sample_show.txt', 'w', encoding='utf-8') as f:
    for i in required_dic_ROBERTA:
        f.write('\t'.join(i))
        f.write('\n')