In [1]:
import os

import torch

import pandas as pd

from torchmetrics import MatthewsCorrCoef

from mtqe.data.loaders import load_ced_test_data
from mtqe.utils.language_pairs import LI_LANGUAGE_PAIRS_WMT_21_CED
from mtqe.utils.metrics import williams_test
from mtqe.utils.paths import PREDICTIONS_DIR

In [2]:
def load_predictions(lp, model, seed):
    """
    Load predictions for language pair using model and seed.
    NOTE: seed can be 'llm' or other string contained at the start
    of the file name for models that were not trained.
    """
    
    if "monolingual" in model or "second_step" in model:
        if lp == 'en-ja':
            dir_path = os.path.join(PREDICTIONS_DIR, "ced_data", f"{model}_enja")
        else:
            dir_path = os.path.join(PREDICTIONS_DIR, "ced_data", model)
    else:
        dir_path = os.path.join(PREDICTIONS_DIR, "ced_data", model)
    
    # the seed is contained at the start of the filename - make sure to ignore final timestamp!
    pred_file = [f for f in os.listdir(dir_path) if (lp in f and str(seed) in f[:20] and 'test' in f and 'csv' in f)][0]

    return pd.read_csv(os.path.join(dir_path, pred_file))

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
mcc = MatthewsCorrCoef(task="binary", num_classes=2).to(device)

In [4]:
def find_median_seed(lp, model, all_seeds = [42, 89, 107, 928, 2710]):
    """
    Find seed of model with median performance on language pair.
    """

    true_labels = load_ced_test_data(lp)['score']
    true_labels = torch.Tensor(true_labels)
    mcc_by_seed = []
    for seed in all_seeds:
        scores = load_predictions(lp, model, seed)['score']
        preds = torch.Tensor(scores) > 0.5
        preds = preds.long()
        mcc_by_seed.append(mcc(true_labels, preds).item())
    df = pd.DataFrame({"seeds": all_seeds, "MCCs": mcc_by_seed})
    
    return df.loc[df['MCCs']==df['MCCs'].median()]['seeds'].iloc[0]

In [5]:
## THIS IS THE MODEL WE WANT TO COMPARE AGAINST

# BASELINE_MODEL = 'baseline'
BASELINE_MODEL = 'train_multilingual_auth_data_all'

In [6]:
all_seeds = [42, 89, 107, 928, 2710]

if 'baseline' not in BASELINE_MODEL:
    baseline_median_seeds = {}
    for lp in LI_LANGUAGE_PAIRS_WMT_21_CED:

        baseline_median_seeds[lp] = find_median_seed(lp, BASELINE_MODEL)

for EXPERIMENT_GROUP_NAME in ['train_monolingual_auth_data', 'train_multilingual_auth_data_all', 'second_step_base_auth_data', 'second_step_base_demetr_data', 'second_step_base_demetr_auth_data', 'prompt_basic', 'prompt_GEMBA', 'wmt21_annotator']:
    if EXPERIMENT_GROUP_NAME != BASELINE_MODEL:
        print(EXPERIMENT_GROUP_NAME)
        exp_results = {}
        for lp in LI_LANGUAGE_PAIRS_WMT_21_CED:
            labels_df = load_ced_test_data(lp)

            if 'prompt' not in EXPERIMENT_GROUP_NAME and 'wmt21' not in EXPERIMENT_GROUP_NAME:
                exp_median_seed = find_median_seed(lp, EXPERIMENT_GROUP_NAME)
            else:
                exp_median_seed = 'llm'
            exp_df = load_predictions(lp, EXPERIMENT_GROUP_NAME, exp_median_seed)

            if 'baseline' not in BASELINE_MODEL:
                baseline_df = load_predictions(lp, BASELINE_MODEL, baseline_median_seeds[lp])
            else:
                baseline_df = pd.read_csv(os.path.join(PREDICTIONS_DIR, "ced_data", "baseline", f"{lp}_test_baseline_cometkiwi_22.csv"))
                    
            merged_df = baseline_df.merge(exp_df, on="idx")
            full_df = merged_df.merge(labels_df, on="idx")

            true_labels = torch.Tensor(full_df['score'])

            baseline_preds = torch.Tensor(full_df['score_x']) > 0.5
            baseline_preds = baseline_preds.long()
            
            exp_preds = torch.Tensor(full_df['score_y']) > 0.5
            exp_preds = exp_preds.long()

            baseline_mcc = mcc(true_labels, baseline_preds)
            exp_mcc = mcc(true_labels, exp_preds)
            metrics_mcc = mcc(baseline_preds, exp_preds)

            exp_results[lp] = [round(exp_mcc.item(), 3), williams_test(baseline_mcc, exp_mcc, metrics_mcc) < 0.05]

        print(exp_results)


train_monolingual_auth_data
{'en-cs': [0.459, False], 'en-de': [0.478, False], 'en-ja': [0.173, True], 'en-zh': [0.28, True]}
second_step_base_auth_data
{'en-cs': [0.489, False], 'en-de': [0.472, False], 'en-ja': [0.243, False], 'en-zh': [0.304, False]}
second_step_base_demetr_data
{'en-cs': [0.489, False], 'en-de': [0.484, False], 'en-ja': [0.255, False], 'en-zh': [0.27, True]}
second_step_base_demetr_auth_data
{'en-cs': [0.472, False], 'en-de': [0.503, False], 'en-ja': [0.137, True], 'en-zh': [0.244, True]}
prompt_basic
{'en-cs': [0.39, True], 'en-de': [0.368, True], 'en-ja': [0.239, False], 'en-zh': [0.327, False]}
prompt_GEMBA
{'en-cs': [0.387, True], 'en-de': [0.333, True], 'en-ja': [0.193, True], 'en-zh': [0.308, False]}
wmt21_annotator
{'en-cs': [0.422, True], 'en-de': [0.475, False], 'en-ja': [0.187, True], 'en-zh': [0.294, False]}
