In [16]:
import sys
sys.path.insert(0, "/home/aisummer/mikhail_workspace/nlp-service")

import pandas as pd 
import ast
import numpy as np
import json
from tqdm import tqdm
import os

from src.DocumentsParser.utils import DBS_DIR_DENSE_VECTORDB_NAME, DBS_DIR_SPARSE_VECTORDB_NAME
from src.DocumentsRetriever.Retriever import RetrieverModule, RetrieverConfig
from src.evaluation_metrics import RetrievalMetrics
from src.utils import DialogueState
from src.logger import Logger

In [114]:
# !!! TO CHANGE !!!
SAVE_LOGFILE = './logs/trial21.json'
BENCHMARKS_INFO = {'sberquad': {'db': 'v2', 'table': 'v1'}, 'squadv2': {'db': 'v2', 'table': 'v1'}}
MODEL_PATH = "/home/aisummer/mikhail_workspace/nlp-service/models/intfloat/multilingual-e5-small"
ENCODE_KWARGS = {'normalize_embeddings': True, 'prompt': 'query: '}
PARAMS = {'similarity': {'k': 3}, 'mmr': {'lambda_mult': 0.5, 'fetch_k': 20, 'k': 3}, 'bm25': {'k': 3}}
WEIGHTS = [0.5, 0.1, 0.4]
BENCHES_SIZE = 1000
# !!! TO CHANGE !!!

In [115]:
banchmark_paths = {}
for name, version in BENCHMARKS_INFO.items():
    banchmark_paths[name] = {
        'table': f"../../data/{name}/tables/{version['table']}/benchmark.csv",
        'dense_db': f"../../data/{name}/dbs/{version['db']}/{DBS_DIR_DENSE_VECTORDB_NAME}",
        'sparse_db':  f"../../data/{name}/dbs/{version['db']}/{DBS_DIR_SPARSE_VECTORDB_NAME}"
    }

benchmark_config = {}
for name, paths in banchmark_paths.items():
    benchmark_config[name] = RetrieverConfig(
        model_path=MODEL_PATH,
        densedb_path=paths['dense_db'],
        sparsedb_path=paths['sparse_db'],
        encode_kwargs=ENCODE_KWARGS,
        params=PARAMS,
        weights=WEIGHTS)

In [116]:
# загрузить benchmark-датасет
benchmarks_df = {}
for name, bench_path in banchmark_paths.items():
    benchmarks_df[name] = pd.read_csv(banchmark_paths[name]['table'], sep=';').iloc[:BENCHES_SIZE,:]
    benchmarks_df[name]['chunk_ids'] = benchmarks_df[name]['chunk_ids'].map(lambda v: ast.literal_eval(v)) 
    benchmarks_df[name]['contexts'] = benchmarks_df[name]['contexts'].map(lambda v: ast.literal_eval(v)) 

# инифицализировать класс с метриками
metrics = RetrievalMetrics()

# logging
logger = Logger(False)
log = logger.get_logger(__name__)

In [117]:
# инициализировать ретриверов
retrievers = {name: RetrieverModule(b_config, log) for name, b_config in benchmark_config.items()}

No sentence-transformers model found with name /home/aisummer/mikhail_workspace/nlp-service/models/intfloat/multilingual-e5-small. Creating a new one with mean pooling.
No sentence-transformers model found with name /home/aisummer/mikhail_workspace/nlp-service/models/intfloat/multilingual-e5-small. Creating a new one with mean pooling.


In [118]:
def get_relevant_chunk_ids(df, retriever):
    relevant_chunk_ids = []
    for i in tqdm(range(df.shape[0])):
        state = DialogueState(query=df['question'][i])
        retriever.base_search(state)
        relevant_chunk_ids.append([item.metadata['chunk_id'] for item in state.base_relevant_docs])
    
    return relevant_chunk_ids


benchmarks_score = {}
for i, name in enumerate(benchmarks_df.keys()):
    print(name)
    # для каждого запроса получить список релевантных чанков
    bench_pred_chunk_ids = get_relevant_chunk_ids(benchmarks_df[name], retrievers[name])
    bench_golden_chunk_ids = benchmarks_df[name]['chunk_ids'].to_list()

    mrr_score = metrics.MRR(bench_pred_chunk_ids, bench_golden_chunk_ids)
    print("MRR: ", mrr_score)
    
    recall = [metrics.recall(relevant_ids, golden_ids, k=3) 
              for relevant_ids, golden_ids in zip(bench_pred_chunk_ids, bench_golden_chunk_ids)]
    print("median Recall: ", np.median(recall))

    precision = [metrics.precision(relevant_ids, golden_ids, k=3) 
                 for relevant_ids, golden_ids in zip(bench_pred_chunk_ids, bench_golden_chunk_ids)]
    print("median Precision: ", np.median(precision))

    f1 = list(map(lambda v: 0 if np.isnan(v) else v, [metrics.f1_score(relevant_ids, golden_ids, k=3) 
          for relevant_ids, golden_ids in zip(bench_pred_chunk_ids, bench_golden_chunk_ids)]))
    print("median F1: ", np.median(f1))

    # посчитать метрики 
    score = {
        'MRR': mrr_score,
        'Recall': np.median(recall),
        'Precision': np.median(precision),
        'F1': np.median(f1)
    }

    benchmarks_score[name] = score

sberquad


100%|██████████| 1000/1000 [00:27<00:00, 36.96it/s]
  return (2 * self.precision(pred_cands, gold_cands, k) * self.recall(pred_cands, gold_cands, k)) / (self.precision(pred_cands, gold_cands, k) + self.recall(pred_cands, gold_cands, k))


MRR:  0.83216659
median Recall:  1.0
median Precision:  0.33333
median F1:  0.499996249990625
squadv2


100%|██████████| 1000/1000 [00:40<00:00, 24.84it/s]

MRR:  0.71383315
median Recall:  1.0
median Precision:  0.33333
median F1:  0.499996249990625



  return (2 * self.precision(pred_cands, gold_cands, k) * self.recall(pred_cands, gold_cands, k)) / (self.precision(pred_cands, gold_cands, k) + self.recall(pred_cands, gold_cands, k))


In [119]:
# сохранить результат
if os.path.exists(SAVE_LOGFILE):
    print("Файл существует!")
    raise ValueError

log_data = {'info': BENCHMARKS_INFO,
            'hyperp': {
                'model_path': MODEL_PATH, 'encode_kwargs': ENCODE_KWARGS, 
                'params': PARAMS, 'weights': WEIGHTS, 'benchmark_sizes': BENCHES_SIZE},
            'scores': benchmarks_score}

with open(SAVE_LOGFILE, 'w', encoding='utf-8') as fd:
    fd.write(json.dumps(log_data, indent=1))