In [25]:
import glob
import os
import json
import csv
import argparse
import re
import string

import sacrebleu
import editdistance

from typing import Dict, List, Tuple, Any, Union
from collections import Counter
from difflib import SequenceMatcher

In [None]:
DATA_DIR = "../multiloko_eval/dev.jsonl"

MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
model_f = MODEL.split("/")[0]
model_name = MODEL.split("/")[1]

prediction_path = os.path.join(
    "../model_output", 
    model_f, 
    model_name,
    "evaluate_5shot.jsonl"
)

out_file = os.path.join(
    "../model_output", 
    model_f, 
    model_name,
    "evaluate_results_5shot.jsonl"
)


In [27]:
def remove_articles(text: str) -> str:
    return re.sub(r"\b(a|an|the)\b", " ", text)

In [28]:
def white_space_fix(text: str) -> str:
    return " ".join(text.split())

In [29]:
def remove_punc(text: str) -> str:
    """
    Removes the punctuation from a string. This used to rely on the builtin string.punctuation constant
    which contains these symbols: !"#$%&'()*+,-.:;<=>?@[]^_`{|}~ but it completely misses CJK punctuation or
    some western european language' variations.
    """
    punct = string.punctuation  # Sadly the builtin punctuation is exclusive to ASCII
    # Adapted from https://stackoverflow.com/questions/36640587/how-to-remove-chinese-punctuation-in-python
    extra_punct = r"""„“«»¡¿《》！？｡。＂＃＄％＆＇（）＊＋，－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."""
    punct = punct + extra_punct
    exclude = set(punct)
    return "".join(ch for ch in text if ch not in exclude)

In [30]:
def normalize_answer(text: str) -> str:
    return white_space_fix(remove_articles(remove_punc(text.lower())))

In [31]:
def postprocess_answers(input: Union[str, List[str]]) -> Union[str, List[str]]:
    if isinstance(input, str):
        return normalize_answer(input)
    else:
        return [normalize_answer(x) for x in input]

In [32]:
def parse_input_jsonl(file_and_lang: List[Tuple[str, str]]) -> Dict[str, Dict[str, List[str]]]:
    # parses the input jsonl file(s)
    # It saves only the id and the true answer ret'language'] = []
    ret = {}
    for filename, language in file_and_lang:
        ret[language] = {}
        with open(filename, "r") as f:
            for line in f:
                data = json.loads(line)
                ret[language][data["id"]] = postprocess_answers(data["targets"])
    return ret

In [33]:
def parse_output_jsonl(filename: str) -> Dict[str, Dict[str, str]]:
    # Parses a JSONl file. Expected format is
    # {"language": "language", "id" : "id", "prediction": "prediction"}
    ret = {}
    with open(filename, "r") as f:
        for line in f:
            data = json.loads(line)
            assert "language" in data
            assert "id" in data
            assert "prediction" in data
            assert data["language"] != ""
            assert data["id"] != ""
            if data["language"] not in ret:
                ret[data["language"]] = {}
            ret[data["language"]][data["id"]] = postprocess_answers(data["prediction"])
    return ret

In [34]:
def output_results(results: Dict[str, Any], output_file: Union[str, None]) -> None:
    if output_file:
        with open(output_file, "w") as f:
            json.dump(results, f, ensure_ascii=False, indent=4)
    else:
        print(json.dumps(results), ensure_ascii=False, indent=4)

In [35]:
def f1(prediction: str, targets: List[str]) -> float:
    def _f1(pred_tokens: List[str], gt_tokens: List[str]) -> float:
        common = Counter(pred_tokens) & Counter(gt_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(pred_tokens)
        recall = 1.0 * num_same / len(gt_tokens)
        return (2 * precision * recall) / (precision + recall)

    return max(_f1(prediction.split(), target.split()) for target in targets)


def exact_match(prediction: str, targets: List[str]) -> float:
    return max(float(prediction == target) for target in targets)


def sentence_bleu(prediction: str, targets: List[str], **kwargs: Any) -> float:
    return sacrebleu.sentence_bleu(prediction, targets, **kwargs).score


def sentence_chrf(prediction: str, targets: List[str], **kwargs: Any) -> float:
    return sacrebleu.sentence_chrf(prediction, targets, **kwargs).score


def edit_distance(prediction_tokens: str, target_tokens: str) -> float:
    """
    Get minimum edit distance (Levenshtein distance) between prediction and target
    """
    return float(editdistance.eval(prediction_tokens, target_tokens))


def edit_distance_many(prediction_tokens: str, target_tokens: List[str]) -> float:
    """
    Get minimum edit distance (Levenshtein distance) between prediction and
    multiple possible targets.
    """
    return float(
        min(edit_distance(prediction_tokens, target) for target in target_tokens)
    )


def edit_similarity(prediction: str, targets: List[str]) -> float:
    return max(SequenceMatcher(None, prediction, target).ratio() for target in targets)


def evaluate_all(reference_answers, our_answers):
    # Evaluate all metrics
    metrics = {
        "em": exact_match,
        "f1": f1,
        "bleu": sentence_bleu,
        "chrf": sentence_chrf,
        "edit_distance": edit_distance_many,
        "edit_similarity": edit_similarity,
    }
    # Compute per example scores
    results = {}
    for lang, examples in our_answers.items():
        results[lang] = {}
        ref_current = reference_answers[lang]
        for id, prediction in examples.items():
            results[lang][id] = {}
            targets = ref_current[id]
            for metric_name, metric in metrics.items():
                results[lang][id][metric_name] = metric(prediction, targets)
            results[lang][id]["targets"] = targets
            results[lang][id]["prediction"] = prediction
        # Now compute aggregate scores for that language
        for metric in metrics:
            count = 0
            mysum = 0
            for id in results[lang]:
                if id in metrics:
                    continue
                mysum += results[lang][id][metric]
                count += 1
            results[lang][metric] = mysum / count
    # Now compute across all languages
    results["group_metrics"] = {}
    for metric in metrics:
        results["group_metrics"][f"average_{metric}"] = sum(results[lang][metric] for lang in results if lang != "group_metrics") / (len(results) - 1)
        results["group_metrics"][f"max_{metric}"] = max(results[lang][metric] for lang in results if lang != "group_metrics")
        results["group_metrics"][f"min_{metric}"] = min(results[lang][metric] for lang in results if lang != "group_metrics")
        results["group_metrics"][f"gap_{metric}"] = results["group_metrics"][f"max_{metric}"] - results["group_metrics"][f"min_{metric}"]
    return results

In [36]:
input_parse_tuples = [(DATA_DIR, "italian"),]
predictions = prediction_path
output = out_file
reference_answers = parse_input_jsonl(input_parse_tuples)
our_answers = parse_output_jsonl(predictions)
results = evaluate_all(reference_answers, our_answers)
output_results(results, output)