<a href="https://colab.research.google.com/github/Sogo95/Question-and-Answer-Test-Train-Overlap/blob/main/Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install wget


Collecting wget
  Downloading wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9656 sha256=6778fdb8379f9e4c794cf45779489a8317fafbb4fb1448281207133683949840
  Stored in directory: /root/.cache/pip/wheels/8b/f1/7f/5c94f0a7a505ca1c81cd1d9208ae2064675d97582078e6c769
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2


# Téléchargement des données

In [None]:
# Copyright (c) 2020-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
"""Download data for calculating question overlap"""
import wget


DATA_DIR = os.path.join(os.getcwd(), 'data')
#DIRNAME = os.path.dirname(os.path.abspath(__file__))
#DATA_DIR = os.path.join(DIRNAME, 'data')

TEST_SETS_TO_DOWNLOAD = [
    ('https://dl.fbaipublicfiles.com/qaoverlap/data/nq-test.qa.csv','nq-test.qa.csv'),
    ('https://dl.fbaipublicfiles.com/qaoverlap/data/triviaqa-test.qa.csv', 'triviaqa-test.qa.csv'),
    ('https://dl.fbaipublicfiles.com/qaoverlap/data/webquestions-test.qa.csv', 'webquestions-test.qa.csv'),
]
ANNOTATIONS_TO_DOWNLOAD = [
    ('https://dl.fbaipublicfiles.com/qaoverlap/data/nq-annotations.jsonl','nq-annotations.jsonl'),
    ('https://dl.fbaipublicfiles.com/qaoverlap/data/triviaqa-annotations.jsonl', 'triviaqa-annotations.jsonl'),
    ('https://dl.fbaipublicfiles.com/qaoverlap/data/webquestions-annotations.jsonl','webquestions-annotations.jsonl')
]

os.makedirs(DATA_DIR, exist_ok=True)
for link, dest in TEST_SETS_TO_DOWNLOAD:
    wget.download(link, os.path.join(DATA_DIR, dest))

for link, dest in ANNOTATIONS_TO_DOWNLOAD:
    wget.download(link, os.path.join(DATA_DIR, dest))


# Evaluation des datasets

In [None]:
from collections import Counter
import string
import re
import json
import os
import argparse
import ast

ANNOTATIONS = [
    'total',
    'question_overlap',
    'no_question_overlap',
    'answer_overlap',
    'no_answer_overlap',
    'answer_overlap_only'
]

DIRNAME = os.getcwd()
DATA_DIR = os.path.join(DIRNAME, 'data')
REFERENCE_PATHS = {
    'triviaqa': os.path.join(DIRNAME, 'data/triviaqa-test.qa.csv'),
    'nq': os.path.join(DIRNAME, 'data/nq-test.qa.csv'),
    'webquestions': os.path.join(DIRNAME, 'data/webquestions-test.qa.csv'),
}
ANNOTATION_PATHS = {
    'triviaqa': os.path.join(DIRNAME, 'data/triviaqa-annotations.jsonl'),
    'nq': os.path.join(DIRNAME, 'data/nq-annotations.jsonl'),
    'webquestions': os.path.join(DIRNAME, 'data/webquestions-annotations.jsonl'),
}



def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
 # Handle list inputs
    if isinstance(s, list):
        s = ' '.join(s)  # Join list elements into a single string
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def read_references(fi, sep='\t'):
    def parse_pandas_answer(a_string):
        try:
            parsed_answers = ast.literal_eval(a_string) if a_string.startswith('[') else ast.literal_eval(a_string.replace('""', '"')[1:-1])
        except:
            parsed_answers = ast.literal_eval(a_string.replace('""', '"').replace('""', '"').replace('""', '"')[1:-1])
        return parsed_answers


    references = []
    for i, line in enumerate(open(fi)):
        q, answer_str = line.strip('\n').split(sep)
        refs = parse_pandas_answer(answer_str)
        references.append({'references': refs, 'id': i})
    return references


def read_jsonl(path):
    with open(path) as f:
        return [json.loads(l) for l in f]


def read_lines(path):
    with open(path) as f:
        return [l.strip() for l in f]


def read_annotations(annotations_data_path):
    return read_jsonl(annotations_data_path)


