In [None]:
import os
import re

import numpy as np
import pandas as pd
import seaborn as sns
import yaml
from matplotlib import pyplot as plt

In [None]:
# Categories as specified in the dataset are different to the paper
# This gives the mapping between them
with open("../configs/demetr/cat_correction.yaml") as stream:
    try:
        cat_correction = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

In [None]:
# Map categories to error severity
# Severity is as specified in the paper
with open("../configs/demetr/cat_severity.yaml") as stream:
    try:
        cat_severity = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

### Load M4ST results

In [252]:
# This is the root directory holding the output JSON files from running
# scripts/demetr/process_demetr.py
# They are also available in the project Sharepoint
m4st_res_dir = "../outputs/demetr"

In [None]:
res_files = os.listdir(m4st_res_dir)

In [None]:
# Read all files into a single dataframe
results_dataframes = []

for i in range(len(res_files)):
    try:
        res_file = res_files[i]
        res_df = pd.read_json(os.path.join(m4st_res_dir, res_file))

        # Get category
        cat_search = re.search("_id.(.)?_", res_file)
        cat_span = cat_search.span()
        cat = res_file[cat_span[0] : cat_span[1]].strip("_id")

        # Get metric name
        name_search = re.search("_(base|critical|major|minor)", res_file)
        name_span = name_search.span()
        metric = res_file[: name_span[0]]

        res_df = res_df.T

        # MetricX produces MQM-style scores, meaning that a lower score indicates a
        # better translation (scores are out of 25)
        # Reverse scores so that they match the other metrics (lower is worse)
        if "metricx" in metric:
            res_df["mt_score"] = 25 - res_df.mt_score
            res_df["disfluent_score"] = 25 - res_df.disfluent_score

        res_df["metric"] = metric
        res_df["sentence_id"] = res_df.index
        res_df["category"] = int(cat)
        results_dataframes.append(res_df)
    except IsADirectoryError:
        pass

In [None]:
all_res = pd.concat(results_dataframes)
all_res.head(5)

In [None]:
# Correct categories to align with the paper
all_res["category"] = all_res["category"].replace(cat_correction)

In [None]:
# Add column indicating DEMETR accuracy
all_res["correct"] = all_res["mt_score"] > all_res["disfluent_score"]

In [None]:
# Accuracy is reversed for category 35 (reference as translation) so need to adjust that
cat_to_rev = all_res.loc[all_res["category"] == 35]
cat_to_rev

In [None]:
cat_to_rev["correct"] = cat_to_rev["mt_score"] < cat_to_rev["disfluent_score"]
cat_to_rev

In [None]:
# Reassign values to original dataframe
all_res.loc[all_res["category"] == 35, "correct"] = cat_to_rev.correct

In [None]:
# Check result
all_res.loc[all_res["category"] == 35]

In [None]:
# Add column for severity
all_res["severity"] = all_res["category"].map(cat_severity)

In [None]:
# Check dataframe
all_res.head(5)

In [None]:
# Save out to file so we have a corrected set of results
all_res.to_csv("../outputs/demetr/all/all.csv", index=False)

In [None]:
# Tidy up naming for plotting
all_res["metric"] = all_res.metric.replace(
    {
        "wmt22-comet-da": "COMET-22-Ref",
        "COMET_Ref": "COMET-21-Ref",
        "COMET-QE": "COMET-21-QE",
        "wmt22-cometkiwi-da": "COMETKiwi-22",
        "Bleu": "BLEU",
        "BLASER_QE": "BLASER-2_QE",
        "BLASER_Ref": "BLASER-2_Ref",
        "google_metricx-24-hybrid-large-v2p6-bfloat16_qe": "MetricX-24L-16-QE",
        "google_metricx-24-hybrid-large-v2p6_qe": "MetricX-24L-QE",
        "google_metricx-24-hybrid-xl-v2p6-bfloat16_qe": "MetricX-24XL-16-QE",
        "google_metricx-24-hybrid-xl-v2p6_qe": "MetricX-24XL-QE",
        "google_metricx-24-hybrid-xxl-v2p6_qe": "MetricX-24XXL-QE",
        "google_metricx-24-hybrid-xxl-v2p6-bfloat16_qe": "MetricX-24XXL-16-QE",
        "google_metricx-24-hybrid-xl-v2p6_ref": "MetricX-24XL-Ref",
        "google_metricx-24-hybrid-xl-v2p6-bfloat16_ref": "MetricX-24XL-16-Ref",
        "google_metricx-24-hybrid-xxl-v2p6-bfloat16_ref": "MetricX-24XXL-16-Ref",
        "google_metricx-24-hybrid-xxl-v2p6_ref": "MetricX-24XXL-Ref",
        "google_metricx-24-hybrid-large-v2p6-bfloat16_ref": "MetricX-24L-16-Ref",
        "google_metricx-24-hybrid-large-v2p6_ref": "MetricX-24L-Ref",
    }
)

In [None]:
# Overall mean accuracy by metric
all_res.groupby("metric").correct.mean().sort_values(ascending=False)

In [None]:
fig, axs = plt.subplots()
by_language = all_res.groupby("source_language")["correct"].mean()
axs.plot(by_language, "x")
plt.xticks(np.arange(10), by_language.index, rotation=45)
plt.ylabel("DEMETR accuracy (%)")
plt.xlabel("Source language")
plt.title("Mean performance across all 35 categories")

In [None]:
fig, axs = plt.subplots()
by_severity = all_res.groupby("severity")["correct"].mean()
by_severity.plot(kind="bar")
plt.xticks(np.arange(4), by_severity.index, rotation=0)
plt.ylabel("DEMETR accuracy")
plt.xlabel("Severity")
plt.title("Mean performance for each error type")

