In [3]:
import collections
import re
import string

Normalize

In [4]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
        return re.sub(regex, ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


In [5]:
def get_tokens(s):
    if not s:
        return []
    return normalize_answer(s).split()

In [6]:
def compute_exact(a_gold, a_pred):
    return int(normalize_answer(a_gold) == normalize_answer(a_pred))

In [7]:
def compute_f1(a_gold, a_pred):
    gold_toks = get_tokens(a_gold)
    pred_toks = get_tokens(a_pred)
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)

    if not common:
        return 0, 0, 0

    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        return int(gold_toks == pred_toks), 0, 0
    if num_same == 0:
        return 0, 0, 0

    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)

    return f1, precision, recall


Metrics cho RAG + reranker

In [None]:
generated_answers = 'rag_with_reranker_answers.txt'
correct_answers = 'answers.txt'
with open(generated_answers, 'r', encoding='utf-8') as f:
    pred_ans = f.readlines()
with open(correct_answers, 'r', encoding='utf-8') as f:
    corr_ans = f.readlines()

print(f"Số câu hỏi: {len(corr_ans)}")
print(f"Số câu trả lời dự đoán: {len(pred_ans)}")


Số câu hỏi: 21
Số câu trả lời dự đoán: 21


In [15]:
output_lines = []
tot_f1 = tot_precision = tot_recall = exact_match = 0

for i in range(len(corr_ans)):
    reference_answers = corr_ans[i].strip().split(';')
    best_f1 = best_precision_val = best_recall = 0

    for reference_answer in reference_answers:
        f1, prec, rec = compute_f1(reference_answer, str(pred_ans[i]))
        if f1 >= best_f1:
            best_f1 = f1
            best_precision_val = prec
            best_recall = rec
    f1 = best_f1
    prec = best_precision_val
    rec = best_recall
    ex_mtch = compute_exact(corr_ans[i], pred_ans[i])

    tot_f1 += f1
    tot_precision += prec
    tot_recall += rec
    exact_match += ex_mtch

    output_lines.append(
        f'Predicted answer: {pred_ans[i]}\n'
        f'Correct answer: {corr_ans[i]}\n'
        f'f1: {f1}\nPrecision: {prec}\nRecall: {rec}\nExact Match: {ex_mtch}\n'
    )


In [17]:
avg_f1 = tot_f1 / len(pred_ans)
avg_precision = tot_precision / len(pred_ans)
avg_recall = tot_recall / len(pred_ans)

print(f'F1: {avg_f1:.4f}')
print(f'Precision: {avg_precision:.4f}')
print(f'Recall: {avg_recall:.4f}')
print(f'Total Exact Match: {exact_match} / {len(pred_ans)}')


F1: 0.6129
Precision: 0.6466
Recall: 0.7103
Total Exact Match: 7 / 21