def safe_parse_prediction(prediction_str):
    """Essaie d'évaluer la prédiction en utilisant ast.literal_eval. Renvoie la chaîne brute en cas d'échec."""
    try:
        return ast.literal_eval(prediction_str)
    except (SyntaxError, ValueError) as e:
        print(f"Erreur lors de l'évaluation de la prédiction : {prediction_str}")
        print(f"Détails de l'erreur : {e}")
        return prediction_str  # Retourne la prédiction sous forme de chaîne brute si l'évaluation échoue


def read_predictions(path):
    if path.endswith('json') or path.endswith('.jsonl'):
        return read_jsonl(path)
    else:
        return [{'id': i, 'prediction': safe_parse_prediction(pred.split('\t')[1])} for i, pred in enumerate(read_lines(path))]


def _get_scores(answers, refs, fn):
    return [metric_max_over_ground_truths(fn, pred, rs) for pred, rs in zip(answers, refs)]


def get_scores(predictions, references, annotations, annotation_labels=None):
    predictions_map = {p['id']: p for p in predictions}
    references_map = {r['id']: r for r in references}
    annotations_map = {a['id']: a for a in annotations}
    assert predictions_map.keys() == references_map.keys(), 'predictions file doesnt match the gold references file '
    assert predictions_map.keys() == annotations_map.keys(), 'prediction file doesnt match the annotation file '
    assert annotations_map.keys() == references_map.keys(), 'annotations file doesnt match the gold references file '

    annotation_labels = ANNOTATIONS if annotation_labels is None else annotation_labels

    results = {}
    for annotation_label in annotation_labels:
        annotation_ids = [annotation['id'] for annotation in annotations if annotation_label in annotation['labels']]
        preds = [predictions_map[idd]['prediction'] for idd in annotation_ids]
        refs = [references_map[idd]['references'] for idd in annotation_ids]
        em = _get_scores(preds, refs, exact_match_score)
        f = _get_scores(preds, refs, f1_score)
        results[annotation_label] = {
            'exact_match': 100 * sum(em) / len(em),
            'f1_score': 100 * sum(f) / len(f),
            'n_examples': len(annotation_ids),
        }
    return results


def _print_score(label, results_dict):
    print('-' * 50)
    print('Label       :', label)
    print('N examples  : ', results_dict['n_examples'])
    print('Exact Match : ', results_dict['exact_match'])
    print('F1 score    : ', results_dict['f1_score'])


def _main(predictions_path, references_path, annotations_path):
    predictions = read_predictions(predictions_path)
    references = read_references(references_path)
    annotations = read_annotations(annotations_path)
    scores = get_scores(predictions, references, annotations)
    for label in ANNOTATIONS:
        _print_score(label, scores[label])


def main(predictions_path, dataset_name):
    references_path = REFERENCE_PATHS[dataset_name]
    annotations_path = ANNOTATION_PATHS[dataset_name]
    if not os.path.exists(references_path):
        raise Exception('References expected at ' + references_path
                        + ' not found, please download them using the download script (see readme)')
    if not os.path.exists(annotations_path):
        raise Exception('Annotations expected at ' + annotations_path
                        + ' not found, please download them using the download script (see readme)')
    _main(predictions_path, references_path, annotations_path)


if __name__ == '__main__':
    datasets = ['nq', 'triviaqa', 'webquestions']
    for dataset_name in datasets:
        predictions_path = os.path.join(DATA_DIR, f'{dataset_name}-test.qa.csv')
        print(f'Évaluation pour le dataset: {dataset_name}')
        main(predictions_path, dataset_name)
        print()
        print()


Évaluation pour le dataset: nq
--------------------------------------------------
Label       : total
N examples  :  3610
Exact Match :  57.8393351800554
F1 score    :  85.74091619998404
--------------------------------------------------
Label       : question_overlap
N examples  :  324
Exact Match :  54.93827160493827
F1 score    :  84.2159443858211
--------------------------------------------------
Label       : no_question_overlap
N examples  :  672
Exact Match :  60.56547619047619
F1 score    :  86.71763270137613
--------------------------------------------------
Label       : answer_overlap
N examples  :  2297
Exact Match :  52.54680017414018
F1 score    :  83.78062285910569
--------------------------------------------------
Label       : no_answer_overlap
N examples  :  1313
Exact Match :  67.0982482863671
F1 score    :  89.17030980546777
--------------------------------------------------
Label       : answer_overlap_only
N examples  :  315
Exact Match :  53.01587301587302
F1 sco