In [4]:
import sys
sys.path.insert(0, "/home/dzigen/Desktop/ITMO/ВКР/КМУ2024/inference.ipynb")

from src.retrievers.bm25colbert import BM25ColBertRetriever
from src.retrievers.bm25e5 import BM25E5Retriever
from src.readers.fid import FiDReader
from src.retrievers.e5 import E5Retriever

import torch
from tqdm import tqdm

In [None]:
TUNED_READER_PATH = '/home/dzigen/Desktop/ITMO/ВКР/КМУ2024/logs/reader_fid_squad/bestmodel.pt'
BASE_PATH = '/home/dzigen/Desktop/ITMO/ВКР/КМУ2024/data/bases/scipdf_bm25_base.pkl'
READER_INPUT_FORMAT = "context: {c}\n\nquestion: {q}"
READER_GEN_ML = 64

QUESTIONS = [
    "What is a RETRO approach",
    "What is a kNN-LM approach",
    "What is a DPR approach",
    "What is a RAG approach",
    "What is a FiD approach",
    "What is a EMDR2 approach",
    "What is a Atlas approach",
    "What is a REPLUG approach",
    "What is a ColBERT approach"
]

In [None]:
def inference(reader, retriever, queries):
    answers = []
    for query in tqdm(queries):
        print("QUERY: ", query)
        texts, k_scores, metadata = retriever.search(query)

        print("CONTEXTS:\n", '\n\n'.join(texts))

        formated_txts = list(map(
            lambda t: READER_INPUT_FORMAT.format(q=query,c=t), texts))

        tokenized_txts = reader.tokenize(formated_txts)
        
        cands_k = len(texts)

        # Generating Answers by predicted indices
        output = reader.model.generate(
            input_ids=tokenized_txts['input_ids'].view(1, cands_k, -1),
            attention_mask=tokenized_txts['attention_mask'].view(1, cands_k, -1), 
            max_length=READER_GEN_ML, eos_token_id=reader.tokenizer.eos_token_id)
        
        predicted = reader.tokenizer.batch_decode(output, skip_special_tokens=True)
        answers += predicted

        print("ANSWER: ", predicted[0])

    return answers

In [None]:
reader = FiDReader()
reader.load_model(TUNED_READER_PATH)

#### E5 + FID

In [None]:
TUNED_RETRIEVER_PATH = ''

colb_retriever = E5Retriever()
colb_retriever.load_model(TUNED_RETRIEVER_PATH)
colb_retriever.load_base(BASE_PATH)

#### BM25ColBERT + FID

In [None]:
TUNED_RETRIEVER_PATH = ''

colb_retriever = BM25ColBertRetriever()
colb_retriever.load_model(TUNED_RETRIEVER_PATH)
colb_retriever.load_base(BASE_PATH)

In [None]:
inference(reader, colb_retriever, QUESTIONS)

#### BM25E5 + FID

In [None]:
TUNED_RETRIEVER_PATH = ''

e5_retriever = BM25ColBertRetriever()
e5_retriever.load_model(TUNED_RETRIEVER_PATH)
e5_retriever.load_base(BASE_PATH)

In [None]:
inference(reader, e5_retriever, QUESTIONS)