In [19]:
import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from sklearn.metrics import matthews_corrcoef

## Load COMETKiwi 2022 model

In [20]:
# log into Huggingface hub
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [21]:
from comet import download_model, load_from_checkpoint

model_path = download_model("Unbabel/wmt22-cometkiwi-da")
model = load_from_checkpoint(model_path)


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Lightning automatically upgraded your loaded checkpoint from v1.8.2 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../../.cache/huggingface/hub/models--Unbabel--wmt22-cometkiwi-da/snapshots/b3a8aea5a5fc22db68a554b92b3d96eb6ea75cc9/checkpoints/model.ckpt`
Encoder model frozen.


## Load data

TODO: set the `main_dir` path below to parent directory of `ARC-MTQE`.

In [27]:
# TODO: set path to parent directory of ARC-MTQE
main_dir = "/Users/rjersakova/Documents/"

data_dir = os.path.join(main_dir, "ARC-MTQE", "mlqe-pe", "data")

path_data = os.path.join(data_dir, "catastrophic_errors", "encs_majority_test_blind.tsv")
path_labels = os.path.join(data_dir, "catastrophic_errors_goldlabels", "encs_majority_test_goldlabels", "goldlabels.txt")


In [28]:
# data has to be in [{"src":"...", "mt":"..."}, {...}] format
df_data = pd.read_csv(path_data, sep='\t', header=None, names=["idx", "source", "target"])
data = []
for i, row in df_data.iterrows():
    data.append({"src": row["source"], "mt": row["target"]})

# turn ground truth labels into 0/1 
df_labels = pd.read_csv(path_labels, sep='\t', header=None, names=["lang_pair", "ref", "idx", "label"])
# ERROR = 1, NOT = 0
df_labels["error"] = np.where(df_labels["label"]=="NOT", 0, 1)


## Make predictions

In [29]:
# get COMETKiwi prediction for each translation
model_output = model.predict(data, batch_size=8, gpus=0)
print (model_output)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Predicting DataLoader 0: 100%|██████████| 125/125 [03:18<00:00,  1.59s/it]

Prediction([('scores', [0.8365867137908936, 0.5143743753433228, 0.7877252697944641, 0.4733821451663971, 0.7254721522331238, 0.7805802822113037, 0.8630162477493286, 0.8904730081558228, 0.6905339360237122, 0.6576987504959106, 0.8327204585075378, 0.4592285752296448, 0.6654027104377747, 0.6696677207946777, 0.8877825736999512, 0.8606453537940979, 0.704640805721283, 0.7503564953804016, 0.7028335332870483, 0.8855043053627014, 0.775689423084259, 0.8642303347587585, 0.7640092372894287, 0.698534369468689, 0.6150897145271301, 0.8028667569160461, 0.4873690605163574, 0.33454611897468567, 0.49928438663482666, 0.6849006414413452, 0.8019722700119019, 0.887939453125, 0.7032549977302551, 0.7155061364173889, 0.8888627290725708, 0.6260592341423035, 0.5937173366546631, 0.799419105052948, 0.6121482253074646, 0.6342429518699646, 0.7091159820556641, 0.6825821399688721, 0.7018523216247559, 0.7312526702880859, 0.584925651550293, 0.44523611664772034, 0.8398633003234863, 0.787706732749939, 0.4344383478164673, 0.5




In [30]:
# min(model_output.scores), max(model_output.scores)

# this might be unnecessary step but it ensures that the indexes of the data and the gold labels are the same
# an alternative would be to save the comet scores in df_labels directly
df_data['comet_score'] = model_output.scores

df_results = pd.merge(df_data, df_labels, on="idx")

## Evaluate

In [31]:
# turn COMETKiwi predictions into binary labels for a range of thresholds
# evaluate accuracy of prediction
thresholds = np.arange(0.1, 1, 0.1)
y_true = df_results["error"]

for t in thresholds:

    # we use 1 to indicate ERROR --> values below threshold are flagged as errors
    y_hat = (df_results['comet_score'] <= t).astype(int)

    print(t, matthews_corrcoef(y_true, y_hat))

0.1 0.0
0.2 0.0
0.30000000000000004 0.1136298865897307
0.4 0.17361933066582394
0.5 0.2831032951475931
0.6 0.3697324982623221
0.7000000000000001 0.3637324741187898
0.8 0.3030024569183256
0.9 0.015273470515558312
