In [1]:
#!/usr/bin/env python3
import os, sys
from normalize_scores_utils import *

import logging
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)

# Main Part: Table and Correlations

In [None]:
data_folder = '/mnt/c/Users/user/Desktop/work/metric_sensitivity_analysis'

# Load the ACES scores
ACES_scores_2022_path = os.path.join(data_folder, 'aces-scored-2022-all-scores.only.quote_errors_removed.tsv')
ACES_scores_2023_path = os.path.join(data_folder, 'aces-scored-2023-all-scores.quote_errors_removed.tsv')
# 2023 later
ACES_scores_2022 = load_ACES_scores(ACES_scores_2022_path, good_token='.-good', bad_token='.-bad', mapping=METRIC_NAMES_MAPPING)
ACES_scores_2023 = load_ACES_scores(ACES_scores_2023_path, good_token='.-good', bad_token='.-bad', mapping=METRIC_NAMES_MAPPING, skip_metrics=[])
metrics_names_2022 = set(ACES_scores_2022.keys())
metrics_names_2023 = set(ACES_scores_2023.keys())

# Load WMT22 metric scores
WMT22_scores_path = os.path.join(data_folder, 'WMT22-metric-scores')
WMT22_scores = load_WMT_scores(WMT22_scores_path, set(metrics_names_2022).union(set(metrics_names_2023)))

# Load WMT23 metric scores
WMT23_scores_path = os.path.join(data_folder, 'wmt23metrics-submissions-v2')
WMT23_scores = load_WMT_scores_23(WMT23_scores_path, set(metrics_names_2022).union(set(metrics_names_2023)))

# WMT_scores = load_WMT_scores(WMT_scores_path, set(metrics_names_2022))

# calculate sensitivities
metrics_names_2022 = list(set(ACES_scores_2022.keys()).intersection(set(WMT22_scores.keys())))
sensitivities_2022, _, _, phenomena_2022, means_good_2022, means_bad_2022 = calculate_sensitivities(ACES_scores_2022, WMT22_scores, mapping=PHENOMENA_MAPPING)
metrics_names_2023 = list(set(ACES_scores_2023.keys()).intersection(set(WMT23_scores.keys())))
sensitivities_2023, _, _, phenomena_2023, means_good_2023, means_bad_2023 = calculate_sensitivities(ACES_scores_2023, WMT23_scores, mapping=PHENOMENA_MAPPING)

# load the ACES scores from the paper
ACES_summary_2022 = load_ACES_scores_summary_2022()
ACES_summary_2023 = load_ACES_scores_summary_2023(skip_metrics=[])

cs-en


100%|██████████| 57/57 [00:05<00:00, 10.27it/s]


cs-uk


100%|██████████| 57/57 [00:03<00:00, 18.20it/s]


de-en


100%|██████████| 57/57 [00:06<00:00,  8.69it/s]


de-fr


100%|██████████| 57/57 [00:04<00:00, 13.09it/s]


en-cs


100%|██████████| 57/57 [00:09<00:00,  5.73it/s]


en-de


100%|██████████| 57/57 [00:12<00:00,  4.60it/s]


en-hr


100%|██████████| 57/57 [00:07<00:00,  7.13it/s]


en-ja


100%|██████████| 57/57 [00:07<00:00,  7.65it/s]


en-liv


100%|██████████| 57/57 [00:02<00:00, 23.12it/s]


en-ru


100%|██████████| 57/57 [00:10<00:00,  5.57it/s]


en-uk


100%|██████████| 57/57 [00:06<00:00,  8.90it/s]


en-zh


100%|██████████| 57/57 [00:09<00:00,  5.85it/s]


fr-de


100%|██████████| 57/57 [00:03<00:00, 18.91it/s]


ja-en


100%|██████████| 57/57 [00:05<00:00,  9.90it/s]


liv-en


100%|██████████| 57/57 [00:02<00:00, 24.14it/s]


ru-en


 61%|██████▏   | 35/57 [00:02<00:01, 13.74it/s]

