In [1]:
import requests
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import json
from tqdm import tqdm


In [2]:
URI = 'api_for_baichuan2_13B'

def call_api(input, word, mode):
    request = {
        'input': input,
        'word': word,
        'mode': mode,
    }
    response = requests.post(URI, json=request)

    if response.status_code == 200:
        result = response.json()
        return result

In [3]:
def get_sim(input_list, correct_extend_map, wrong_extend_map, mode):
    if 'zh' in mode:
        conj = '实体类型：'
    else:
        conj = 'Entity:'
    
    res_sim_dict = {k:[] for k in correct_extend_map.keys()}
    
    try:
        for idx, key in enumerate(correct_extend_map.keys()):
            cur_text_list = input_list[idx]
            cur_target_label = key
            cur_correct_extend_list = correct_extend_map[key]
            cur_wrong_extend_list = wrong_extend_map[key]
            
            correct_wrong_pairs = list(zip(cur_correct_extend_list, cur_wrong_extend_list))
            
            
            for cw_pair in tqdm(correct_wrong_pairs,total=len(correct_wrong_pairs), desc=f'get sim for {key}...'):
                cur_sim_list = []
                for text in cur_text_list:
                    input_t = f'"{text}" {conj}{cur_target_label}'
                    input_1 = f'"{text}" {conj}{cw_pair[0]}'
                    input_2 = f'"{text}" {conj}{cw_pair[1]}'
                    
                    sent_embedding_t = np.array(call_api(input_t, cur_target_label, mode))
                    sent_embedding_1 = np.array(call_api(input_1, cw_pair[0], mode))
                    sent_embedding_2 = np.array(call_api(input_2, cw_pair[1], mode))
                    
                    sim_A = cosine_similarity(sent_embedding_t.reshape(1, -1), sent_embedding_1.reshape(1, -1))[0][0]
                    sim_B = cosine_similarity(sent_embedding_t.reshape(1, -1), sent_embedding_2.reshape(1, -1))[0][0]
                    cur_sim_list.append(sim_A - sim_B)
                res_sim_dict[key].append(np.mean(cur_sim_list))
             
    except Exception as e:
        print(f'error: {e}')
        print(f'key: {key},')      
    return res_sim_dict
                
        

In [None]:
with open('./data/cmeee/input_list.txt','r',encoding='utf-8') as f:
    input_list = [line.strip() for line in f.readlines()]

result_list = []
for i in range(0, len(input_list), 10):
    sublist = input_list[i:i+3]
    result_list.append(sublist)

with open('./data/cmeee/select_correct_entity_extend_map_baichuan2_13B_zh.json','r',encoding='utf-8') as f:
    correct_extend_map = json.load(f)

with open('./data/cmeee/select_wrong_entity_extend_map_baichuan2_13B_zh.json','r',encoding='utf-8') as f:
    wrong_extend_map = json.load(f)

res_dict = get_sim(result_list, correct_extend_map, wrong_extend_map, mode='baichuan2_13B_zh')
for k in res_dict.keys():
    print(f'{k}: {np.mean(res_dict[k])}')

In [None]:
with open('./data/ace05/input_list.txt','r',encoding='utf-8') as f:
    input_list = [line.strip() for line in f.readlines()]

result_list = []
for i in range(0, len(input_list), 10):
    sublist = input_list[i:i+3]
    result_list.append(sublist)

with open('./data/ace05/select_correct_entity_extend_map_baichuan2_13B_en.json','r',encoding='utf-8') as f:
    correct_extend_map = json.load(f)

with open('./data/ace05/select_wrong_entity_extend_map_baichuan2_13B_en.json','r',encoding='utf-8') as f:
    wrong_extend_map = json.load(f)

res_dict = get_sim(result_list, correct_extend_map, wrong_extend_map, mode='baichuan2_13B_en')
for k in res_dict.keys():
    print(f'{k}: {np.mean(res_dict[k])}')