In [1]:
import sys
sys.path.insert(0, "/home/aisummer/mikhail_workspace/nlp-service")

import pandas as pd 
import ast
import numpy as np
import json
import os

from src.DocumentsParser.utils import DBS_DIR_DENSE_VECTORDB_NAME, DBS_DIR_SPARSE_VECTORDB_NAME
from src.DocumentsRetriever.Retriever import RetrieverModule, RetrieverConfig
from src.evaluation_metrics import RetrievalMetrics
from src.utils import DialogueState
from src.logger import Logger

In [None]:
# !!! TO CHANGE !!!
SAVE_LOGFILE = './logs/trial1.json'
BENCHMARKS_INFO = {'sberquad': 'v1', 'squadv2': 'v1'}
MODEL_PATH = "/home/aisummer/mikhail_workspace/nlp-service/models/intfloat/multilingual-e5-small"
ENCODE_KWARGS = {'normalize_embeddings': False, 'prompt': 'query: '}
PARAMS = {'similarity': {'k': 4}, 'mmr': {'lambda_mult': 0.5, 'fetch_k': 20, 'k': 4}, 'bm25': {'k': 4}}
# !!! TO CHANGE !!!

In [2]:
banchmark_paths = {}
for name, version in BENCHMARKS_INFO.items():
    banchmark_paths[name] = {
        'table': f"../../data/{name}/{version}/benchmark.csv",
        'dense_db': f"../../data/{name}/dbs/{version}/{DBS_DIR_DENSE_VECTORDB_NAME}",
        'sparse_db':  f"../../data/{name}/dbs/{version}/{DBS_DIR_SPARSE_VECTORDB_NAME}"
    }

benchmark_config = {}
for name, paths in banchmark_paths.items():
    benchmark_config[name] = RetrieverConfig(
        model_path=MODEL_PATH,
        densedb_path=paths['dense_db'],
        sparsedb_path=paths['sparse_db'],
        encode_kwargs=ENCODE_KWARGS,
        params=PARAMS)

In [None]:
# загрузить benchmark-датасет
benchmarks_df = {}
for name, bench_path in banchmark_paths.items():
    benchmarks_df.append(pd.read_csv(bench_path[name]['table']))
    benchmarks_df[name]['chunk_ids'] = benchmarks_df[name]['chunk_ids'].map(lambda v: ast.literal_eval(v)) 

# инифицализировать класс с метриками
metrics = RetrievalMetrics()

# logging
logger = Logger(True)
log = logger.get_logger(__name__)

In [None]:
# инициализировать ретриверов
retrievers = {name: RetrieverModule(b_config, log) for name, b_config in benchmark_config.items()}

In [None]:
def get_relevant_chunk_ids(df, retriever):
    relevant_chunk_ids = []
    for i in range(df.shape[0]):
        state = DialogueState(query=df['question'][i])
        retriever.base_search(state)
        relevant_chunk_ids.append([item.metadata['chunk_id'] for item in state.base_relevant_docs])
    
    return relevant_chunk_ids


bench_relevant_chunk_ids = []
benchmarks_score = []
for i, name in enumerate(benchmarks_df.keys()):
    # для каждого запроса получить список релевантных чанков
    bench_relevant_chunk_ids.append(get_relevant_chunk_ids(benchmarks_df[name], retrievers[name]))
    bench_golden_chunk_ids = benchmarks_df[name]['chunk_ids']

    # посчитать метрики 
    score = {
        'MRR': metrics.MRR(bench_relevant_chunk_ids, bench_golden_chunk_ids),
        'mAP': metrics.mAP(bench_relevant_chunk_ids, bench_golden_chunk_ids),
        'Recall': np.mean([metrics.recall(relevant_ids, golden_ids) 
                           for relevant_ids, golden_ids in zip(bench_relevant_chunk_ids, benchmarks_df[name])]),
        'Precision': np.mean([metrics.precision(relevant_ids, golden_ids) 
                              for relevant_ids, golden_ids in zip(bench_relevant_chunk_ids, benchmarks_df[name])])
    }

    benchmarks_score.append(score)

In [None]:
# сохранить результат
if os.path.exists(SAVE_LOGFILE):
    print("Файл существует!")
    raise ValueError

log_data = {'info': BENCHMARKS_INFO,  'model_path': MODEL_PATH,
            'encode_kwargs': ENCODE_KWARGS, 'params': PARAMS}

with open(SAVE_LOGFILE, 'w', encoding='utf-8') as fd:
    fd.write(json.dumps(log_data, indent=1))