In [1]:
#!/usr/bin/env python3
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "ACES_private/score_normalization")))
from normalize_scores_utils import *

In [3]:
# Load the ACES scores
ACES_scores_2022_path = '/home/arnisa/Desktop/work/metric_sensitivity_analysis/scored_aces_final.tsv'
# 2023 later
ACES_scores = load_ACES_scores(ACES_scores_2022_path)

# Load WMT scores
metrics_names = set(ACES_scores.keys())
WMT_scores_path = '/home/arnisa/Desktop/work/metric_sensitivity_analysis/mt-metrics-eval-v2/wmt22/metric-scores'
WMT_scores = load_WMT_scores(WMT_scores_path, metrics_names)
metrics_names = list(set(ACES_scores.keys()).intersection(set(WMT_scores.keys())))

# Get sensitivities (we only get mean(good-bad)) of the normalized scores
sensitivities, _, _, phenomena = calculate_sensitivities(ACES_scores, WMT_scores, mapping=PHENOMENA_MAPPING)

# load the ACES scores from the paper
ACES_summary_2022 = load_ACES_scores_summary_2022()

INFO:logger:Loading ACES the scores...


en-uk


100%|██████████| 30/30 [00:00<00:00, 198.86it/s]


en-ja


100%|██████████| 30/30 [00:00<00:00, 135.09it/s]


liv-en


100%|██████████| 30/30 [00:00<00:00, 1001.59it/s]


uk-cs


100%|██████████| 30/30 [00:00<00:00, 129.45it/s]


zh-en


100%|██████████| 30/30 [00:00<00:00, 47.03it/s]


en-ru


100%|██████████| 30/30 [00:00<00:00, 85.52it/s]


de-fr


100%|██████████| 30/30 [00:00<00:00, 234.23it/s]


en-cs


100%|██████████| 30/30 [00:00<00:00, 84.91it/s]


en-de


100%|██████████| 30/30 [00:00<00:00, 49.49it/s]


fr-de


100%|██████████| 30/30 [00:00<00:00, 193.03it/s]


en-hr


100%|██████████| 30/30 [00:00<00:00, 139.42it/s]


ja-en


100%|██████████| 30/30 [00:00<00:00, 119.77it/s]


ru-sah


100%|██████████| 30/30 [00:00<00:00, 953.16it/s]


en-zh


100%|██████████| 30/30 [00:00<00:00, 62.93it/s]


en-liv


100%|██████████| 30/30 [00:00<00:00, 1057.38it/s]


de-en


100%|██████████| 30/30 [00:00<00:00, 87.81it/s]


sah-ru


100%|██████████| 30/30 [00:00<00:00, 808.38it/s]


cs-uk


100%|██████████| 30/30 [00:00<00:00, 96.92it/s]


uk-en


100%|██████████| 30/30 [00:00<00:00, 99.82it/s]


ru-en


100%|██████████| 30/30 [00:00<00:00, 99.98it/s]


cs-en


100%|██████████| 30/30 [00:00<00:00, 77.67it/s]


In [4]:
sensitivities

{'metricx_xl_DA_2019': {'undertranslation': 0.3671614119894317,
  'untranslated': 1.7529706602024642,
  'mistranslation': 0.46832359419799086,
  'do not translate': 0.5465601074166211,
  'omission': 0.4409283424473139,
  'overtranslation': 0.6809637398565015,
  'punctuation': 0.28239264672353126,
  'wrong language': 0.7460884373810064,
  'real-world knowledge': 0.5027537175827648,
  'addition': 0.0950902947197431},
 'MEE': {'undertranslation': nan,
  'untranslated': nan,
  'mistranslation': nan,
  'do not translate': 0.33100090392079723,
  'omission': nan,
  'overtranslation': nan,
  'punctuation': 0.2064284308793095,
  'wrong language': nan,
  'real-world knowledge': nan,
  'addition': nan},
 'COMETKiwi': {'undertranslation': 0.3592402492974571,
  'untranslated': 0.6638297184083296,
  'mistranslation': 0.5732563461813599,
  'do not translate': 0.2624675537257028,
  'omission': 0.4366934012988357,
  'overtranslation': 0.571782725379702,
  'punctuation': 0.22746049678056637,
  'wrong la

In [5]:
metrics_names == list(sensitivities.keys())

True

In [8]:
# create groups here:
means = {metric:[sensitivities[metric][p] for p in phenomena] for metric in metrics_names}
tau = {metric:[ACES_summary_2022[metric][p] for p in phenomena] for metric in metrics_names if metric in ACES_summary_2022}

In [9]:
means

{'metricx_xl_DA_2019': [0.3671614119894317,
  1.7529706602024642,
  0.46832359419799086,
  0.5465601074166211,
  0.4409283424473139,
  0.6809637398565015,
  0.28239264672353126,
  0.7460884373810064,
  0.5027537175827648,
  0.0950902947197431],
 'MEE': [nan,
  nan,
  nan,
  0.33100090392079723,
  nan,
  nan,
  0.2064284308793095,
  nan,
  nan,
  nan],
 'COMETKiwi': [0.3592402492974571,
  0.6638297184083296,
  0.5732563461813599,
  0.2624675537257028,
  0.4366934012988357,
  0.571782725379702,
  0.22746049678056637,
  -0.5734240497846572,
  0.3392520839189057,
  0.12424218010601351],
 'chrF': [-0.22512321669961785,
  1.514542123624745,
  -0.07828824967779315,
  0.2776495256981501,
  0.2517183282944141,
  -0.28815119987870125,
  0.1017329547199582,
  0.6501324871223509,
  -0.09071221405154077,
  0.0594788644725885],
 'UniTE-ref': [0.18290898643881418,
  0.08114389371218786,
  0.2853519412386816,
  0.29993927966972367,
  0.37929767387818797,
  0.3765773864095423,
  0.15783458316341825,
  

In [52]:
import plotly.express as px
def grouped_line_plot(groups: List[Dict[str,list]], metrics_names: List[str], group_labels: List[str], phenomena: List[str]):
    '''
    Inputs: 
        1. means and tau scores
        format = {
            metric1: [score for phenomenon 1, score for phenomenon 2, ..]
        }
        2. A list of the labels for: 
            the groups (mean (good-bad), tau, ...)
            metrics
            phenomena (the order is important because in means and tau scores the scores are ordered acc. to phenomena)
    '''
    assert len(groups) > 0 and len(groups) == len(group_labels) and len(metrics_names) > 0
    fig = go.Figure()
    colors = [['lightsteelblue',  'aqua', 'aquamarine', 'darkturquoise'],
        ['chocolate', 'coral', 'crimson', 'orange']]
    for i,group in enumerate(groups):
        for j,metric in enumerate(metrics_names):
            fig.add_trace(go.Scatter(x=phenomena, y=group[metric],mode='lines',name=group_labels[i]+" - "+metric,
                          line=dict(color=colors[i][j])))
    fig.update_layout(
        title=" ".join(group_labels)
        )
    fig.show()

In [53]:
bleus = ["BLEU", "f101spBLEU", "f200spBLEU"]
comets = ["COMET-20", "COMET-22", "MS-COMET-22", "MS-COMET-QE-22"]
xl = ["metricx_xl_DA_2019", "metricx_xl_MQM_2020", "metricx_xxl_DA_2019", "metricx_xxl_MQM_2020"]
unite = ["UniTE", "UniTE-src", "UniTE-ref"]
grouped_line_plot([means, tau], comets, ["Mean(good-bad)", "tau"], phenomena)