In [1]:
import sys
sys.path.insert(0, "/home/dzigen/Desktop/ITMO/smiles2024/RAG-project-SMILES-2024-")

import pandas as pd 
import ast
import numpy as np
import json
from tqdm import tqdm
import os

from src.Retriever import ThresholdRetriever, ThresholdRetrieverConfig
from src.utils import RetrieverMetrics

In [2]:
# !!! TO CHANGE !!!
SAVE_LOGFILE = './logs/trial2.json'
BENCHMARKS_INFO = {'mtssquad': {'db': 'v1', 'table': 'v1'}}

CUSTOM_ARGS = {
    "model_path": "/home/dzigen/Desktop/nlp_models/intfloat/multilingual-e5-small",
    "densedb_kwargs": {'metadata': {"hnsw:space": "l2"}},
    "model_kwargs": {'device':'cuda'},
    "encode_kwargs": {'normalize_embeddings': False, 'prompt': 'query: '},
    "params": {'fetch_k': 50, 'threshold': 8.0, 'max_k': 3}
}


BENCHES_SIZE = 1000
# !!! TO CHANGE !!!

In [3]:
banchmark_paths = {}
for name, version in BENCHMARKS_INFO.items():
    banchmark_paths[name] = {
        'table': f"../../data/{name}/tables/{version['table']}/benchmark.csv",
        'dense_db': f"../../data/{name}/dbs/{version['db']}/densedb"
    }

benchmark_config = {}
for name, paths in banchmark_paths.items():
    
    CUSTOM_ARGS['densedb_path'] = banchmark_paths[name]['dense_db']
    CUSTOM_ARGS['densedb_kwargs']['name'] = name

    config = ThresholdRetrieverConfig(**CUSTOM_ARGS)

    benchmark_config[name] = config

In [4]:
# загрузить benchmark-датасет
benchmarks_df = {}
for name, bench_path in banchmark_paths.items():
    benchmarks_df[name] = pd.read_csv(banchmark_paths[name]['table'], sep=';').iloc[:BENCHES_SIZE,:]
    benchmarks_df[name]['chunk_ids'] = benchmarks_df[name]['chunk_ids'].map(lambda v: ast.literal_eval(v)) 
    benchmarks_df[name]['contexts'] = benchmarks_df[name]['contexts'].map(lambda v: ast.literal_eval(v)) 

# инифицализировать класс с метриками
metrics = RetrieverMetrics()

In [5]:
# инициализировать ретриверов
retrievers = {name: ThresholdRetriever(b_config) for name, b_config in benchmark_config.items()}

No sentence-transformers model found with name /home/dzigen/Desktop/nlp_models/intfloat/multilingual-e5-small. Creating a new one with mean pooling.


In [6]:
def get_relevant_chunk_ids(df, retriever):
    relevant_chunk_ids = []
    for i in tqdm(range(df.shape[0])):
        output = retriever.invoke(df['question'][i])
        relevant_chunk_ids.append(list(map(lambda item: item[2]['chunk_id'], output)))
    
    return relevant_chunk_ids


benchmarks_score = {}
for i, name in enumerate(benchmarks_df.keys()):
    print(name)
    # для каждого запроса получить список релевантных чанков
    bench_pred_chunk_ids = get_relevant_chunk_ids(benchmarks_df[name], retrievers[name])
    bench_golden_chunk_ids = benchmarks_df[name]['chunk_ids'].to_list()

    mrr_score = metrics.MRR(bench_pred_chunk_ids, bench_golden_chunk_ids)
    print("MRR: ", mrr_score)
    
    recall = [metrics.recall(relevant_ids, golden_ids, k=benchmark_config[name].params['max_k']) 
              for relevant_ids, golden_ids in zip(bench_pred_chunk_ids, bench_golden_chunk_ids)]
    print("median Recall: ", np.median(recall))

    precision = [metrics.precision(relevant_ids, golden_ids, k=benchmark_config[name].params['max_k']) 
                 for relevant_ids, golden_ids in zip(bench_pred_chunk_ids, bench_golden_chunk_ids)]
    print("median Precision: ", np.median(precision))

    f1 = list(map(lambda v: 0 if np.isnan(v) else v, [metrics.f1_score(relevant_ids, golden_ids, k=benchmark_config[name].params['max_k']) 
          for relevant_ids, golden_ids in zip(bench_pred_chunk_ids, bench_golden_chunk_ids)]))
    print("median F1: ", np.median(f1))

    # посчитать метрики 
    score = {
        'MRR': mrr_score,
        f'Recall@{benchmark_config[name].params["max_k"]}': np.median(recall),
        f'Precision@{benchmark_config[name].params["max_k"]}': np.median(precision),
        f'F1@{benchmark_config[name].params["max_k"]}': np.median(f1)
    }

    benchmarks_score[name] = score

mtssquad


100%|██████████| 1000/1000 [00:11<00:00, 85.22it/s]


MRR:  0.8014998900000001
median Recall:  1.0
median Precision:  0.33333
median F1:  0.499996249990625


  return (2 * self.precision(pred_cands, gold_cands, k) * self.recall(pred_cands, gold_cands, k)) / (self.precision(pred_cands, gold_cands, k) + self.recall(pred_cands, gold_cands, k))


In [7]:
# сохранить результат
if os.path.exists(SAVE_LOGFILE):
    print("Файл существует!")
    raise ValueError

log_data = {'info': BENCHMARKS_INFO,
            'hyperp': {'args': CUSTOM_ARGS, 'benchmark_sizes': BENCHES_SIZE},
            'scores': benchmarks_score}

with open(SAVE_LOGFILE, 'w', encoding='utf-8') as fd:
    fd.write(json.dumps(log_data, indent=1))