In [1]:
import sys
import random
import pandas as pd
import ast
from tqdm import tqdm
import json

#BASE_DIR = "/home/dzigen/Desktop/ITMO/smiles2024/RAG-project-SMILES-2024-"
BASE_DIR = "/trinity/home/team06/workspace/mikhail_workspace/rag_project"
RANDOM_SEED = 42

sys.path.insert(0, BASE_DIR)
random.seed(RANDOM_SEED)

from src.Reader import LLM_Model
from src.utils import ReaderMetrics, save_reader_trial_log, prepare_reader_configs, load_benchmarks_df
from src.utils import evaluate_reader


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import gc

torch.cuda.empty_cache()
gc.collect()

124

In [3]:
# !!! TO CHANGE !!!
TRIAL = 2
BENCHMARKS_MAXSIZE = 500
BENCHMARKS_INFO = {'mtssquad': {'table': 'v1'}}

READER_PARAMS = {
    'prompts': {
        "assistant": "Отвечай на вопросы, используя информацию из текстов в списке ниже:",
        "system": "Ты вопросно-ответная система. Все ответы генерируй на русском языке. По вопросам отвечай кратко, чётко и конкретно. Не генерируй излишнюю информацию.",
    },
    'gen': {'max_new_tokens': 512, 'eos_token_id': 79097},
    'data_operate': {'batch_size': 1, 'num_workers':8}
    }

ADDITIONAL_PARAMS = {
    'unrel_c_mltp': 4
}

BERTSCORE_MODEL_PATH = "ru_electra_medium"
# !!! TO CHANGE !!!

SAVE_LOGDIR = f'./logs/trial{TRIAL}'
SAVE_HYPERPARAMS = f'{SAVE_LOGDIR}/hyperparams.json'
SAVE_READERCACHE = f'{SAVE_LOGDIR}/reader_cache.json'
SAVE_RETRIEVERCACHE = f'{SAVE_LOGDIR}/retriever_cache.json'

In [4]:
banchmarks_path = {}
for name, version in BENCHMARKS_INFO.items():
    banchmarks_path[name] = {
        'table': f"{BASE_DIR}/data/{name}/tables/{version['table']}/benchmark.csv",
        'chunked_docs': f"{BASE_DIR}/data/{name}/tables/{version['table']}/chunked_docs.csv"
    }

In [5]:
benchmarks_df = load_benchmarks_df(banchmarks_path, BENCHMARKS_MAXSIZE)

In [6]:
reader_config = prepare_reader_configs(READER_PARAMS)
reader_metrics = ReaderMetrics(base_dir=BASE_DIR, model_path=BERTSCORE_MODEL_PATH)

Loading Meteor...
Loading ExactMatch


In [7]:
READER = LLM_Model(reader_config)

Loading checkpoint shards: 100%|██████████| 4/4 [03:18<00:00, 49.74s/it]


In [8]:
READER.config.data_operate.batch_size = 1

In [9]:
# prepare raw unrelevant contexts
raw_contexts = {}
retriever_cache = {}
for name in banchmarks_path.keys():
    print(name)
    cur_chunked_df = pd.read_csv(banchmarks_path[name]['chunked_docs'], sep=';')
    cur_chunked_df['metadata'] = cur_chunked_df['metadata'].map(lambda v: ast.literal_eval(v))
    chunks_dict = {cur_chunked_df['metadata'][i]['chunk_id']: cur_chunked_df['chunks'][i] for i in range(cur_chunked_df.shape[0])}
    all_chunk_ids = set(chunks_dict.keys())

    raw_contexts[name] = []
    retriever_cache[name] = []
    for i in tqdm(range(benchmarks_df[name].shape[0])):
        cur_relevant_chunk_ids = set(benchmarks_df[name]['chunk_ids'][i])
        unrelevant_chunk_ids = all_chunk_ids.difference(cur_relevant_chunk_ids)
        
        selected_unrelevant_chunk_ids = random.sample(unrelevant_chunk_ids, ADDITIONAL_PARAMS['unrel_c_mltp'] * len(cur_relevant_chunk_ids))
        raw_contexts[name].append([chunks_dict[idx] for idx in selected_unrelevant_chunk_ids])
        retriever_cache[name].append(selected_unrelevant_chunk_ids)

mtssquad


since Python 3.9 and will be removed in a subsequent version.
  selected_unrelevant_chunk_ids = random.sample(unrelevant_chunk_ids, ADDITIONAL_PARAMS['unrel_c_mltp'] * len(cur_relevant_chunk_ids))
100%|██████████| 500/500 [00:00<00:00, 1968.49it/s]


In [10]:
contexts = {name: [reader_config.prompts.assistant + "\n\n" + "\n\n".join([f'{i+1}. {doc}' for i, doc in enumerate(docs)]) 
                   for docs in contexts] for name, contexts in raw_contexts.items()}

In [11]:
reader_scores, reader_cache = evaluate_reader(benchmarks_df, READER, reader_metrics, contexts)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 500/500 [44:51<00:00,  5.38s/it, BLEU2=0.175, BLEU1=0.239, ExactMatch=0.002, METEOR=0.372, BertScore=nan]    


In [12]:
save_reader_trial_log(SAVE_LOGDIR, reader_scores, SAVE_HYPERPARAMS, SAVE_READERCACHE, 
                      reader_cache, BENCHMARKS_INFO, BENCHMARKS_MAXSIZE, READER_PARAMS, ADDITIONAL_PARAMS)

In [13]:
with open(SAVE_RETRIEVERCACHE, 'w', encoding='utf-8') as fd:
    fd.write(json.dumps(retriever_cache, indent=1, ensure_ascii=False))  