In [9]:
""" compare 3 Embedding models 
1. base
2. epoch 2
3. epoch 25
"""

import json
from tqdm.notebook import tqdm
import pandas as pd
import os

In [7]:
#load dataset

train_dataset_path = 'da_dataset/da_train_dataset.json'
val_dataset_path = 'da_dataset/da_val_dataset.json'


with open(train_dataset_path, 'r+', encoding='utf-8') as f :
    train_dataset = json.load(f)

with open(val_dataset_path, 'r', encoding='utf-8') as f :
    val_dataset = json.load(f)

In [10]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import SentenceTransformer

def evaluate_st(
    dataset,
    model_id,
    name,
):
    corpus = dataset['corpus']
    queries = dataset['queries']
    relevant_docs = dataset['relevant_docs']

    evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs, name=name)
    model = SentenceTransformer(model_id)
    if not os.path.exists('eval_results'):
        os.mkdir('eval_results')

    return evaluator(model, output_path='results/')

In [11]:
#model path config
kosim = "./BM-K/KoSimCSE-roberta-multitask/"
epoch_2 = "./da_finetune_epoch_2/"
epoch_25 = "./da_finetune_epoch_25/"

In [12]:
evaluate_st(val_dataset, kosim, name="KoSimCSE")

0.5660207518685775

In [13]:
evaluate_st(val_dataset, epoch_2, name="Epoch2")

0.8981735343096556

In [14]:
evaluate_st(val_dataset, epoch_25, name="Epoch25")

0.9023171302459595

In [15]:
df_st_kosim = pd.read_csv("results/Information-Retrieval_evaluation_KoSimCSE_results.csv")
df_st_epoch2 = pd.read_csv("results/Information-Retrieval_evaluation_Epoch2_results.csv")
df_st_epoch25 = pd.read_csv("results/Information-Retrieval_evaluation_Epoch25_results.csv")

In [16]:
df_st_kosim['model'] = 'KoSimCSE'
df_st_epoch2['model'] = 'epoch2'
df_st_epoch25['model'] = 'epoch25'
df_st_all = pd.concat([df_st_kosim, df_st_epoch2, df_st_epoch25])
df_st_all = df_st_all.set_index('model')
df_st_all

Unnamed: 0_level_0,epoch,steps,cos_sim-Accuracy@1,cos_sim-Accuracy@3,cos_sim-Accuracy@5,cos_sim-Accuracy@10,cos_sim-Precision@1,cos_sim-Recall@1,cos_sim-Precision@3,cos_sim-Recall@3,...,dot_score-Recall@1,dot_score-Precision@3,dot_score-Recall@3,dot_score-Precision@5,dot_score-Recall@5,dot_score-Precision@10,dot_score-Recall@10,dot_score-MRR@10,dot_score-NDCG@10,dot_score-MAP@100
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
KoSimCSE,-1,-1,0.456543,0.634193,0.694365,0.775549,0.456543,0.456543,0.211398,0.634193,...,0.409742,0.197389,0.592168,0.132951,0.664756,0.075263,0.752627,0.517378,0.573762,0.526105
epoch2,-1,-1,0.843362,0.948424,0.964661,0.978032,0.843362,0.843362,0.316141,0.948424,...,0.818529,0.31455,0.943649,0.192741,0.963706,0.097899,0.978988,0.883868,0.907693,0.884566
epoch25,-1,-1,0.859599,0.936963,0.95702,0.974212,0.859599,0.859599,0.312321,0.936963,...,0.851958,0.312321,0.936963,0.191595,0.957975,0.097421,0.974212,0.897436,0.916352,0.898317
