In [1]:
#!/usr/bin/env python3
import os, sys
from normalize_scores_utils import *

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from aces.cli.evaluate import comp_aces_score

# Main Part: Table and Correlations

In [53]:
# Load the ACES scores
ACES_scores_2022_path = '/home/arnisa/Desktop/work/metric_sensitivity_analysis/scored_aces_final.tsv'
ACES_scores_2023_path = '/home/arnisa/Desktop/work/metric_sensitivity_analysis/aces-scored-2023-all-scores.quote_errors_removed.tsv'
# 2023 later
ACES_scores_2022 = load_ACES_scores(ACES_scores_2022_path, good_token='.-good', bad_token='.-bad', metric_mapping=METRIC_NAMES_MAPPING)
ACES_scores_2023 = load_ACES_scores(ACES_scores_2023_path, good_token='.-good', bad_token='.-bad', metric_mapping=METRIC_NAMES_MAPPING, skip_metrics=[])

# Load WMT scores
metrics_names = set(ACES_scores_2023.keys())
WMT_scores_path = '/home/arnisa/Desktop/work/metric_sensitivity_analysis/mt-metrics-eval-v2/wmt22/metric-scores'
WMT_scores = load_WMT_scores(WMT_scores_path, metrics_names)
metrics_names = list(set(ACES_scores_2023.keys()).intersection(set(WMT_scores.keys())))

# Get sensitivities (we only get mean(good-bad)) of the normalized scores
sensitivities, _, _, phenomena = calculate_sensitivities(ACES_scores_2023, WMT_scores, mapping=PHENOMENA_MAPPING)

# load the ACES scores from the paper
ACES_summary_2022 = load_ACES_scores_summary_2022()
ACES_summary_2023 = load_ACES_scores_summary_2023(skip_metrics=[])

en-uk


100%|██████████| 58/58 [00:00<00:00, 462.03it/s]


en-ja


100%|██████████| 58/58 [00:00<00:00, 336.60it/s]


liv-en


100%|██████████| 58/58 [00:00<00:00, 3509.98it/s]


uk-cs


100%|██████████| 58/58 [00:00<00:00, 289.21it/s]


zh-en


100%|██████████| 58/58 [00:00<00:00, 133.71it/s]


en-ru


100%|██████████| 58/58 [00:00<00:00, 234.13it/s]


de-fr


100%|██████████| 58/58 [00:00<00:00, 534.53it/s]


en-cs


100%|██████████| 58/58 [00:00<00:00, 227.78it/s]


en-de


100%|██████████| 58/58 [00:00<00:00, 140.07it/s]


fr-de


100%|██████████| 58/58 [00:00<00:00, 512.95it/s]


en-hr


100%|██████████| 58/58 [00:00<00:00, 346.41it/s]


ja-en


100%|██████████| 58/58 [00:00<00:00, 283.33it/s]


ru-sah


100%|██████████| 58/58 [00:00<00:00, 3466.07it/s]


en-zh


100%|██████████| 58/58 [00:00<00:00, 203.74it/s]


en-liv


100%|██████████| 58/58 [00:00<00:00, 3893.50it/s]


de-en


100%|██████████| 58/58 [00:00<00:00, 240.14it/s]


sah-ru


100%|██████████| 58/58 [00:00<00:00, 3315.84it/s]


cs-uk


100%|██████████| 58/58 [00:00<00:00, 402.95it/s]


uk-en


100%|██████████| 58/58 [00:00<00:00, 452.78it/s]


ru-en


100%|██████████| 58/58 [00:00<00:00, 361.32it/s]


cs-en


100%|██████████| 58/58 [00:00<00:00, 291.70it/s]


In [54]:
# From the ACES 2022 Paper:
METRICS_GROUPING_2022= {"baseline": ["BLEU", "f101spBLEU", "f200spBLEU", "chrF", "BERTScore", "BLEURT-20",
                                     "COMET-20", "COMET-QE", "YISI-1"],
                        "reference-based": ["COMET-22", 'metricx_xl_DA_2019', 'metricx_xl_MQM_2020', 'metricx_xxl_DA_2019', 'metricx_xxl_MQM_2020',
                                            'MS-COMET-22', "UniTE", "UniTE-ref"],
                        "reference-free": ["COMETKiwi", "Cross-QE", 'HWTSC-Teacher-Sim', 'HWTSC-TLM',
                                           'KG-BERTScore', "MS-COMET-QE-22", "UniTE-src"]

                    }

