In [9]:
import re
import json

In [10]:
def read_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

In [11]:
def remove_punctuation(item):
    item = re.sub(r"\(.*?\)", "", item)
    item = re.sub(r"[^\w\s]", "", item)
    return item.strip()

In [12]:
#def text(item, ref_output, gen_output):
    #return
def diagnosis(name, gen_output):
    name = remove_punctuation(name)
    gen_score = 0
    gen_match = set()
    char_count = len(name)
    if char_count <= 2:
        req_match = char_count
    elif 2 < char_count <= 5:
        req_match = 3
    elif 5 < char_count <= 8:
        req_match = 4
    else:
        req_match = char_count - 4
    window_size = char_count + 2
    for i in range(len(gen_output) - window_size + 1):
        window = gen_output[i:i + window_size]
        match_count = sum(1 for char in name if char in window)
        if match_count >= req_match and name not in gen_match:
            gen_score = 1
            gen_match.add(name)
            break
    return gen_score
def items(contexts, ref_output, gen_output):
    items = []
    if isinstance(contexts, list):
        if all(isinstance(item, list) for item in contexts):
            items = [element for sublist in contexts for element in sublist]
        else:
            items = contexts
    else:
        items = [contexts]
    items = [remove_punctuation(item) for item in items]
    ref_count = 0
    gen_count = 0
    ref_match = set()
    gen_match = set()
    for item in items:
        if item in ref_output and item not in ref_match:
            ref_count += 1
            ref_match.add(item)
        if item in gen_output and item not in gen_match:
            gen_count += 1
            gen_match.add(item)
    return gen_count / ref_count if ref_count != 0 else 1
def time(item, gen_output):
    sp_cases = {
        "一个月": "1个月", "两个月": "2个月","三个月": "3个月", 
        "半年": "6个月", "一年": "1年", "两年": "2年", "三年": "3年", "数年": "3-5年"
    }
    for case, value in sp_cases.items():
        item, gen_output = item.replace(case, value), gen_output.replace(case, value)
    keywords = ['终身', '长期', '间歇性','一直', '永久']
    pattern = r'(\d+)\s*[—–\-~─－至到]+\s*(\d+)?\s*(?:个)?\s*(周期|月|年|天|周|日)'
    matches = re.findall(pattern, item)
    time_info = []
    for match in matches:
        start_time = float(match[0])
        end_time = float(match[1]) if match[1] else start_time
        unit = match[2]
        time_info.append((start_time, end_time, unit))
    if not time_info: 
        item_has_keyword = any(keyword in item for keyword in keywords)
        gen_has_keyword = any(keyword in gen_output for keyword in keywords)
        if not item_has_keyword: return 1
        elif gen_has_keyword: return 1
        else: return 0
    pat_match = re.findall(pattern, gen_output)
    for match in pat_match:
        t1 = float(match[0])
        t2 = float(match[1]) if match[1] else t1
        u = match[2]
        for start_time, end_time, unit in time_info:
            if start_time <= t1 <= end_time and start_time <= t2 <= end_time and u == unit: return 1
    return 0
def prob(item, gen_output):
    ranges = []
    pattern = r'(\d+(?:\.\d+)?%)\s*\D*\s*(\d+(?:\.\d+)?%)?'
    matches = re.findall(pattern, item)
    for match in matches:
        values = [float(value.replace('%', '')) for value in match if value]
        lower = min(values)
        upper = max(values) if len(values) > 1 else lower
        ranges.append((lower, upper))
    if not ranges: return 1
    output_probs = re.findall(r'\d+(?:\.\d+)?%', gen_output)
    output_probs = [float(prob.replace('%', '')) for prob in output_probs]
    if not output_probs: return 0
    for prob in output_probs:
        for lower, upper in ranges:
            if lower > prob or prob > upper: return 0
    return 1
