In [1]:
#!/usr/bin/env python3
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "ACES_private/score_normalization")))
from normalize_scores_utils import *

# Normalize the Metric Names: Should be same for all

In [52]:
# get the metric names from WMT scores
lang_dir = os.path.join(WMT_scores_path, 'cs-en')
files = os.listdir(lang_dir)
WMT_metric_names = []
os.listdir(lang_dir)
for file in files:
    if '-refC.seg.score' in file:
        WMT_metric_names.append(file[:-15])
    elif '-src.seg.score' in file:
        WMT_metric_names.append(file[:-14])

In [58]:
# NOTE: These ones are not in WMT metrics: 'MATESE', 'MATESE-QE', 'MEE', 'MEE2', 'MEE4'
# The scores are also missing in the ACES 2022 scores

# NOTE: Also these from the ACES-2023 (but they are not listed in the metrics shared task paper anyway):
# 'Calibri-COMET22', 'Calibri-COMET22-QE'
# These don't exist in WMT yet: 'CometKiwi-XL', 'CometKiwi-XXL','GEMBA-MQM','MEE4','MEE4_stsb_xlm',
# 'MetricX-23', 'MetricX-23-QE','MetricX-23-QE-b','MetricX-23-QE-c','MetricX-23-b','MetricX-23-c','Random-sysname',
# and more..

METRIC_NAMES_MAPPING = {
    # For ACES 2022
    'COMET-QE-Baseline':'COMET-QE',
    # For ACES 2023
    'BERTscore':'BERTScore',
    'COMET':'COMET-22',
    'CometKiwi':'COMETKiwi',
    'MaTESe':'MATESE', 
}

In [55]:
set(ACES_scores_2023.keys()) - set(np.unique(WMT_metric_names))

{'BERTscore',
 'COMET',
 'Calibri-COMET22',
 'Calibri-COMET22-QE',
 'CometKiwi',
 'CometKiwi-XL',
 'CometKiwi-XXL',
 'GEMBA-MQM',
 'MEE4',
 'MEE4_stsb_xlm',
 'MaTESe',
 'MetricX-23',
 'MetricX-23-QE',
 'MetricX-23-QE-b',
 'MetricX-23-QE-c',
 'MetricX-23-b',
 'MetricX-23-c',
 'Random-sysname',
 'XCOMET-Ensemble',
 'XCOMET-QE-Ensemble',
 'XCOMET-XL',
 'XCOMET-XXL',
 'XLsim',
 'XLsimQE',
 'cometoid22-wmt21',
 'cometoid22-wmt22',
 'cometoid22-wmt23',
 'eBLEU',
 'embed_llama',
 'mt0regressor',
 'mtsamp-bleurt0p2p1-qe',
 'mtsamp-bleurtxv1p-qe',
 'partokengram_F',
 'prismRef',
 'prismSrc2',
 'spBLEU',
 'tokengram_F'}

In [57]:
print(list(set(ACES_scores_2023.keys())))

['mtsamp-bleurtxv1p-qe', 'tokengram_F', 'mtsamp-bleurt0p2p1-qe', 'CometKiwi', 'CometKiwi-XL', 'prismRef', 'BERTscore', 'MaTESe', 'cometoid22-wmt23', 'COMET', 'prismSrc2', 'BLEURT-20', 'MetricX-23-QE', 'BLEU', 'Random-sysname', 'KG-BERTScore', 'spBLEU', 'XCOMET-XL', 'CometKiwi-XXL', 'cometoid22-wmt21', 'embed_llama', 'YiSi-1', 'XLsimQE', 'XCOMET-XXL', 'Calibri-COMET22', 'mt0regressor', 'MetricX-23-QE-c', 'partokengram_F', 'MetricX-23', 'XCOMET-QE-Ensemble', 'cometoid22-wmt22', 'MetricX-23-QE-b', 'MetricX-23-b', 'chrF', 'MS-COMET-QE-22', 'eBLEU', 'MetricX-23-c', 'XCOMET-Ensemble', 'GEMBA-MQM', 'XLsim', 'Calibri-COMET22-QE', 'MEE4_stsb_xlm', 'MEE4']


In [10]:
set(ACES_scores_2022.keys())

