In [None]:
import boto3
import json
import os
import numpy as np

smr_client = boto3.client("sagemaker-runtime")

In [None]:
def get_vector_by_sm_endpoint(questions, sm_client, endpoint_name, instruction=''):
    parameters = {
    }

    response_model = sm_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(
            {
                "inputs": questions,
                "is_query": True,
                "instruction" :  instruction,
                "parameters" : ""
            }
        ),
        ContentType="application/json",
    )
    # 中文instruction => 为这个句子生成表示以用于检索相关文章：
    # English instruction => Represent this sentence for searching relevant passages:
    json_str = response_model['Body'].read().decode('utf8')
    json_obj = json.loads(json_str)
    embeddings = json_obj['sentence_embeddings']
    return embeddings[0]

def similarity_calc(vec1, vec2):
    dot_product = np.dot(vec1, vec2)
    norm_vec1 = np.sqrt(np.dot(vec1, vec1))
    norm_vec2 = np.sqrt(np.dot(vec2, vec2))
    cosine_sim = dot_product / (norm_vec1 * norm_vec2)
    
    return cosine_sim

def similarity_calc_stat(vec1, vec2_list):
    sim_list = [ similarity_calc(vec1, vec2) for vec2 in vec2_list ]
    return np.max(sim_list), np.mean(sim_list)

def similarity_stat(vec1, vec2_list, pos_sim_val):
    sim_list = [ similarity_calc(vec1, vec2) for vec2 in vec2_list ]
    wrong_cnt = 0
    delta_sim_val_list = []
    
    for item in sim_list:
        if item > pos_sim_val:
            wrong_cnt += 1
    
    return wrong_cnt, sim_list

In [65]:
def visual_preprocess(input_file, pair_a_name, pair_b_name, neg_name, endpoint_name, instruction=''):           
    pos_rank_list = []
    neg_similarity_list = []
    pos_similarity_list = []
    test_items = []
    
    filename = os.path.basename(input_file).split('.')[0]
    with open(input_file, 'r') as input_f:
        lines = input_f.readlines()
        for line in lines:
            json_obj = json.loads(line)
            test_items.append(json_obj)

        for idx, test_item in enumerate(test_items):
            q = test_item['query']
            p = test_item['pos'][0]
            neg_p_list = list(set(test_item['neg']))
            q_emb = get_vector_by_sm_endpoint(q, smr_client, endpoint_name, instruction)
            p_emb = get_vector_by_sm_endpoint(p, smr_client, endpoint_name)
            neg_p_embs = [ get_vector_by_sm_endpoint(neg_p, smr_client, endpoint_name) for neg_p in neg_p_list ]

            pos_sim = similarity_calc(q_emb, p_emb)
            pos_rank, neg_sim_vals = similarity_stat(q_emb, neg_p_embs, pos_sim)
            
            pos_similarity_list.append(pos_sim)
            pos_rank_list.append(pos_rank)
            neg_similarity_list.extend(neg_sim_vals)

            if idx % 100 == 0:
                print(f"{idx}-th : pos_sim: {pos_sim}, pos_rank:{pos_rank}")
              
    return pos_rank_list, pos_similarity_list, neg_similarity_list

In [59]:
import seaborn as sns
import matplotlib.pyplot as plt
from collections import Counter
import random

def plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, plot_name):
    def gen_label_name(k):
        return "other" if k > 4 else f"Top{k+1}" 
    pos_rank_str_list = [ gen_label_name(item) for item in pos_rank_list]
    counter = Counter(pos_rank_str_list)
    
    result_dict = { k : v for k, v in counter.items()}
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    plt.figure(figsize=(6, 6))
    sns.set_style("whitegrid")
    ax1.pie(result_dict.values(), labels=result_dict.keys(), autopct='%1.1f%%', startangle=140, labeldistance=1.1)
    ax1.set_title('PostiveLabel Rank')  

    plt.figure(figsize=(8, 6))
    sns.set_style("whitegrid")
    N = len(pos_similarity_list)
    sample_neg_similarity_list = random.sample(neg_similarity_list, N)
    sns.histplot(pos_similarity_list, color='green',kde=True, ax=ax2)
    sns.histplot(sample_neg_similarity_list, color='red',kde=True, ax=ax2)
    ax2.set_title('Similarity Dist')
    ax2.set_xlabel('Sim_Value')
    ax2.set_ylabel('Frequency')
    
    plt.tight_layout()
    fig.savefig(f"{plot_name}.png", dpi=100, bbox_inches='tight')