METRICS_GROUPING_2023 = {
    "baseline": ["BERTScore", "BLEU", "BLEURT-20", "chrF", "COMET-22", "COMETKiwi", 
                 "f200spBLEU", "MS-COMET-QE-22", "Random-sysname", "YiSi-1"],
    "reference-based": ["eBLEU", "embed_llama", "MaTESe", "MetricX-23", "MetricX-23-b", 
                        "MetricX-23-c", "partokengram_F", "tokengram_F", "XCOMET-Ensemble",
                        "XCOMET-XL", "XCOMET-XXL", "XLsim"],
    "reference-free": ["cometoid22-wmt21", "cometoid22-wmt22", "cometoid22-wmt23", "CometKiwi-XL", 
                       "CometKiwi-XXL", "GEMBA-MQM", "KG-BERTScore", "MetricX-23-QE", "MetricX-23-QE-b", 
                       "MetricX-23-QE-c", "XCOMET-QE-Ensemble", "XLsimQE"]                           
}

METRIC_MAPPING_BACK = {'YISI-1':'YiSi-1'}

In [75]:
import decimal
import re

def format_number(number:float, max_phenomena:bool = False, dec:str = '0.000') -> str:
    if number == -np.inf:
        return '----'
    if number < 0:
        out = '-' + str(decimal.Decimal(-number).quantize(decimal.Decimal(dec)))
    else:
        out = '\phantom{-}' + str(decimal.Decimal(number).quantize(decimal.Decimal(dec)))
    if max_phenomena:
        return '\colorbox[HTML]{B2EAB1}{\\textbf{' + out + '}}'
    else:
        return out
    
def format_metric(metric:str):
    return re.sub(r'_', '\\_', metric)

def find_max_on_col(scores:Dict[str, Dict[str, Dict[str, int]]], metrics_names:List[str] = metrics_names) -> Dict[str,str]:
    max_metrics = [[] for metric in metrics_names]
    avgs = []
    for i,p in enumerate(PHENOMENA):
        col = []
        for metric in metrics_names:
            if metric not in scores and metric not in METRIC_MAPPING_BACK:
                col.append(-np.inf)
            else:
                if metric not in scores:           
                    metric = METRIC_MAPPING_BACK[metric]
                else:
                    col.append(sensitivities[metric][p])
        max_ids = np.where(col == np.max(col))[0]
        for max_id in max_ids:
            max_metrics[max_id].append(i)
        avgs.append(np.average(col))
    return max_metrics, avgs

def generate_summary_table(scores:Dict[str, Dict[str, Dict[str, int]]], metrics_groups:Dict[str,list] = METRICS_GROUPING_2022):
    out = ''
    metrics_names = []
    for group in metrics_groups.values():
        metrics_names.extend(group)
    max_in_columns, avgs = find_max_on_col(scores, metrics_names=metrics_names)

    aces_scores_col = []
    for group, metrics in metrics_groups.items():
        for metric in metrics:
            row = {}
            for p_id, p in enumerate(PHENOMENA):
                if metric not in scores and metric not in METRIC_MAPPING_BACK:
                    out += '&\t ---- \t' 
                    row[p] = 0.0
                else:
                    if metric not in scores:           
                        metric = METRIC_MAPPING_BACK[metric]
                    row[p] = scores[metric][p]
            aces_scores_col.append(comp_aces_score(row))
    max_aces_ids = np.where(list(aces_scores_col) == np.max(aces_scores_col))[0]

    m_id = 0
    for group, metrics in metrics_groups.items():
        for metric in metrics:
            out += format_metric(metric) + '\t\t\t\t\t'
            
            for p_id, p in enumerate(PHENOMENA):
                if metric not in scores and metric not in METRIC_MAPPING_BACK:
                    out += '&\t ---- \t' 
                else:
                    if metric not in scores:           
                        metric = METRIC_MAPPING_BACK[metric]
                        # print(p, metric, )
                    # print(p, metric, scores[metric][p])
                    if metric in metrics_names:
                        max_ids = max_in_columns[metrics_names.index(metric)]
                        out += '&\t' + format_number(scores[metric][p], max_phenomena=p_id in max_ids) + '\t'   
                    elif metric in METRIC_NAMES_MAPPING:
                        max_ids = max_in_columns[metrics_names.index(METRIC_NAMES_MAPPING[metric])]
                        out += '&\t' + format_number(scores[metric][p], max_phenomena=p_id in max_ids) + '\t'   
                    else:
                        out += '&\t ---- \t' 
                                 
            out += '&\t' + format_number(aces_scores_col[m_id], dec='0.00', max_phenomena=m_id in max_aces_ids) + '\t \\\\ \n'
            m_id += 1
        out += '\midrule \n'
    out += 'Average\t\t\t\t\t'
    for p_id, p in enumerate(PHENOMENA):
        out += '&\t' + format_number(avgs[p_id], max_phenomena=False) + '\t'
    out += '\\\\ \n\\bottomrule\t\t\t\t\t'
    return out