{'BERTScore',
 'BLEU',
 'BLEURT-20',
 'COMET-20',
 'COMET-22',
 'COMET-QE-Baseline',
 'COMETKiwi',
 'Cross-QE',
 'HWTSC-TLM',
 'HWTSC-Teacher-Sim',
 'KG-BERTScore',
 'MATESE',
 'MATESE-QE',
 'MEE',
 'MEE2',
 'MEE4',
 'MS-COMET-22',
 'MS-COMET-QE-22',
 'REUSE',
 'UniTE',
 'UniTE-ref',
 'UniTE-src',
 'YiSi-1',
 'chrF',
 'f101spBLEU',
 'f200spBLEU',
 'metricx_xl_DA_2019',
 'metricx_xl_MQM_2020',
 'metricx_xxl_DA_2019',
 'metricx_xxl_MQM_2020'}

In [53]:
set(np.unique(WMT_metric_names))

{'BERTScore',
 'BLEU',
 'BLEURT-20',
 'COMET-20',
 'COMET-22',
 'COMET-QE',
 'COMETKiwi',
 'Cross-QE',
 'HWTSC-TLM',
 'HWTSC-Teacher-Sim',
 'KG-BERTScore',
 'MS-COMET-22',
 'MS-COMET-QE-22',
 'REUSE',
 'UniTE',
 'UniTE-ref',
 'UniTE-src',
 'YiSi-1',
 'chrF',
 'f101spBLEU',
 'f200spBLEU',
 'metricx_xl_DA_2019',
 'metricx_xl_MQM_2020',
 'metricx_xxl_DA_2019',
 'metricx_xxl_MQM_2020'}

# Main Part: Table and Correlations

In [61]:
# Load the ACES scores
ACES_scores_2022_path = '/home/arnisa/Desktop/work/metric_sensitivity_analysis/scored_aces_final.tsv'
ACES_scores_2023_path = '/home/arnisa/Desktop/work/metric_sensitivity_analysis/aces-scored-2023-all-scores.quote_errors_removed.tsv'
# 2023 later
ACES_scores_2022 = load_ACES_scores(ACES_scores_2022_path, good_token='.-good', bad_token='.-bad', metric_mapping=METRIC_NAMES_MAPPING)
ACES_scores_2023 = load_ACES_scores(ACES_scores_2023_path, good_token='.-good', bad_token='.-bad', metric_mapping=METRIC_NAMES_MAPPING)

# Load WMT scores
metrics_names = set(ACES_scores_2022.keys())
WMT_scores_path = '/home/arnisa/Desktop/work/metric_sensitivity_analysis/mt-metrics-eval-v2/wmt22/metric-scores'
WMT_scores = load_WMT_scores(WMT_scores_path, metrics_names)
metrics_names = list(set(ACES_scores_2022.keys()).intersection(set(WMT_scores.keys())))

# Get sensitivities (we only get mean(good-bad)) of the normalized scores
sensitivities, _, _, phenomena = calculate_sensitivities(ACES_scores_2022, WMT_scores, mapping=PHENOMENA_MAPPING)

# load the ACES scores from the paper
ACES_summary_2022 = load_ACES_scores_summary_2022()

en-uk


100%|██████████| 30/30 [00:00<00:00, 222.72it/s]


en-ja


100%|██████████| 30/30 [00:00<00:00, 158.93it/s]


liv-en


100%|██████████| 30/30 [00:00<00:00, 1485.81it/s]


uk-cs


100%|██████████| 30/30 [00:00<00:00, 141.04it/s]


zh-en


100%|██████████| 30/30 [00:00<00:00, 54.41it/s]


en-ru


100%|██████████| 30/30 [00:00<00:00, 98.59it/s]


de-fr


100%|██████████| 30/30 [00:00<00:00, 334.34it/s]


en-cs


100%|██████████| 30/30 [00:00<00:00, 112.93it/s]


en-de


100%|██████████| 30/30 [00:00<00:00, 61.70it/s]


fr-de


100%|██████████| 30/30 [00:00<00:00, 235.29it/s]


en-hr


100%|██████████| 30/30 [00:00<00:00, 170.61it/s]


ja-en


100%|██████████| 30/30 [00:00<00:00, 136.88it/s]


ru-sah


100%|██████████| 30/30 [00:00<00:00, 1304.96it/s]


en-zh


100%|██████████| 30/30 [00:00<00:00, 90.94it/s]


en-liv


100%|██████████| 30/30 [00:00<00:00, 1692.43it/s]


de-en


100%|██████████| 30/30 [00:00<00:00, 120.55it/s]


sah-ru


100%|██████████| 30/30 [00:00<00:00, 1485.15it/s]


cs-uk