In [None]:
# not working -just discard later
sensitivities_unscaled_2022,  _ = calculate_sensitivities_self_scaled(ACES_scores_2022, mapping=PHENOMENA_MAPPING)
sensitivities_unscaled_2023,  _ = calculate_sensitivities_self_scaled(ACES_scores_2023, mapping=PHENOMENA_MAPPING)

In [None]:
COLORS = ['\colorbox{green1}', '\colorbox{green2}', '\colorbox{green3}', '\colorbox{green4}', '\colorbox{green5}', '\colorbox{white}', '\colorbox{red1}', '\colorbox{red2}', '\colorbox{red3}', '\colorbox{red4}', '\colorbox{red5}']
COLORS = ['\colorbox{green1}', '\colorbox{green2}', '\colorbox{green4}', '\colorbox{green5}', '\colorbox{white}', '\colorbox{red1}', '\colorbox{red2}', '\colorbox{red4}', '\colorbox{red5}']

# From the ACES 2022 Paper:
METRICS_GROUPING_SHORT_2022 = {"baseline": ["BLEU", "COMET-20", "COMET-QE"],
                        "reference-based": ["COMET-22", 'metricx_xl_MQM_2020', "UniTE", "UniTE-ref"],
                        "reference-free": ["COMETKiwi", 'KG-BERTScore', "UniTE-src"]
                    }
METRICS_GROUPING_SHORT_2023 = {"baseline": ["BLEU", "BERTScore", "BLEURT-20", "YISI-1"],
                        "reference-based": ["COMET-22", 'metricx_xl_MQM_2020', "UniTE", "UniTE-ref"],
                        "reference-free": ["COMETKiwi", 'KG-BERTScore', "UniTE-src"]
                    }

phenomena = ['omission', 'mistranslation', 'untranslated', 'real-world knowledge', 'wrong language']
# phenomena = ["hallucination-number-level-1", "hallucination-number-level-2", "hallucination-number-level-3"]

print(make_header(scores=ACES_scores_2023, ACES_column=True, p_header_1=PHENOMENA_HEADER_1, p_header_2=PHENOMENA_HEADER_2))
print(generate_summary_table(sensitivities_2023, metrics_groups=METRICS_GROUPING_2023, phenomena=PHENOMENA, ACES_column=True))
# print(make_footer(averages=SUMMARY_AVERAGES_2023, phenomena=PHENOMENA))

\begin{table*}[ht] 
 \small 
 \setlength{\tabcolsep}{3.75pt} 
 \centering 
 \begin{tabular}{@{}lccccccccccc@{}} 
 \\\toprule 
 & \hyperref[sec:addition-omission]{\textbf{addition}} & \hyperref[sec:addition-omission]{\textbf{omission}} & \hyperref[sec:source-disambig]{\textbf{mistranslation}} & \hyperref[sec:untranslated]{\textbf{untranslated}} & \hyperref[sec:do-not-translate]{\textbf{do not}} & \hyperref[sec:overtranslation_undertranslation]{\textbf{overtranslation}} & \hyperref[sec:overtranslation_undertranslation]{\textbf{undertranslation}} & \hyperref[sec:real-world-knowledge]{\textbf{real-world}} & \hyperref[sec:wrong_language]{\textbf{wrong}} & \hyperref[sec:punctuation]{\textbf{punctuation}} & \textbf{ACES-} \\
 &  &  &  &  & \hyperref[sec:do-not-translate]{\textbf{translate}} &  &  & \hyperref[sec:real-world-knowledge]{\textbf{knowledge}} & \hyperref[sec:wrong_language]{\textbf{language}} &  & \textbf{Score}\\
\midrule
\textit{\textbf{Examples}}  & \textit{931} & \textit{951} &

In [7]:
"BERTScore" in WMT23_scores.keys()

False

# Table for Mistranslation Groups -dicourse, hallucinations, other

