<a href="https://colab.research.google.com/github/BottleMin/Paper_Implement/blob/main/RAG/evaluate_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

%cd /content/drive/MyDrive/rag_project

!pip install faiss-cpu
!pip install faiss-gpu
!pip install transformers datasets


In [None]:
import argparse
import logging
import os
import faiss

import pandas as pd
import torch
from tqdm import tqdm
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from transformers import logging as transformers_logging
from datasets import Dataset, load_from_disk

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
transformers_logging.set_verbosity_info()

# 데이터 정규화 및 평가 함수 정의

- `normalize_answer(s)`: 문자열을 정규화하는 함수. 소문자 변환, 구두점 제거, 불필요한 공백 제거 등을 수행.

- `f1_score(prediction, ground_truth)`: 예측된 답변과 실제 답변 간의 F1 점수.

- `exact_match_score(prediction, ground_truth)`: 예측된 답변과 실제 답변이 정확히 일치하는지 확인.

In [None]:
import string
import re
from collections import Counter
from typing import Callable, Dict, Iterable, List

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)

# 평가 함수

- `evaluate_batch_e2e(args, rag_model, questions)`: 모델을 사용하여 질문에 대한 답변을 생성하는 함수.

질문을 토큰화하고 입력 데이터로 준비. 모델을 통해 답변을 생성하고, 생성된 답변을 디코딩한다.

- `args.print_predictions`가 참이면 질문과 답변을 로그에 기록한다.

In [None]:
def evaluate_batch_e2e(args, rag_model, questions):
    with torch.no_grad():
        # Tokenize the questions
        inputs_dict = rag_model.retriever.question_encoder_tokenizer.batch_encode_plus(
            questions, return_tensors="pt", padding=True, truncation=True
        )

        input_ids = inputs_dict['input_ids'].to(args.device)
        attention_mask = inputs_dict['attention_mask'].to(args.device)

        outputs = rag_model.generate(
            input_ids,
            attention_mask=attention_mask,
            num_beams=args.num_beams,
            min_length=args.min_length,
            max_length=args.max_length,
            early_stopping=True,
            num_return_sequences=1,
            bad_words_ids=[[0, 0]],  # BART likes to repeat BOS tokens, dont allow it to generate more than one
        )
        answers = rag_model.retriever.generator_tokenizer.batch_decode(outputs, skip_special_tokens=True)

        if args.print_predictions:
            for q, a in zip(questions, answers):
                logger.info("Q: {} - A: {}".format(q, a))

        return answers

In [None]:
parser = argparse.ArgumentParser(description="Test Script")

parser.add_argument("--model_name_or_path", type=str, default="/content/drive/MyDrive/rag_project/trained_model", help="Path to pretrained model or model identifier")
parser.add_argument("--evaluation_set", type=str, default="nq-test.csv", help="Path to the evaluation dataset (CSV file)")
parser.add_argument("--predictions_path", type=str, default="predictions.txt", help="Path to save predictions")
parser.add_argument("--num_beams", type=int, default=4, help="Number of beams for generation")
parser.add_argument("--min_length", type=int, default=1, help="Minimum length of generated answers")
parser.add_argument("--max_length", type=int, default=50, help="Maximum length of generated answers")
parser.add_argument("--eval_batch_size", type=int, default=8, help="Batch size for evaluation")
parser.add_argument("--print_predictions", action="store_true", help="If true, print predictions")  # 추가된 부분

args = parser.parse_args([])

args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loaded_faiss_index = faiss.read_index('embeddings.faiss')

loaded_embeddings_dataset = Dataset.load_from_disk('wiki_dataset_without_indexes')

# 로드된 FAISS 인덱스를 데이터셋에 추가
loaded_embeddings_dataset.add_faiss_index(column='embeddings')


retriever = RagRetriever.from_pretrained(args.model_name_or_path, indexed_dataset = loaded_embeddings_dataset)
model = RagSequenceForGeneration.from_pretrained(args.model_name_or_path, retriever=retriever)
model.to(args.device)

In [None]:
logger.info("***** Running evaluation *****")
logger.info("  Batch size = %d", args.eval_batch_size)
logger.info("  Predictions will be stored under %s", args.predictions_path)

import gc
import random
gc.collect()
torch.cuda.empty_cache()

random_batch = random.randrange(len(question))

# 하나의 질문에 대해서만 예측 수행
eval_data = pd.read_csv(args.evaluation_set)
question = eval_data['Question'].tolist()
answers = eval_data['Answer'].tolist()

batch_question = [question[random_batch]]
gold_answer = answers[random_batch]
print(f"Evaluating for question: {batch_question}")

# 예측 수행


batch_answers = evaluate_batch_e2e(args, model, batch_question)

# 예측 결과 출력
predicted_answer = batch_answers[0]
print(f"Predicted answer: {predicted_answer}")
print(f'Gold answer: {gold_answer}')

In [None]:
 f1_score = f1_score(predicted_answer, gold_answer)
 print(f'f1 score : {f1_score}')