def cost(item, gen_output):
    ranges = []
    pattern = r'[（(]\s*([\d,]+(?:\.\d+)?)\s*[—–\-~─－]+\s*([\d,]+(?:\.\d+)?)\s*元?\s*[）)]'
    matches = re.finditer(pattern, item)
    for match in matches:
        lower = float(match.group(1).replace(',', ''))
        upper = float(match.group(2).replace(',', ''))
        ranges.append((lower, upper))
    if not ranges: return 1
    numbers_in_output = re.findall(r'[\d]+', gen_output)
    numbers_in_output = [int(num_str) for num_str in numbers_in_output]
    if not numbers_in_output: return 0
    for num in numbers_in_output:
        for lower, upper in ranges:
            if lower > num or num > upper: return 0
    return 1
def status(item, gen_output):
    ref_yes = not any(phrase in item for phrase in ["无传染性", "非传染性疾病", "无特定人群", "无特定的人群", "无特殊人群", "无特发人群", "否"])
    gen_yes = True
    phrases = re.finditer(r"(传染病|传染性|医保|人群)", gen_output)
    for match in phrases:
        start = max(0, match.start() - 10)
        end = min(len(gen_output), match.end() + 3)
        context = gen_output[start:end]
        if re.search(r"(没|非|不|无)", context): gen_yes = False
    return 1 if ref_yes == gen_yes else 0