In [3]:
PHENOMENA_MAPPING_MISTRANSLATION = {
    'ambiguous-translation-wrong-discourse-connective-since-causal': 'discourse',
    'ambiguous-translation-wrong-discourse-connective-since-temporal': 'discourse',
    'ambiguous-translation-wrong-discourse-connective-while-contrast': 'discourse',
    'ambiguous-translation-wrong-discourse-connective-while-temporal': 'discourse',
    'ambiguous-translation-wrong-gender-female-anti': 'other',
    'ambiguous-translation-wrong-gender-female-pro': 'other',
    'ambiguous-translation-wrong-gender-male-anti': 'other',
    'ambiguous-translation-wrong-gender-male-pro': 'other',
    # 'ambiguous-translation-wrong-sense-frequent': 'other',
    # 'ambiguous-translation-wrong-sense-infrequent': 'other',
    'anaphoric_group_it-they:deletion': 'discourse',
    'anaphoric_group_it-they:substitution': 'discourse',
    'anaphoric_intra_non-subject_it:deletion': 'discourse',
    'anaphoric_intra_non-subject_it:substitution': 'discourse',
    'anaphoric_intra_subject_it:deletion': 'discourse',
    'anaphoric_intra_subject_it:substitution': 'discourse',
    'anaphoric_intra_they:deletion': 'discourse',
    'anaphoric_intra_they:substitution': 'discourse',
    'anaphoric_singular_they:deletion': 'discourse',
    'anaphoric_singular_they:substitution': 'discourse',
    'coreference-based-on-commonsense': 'discourse',
    'hallucination-date-time': 'hallucination',
    'hallucination-named-entity-level-1': 'hallucination',
    'hallucination-named-entity-level-2': 'hallucination',
    'hallucination-named-entity-level-3': 'hallucination',
    'hallucination-number-level-1': 'hallucination',
    'hallucination-number-level-2': 'hallucination',
    'hallucination-number-level-3': 'hallucination',
    'hallucination-real-data-vs-ref-word': 'hallucination',
    'hallucination-real-data-vs-synonym': 'hallucination',
    'hallucination-unit-conversion-amount-matches-ref': 'hallucination',
    'hallucination-unit-conversion-unit-matches-ref': 'hallucination',
    'lexical-overlap': 'other',
    'modal_verb:deletion': 'other',
    'modal_verb:substitution': 'other',
    'nonsense': 'hallucination',
    'ordering-mismatch': 'other',
    'overly-literal-vs-correct-idiom': 'other',
    'overly-literal-vs-explanation': 'other',
    'overly-literal-vs-ref-word': 'other',
    'overly-literal-vs-synonym': 'other',
    'pleonastic_it:deletion': 'discourse',
    'pleonastic_it:substitution': 'discourse',
    'xnli-addition-contradiction': 'other',
    'xnli-addition-neutral': 'other',
    'xnli-omission-contradiction': 'other',
    'xnli-omission-neutral': 'other'
}

In [9]:
metrics_names_2022 = list(set(ACES_scores_2022.keys()).intersection(set(WMT_scores.keys())))
sensitivities_2022_mistranslation, _, _, phenomena_2022_mistranslation, _, _ = calculate_sensitivities(ACES_scores_2022, WMT_scores, mapping=PHENOMENA_MAPPING_MISTRANSLATION)
ACES_scores_2022_mistranslation = map_to_higher(ACES_scores_2022, PHENOMENA_MAPPING_MISTRANSLATION)
ACES_summary_2022_mistranslation, phenomena_mistranslation = calculate_tau_correlations(ACES_scores_2022_mistranslation, phenomena=set(PHENOMENA_MAPPING_MISTRANSLATION.values()))
# ACES_summary_2022_mistranslation = 

In [10]:
COLORS = ['\colorbox{green1}', '\colorbox{green2}', '\colorbox{green3}', '\colorbox{green4}', '\colorbox{green5}', '\colorbox{white}', '\colorbox{red1}', '\colorbox{red2}', '\colorbox{red3}', '\colorbox{red4}', '\colorbox{red5}']
COLORS = ['\colorbox{green1}', '\colorbox{green2}', '\colorbox{green4}', '\colorbox{green5}', '\colorbox{white}', '\colorbox{red1}', '\colorbox{red2}', '\colorbox{red4}', '\colorbox{red5}']