### Check Recall@N and similarity distribution

In [83]:
endpoint_name_bge = 'bge-large-zh-2023-09-15-08-29-52-242-endpoint'
endpoint_name_st = 'finetuned-mpnet-2023-09-16-06-05-37-797-endpoint'
endpoint_name_bge15 = 'bge-zh-15-2023-09-17-01-00-27-086-endpoint'
endpoint_name_st_ft = 'finetuned-mpnet-bz9-2023-09-17-23-26-37-824-endpoint'
endpoint_name_bge15_ft = 'bge15-finetuned-2023-09-18-02-41-16-015-endpoint'

In [80]:
!mkdir bge_finetune15
!mkdir bge_finetune15/eval
!mkdir bge_finetune15/test

In [None]:
%time
pos_rank_list, pos_similarity_list, neg_similarity_list = visual_preprocess('chatgpt_synthesis/qq1_valid.jsonl', 'origin_q', 'generated_q','generated_neg',endpoint_name_bge15_ft)
plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, 'bge_finetune15/eval/origin_q-generated_q-generated_neg-bge')

pos_rank_list, pos_similarity_list, neg_similarity_list = visual_preprocess('chatgpt_synthesis/qq2_valid.jsonl', 'origin_q', 'generated_q', 'origin_neg',endpoint_name_bge15_ft)
plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, 'bge_finetune15/eval/origin_q-generated_q-origin_neg-bge')

pos_rank_list, pos_similarity_list, neg_similarity_list = visual_preprocess('chatgpt_synthesis/oqgd_valid.jsonl', 'origin_q', 'generated_d', 'generated_d',endpoint_name_bge15_ft)
plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, 'bge_finetune15/eval/origin_q-generated_d-generated_d-bge')

pos_rank_list, pos_similarity_list, neg_similarity_list = visual_preprocess('chatgpt_synthesis/gqod_valid.jsonl', 'generated_q', 'origin_d', 'origin_d',endpoint_name_bge15_ft)
plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, 'bge_finetune15/eval/generated_q-origin_d-origin_d-bge')

pos_rank_list, pos_similarity_list, neg_similarity_list = visual_preprocess('chatgpt_synthesis/qq1_test.jsonl', 'origin_q', 'generated_q','generated_neg',endpoint_name_bge15_ft)
plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, 'bge_finetune15/test/origin_q-generated_q-generated_neg-bge')

pos_rank_list, pos_similarity_list, neg_similarity_list = visual_preprocess('chatgpt_synthesis/qq2_test.jsonl', 'origin_q', 'generated_q', 'origin_neg',endpoint_name_bge15_ft)
plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, 'bge_finetune15/test/origin_q-generated_q-origin_neg-bge')

pos_rank_list, pos_similarity_list, neg_similarity_list = visual_preprocess('chatgpt_synthesis/oqgd_test.jsonl', 'origin_q', 'generated_d', 'generated_d',endpoint_name_bge15_ft)
plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, 'bge_finetune15/test/origin_q-generated_d-generated_d-bge')

pos_rank_list, pos_similarity_list, neg_similarity_list = visual_preprocess('chatgpt_synthesis/gqod_test.jsonl', 'generated_q', 'origin_d', 'origin_d',endpoint_name_bge15_ft)
plot_stat(pos_rank_list, pos_similarity_list, neg_similarity_list, 'bge_finetune15/test/generated_q-origin_d-origin_d-bge')