In [None]:
fig, axs = plt.subplots()
sev_by_lang = all_res.groupby(["source_language", "severity"])["correct"].mean()
sev_by_lang.unstack().plot(kind="bar", ax=axs)
plt.xticks(rotation=45)
plt.ylabel("DEMETR accuracy")
plt.xlabel("Source language")
plt.title("Mean performance for each severity level, by language")
plt.legend(loc="right", bbox_to_anchor=(1.25, 0.5))

In [None]:
fig, axs = plt.subplots()
sev_by_lang = all_res.groupby(["source_language", "metric"])["correct"].mean()
sev_by_lang.unstack().plot(kind="bar", ax=axs)
plt.xticks(rotation=45)
plt.ylabel("DEMETR accuracy")
plt.xlabel("Source language")
plt.title("Mean performance for each metric, by language")
plt.legend(loc="right", bbox_to_anchor=(1.4, 0.5))

In [None]:
fig, axs = plt.subplots()
to_table = sev_by_lang.reset_index()
sev_by_lang = all_res.groupby(["source_language", "metric"])["correct"].mean()
sev_by_lang.unstack().plot(kind="bar", ax=axs)
plt.xticks(rotation=45)
plt.ylabel("DEMETR accuracy")
plt.xlabel("Source language")
plt.title("Mean performance for COMET metrics, by language")
plt.legend(loc="right", bbox_to_anchor=(1.4, 0.5))

In [None]:
fig, axs = plt.subplots()
sorted_overall_mean = (
    all_res.groupby(["metric"])["correct"].mean().sort_values(ascending=False)
)
axs.plot(sorted_overall_mean, "x")
plt.xticks(rotation=30)
plt.xlabel("Metric")
plt.ylabel("Accuracy")
# plt.title("Mean performance across all languages")

In [None]:
corr_by_category = (
    all_res.groupby(["metric", "category"])["correct"].mean().reset_index()
)
corr_by_category

In [None]:
metrics_to_plot = [
    "BLEU",
    "COMET-22-Ref",
    "MetricX-24L-Ref",
    "XCOMET-XL",
    "BLASER-2_Ref",
    "ChrF2",
]

In [None]:
corr_by_category = corr_by_category[corr_by_category.metric.isin(metrics_to_plot)]

In [None]:
grouped = corr_by_category.groupby("metric").median().sort_values(by="correct")

In [None]:
fig, axs = plt.subplots()
g = sns.boxplot(
    corr_by_category,
    x="metric",
    y="correct",
    fill=False,
    ax=axs,
    width=0.5,
    order=grouped.index,
)
axs.set_xticklabels(rotation=20, labels=axs.get_xticklabels())
axs.set_xlabel("Metric")
axs.set_ylabel("Accuracy")
plt.tight_layout()
plt.savefig("../outputs/demetr/plots/metrics-boxplot.png")

In [None]:
fig, axs = plt.subplots()
res_subset = all_res[all_res.metric.isin(metrics_to_plot)]
sev_by_lang = res_subset.groupby(["metric", "severity"])["correct"].mean()
sev_by_lang = sev_by_lang.unstack()
sev_by_lang.plot(kind="barh", ax=axs)
plt.xticks(rotation=0)
plt.xlabel("DEMETR accuracy")
plt.ylabel("Metric")
# plt.title("Mean performance for each severity level by metric")
plt.legend(bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.savefig("../outputs/demetr/plots/demetr-by-severity.png")

In [None]:
fig, axs = plt.subplots()
sev_by_lang = all_res.groupby(["metric", "severity"])["correct"].mean()
sev_by_lang.unstack().plot(kind="bar", ax=axs)
plt.xticks(rotation=30)
plt.ylabel("DEMETR accuracy")
plt.xlabel("Metric")
# plt.title("Mean performance by severity for COMET metrics")
plt.legend(loc="right", bbox_to_anchor=(1.23, 0.5))

### BLASER only

Look at BLASER performance for three different perturbation types, for different language pairs.

In [None]:
m4st_res_dir = "../outputs/demetr"

In [None]:
blaser_new_15 = pd.read_json(
    os.path.join(m4st_res_dir, "BLASER_REF_minor_id15_case.json")
)
blaser_new_8 = pd.read_json(
    os.path.join(m4st_res_dir, "BLASER_Ref_critical_id8_negation.json")
)
blaser_new_6 = pd.read_json(
    os.path.join(m4st_res_dir, "BLASER_Ref_critical_id6_addition.json")
)

In [None]:
blaser_new_15 = blaser_new_15.T
blaser_new_8 = blaser_new_8.T
blaser_new_6 = blaser_new_6.T

In [None]:
blaser_new_15["diff"] = blaser_new_15.mt_score - blaser_new_15.disfluent_score
blaser_new_8["diff"] = blaser_new_8.mt_score - blaser_new_8.disfluent_score
blaser_new_6["diff"] = blaser_new_6.mt_score - blaser_new_6.disfluent_score

In [None]:
fig, axs = plt.subplots()
blaser_new_15.groupby("source_language").mean()["diff"].plot(ax=axs)
blaser_new_8.groupby("source_language").mean()["diff"].plot(ax=axs)
blaser_new_6.groupby("source_language").mean()["diff"].plot(ax=axs)

fig.legend(
    labels=["Pronoun case", "Negation", "Addition"],
    loc="right",
    bbox_to_anchor=(1.15, 0.5),
)
axs.set_ylabel("Score difference")
plt.xticks(np.arange(10), np.unique(blaser_new_15.source_language), rotation=30)