# From the ACES 2022 Paper:
METRICS_GROUPING_SHORT_2022 = {"baseline": ["BLEU", "COMET-20", "COMET-QE"],
                        "reference-based": ["COMET-22", 'metricx_xl_MQM_2020', "UniTE", "UniTE-ref"],
                        "reference-free": ["COMETKiwi", 'KG-BERTScore', "UniTE-src"]
                    }
METRICS_GROUPING_SHORT_2023 = {"baseline": ["BLEU", "BERTScore", "BLEURT-20", "YISI-1"],
                        "reference-based": ["COMET-22", 'metricx_xl_MQM_2020', "UniTE", "UniTE-ref"],
                        "reference-free": ["COMETKiwi", 'KG-BERTScore', "UniTE-src"]
                    }

phenomena = ['untranslated', 'real-world knowledge', 'wrong language']
phenomena = ["discourse", "hallucination", "other"]

PHENOMENA_HEADER_1_MISTRANSLATION = dict(zip(phenomena, ['\\textbf{disco.}', '\\textbf{halluci.}', '\\textbf{other}']))
PHENOMENA_HEADER_2_MISTRANSLATION = dict(zip(phenomena, ['', '',  '']))

NUM_SAMPLES_MISTRANSLATION = {p:0 for p in phenomena}
for p,target in PHENOMENA_MAPPING_MISTRANSLATION.items():
    NUM_SAMPLES_MISTRANSLATION[target] += len(ACES_scores_2022["BLEU"][p][0])

print(make_header(phenomena=phenomena, ACES_column=False, p_header_1=PHENOMENA_HEADER_1_MISTRANSLATION, p_header_2=PHENOMENA_HEADER_2_MISTRANSLATION, num_samples=NUM_SAMPLES_MISTRANSLATION))
print(generate_summary_table_double(ACES_summary_2022_mistranslation, sensitivities_2022_mistranslation, metrics_groups=METRICS_GROUPING_2022, phenomena=phenomena, ACES_column=False))
# print(make_footer(averages=SUMMARY_AVERAGES_2023, phenomena=PHENOMENA))

\begin{table*}[ht] 
 \small 
 \setlength{\tabcolsep}{3.75pt} 
 \centering 
 \begin{tabular}{@{}lccccccccccc@{}} 
 \\\toprule 
 & \textbf{disco.} & \textbf{halluci.} & \textbf{other} \\
 &  &  & \\
\midrule
\textit{\textbf{Examples}}  & \textit{3623} & \textit{9634} & \textit{9273}\\ 
 \midrule
