In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
import string
import re
from collections import Counter

prediction_f = '/content/drive/MyDrive/CMU_FALL2024/11711_ANLP/HW2_ANLP/FinalModelEval/RELEASED_QA_MODEL_OUTPUT.txt'
reference_f = '/content/drive/MyDrive/CMU_FALL2024/11711_ANLP/HW2_ANLP/FinalModelEval/RELEASED_QA_ANSWERS.txt'

# Function to normalize text (removes punctuation, articles, etc.)
def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punctuation(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punctuation(lower(s))))

# Function to calculate F1 score
def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())

    if num_same == 0:
        return 0

    precision = num_same / len(prediction_tokens)
    recall = num_same / len(ground_truth_tokens)
    f1 = 2 * precision * recall / (precision + recall)

    return f1

# Function to calculate exact match
def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)

# Function to calculate recall
def recall_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())

    if len(ground_truth_tokens) == 0:
        return 0

    recall = num_same / len(ground_truth_tokens)
    return recall

question_scores = []
# Updated evaluate function to include recall
def evaluate_with_recall(predictions, references):
    total = len(references)
    f1 = exact_match = recall = 0
    question_num = 1

    for pred, ref in zip(predictions, references):
        exact_match_dif = exact_match_score(pred, ref)
        f1_dif = f1_score(pred, ref)
        recall_dif = recall_score(pred, ref)

        question_scores.append((1, "Question: ", question_num, " scores are -> ", 'exact_match: ', exact_match_dif, ' f1: ', f1_dif, ' recall: ', recall_dif))
        question_num += 1

        exact_match += exact_match_dif
        f1 += f1_dif
        recall += recall_dif

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    recall = 100.0 * recall / total

    return {'exact_match': exact_match, 'f1': f1, 'recall': recall}

def evaluate(prediction_file=prediction_f):
    # Example usage
    predictions = []
    references = []
    # Read references file
    with open(reference_f, 'r') as f:
        references = f.readlines()

    # Read predictions file
    with open(prediction_file, 'r') as f:
        predictions = f.readlines()


    # the first 100 is event, the second 100 is general info, the third 100 is sports, the fourth 100 is music and culture

    # Evaluate the first 100 questions
    results = evaluate_with_recall(predictions, references)
    print('Accuracy: ')
    print(results)

# Rag model
print("---RAG model accuracy on released test data---")
evaluate(prediction_f)
for elem in question_scores:
  print(elem)

---RAG model accuracy on released test data---
Accuracy: 
{'exact_match': 9.24731182795699, 'f1': 18.64688227426668, 'recall': 19.025293105938275}
(1, 'Question: ', 1, ' scores are -> ', 'exact_match: ', False, ' f1: ', 0.4, ' recall: ', 0.5)
(1, 'Question: ', 2, ' scores are -> ', 'exact_match: ', False, ' f1: ', 0, ' recall: ', 0.0)
(1, 'Question: ', 3, ' scores are -> ', 'exact_match: ', True, ' f1: ', 1.0, ' recall: ', 1.0)
(1, 'Question: ', 4, ' scores are -> ', 'exact_match: ', False, ' f1: ', 0, ' recall: ', 0.0)
(1, 'Question: ', 5, ' scores are -> ', 'exact_match: ', False, ' f1: ', 0.6666666666666666, ' recall: ', 0.5)
(1, 'Question: ', 6, ' scores are -> ', 'exact_match: ', False, ' f1: ', 0, ' recall: ', 0.0)
(1, 'Question: ', 7, ' scores are -> ', 'exact_match: ', False, ' f1: ', 0.9090909090909091, ' recall: ', 0.8333333333333334)
(1, 'Question: ', 8, ' scores are -> ', 'exact_match: ', False, ' f1: ', 0, ' recall: ', 0.0)
(1, 'Question: ', 9, ' scores are -> ', 'exact_ma