In [13]:
def check_entities(instance, reference_data, generated_data, index, start_prompt, score, n, s7, n7, s8, n8, s9, n9):
    def process_text(contexts, count):
        #reference_output = reference_data[index + count]['output']
        #generated_output = generated_data[index + count]['output']
        count += 1
        #result = compare(contexts[0], reference_output, generated_output)
        #return result, count
        return count
    def process_name(name, count):
        generated_output = generated_data[index + count]['output']
        count += 1
        result = diagnosis(name, generated_output)
        return result, count
    def process_items(contexts, count):
        reference_output = reference_data[index + count]['output']
        generated_output = generated_data[index + count]['output']
        count += 1
        result = items(contexts, reference_output, generated_output)
        return result, count
    def process_mixed_1(contexts, count):
        reference_output = reference_data[index + count]['output']
        generated_output = generated_data[index + count]['output']
        count += 1
        result_1 = status(contexts[0], generated_output)
        result_2 = prob(contexts[1], generated_output)
        result_3 = status(contexts[2], generated_output)
        return result_1, result_2, result_3, count
    def process_mixed_2(contexts, count):
        reference_output = reference_data[index + count]['output']
        generated_output = generated_data[index + count]['output']
        count += 1
        result_1 = items(contexts[0], reference_output, generated_output)
        result_2 = prob(contexts[1], generated_output)
        result_3 = time(contexts[2], generated_output)
        return result_1, result_2, result_3, count
    def process_mixed_3(contexts, count):
        generated_output = generated_data[index + count]['output']
        count += 1
        result_1 = cost(contexts[0], generated_output)
        result_2 = status(contexts[1], generated_output)
        return result_1, result_2, count

    prompt_start = start_prompt if index == 0 else 0
    count = 0

    if prompt_start <= 0:
        contexts = [instance.get('desc', '')]
        if any(context not in ['', []] for context in contexts):
            #result, count = process_text(contexts, count)
            count = process_text(contexts, count)
            n[0] += 1
            #score[0] = (score[0] * (n[0]-1) + result) / n[0]
    if prompt_start <= 1:
        contexts = [instance.get('cause', '')]
        if any(context not in ['', []] for context in contexts):
            #result, count = process_text(contexts, count)
            count = process_text(contexts, count)
            n[1] += 1
            #score[1] = (score[1] * (n[1]-1) + result) / n[1]
    if prompt_start <= 2:
        contexts = [instance.get('prevent', '')]
        if any(context not in ['', []] for context in contexts):
            #result, count = process_text(contexts, count)
            count = process_text(contexts, count)
            n[2] += 1
            #score[2] = (score[2] * (n[2]-1) + result) / n[2]
    if prompt_start <= 3:
        contexts = [instance.get('symptom', '')]
        if any(context not in ['', []] for context in contexts):
            result, count = process_name(instance['name'], count)
            n[3] += 1
            score[3] = (score[3] * (n[3]-1) + result) / n[3]
    if prompt_start <= 4:
        contexts = [instance.get('symptom', '')]
        if any(context not in ['', []] for context in contexts):
            result, count = process_items(contexts, count)
            n[4] += 1
            score[4] = (score[4] * (n[4]-1) + result) / n[4]
    if prompt_start <= 5:
        contexts = [instance.get('check', '')]
        if any(context not in ['', []] for context in contexts):
            result, count = process_items(contexts, count)
            n[5] += 1
            score[5] = (score[5] * (n[5]-1) + result) / n[5]
    if prompt_start <= 6:
        contexts = [instance.get('acompany', '')]
        if any(context not in ['“暂无”', '', []] for context in contexts):
            result, count = process_items(contexts, count)
            n[6] += 1
            score[6] = (score[6] * (n[6]-1) + result) / n[6]
    if prompt_start <= 7:
        contexts = [instance.get('get_way', ''), instance.get('get_prob', ''), instance.get('easy_get', '')]
        if any(context not in ['', []] for context in contexts):
            result_1, result_2, result_3, count = process_mixed_1(contexts, count)
            if instance.get('get_way', '') not in ['', []]:
                n7[0] += 1
                s7[0] = (s7[0] * (n7[0]-1) + result_1) / n7[0]
            if instance.get('get_prob', '') not in ['', []]:
                n7[1] += 1
                s7[1] = (s7[1] * (n7[1]-1) + result_2) / n7[1]
            if instance.get('easy_get', '') not in ['', []]:
                n7[2] += 1
                s7[2] = (s7[2] * (n7[2]-1) + result_3) / n7[2]
            score[7] = sum(s7) / 3
    if prompt_start <= 8:
        contexts = [instance.get('cure_way', ''), instance.get('cured_prob', ''), instance.get('cure_lasttime', '')]
        if any(context not in ['', []] for context in contexts):
            result_1, result_2, result_3, count = process_mixed_2(contexts, count)
            if instance.get('cure_way', '') not in ['', []]:
                n8[0] += 1
                s8[0] = (s8[0] * (n8[0]-1) + result_1) / n8[0]
            if instance.get('cured_prob', '') not in ['', []]:
                n8[1] += 1
                s8[1] = (s8[1] * (n8[1]-1) + result_2) / n8[1]
            if instance.get('cure_lasttime', '') not in ['', []]:
                n8[2] += 1
                s8[2] = (s8[2] * (n8[2]-1) + result_3) / n8[2]
            score[8] = sum(s8) / 3
    if prompt_start <= 9:
        contexts = [instance.get('cost_money', ''), instance.get('yibao_status', '')]
        if any(context not in ['', []] for context in contexts):
            result_1, result_2, count = process_mixed_3(contexts, count)
            if instance.get('cost_money', '') not in ['“暂无”', '', []]:
                n9[0] += 1
                s9[0] = (s9[0] * (n9[0]-1) + result_1) / n9[0]
            if instance.get('yibao_status', '') not in ['“暂无”', '', []]:
                n9[1] += 1
                s9[1] = (s9[1] * (n9[1]-1) + result_2) / n9[1]
            score[9] = sum(s9) / 2
    if prompt_start <= 10:
        contexts = [instance.get('common_drug', ''), instance.get('recommand_drug', '')]
        if any(context not in ['', []] for context in contexts):
            result, count = process_items(contexts, count)
            n[10] += 1
            score[10] = (score[10] * (n[10]-1) + result) / n[10]
    if prompt_start <= 11:
        contexts = [instance.get('drug_detail', '')]
        if any(context not in ['', []] for context in contexts):
            result, count = process_items(contexts, count)
            n[11] += 1
            score[11] = (score[11] * (n[11]-1) + result) / n[11]
    if prompt_start <= 12:
        contexts = [instance.get('do_eat', ''), instance.get('not_eat', '')]
        if any(context not in ['', []] for context in contexts):
            result, count = process_items(contexts, count)
            n[12] += 1
            score[12] = (score[12] * (n[12]-1) + result) / n[12]
    if prompt_start <= 12:
        contexts = [instance.get('recommand_eat', '')]
        if any(context not in ['', []] for context in contexts):
            result, count = process_items(contexts, count)
            n[13] += 1
            score[13] = (score[13] * (n[13]-1) + result) / n[13]
    
    return index + count, score, n, s7, n7, s8, n8, s9, n9