100%|██████████| 30/30 [00:00<00:00, 198.33it/s]


uk-en


100%|██████████| 30/30 [00:00<00:00, 219.16it/s]


ru-en


100%|██████████| 30/30 [00:00<00:00, 184.63it/s]


cs-en


100%|██████████| 30/30 [00:00<00:00, 141.71it/s]


In [60]:
def load_ACES_scores(ACES_scores_path: str, good_token:str = '.-good', bad_token:str ='.-bad', metric_mapping:Dict[str, str] = METRIC_MAPPING) -> Dict[str, Dict[str, List[List]]]: 
    '''
    Return the metric scores for each phenomena and metric and samples
    format = {
        'BLEU': {
            'addition': [[1,2,3,…],	# list of scores for the good translations
                        [1,2,3,…]],	# list of scores for the bad translations
            , …, 'all': [[1,2,3,…],	# list of scores for the good translations for all phenomena
                        [1,2,3,…]]	# list of scores for the bad translations for all phenomena
            }
    }
    '''
    if not os.path.exists(ACES_scores_path):
        logger.error('No ACES scores path: %s' %(ACES_scores_path))
        exit()
    logger.info('Loading ACES the scores...')
    ACES_scores = read_file(ACES_scores_path)
    phenomena = np.unique(ACES_scores['phenomena'])

    ACES_metrics = []
    for key in ACES_scores.keys():
        if good_token in key:
            ACES_metrics.append(key[:-len(good_token)])
        elif bad_token in key:
            ACES_metrics.append(key[:-len(bad_token)])
    metrics_names = np.unique(ACES_metrics)

    if metric_mapping == None:
        metric_mapping = dict(zip(list(metrics_names), list(metrics_names)))
    else:
        for metric in metrics_names:
            if metric not in metric_mapping:
                metric_mapping[metric] = metric

    template = dict(zip(phenomena, np.empty((len(phenomena),2))))
    ACES_metrics = {}
    for metric in metric_mapping.values():
        ACES_metrics[metric] = copy.deepcopy(template)
    for p in phenomena:
        ids = np.where(ACES_scores['phenomena']==p)[0]
        for metric in metrics_names:
            if type(ACES_metrics[metric_mapping[metric]][p][0]) != list:
                ACES_metrics[metric_mapping[metric]][p] = [list(ACES_scores[metric+good_token][ids]), list(ACES_scores[metric+bad_token][ids])]
            else:     
                ACES_metrics[metric_mapping[metric]][p][0].extend(list(ACES_scores[metric+good_token][ids]))
                ACES_metrics[metric_mapping[metric]][p][1].extend(list(ACES_scores[metric+bad_token][ids])) 
    for metric in metrics_names:
        if "all" not in ACES_metrics[metric_mapping[metric]]:
            ACES_metrics[metric_mapping[metric]]["all"] = [list(ACES_scores[metric+good_token]), list(ACES_scores[metric+bad_token])]
        else:
            ACES_metrics[metric_mapping[metric]]["all"][0].extend(list(ACES_scores[metric+good_token]))
            ACES_metrics[metric_mapping[metric]]["all"][1].extend(list(ACES_scores[metric+bad_token])) 
    return ACES_metrics

In [64]:
ACES_scores_2022.keys()

dict_keys(['COMET-QE', 'BERTScore', 'COMET-22', 'COMETKiwi', 'MATESE', 'BLEU', 'BLEURT-20', 'COMET-20', 'Cross-QE', 'HWTSC-TLM', 'HWTSC-Teacher-Sim', 'KG-BERTScore', 'MATESE-QE', 'MEE', 'MEE2', 'MEE4', 'MS-COMET-22', 'MS-COMET-QE-22', 'REUSE', 'UniTE', 'UniTE-ref', 'UniTE-src', 'YiSi-1', 'chrF', 'f101spBLEU', 'f200spBLEU', 'metricx_xl_DA_2019', 'metricx_xl_MQM_2020', 'metricx_xxl_DA_2019', 'metricx_xxl_MQM_2020'])

In [65]:
# From the Paper:
METRICS_GROUPING_2022: {"baseline": ["BLEU", "spBLEU", "chrF", "BERTScore", "BLEURT-20",
                                     "COMET-20", "COMET-QE", "YISI-1"],
                        "reference-based": ["COMET-22", 'metricx_xl_DA_2019', 'metricx_xl_MQM_2020', 'metricx_xxl_DA_2019', 'metricx_xxl_MQM_2020',
                                            'MS-COMET-22', "UniTE", "UniTE-ref"],
                        "reference-free": ["COMETKiwi", "Cross-QE", 'HWTSC-Teacher-Sim', 'HWTSC-TLM',
                                           'KG-BERTScore', "MS-COMET-QE-22", "UniTE-src"]

                    }