BLEU					&	\colorbox{white}{\textbf{\phantom{-}0.142}}	&	\colorbox{red2}{\textbf{-0.497}}	&	\colorbox{red1}{\textbf{-0.187}}	&	\colorbox{white}{\textbf{-0.057}}	&	\colorbox{red2}{\textbf{-0.561}}	&	\colorbox{green1}{\textbf{\phantom{-}0.414}}		 \\ 
f101spBLEU					&	\colorbox{green1}{\textbf{\phantom{-}0.175}}	&	\colorbox{red1}{\textbf{-0.312}}	&	\colorbox{red1}{\textbf{-0.206}}	&	\colorbox{white}{\textbf{-0.042}}	&	\colorbox{red1}{\textbf{-0.373}}	&	\colorbox{green1}{\textbf{\phantom{-}0.231}}		 \\ 
f200spBLEU					&	\colorbox{white}{\textbf{\phantom{-}0.153}}	&	\colorbox{red1}{\textbf{-0.296}}	&	\colorbox{red1}{\textbf{-0.205}}	&	\colorbox{white}{\textbf{-0.042}}	&	\colorbox{red1}{\textbf{-0.364}

# Hallucination -Numbers double table

In [14]:
PHENOMENA_MAPPING_NUMBERS = {
    'hallucination-number-level-1': 'hallucination-number-level-1',
    'hallucination-number-level-2': 'hallucination-number-level-2',
    'hallucination-number-level-3': 'hallucination-number-level-3'
}
metrics_names_2022 = list(set(ACES_scores_2022.keys()).intersection(set(WMT_scores.keys())))
sensitivities_2022_numbers, _, _, phenomena_2022_numbers, _, _ = calculate_sensitivities(ACES_scores_2022, WMT_scores, mapping=PHENOMENA_MAPPING_NUMBERS)
ACES_summary_2022_numbers, phenomena_numbers = calculate_tau_correlations(ACES_scores_2022, phenomena=PHENOMENA_MAPPING_NUMBERS.values())


In [15]:
COLORS = ['\colorbox{green1}', '\colorbox{green2}', '\colorbox{green3}', '\colorbox{green4}', '\colorbox{green5}', '\colorbox{white}', '\colorbox{red1}', '\colorbox{red2}', '\colorbox{red3}', '\colorbox{red4}', '\colorbox{red5}']
COLORS = ['\colorbox{green1}', '\colorbox{green2}', '\colorbox{green4}', '\colorbox{green5}', '\colorbox{white}', '\colorbox{red1}', '\colorbox{red2}', '\colorbox{red4}', '\colorbox{red5}']

# From the ACES 2022 Paper:
METRICS_GROUPING_SHORT_2022 = {"baseline": ["BLEU", "COMET-20", "COMET-QE"],
                        "reference-based": ["COMET-22", 'metricx_xl_MQM_2020', "UniTE", "UniTE-ref"],
                        "reference-free": ["COMETKiwi", 'KG-BERTScore', "UniTE-src"]
                    }
METRICS_GROUPING_SHORT_2023 = {"baseline": ["BLEU", "BERTScore", "BLEURT-20", "YISI-1"],
                        "reference-based": ["COMET-22", 'metricx_xl_MQM_2020', "UniTE", "UniTE-ref"],
                        "reference-free": ["COMETKiwi", 'KG-BERTScore', "UniTE-src"]
                    }

phenomena = ['untranslated', 'real-world knowledge', 'wrong language']
phenomena = ["hallucination-number-level-1", "hallucination-number-level-2", "hallucination-number-level-3"]

PHENOMENA_HEADER_1_NUMBERS = dict(zip(phenomena, ['\\textbf{Level 1}', '\\textbf{Level 2}', '\\textbf{Level 3}']))
PHENOMENA_HEADER_2_NUMBERS = dict(zip(phenomena, ['', '',  '']))
NUM_SAMPLES_NUMBERS = {p:len(ACES_scores_2022["BLEU"][p][0]) for p in phenomena}

# print(make_header(phenomena=phenomena, ACES_column=False, p_header_1=PHENOMENA_HEADER_1_NUMBERS, p_header_2=PHENOMENA_HEADER_2_NUMBERS, num_samples=NUM_SAMPLES_NUMBERS))
print(generate_summary_table_double(ACES_summary_2022_numbers, sensitivities_2022_numbers, metrics_groups=METRICS_GROUPING_2022, phenomena=phenomena, ACES_column=False))
# print(make_footer(averages=SUMMARY_AVERAGES_2023, phenomena=PHENOMENA))

BLEU					&	\colorbox{green3}{\textbf{\phantom{-}0.738}}	&	\colorbox{red3}{\textbf{-0.641}}	&	\colorbox{red5}{\textbf{-0.989}}	&	\colorbox{white}{\textbf{\phantom{-}0.357}}	&	\colorbox{white}{\textbf{-0.453}}	&	\colorbox{red5}{\textbf{-2.407}}		 \\ 
f101spBLEU					&	\colorbox{green3}{\textbf{\phantom{-}0.702}}	&	\colorbox{red3}{\textbf{-0.620}}	&	\colorbox{red5}{\textbf{-1.000}}	&	\colorbox{white}{\textbf{\phantom{-}0.154}}	&	\colorbox{white}{\textbf{-0.289}}	&	\colorbox{red3}{\textbf{-1.584}}		 \\ 
f200spBLEU					&	\colorbox{green3}{\textbf{\phantom{-}0.745}}	&	\colorbox{red3}{\textbf{-0.612}}	&	\colorbox{red5}{\textbf{-1.000}}	&	\colorbox{white}{\textbf{\phantom{-}0.148}}	&	\colorbox{white}{\textbf{-0.279}}	&	\colorbox{red3}{\textbf{-1.519}}		 \\ 
chrF					&	\colorbox{green5}{\textbf{\phantom{-}0.983}}	&	\colorbox{red3}{\textbf{-0.702}}	&	\colorbox{red5}{\textbf{-1.000}}	&	\colorbox{white}{\textbf{\phantom{-}0.115}}	&	\colorbox{white}{\textbf{-0.245}}	&	\colorbox{red2}{\textbf{-1.299}

# Functions to Normalize summary scores and sensitivity scores to 0-1, then calculate the difference
In process

In [71]:
def scale_to_max_one(scores:Dict[str, Dict[str, Dict[str, int]]]):
    out = copy.deepcopy(scores)
    max = np.max([np.max(list(p.values())) for p in scores.values()])
    min = np.min([np.min(list(p.values())) for p in scores.values()])
    for m in out:
        for p in out[m]:
            out[m][p] = (out[m][p] - min) / (max - min)
    return out

def scale_to_one(scores:Dict[str, Dict[str, Dict[str, int]]]):
    out = copy.deepcopy(scores)
    sum = 0.0
    for p in scores.values():
        sum += np.sum(list(p.values()))
    for m in out:
        for p in out[m]:
            out[m][p] /= sum
    return out

def diff(scores1, scores2):
    scores_out = {}
    for metric in scores1:
        if metric in METRIC_NAMES_MAPPING:
            metric = METRIC_NAMES_MAPPING[metric]
        if metric in scores2:
            metric2 = metric
            scores_out[metric] = {p:scores2[metric2][p]-scores1[metric][p] for p in scores1[metric] if p in scores2[metric]} 
        elif metric in METRIC_MAPPING_BACK:
            metric2 = METRIC_MAPPING_BACK[metric]
            scores_out[metric] = {p:scores2[metric2][p]-scores1[metric][p] for p in scores1[metric] if p in scores2[metric]} 
        else:
            print(metric)
        
    return scores_out

# Plots

In [26]:
# create groups here:
means = {metric:[sensitivities_2022[metric][p] for p in PHENOMENA] for metric in metrics_names_2022}
tau = {metric:[ACES_summary_2022[metric][p] for p in PHENOMENA] for metric in metrics_names_2022 if metric in ACES_summary_2022}

In [21]:
# Group the metrics:
bleus = ["BLEU", "f101spBLEU", "f200spBLEU"]
comets = ["COMET-20", "COMET-22", "MS-COMET-22", "MS-COMET-QE-22"]
xl = ["metricx_xl_DA_2019", "metricx_xl_MQM_2020", "metricx_xxl_DA_2019", "metricx_xxl_MQM_2020"]
unite = ["UniTE", "UniTE-src", "UniTE-ref"]

In [28]:
grouped_line_plot([means, tau], bleus, ["Mean(good-bad)", "tau"], PHENOMENA)

In [38]:
def grouped_line_plot(groups: List[Dict[str,list]], metrics_names: List[str], group_labels: List[str], phenomena: List[str]):
    '''
    Inputs: 
        1. means and tau scores
        format = {
            metric1: [score for phenomenon 1, score for phenomenon 2, ..]
        }
        2. A list of the labels for: 
            the groups (mean (good-bad), tau, ...)
            metrics
            phenomena (the order is important because in means and tau scores the scores are ordered acc. to phenomena)
    '''
    assert len(groups) > 0 and len(groups) == len(group_labels) and len(metrics_names) > 0
    fig = go.Figure()
    colors = [['lightsteelblue',  'aqua', 'aquamarine', 'darkturquoise'],
        ['chocolate', 'coral', 'crimson', 'orange']]
    for i,group in enumerate(groups):
        for j,metric in enumerate(metrics_names):
            fig.add_trace(go.Scatter(x=phenomena, y=group[metric],mode='lines',name=group_labels[i]+" - "+metric,
                          line=dict(color=colors[i][j])))
    fig.update_layout(
        title=" ".join(group_labels)
        )
    fig.show()

# Extra Latex Tables

In [None]:
def find_max_on_col(scores:Dict[str, Dict[str, Dict[str, int]]], metrics_names:List[str], phenomena:List[str]=PHENOMENA, k_highest:int=1) -> Dict[str,str]:
    max_metrics = {str(k): [[] for metric in metrics_names] for k in range(k_highest)}
    avgs = []
    for i,p in enumerate(phenomena):
        col = []
        for metric in metrics_names:
            if metric not in scores and metric in METRIC_NAMES_MAPPING:
                metric = METRIC_NAMES_MAPPING[metric]
            elif metric not in scores and metric in METRIC_MAPPING_BACK:
                metric = METRIC_MAPPING_BACK[metric]
            if metric not in scores:
                col.append(-np.inf)
            else:
                col.append(scores[metric][p])
        for k in range(k_highest):
            max_ids = np.where(col == np.partition(col, -k-1)[-k-1])[0]
            for max_id in max_ids:
                max_metrics[str(k)][max_id].append(i)
        col = np.array(col)
        avgs.append(np.average(col[col > -np.inf]))
    return max_metrics, avgs

def make_header(scores:Dict[str, Dict[str, Dict[str, int]]], phenomena:List[str]=PHENOMENA, p_header_1:dict=PHENOMENA_HEADER_1, p_header_2:dict=PHENOMENA_HEADER_2, num_samples:dict=None, phenomena_mapping:Dict[str,str]=None, ACES_column:bool=True) -> str:
    if phenomena_mapping == None:
        phenomena_mapping = PHENOMENA_MAPPING
    if num_samples == None:
        num_samples = {p:0 for p in phenomena}
        for p,target in phenomena_mapping.items():
            num_samples[target] += len(scores["BLEU"][p][0])
    res = "\\begin{table*}[ht] \n \small \n \setlength{\\tabcolsep}{3.75pt} \n \centering \n \\begin{tabular}{@{}lccccccccccc@{}} \n \\\\\\toprule \n"
    for p in phenomena:
        res += " & "
        res += p_header_1[p]
    if ACES_column:
        res += """ & \\textbf{ACES-}"""
    res += " \\\\\n"
    for p in phenomena:
        res += " & "
        res += p_header_2[p]
    if ACES_column:
        res += """ & \\textbf{Score}"""
    res += "\\\\\n\midrule\n\\textit{\\textbf{Examples}} "
    for p in phenomena:
        res += " & "
        res += '\\textit{' + str(num_samples[p]) + '}'
    res += """\\\\ \n \midrule"""
    return res

def make_footer(averages:List, phenomena:List[str]=PHENOMENA) -> str:
    res = "\midrule\nAverage (all metrics)\t"
    for p in phenomena:
        if p in phenomena:
            res += " & "
            res += format_number(averages[p])
    res += """\\\\ \n \\bottomrule"""
    return res

def generate_summary_table(scores:Dict[str, Dict[str, Dict[str, int]]], metrics_groups:Dict[str,list] = METRICS_GROUPING_2022, phenomena:List[str]=PHENOMENA, ACES_column:bool=True, global_colors:bool=True, k_highest:int=1, colors:List[str]=None) -> str:
    """
    if k_highest % 2 == 1:
        colors = COLORS[len(COLORS)//2-k_highest//2:len(COLORS)//2+k_highest//2+1]
    else:
        colors = COLORS[len(COLORS)//2-k_highest//2:len(COLORS)//2] + COLORS[len(COLORS)//2+1:len(COLORS)//2+k_highest//2+1]
    """
    if global_colors:
        k_highest = 1
    out = ''
    metrics_names = []
    for group in metrics_groups.values():
        for metric in group:
            if metric not in scores and metric in METRIC_NAMES_MAPPING:
                metrics_names.append(METRIC_NAMES_MAPPING[metric])
            elif metric not in scores and metric in METRIC_MAPPING_BACK:
                metrics_names.append(METRIC_MAPPING_BACK[metric])
            else:
                metrics_names.append(metric)
    # print(metrics_names)
    max_in_columns, avgs = find_max_on_col(scores, metrics_names=metrics_names, phenomena=phenomena, k_highest=k_highest)
    if ACES_column:
        aces_scores_col = []
        for metric in metrics_names:
            # print(metric)
            if metric not in scores and metric in METRIC_NAMES_MAPPING:
                metric = METRIC_NAMES_MAPPING[metric]
            elif metric not in scores and metric in METRIC_MAPPING_BACK:
                metric = METRIC_MAPPING_BACK[metric]
            row = {}
            # print(metric)
            for p_id, p in enumerate(phenomena):
                if metric not in scores:
                    row[p] = 0.0
                else:
                    row[p] = scores[metric][p]
            aces_scores_col.append(comp_aces_score(row))
        # print(aces_scores_col)
        aces_scores_col_colors = {m_id:"" for m_id in range(len(metrics_names))}
    
    if global_colors:
        maximum = np.max([np.max(list(p.values())) for metric,p in scores.items() if metric in metrics_names])
        minimum = np.min([np.min(list(p.values())) for metric,p in scores.items() if metric in metrics_names])
        # print(minimum, maximum, metrics_names)
    if ACES_column:
        if global_colors:
            for i in range(len(aces_scores_col)):
                aces_scores_col_colors[i] = map_to_color(aces_scores_col[i], np.max(aces_scores_col), np.min(aces_scores_col))
        elif k_highest == 1:
            max_aces_ids = np.where(list(aces_scores_col) == np.max(aces_scores_col))[0]
            for i in max_aces_ids:
                aces_scores_col_colors[i] = '\colorbox[HTML]{B2EAB1}'
        else:
            for k in range(k_highest):
                max_aces_ids = np.where(aces_scores_col == np.partition(aces_scores_col, -k-1)[-k-1])[0]
                for i in max_aces_ids:
                    aces_scores_col_colors[i] = colors[k]
                   
    for group, metrics in metrics_groups.items():
        for m_id, metric in enumerate(metrics):
            if metric not in scores and metric in METRIC_NAMES_MAPPING:
                metric = METRIC_NAMES_MAPPING[metric]
            elif metric not in scores and metric in METRIC_MAPPING_BACK:
                metric = METRIC_MAPPING_BACK[metric]
            out += format_metric(metric) + '\t\t\t\t\t'
            for p_id, p in enumerate(phenomena):
                if metric not in scores:
                    out += '&\t ---- \t' 
                else:
                    if global_colors:
                        out += '&\t' + format_number(scores[metric][p], max_phenomena=True, color=map_to_color(scores[metric][p], maximum, minimum)) + '\t'   
                    elif k_highest == 1:
                        max_ids = max_in_columns['0'][metrics_names.index(metric)]
                        out += '&\t' + format_number(scores[metric][p], max_phenomena=p_id in max_ids) + '\t'   
                    else:
                        for k in range(k_highest):
                            max_ids = max_in_columns[str(k)][metrics_names.index(metric)]
                            if p_id in max_ids:
                                color=colors[k]
                                break
                        out += '&\t' + format_number(scores[metric][p], max_phenomena=True, color=color) + '\t'   
            if ACES_column:
                tmp_color =  aces_scores_col_colors[metrics_names.index(metric)]    
                out += '&\t' + format_number(aces_scores_col[metrics_names.index(metric)], dec='0.00', max_phenomena=tmp_color!="", color=tmp_color)
            out += '\t \\\\ \n'
        out += '\midrule \n'
    out += 'Average\t\t\t\t\t'
    for p_id, p in enumerate(phenomena):
        out += '&\t' + format_number(avgs[p_id], max_phenomena=True, color=map_to_color(avgs[p_id], max=np.max(avgs), min=np.min(avgs))) + '\t'
    out += '\\\\'
    return out