In [14]:
def process_json_instances(start_entity, end_entity, start_prompt, kg_file, input_file, output_file, model):
    kg_data = read_data(kg_file)
    reference_data = read_data(input_file)
    generated_data = read_data(output_file)

    start_processing = False
    index = 0
    score, n = [0]*14, [0]*14
    s7, n7, s8, n8, s9, n9 = [0]*3, [0]*3, [0]*3, [0]*3, [0]*2, [0]*2

    for instance in kg_data:
        if instance['name'] == start_entity:
            start_processing = True
        if start_processing:
            index, score, n, s7, n7, s8, n8, s9, n9 = check_entities(instance, reference_data, generated_data, index, start_prompt, score, n, s7, n7, s8, n8, s9, n9)
        if instance['name'] == end_entity:
            start_processing = False

    #text_type = ["desc", "cause", "prevent"]
    #text_score = [score[0], score[1], score[2]]
    type1 = ["diagnosis"]
    score1 = [score[3]]
    type2 = ["symptoms", "examinations", "complications", "treatments", "medications", "medication_details", "diet", "recipes"]
    score2 = [score[4], score[5], score[6], s8[0], score[10], score[11], score[12], score[13]]
    type3 = ["morbidity rate", "cure rate", "duration", "cost"]
    score3 = [s7[1], s8[1], s8[2], s9[0]]
    type4 = ["transmission", "susceptible groups", "insurance_status"]
    score4 = [s7[0], s7[2], s9[1]]
   
    #indices = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
    #types = [f"type{index + 1:<3}" for index in indices]
    #values = [f"{score[index]*100:<7.2f}" for index in indices]

    print("——————————————————————————————————————————————————————————————————————————————————————")
    print(f"MODEL: {model}")
    print("——————————————————————————————————————————————————————————————————————————————————————")
    #print("QUESTION-TYPE")
    #print(" ".join(types))
    #print(" ".join(values))
    #print("—————————————————————————————————————————————————————————————————————————————————————————————————————————————————————————————")
    #print("TEXT")
    #print(" ".join((f"{type:<13}" for type in text_type)))
    #print(" ".join(f"{score*100:<13.2f}" for score in text_score))
    #print("----------------------------------------------------------------------------------------------------------------------------")
    print("Sliding Window Fuzzy Match")
    print(" ".join(f"{type:<13}" for type in type1))
    print(" ".join(f"{score*100:<13.2f}" for score in score1))
    print("-----------------------------------------------------------------------------------------------------------------------------")
    print("Match Coverage Ratio")
    print(" ".join(f"{type:<13}" for type in type2))
    print(" ".join(f"{score*100:<13.2f}" for score in score2))
    print("-----------------------------------------------------------------------------------------------------------------------------")
    print("Statistical Range Fit")
    print(" ".join((f"{type:<13}" for type in type3)))
    print(" ".join(f"{score*100:<13.2f}" for score in score3))
    print("-----------------------------------------------------------------------------------------------------------------------------")
    print("Consistency")
    print(" ".join((f"{type:<13}" for type in type4)))
    print(" ".join(f"{score*100:<13.2f}" for score in score4))
    print("—————————————————————————————————————————————————————————————————————————————————————————————————————————————————————————————")

In [None]:
start_entity = '肺泡蛋白质沉积症'
#end_entity = '肺泡蛋白质沉积症'
end_entity = '糖尿病乳酸性酸中毒'
start_prompt = 0

models = ['medka-8b']
#models = ['medka-8b', 'llama3.1-8b-chinese-chat', 'qwen2.5-7b-instruct', 'huatuogpt2-7b', 'apollo2-7b']
kg_file_path = 'dataset/CMKG/medical_data.json'
input_file_path = 'dataset/cMKGQA/evaluation_data.json'
for model in models:
    #output_file_path = f'dataset/cMKGQA/evaluation_data.json'
    output_file_path = f'output/cMKGQA/test_medka-8b.json'
    process_json_instances(start_entity, end_entity, start_prompt, kg_file_path, input_file_path, output_file_path, model)