In [66]:
sensitivities

{'COMET-22': {'mistranslation': 0.3888224376606538,
  'real-world knowledge': 0.2886188317537399,
  'undertranslation': 0.2965250375195492,
  'do not translate': 0.4270982024272212,
  'omission': 0.339629071269166,
  'overtranslation': 0.48378424720659197,
  'wrong language': -0.6934677117617782,
  'addition': 0.060463629648521644,
  'punctuation': 0.2744474024129998,
  'untranslated': 0.5161487508004713},
 'COMETKiwi': {'mistranslation': 0.5732563461813599,
  'real-world knowledge': 0.3392520839189057,
  'undertranslation': 0.3592402492974571,
  'do not translate': 0.2624675537257028,
  'omission': 0.4366934012988357,
  'overtranslation': 0.571782725379702,
  'wrong language': -0.5734240497846572,
  'addition': 0.12424218010601351,
  'punctuation': 0.22746049678056637,
  'untranslated': 0.6638297184083296},
 'COMET-QE': {'mistranslation': 0.06962277641161764,
  'real-world knowledge': 0.09709456900810254,
  'undertranslation': 0.19738572854782688,
  'do not translate': 0.0389508707864

In [None]:
def generate_summary_table(scores:Dict[str, Dict[str, Dict[str, int]]], metrics_groups:Dict[str,list] = METRICS_GROUPING_2022):
    

# Plots

In [5]:
# create groups here:
means = {metric:[sensitivities[metric][p] for p in phenomena] for metric in metrics_names}
tau = {metric:[ACES_summary_2022[metric][p] for p in phenomena] for metric in metrics_names if metric in ACES_summary_2022}

In [15]:
print(metrics_names)

['COMET-20', 'YiSi-1', 'metricx_xxl_MQM_2020', 'UniTE', 'MATESE', 'f200spBLEU', 'metricx_xl_DA_2019', 'UniTE-src', 'MEE', 'KG-BERTScore', 'BERTScore', 'BLEU', 'Cross-QE', 'chrF', 'MEE4', 'MATESE-QE', 'HWTSC-Teacher-Sim', 'MEE2', 'REUSE', 'MS-COMET-QE-22', 'metricx_xl_MQM_2020', 'MS-COMET-22', 'HWTSC-TLM', 'COMET-22', 'metricx_xxl_DA_2019', 'UniTE-ref', 'f101spBLEU', 'BLEURT-20', 'COMETKiwi']


In [23]:
# Group the metrics:
bleus = ["BLEU", "f101spBLEU", "f200spBLEU"]
comets = ["COMET-20", "COMET-22", "MS-COMET-22", "MS-COMET-QE-22"]
xl = ["metricx_xl_DA_2019", "metricx_xl_MQM_2020", "metricx_xxl_DA_2019", "metricx_xxl_MQM_2020"]
unite = ["UniTE", "UniTE-src", "UniTE-ref"]

In [41]:
grouped_line_plot([means, tau], comets, ["Mean(good-bad)", "tau"], phenomena)

In [38]:
def grouped_line_plot(groups: List[Dict[str,list]], metrics_names: List[str], group_labels: List[str], phenomena: List[str]):
    '''
    Inputs: 
        1. means and tau scores
        format = {
            metric1: [score for phenomenon 1, score for phenomenon 2, ..]
        }
        2. A list of the labels for: 
            the groups (mean (good-bad), tau, ...)
            metrics
            phenomena (the order is important because in means and tau scores the scores are ordered acc. to phenomena)
    '''
    assert len(groups) > 0 and len(groups) == len(group_labels) and len(metrics_names) > 0
    fig = go.Figure()
    colors = [['lightsteelblue',  'aqua', 'aquamarine', 'darkturquoise'],
        ['chocolate', 'coral', 'crimson', 'orange']]
    for i,group in enumerate(groups):
        for j,metric in enumerate(metrics_names):
            fig.add_trace(go.Scatter(x=phenomena, y=group[metric],mode='lines',name=group_labels[i]+" - "+metric,
                          line=dict(color=colors[i][j])))
    fig.update_layout(
        title=" ".join(group_labels)
        )
    fig.show()