In [76]:
METRIC_MAPPING_BACK = {'YISI-1':'YiSi-1', 'BERTScore':'BERTscore', 'COMETKiwi':'CometKiwi', 'MaTESe':'MATESE'}

print(generate_summary_table(sensitivities, metrics_groups=METRICS_GROUPING_2023))

&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&	 ---- 	&

In [63]:
'Random-sysname' in sensitivities

False

In [8]:
ACES_scores_2023.keys()

dict_keys(['COMET-QE', 'BERTScore', 'COMET-22', 'COMETKiwi', 'BLEU', 'BLEURT-20', 'COMET-20', 'Cross-QE', 'HWTSC-TLM', 'HWTSC-Teacher-Sim', 'KG-BERTScore', 'MS-COMET-22', 'MS-COMET-QE-22', 'UniTE', 'UniTE-ref', 'UniTE-src', 'YiSi-1', 'chrF', 'f101spBLEU', 'f200spBLEU', 'metricx_xl_DA_2019', 'metricx_xl_MQM_2020', 'metricx_xxl_DA_2019', 'metricx_xxl_MQM_2020', 'Calibri-COMET22', 'Calibri-COMET22-QE', 'CometKiwi-XL', 'CometKiwi-XXL', 'GEMBA-MQM', 'MEE4_stsb_xlm', 'MetricX-23', 'MetricX-23-QE', 'MetricX-23-QE-b', 'MetricX-23-QE-c', 'MetricX-23-b', 'MetricX-23-c', 'Random-sysname', 'XCOMET-Ensemble', 'XCOMET-QE-Ensemble', 'XCOMET-XL', 'XCOMET-XXL', 'XLsim', 'XLsimQE', 'cometoid22-wmt21', 'cometoid22-wmt22', 'cometoid22-wmt23', 'eBLEU', 'embed_llama', 'mt0regressor', 'mtsamp-bleurt0p2p1-qe', 'mtsamp-bleurtxv1p-qe', 'partokengram_F', 'prismRef', 'prismSrc2', 'spBLEU', 'tokengram_F'])

# Plots

In [5]:
# create groups here:
means = {metric:[sensitivities[metric][p] for p in phenomena] for metric in metrics_names}
tau = {metric:[ACES_summary_2022[metric][p] for p in phenomena] for metric in metrics_names if metric in ACES_summary_2022}

In [15]:
print(metrics_names)

['COMET-20', 'YiSi-1', 'metricx_xxl_MQM_2020', 'UniTE', 'MATESE', 'f200spBLEU', 'metricx_xl_DA_2019', 'UniTE-src', 'MEE', 'KG-BERTScore', 'BERTScore', 'BLEU', 'Cross-QE', 'chrF', 'MEE4', 'MATESE-QE', 'HWTSC-Teacher-Sim', 'MEE2', 'REUSE', 'MS-COMET-QE-22', 'metricx_xl_MQM_2020', 'MS-COMET-22', 'HWTSC-TLM', 'COMET-22', 'metricx_xxl_DA_2019', 'UniTE-ref', 'f101spBLEU', 'BLEURT-20', 'COMETKiwi']


In [23]:
# Group the metrics:
bleus = ["BLEU", "f101spBLEU", "f200spBLEU"]
comets = ["COMET-20", "COMET-22", "MS-COMET-22", "MS-COMET-QE-22"]
xl = ["metricx_xl_DA_2019", "metricx_xl_MQM_2020", "metricx_xxl_DA_2019", "metricx_xxl_MQM_2020"]
unite = ["UniTE", "UniTE-src", "UniTE-ref"]

In [41]:
grouped_line_plot([means, tau], comets, ["Mean(good-bad)", "tau"], phenomena)

In [38]:
def grouped_line_plot(groups: List[Dict[str,list]], metrics_names: List[str], group_labels: List[str], phenomena: List[str]):
    '''
    Inputs: 
        1. means and tau scores
        format = {
            metric1: [score for phenomenon 1, score for phenomenon 2, ..]
        }
        2. A list of the labels for: 
            the groups (mean (good-bad), tau, ...)
            metrics
            phenomena (the order is important because in means and tau scores the scores are ordered acc. to phenomena)
    '''
    assert len(groups) > 0 and len(groups) == len(group_labels) and len(metrics_names) > 0
    fig = go.Figure()
    colors = [['lightsteelblue',  'aqua', 'aquamarine', 'darkturquoise'],
        ['chocolate', 'coral', 'crimson', 'orange']]
    for i,group in enumerate(groups):
        for j,metric in enumerate(metrics_names):
            fig.add_trace(go.Scatter(x=phenomena, y=group[metric],mode='lines',name=group_labels[i]+" - "+metric,
                          line=dict(color=colors[i][j])))
    fig.update_layout(
        title=" ".join(group_labels)
        )
    fig.show()