In [1]:
import sys
BASE_DIR = "/home/dzigen/Desktop/ITMO/smiles2024/RAG-project-SMILES-2024-"
sys.path.insert(0, BASE_DIR)

import pandas as pd 
import ast
import numpy as np
import json
from tqdm import tqdm
from IPython.display import clear_output
import os

from src.Retriever import ThresholdRetriever
from src.Reader import LLM_model
from src.utils import ReaderMetrics, RetrieverMetrics, save_rag_trial_log, prepare_retriever_configs, prepare_reader_configs, load_benchmarks_df
from src.utils import evaluate_retriever, evaluate_reader

In [2]:
# !!! TO CHANGE !!!
TRIAL = 1
BENCHMARKS_MAXSIZE = 1000
BENCHMARKS_INFO = {'mtssquad': {'db': 'v2', 'table': 'v1'}}

# Retriever part
RETRIEVER_PARAMS = {
    "model_path": "/home/dzigen/Desktop/nlp_models/intfloat/multilingual-e5-small",
    "densedb_kwargs": {'metadata': {"hnsw:space": "ip"}},
    "model_kwargs": {'device':'cuda'},
    "encode_kwargs": {'normalize_embeddings': True, 'prompt': 'query: '},
    "params": {'fetch_k': 50, 'threshold': 0.2, 'max_k': 3}
}

# Reader part
READER_PARAMS = {
    "temperature": 0.2,
    "top_k": 40,
    "top_p": 0.95,
    "min_p": 0.05,
    "typical_p": 1,
    "max_tokens": -1
}
# !!! TO CHANGE !!!

SAVE_LOGDIR = f'./logs/trial{TRIAL}'
SAVE_HYPERPARAMS = f'{SAVE_LOGDIR}/hyperparams.json'
SAVE_READERCACHE = f'{SAVE_LOGDIR}/reader_cache.json'
SAVE_RETRIEVERCACHE = f'{SAVE_LOGDIR}/retriever_cache.json'

##### Configure Retriever-part

In [None]:
retrievers_config, benchmarks_path = prepare_retriever_configs(BASE_DIR, BENCHMARKS_INFO, RETRIEVER_PARAMS)
retriever_metrics = RetrieverMetrics()

In [4]:
benchmarks_df = load_benchmarks_df(benchmarks_path, BENCHMARKS_MAXSIZE)

##### Configure Reader-part

In [6]:
hardw_c, hyperp_c = prepare_reader_configs(READER_PARAMS)
reader_metrics = ReaderMetrics(base_dir=BASE_DIR)

In [8]:
READER = LLM_model(hardw_c, hyperp_c)

##### Evaluating pipeline

In [None]:
retriever_scores, retriever_cache = evaluate_retriever(benchmarks_df, retrievers_config, retriever_metrics)

In [None]:
reader_scores, reader_cache = evaluate_reader(benchmarks_df, READER, reader_metrics, [])_

In [None]:
save_rag_trial_log(SAVE_LOGDIR, reader_scores, retriever_scores, 
                   SAVE_HYPERPARAMS, SAVE_READERCACHE, SAVE_RETRIEVERCACHE,
                   reader_cache, retriever_cache, BENCHMARKS_INFO, BENCHMARKS_MAXSIZE,
                   READER_PARAMS, RETRIEVER_PARAMS)