In [1]:
import os
import torch
import numpy as np
from itertools import chain, combinations
from functools import partial
from scipy.stats import pearsonr
from sklearn.metrics import roc_auc_score, average_precision_score
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.core.display import display, HTML
from bertviz import head_view, model_view
import shutil

from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.stats import pearsonr
from evaluate_explanations import evaluate_word_level, evaluate_sentence_level
from utils import aggregate_pieces, read_qe_files

In [15]:
def powerset(iterable):
    s = list(iterable)
    return list(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))[1:]

def rec_topk(gold, pred):
    idxs = np.argsort(pred)[::-1][:int(sum(gold))]
    return len([idx for idx in idxs if int(gold[idx]) == 1])/sum(gold)

In [7]:
def aggregate_explanations(
    mt_model_explanations, 
    src_model_explanations, 
    mt_fp_mask, 
    src_fp_mask, 
    transform='none', 
    reduction='sum',
    is_rembert=False
):
    # scores for each word and also for <s> and </s>
    for i in range(len(src_model_explanations)):
        src_model_explanations[i] = torch.tensor(src_model_explanations[i])
        mt_model_explanations[i] = torch.tensor(mt_model_explanations[i])
        src_fp_mask[i] = torch.tensor(src_fp_mask[i])
        mt_fp_mask[i] = torch.tensor(mt_fp_mask[i])

        if transform == 'pre':
            src_model_explanations[i] = torch.sigmoid(torch.abs(src_model_explanations[i]))
            mt_model_explanations[i] = torch.sigmoid(torch.abs(mt_model_explanations[i]))

        if transform == 'pre_neg':
            src_model_explanations[i] = - src_model_explanations[i]
            mt_model_explanations[i] = - mt_model_explanations[i]

        src_model_explanations[i] = aggregate_pieces(src_model_explanations[i], src_fp_mask[i], reduction)
        mt_model_explanations[i] = aggregate_pieces(mt_model_explanations[i], mt_fp_mask[i], reduction)

        if transform == 'pos':
            src_model_explanations[i] = torch.sigmoid(torch.abs(src_model_explanations[i]))
            mt_model_explanations[i] = torch.sigmoid(torch.abs(mt_model_explanations[i]))

        # remove <s> and </s>
        a = 0 if reduction == 'none' else 1
        b = None if reduction == 'none' else -1
        src_model_explanations[i] = src_model_explanations[i][a:b].tolist()
        a = 0 if is_rembert else a
        mt_model_explanations[i] = mt_model_explanations[i][a:b].tolist()
    
    return mt_model_explanations, src_model_explanations

def get_explanations(lp, explainer, transform='none', reduction='sum', threshold=None, revert_mt_src=True):
    pred_scores, mt_expls, src_expls, mt_fp_mask, src_fp_mask = read_explanations(lp, explainer)
    is_rembert = 'rembert' in explainer
    if is_rembert and revert_mt_src:
        mt_expls, src_expls = src_expls, mt_expls
        mt_fp_mask, src_fp_mask =  src_fp_mask, mt_fp_mask
    if mt_fp_mask is not None:
        mt_expls, src_expls = aggregate_explanations(mt_expls, src_expls, mt_fp_mask, src_fp_mask, 
                                                     transform=transform, reduction=reduction, is_rembert=is_rembert)
    if threshold is not None:
        for i in range(len(pred_scores)):
            if pred_scores[i] > threshold:
                mt_expls[i] = [1-x for x in mt_expls[i]]
                src_expls[i] = [1-x for x in src_expls[i]]
        
    return mt_expls, src_expls

In [8]:
def ensemble_explanations(all_mt_expls, all_src_expls, weights, gold_mt_tags, gold_src_tags):
    if weights == 'uni':
        weights = torch.tensor(1) / len(all_mt_expls)
    else:
        weights = torch.tensor(weights)
    N = len(all_mt_expls[0])
    E = len(all_mt_expls)
    mt_expls = []
    src_expls = []
    w = weights.unsqueeze(-1)
    for i in range(N):
        exs_t = []
        exs_s = []
        for e in range(E):
            ex_t = torch.tensor(all_mt_expls[e][i])
            ex_s = torch.tensor(all_src_expls[e][i])
            exs_t.append(ex_t)
            exs_s.append(ex_s)
        mt_expl = (torch.stack(exs_t) * w).sum(0)
        src_expl = (torch.stack(exs_s) * w).sum(0)
        mt_expls.append(mt_expl.tolist())
        src_expls.append(src_expl.tolist())
    src_auc_score, src_ap_score, src_rec_topk = evaluate_word_level(gold_src_tags, src_expls, do_print=False)
    mt_auc_score, mt_ap_score, mt_rec_topk = evaluate_word_level(gold_mt_tags, mt_expls, do_print=False)
    print('{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(
        mt_auc_score, mt_ap_score, mt_rec_topk, src_auc_score, src_ap_score, src_rec_topk
    ))

In [10]:
lp = 'ro-en'
gold_scores, mt_tokens, src_tokens, mt_word_tags, src_word_tags = read_gold_data(lp)
explanations={
#     'attn_norm_head_18_3': partial(get_explanations, lp=lp, reduction='sum', transform='none'),
#     'attn_norm_cross_head_18_3': partial(get_explanations, lp=lp, reduction='sum', transform='none'),
#     'attn_norm_finetuned_softmax_large_head_18_3': partial(get_explanations, lp=lp, reduction='sum', transform='none'),
#     'attn_norm_merge_05_head_19_2': partial(get_explanations, lp=lp, reduction='sum', transform='none'),
#     'attn_norm_merge_05_layer_19': partial(get_explanations, lp=lp, reduction='sum', transform='none'),
#     'attn_norm_rembert_layer_23': partial(get_explanations, lp=lp, reduction='sum', transform='none'),
    
    'wordlevel_lbda1000': partial(get_explanations),
    'wordlevel_lbda10000': partial(get_explanations),
    'wordlevel_lbda100000': partial(get_explanations),
#     'xlmr_large_allall_sl01': partial(get_explanations),
    'xlmr_large_roen_sl01': partial(get_explanations),
#     'xlmr_large_allall_reversed': partial(get_explanations),
#     'xlmr_large_roen_reversed': partial(get_explanations),
    'wordlevel_unbabel_lbda1000': partial(get_explanations),
    'wordlevel_unbabel_lbda10000': partial(get_explanations),
    'wordlevel_unbabel_lbda100000': partial(get_explanations),
    'wordlevel_rembert_lbda10000': partial(get_explanations, revert_mt_src=False),
}

for expl in powerset(list(explanations.keys())):
    print(expl)
    explanations_subset = [explanations[ex](lp=lp, explainer=ex, threshold=95) for ex in expl]
    ensemble_explanations(
        all_mt_expls=[ex[0] for ex in explanations_subset],
        all_src_expls=[ex[1] for ex in explanations_subset],
        weights='uni',
        gold_mt_tags=mt_word_tags,
        gold_src_tags=src_word_tags
    )
    print('==='*20)

('wordlevel_lbda1000',)
0.9124	0.8395	0.7459	0.9003	0.8144	0.7178
('wordlevel_lbda10000',)
0.9284	0.8515	0.7638	0.9145	0.8246	0.7216
('wordlevel_lbda100000',)
0.9278	0.8494	0.7528	0.9140	0.8195	0.7199
('xlmr_large_roen_sl01',)
0.9207	0.8261	0.7184	0.9072	0.8113	0.7041
('wordlevel_unbabel_lbda1000',)
0.9206	0.8448	0.7488	0.9049	0.8214	0.7143
('wordlevel_unbabel_lbda10000',)
0.9259	0.8506	0.7607	0.9134	0.8261	0.7243
('wordlevel_unbabel_lbda100000',)
0.9197	0.8444	0.7473	0.9104	0.8200	0.7139
('wordlevel_rembert_lbda10000',)
0.9339	0.8623	0.7694	0.9178	0.8315	0.7307
('wordlevel_lbda1000', 'wordlevel_lbda10000')
0.9263	0.8490	0.7553	0.9140	0.8254	0.7247
('wordlevel_lbda1000', 'wordlevel_lbda100000')
0.9252	0.8518	0.7600	0.9130	0.8242	0.7227
('wordlevel_lbda1000', 'xlmr_large_roen_sl01')
0.9257	0.8443	0.7476	0.9134	0.8237	0.7131
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda1000')
0.9218	0.8525	0.7610	0.9104	0.8299	0.7307
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda10000')
0.9242	0.8598	0.

0.9293	0.8608	0.7699	0.9168	0.8329	0.7361
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000')
0.9278	0.8586	0.7673	0.9156	0.8322	0.7270
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda1000', 'wordlevel_rembert_lbda10000')
0.9366	0.8690	0.7763	0.9215	0.8412	0.7368
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000')
0.9273	0.8594	0.7705	0.9159	0.8313	0.7276
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.9364	0.8677	0.7784	0.9235	0.8394	0.7351
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9338	0.8650	0.7715	0.9220	0.8403	0.7362
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'xlmr_large_roen_sl01')
0.9357	0.8570	0.7614	0.9220	0.8308	0.7264
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda1000')
0.9351	0.8636	0.7719	0.9204	0.8360	0.7319
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda10000')
0.9355	0.

0.9357	0.8671	0.7810	0.9226	0.8351	0.7289
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9374	0.8684	0.7796	0.9247	0.8408	0.7375
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000')
0.9379	0.8671	0.7790	0.9238	0.8409	0.7404
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda10000')
0.9369	0.8648	0.7765	0.9243	0.8386	0.7371
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda100000')
0.9373	0.8630	0.7742	0.9239	0.8370	0.7300
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_rembert_lbda10000')
0.9391	0.8678	0.7784	0.9270	0.8461	0.7417
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000')
0.9345	0.8668	0.7805	0.9220	0.8403	0.7407
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel

0.9350	0.8672	0.7785	0.9229	0.8368	0.7348
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.9408	0.8712	0.7815	0.9249	0.8416	0.7372
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9390	0.8705	0.7782	0.9252	0.8438	0.7379
('wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000')
0.9382	0.8670	0.7766	0.9234	0.8412	0.7427
('wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000')
0.9365	0.8653	0.7762	0.9237	0.8395	0.7353
('wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_rembert_lbda10000')
0.9411	0.8722	0.7836	0.9259	0.8456	0.7422
('wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000')
0.9387	0.8676	0.7820	0.9230	0.8352	0.7306
('wordlevel_lbda10000', 'xlmr_large_roen_sl01'

0.9385	0.8669	0.7776	0.9254	0.8418	0.7419
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000')
0.9375	0.8669	0.7803	0.9235	0.8381	0.7357
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_rembert_lbda10000')
0.9414	0.8732	0.7836	0.9259	0.8450	0.7396
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000')
0.9377	0.8692	0.7854	0.9250	0.8403	0.7383
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.9420	0.8709	0.7817	0.9273	0.8447	0.7433
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9404	0.8701	0.7828	0.9268	0.8429	0.7390
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'wordlevel_unbabel_lbda1000', 'wordlevel_

0.9389	0.8707	0.7866	0.9236	0.8405	0.7417
('wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.9409	0.8731	0.7854	0.9260	0.8438	0.7432
('wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9416	0.8751	0.7876	0.9257	0.8432	0.7384
('wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9414	0.8727	0.7846	0.9265	0.8434	0.7395
('wordlevel_lbda10000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9398	0.8735	0.7889	0.9250	0.8436	0.7378
('wordlevel_lbda100000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000')
0.9371	0.8674	0.7800	0.9225	0.8373	0.7372
('wordlevel_lbda100000', 'xlmr_larg

0.9388	0.8700	0.7838	0.9237	0.8391	0.7371
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.9410	0.8728	0.7823	0.9258	0.8433	0.7419
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9410	0.8722	0.7825	0.9257	0.8430	0.7386
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9407	0.8732	0.7853	0.9267	0.8439	0.7419
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.9396	0.8733	0.7867	0.9250	0.8434	0.7376
('wordlevel_lbda10000', 'xlmr_large_roen_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unb

In [11]:
lp = 'et-en'
gold_scores, mt_tokens, src_tokens, mt_word_tags, src_word_tags = read_gold_data(lp)
explanations={
#     'attn_finetuned_softmax_large_layer_19': partial(get_explanations, reduction='sum', transform='none'),
#     'attn_norm_head_18_3': partial(get_explanations, reduction='sum', transform='none'),
#     'attn_norm_cross_head_18_3': partial(get_explanations, reduction='sum', transform='none'),
#     'attn_norm_finetuned_softmax_large_head_18_3': partial(get_explanations, reduction='sum', transform='none'),
#     'attn_norm_merge_05_head_19_2': partial(get_explanations, reduction='sum', transform='none'),
#     'attn_norm_merge_05_layer_19': partial(get_explanations, reduction='sum', transform='none'),
#     'attn_norm_rembert_layer_23': partial(get_explanations, lp=lp, reduction='sum', transform='none'),

#     'attn_bottleneck_sigmoid_bce_sepmtsrc_lbda10': partial(get_explanations),
#     'attn_bottleneck_sigmoid_bce_jointmtsrc_lbda10': partial(get_explanations),
    'wordlevel_lbda1000': partial(get_explanations),
    'wordlevel_lbda10000': partial(get_explanations),
    'wordlevel_lbda100000': partial(get_explanations),
#     'xlmr_large_allall_sl01': partial(get_explanations),
    'xlmr_large_eten_sl01': partial(get_explanations),
#     'xlmr_large_allall_reversed': partial(get_explanations),
#     'xlmr_large_eten_reversed': partial(get_explanations),
    'wordlevel_unbabel_lbda1000': partial(get_explanations),  
    'wordlevel_unbabel_lbda10000': partial(get_explanations),
    'wordlevel_unbabel_lbda100000': partial(get_explanations),
    'wordlevel_rembert_lbda10000': partial(get_explanations, revert_mt_src=False),
}

for expl in powerset(list(explanations.keys())):
    print(expl)
    explanations_subset = [explanations[ex](lp=lp, explainer=ex) for ex in expl]
    ensemble_explanations(
        all_mt_expls=[ex[0] for ex in explanations_subset],
        all_src_expls=[ex[1] for ex in explanations_subset],
        weights='uni',
        gold_mt_tags=mt_word_tags,
        gold_src_tags=src_word_tags
    )
    print('==='*20)

('wordlevel_lbda1000',)
0.8806	0.8082	0.7051	0.8540	0.7584	0.6476
('wordlevel_lbda10000',)
0.8810	0.8138	0.7108	0.8584	0.7681	0.6558
('wordlevel_lbda100000',)
0.8800	0.8126	0.7136	0.8521	0.7558	0.6444
('xlmr_large_eten_sl01',)
0.8732	0.7978	0.6921	0.8483	0.7491	0.6350
('wordlevel_unbabel_lbda1000',)
0.8714	0.7975	0.6951	0.8430	0.7377	0.6221
('wordlevel_unbabel_lbda10000',)
0.8746	0.8036	0.7037	0.8541	0.7511	0.6301
('wordlevel_unbabel_lbda100000',)
0.8734	0.7971	0.6960	0.8440	0.7391	0.6179
('wordlevel_rembert_lbda10000',)
0.8815	0.8058	0.7035	0.8515	0.7470	0.6313
('wordlevel_lbda1000', 'wordlevel_lbda10000')
0.8856	0.8205	0.7166	0.8614	0.7731	0.6647
('wordlevel_lbda1000', 'wordlevel_lbda100000')
0.8871	0.8172	0.7118	0.8597	0.7658	0.6553
('wordlevel_lbda1000', 'xlmr_large_eten_sl01')
0.8874	0.8193	0.7141	0.8609	0.7655	0.6558
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda1000')
0.8864	0.8183	0.7166	0.8594	0.7656	0.6562
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda10000')
0.8870	0.8209	0.

0.8858	0.8202	0.7249	0.8622	0.7654	0.6535
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000')
0.8864	0.8193	0.7219	0.8597	0.7625	0.6443
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda1000', 'wordlevel_rembert_lbda10000')
0.8947	0.8300	0.7318	0.8664	0.7739	0.6668
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000')
0.8884	0.8229	0.7258	0.8623	0.7636	0.6454
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.8947	0.8299	0.7335	0.8685	0.7739	0.6565
('wordlevel_lbda1000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8935	0.8268	0.7275	0.8653	0.7712	0.6558
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'xlmr_large_eten_sl01')
0.8897	0.8242	0.7211	0.8668	0.7729	0.6590
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda1000')
0.8884	0.8219	0.7215	0.8637	0.7718	0.6641
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda10000')
0.8891	0.

0.8906	0.8264	0.7244	0.8651	0.7726	0.6632
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8956	0.8294	0.7263	0.8685	0.7786	0.6717
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000')
0.8899	0.8249	0.7243	0.8660	0.7754	0.6696
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda10000')
0.8910	0.8258	0.7247	0.8673	0.7751	0.6629
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda100000')
0.8903	0.8262	0.7258	0.8665	0.7748	0.6645
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_rembert_lbda10000')
0.8956	0.8331	0.7335	0.8693	0.7779	0.6679
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000')
0.8892	0.8256	0.7302	0.8660	0.7733	0.6619
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel

0.8899	0.8236	0.7217	0.8658	0.7691	0.6562
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.8968	0.8326	0.7343	0.8705	0.7755	0.6597
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8954	0.8299	0.7311	0.8680	0.7744	0.6594
('wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000')
0.8878	0.8227	0.7234	0.8671	0.7750	0.6661
('wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000')
0.8875	0.8232	0.7203	0.8645	0.7731	0.6660
('wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_rembert_lbda10000')
0.8951	0.8312	0.7331	0.8683	0.7776	0.6688
('wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000')
0.8882	0.8232	0.7208	0.8665	0.7739	0.6608
('wordlevel_lbda10000', 'xlmr_large_eten_sl01'

0.8902	0.8258	0.7274	0.8681	0.7760	0.6657
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000')
0.8904	0.8259	0.7254	0.8662	0.7749	0.6665
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_rembert_lbda10000')
0.8960	0.8318	0.7324	0.8690	0.7757	0.6659
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000')
0.8909	0.8272	0.7268	0.8676	0.7762	0.6654
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.8958	0.8332	0.7349	0.8706	0.7800	0.6706
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8951	0.8308	0.7318	0.8692	0.7782	0.6685
('wordlevel_lbda1000', 'wordlevel_lbda10000', 'wordlevel_unbabel_lbda1000', 'wordlevel_

0.8879	0.8240	0.7238	0.8649	0.7711	0.6594
('wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.8942	0.8307	0.7337	0.8694	0.7765	0.6654
('wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8944	0.8312	0.7341	0.8683	0.7755	0.6644
('wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8943	0.8304	0.7326	0.8691	0.7776	0.6671
('wordlevel_lbda10000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8941	0.8293	0.7336	0.8668	0.7714	0.6577
('wordlevel_lbda100000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000')
0.8871	0.8237	0.7248	0.8623	0.7668	0.6537
('wordlevel_lbda100000', 'xlmr_larg

0.8889	0.8248	0.7259	0.8667	0.7722	0.6591
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_rembert_lbda10000')
0.8949	0.8322	0.7355	0.8703	0.7765	0.6644
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8952	0.8331	0.7372	0.8696	0.7760	0.6656
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8949	0.8337	0.7377	0.8696	0.7771	0.6652
('wordlevel_lbda10000', 'wordlevel_lbda100000', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unbabel_lbda100000', 'wordlevel_rembert_lbda10000')
0.8945	0.8300	0.7342	0.8685	0.7716	0.6582
('wordlevel_lbda10000', 'xlmr_large_eten_sl01', 'wordlevel_unbabel_lbda1000', 'wordlevel_unbabel_lbda10000', 'wordlevel_unb

In [30]:
# go to internal leaderboard and see .sh scripts to get the name of the explainers!
lp = 'ru-de'
# lp = 'en-de'

if lp == 'ru-de':
#     start, end = 1000, 2000  # if not word-level and want to select en-de
    start, end = 0, 2000     # otherwise
elif lp == 'en-de':
    start, end = 0, 1000
else:
    start, end = 0, 10000000

# XLMR-Large + finetuned + attn * norm - layer 20 - head 3	ru-de
# XLMR-Large + finetuned + attn * norm - layer 19 - head 12	ru-de
# Rembert - layer 13 - head 3	ru-de
# Rembert - layer 23 - head 13	ru-de
# Rembert - layer 23	ru-de
# XLMR-Large + finetuned + attn * norm - layer 19 - head 1	en-de
# XLMR-Large + finetuned + attn * norm - layer 18 - head 0	en-de
# rember - layer 13 - head 3	en-de
# rember -  layer 24 - head 16	en-de
    
gold_scores, mt_tokens, src_tokens, mt_word_tags, src_word_tags = read_gold_data(lp)
explanations={
    'attn_finetuned_new_head_20_3': partial(get_explanations, reduction='sum', transform='none'),
    'attn_finetuned_new_head_19_12': partial(get_explanations, reduction='sum', transform='none'),
    'attn_norm_rembert_new_layer_23': partial(get_explanations, reduction='sum', transform='none', revert_mt_src=False),
    'attn_norm_rembert_new_head_13_3': partial(get_explanations, reduction='sum', transform='none', revert_mt_src=False),
    'attn_norm_rembert_new_head_23_13': partial(get_explanations, reduction='sum', transform='none', revert_mt_src=False),
    'attn_finetuned_new_head_19_1': partial(get_explanations, reduction='sum', transform='none'),
    'attn_finetuned_new_head_18_0': partial(get_explanations, reduction='sum', transform='none'),
    'attn_norm_rembert_new_head_24_16': partial(get_explanations, reduction='sum', transform='none', revert_mt_src=False),
#     'wordlevel_xlmr_allall_lbda10000': partial(get_explanations),
#     'wordlevel_unbabel_allall_lbda10000': partial(get_explanations),
#     'wordlevel_rembert_allall_lbda10000': partial(get_explanations, revert_mt_src=False),
#     'wordlevel_lbda100000': partial(get_explanations),
#     'wordlevel_unbabel_lbda100000': partial(get_explanations),
}


for expl in powerset(list(explanations.keys())):
    print(expl)
    explanations_subset = [explanations[ex](lp=lp, explainer=ex) for ex in expl]
    ensemble_explanations(
        all_mt_expls=[ex[0][start:end] for ex in explanations_subset],
        all_src_expls=[ex[1][start:end] for ex in explanations_subset],
        weights='uni',
        gold_mt_tags=mt_word_tags[start:end],
        gold_src_tags=src_word_tags[start:end]
    )
    print('==='*20)

('attn_finetuned_new_head_20_3', 'attn_norm_rembert_new_head_13_3', 'attn_finetuned_new_head_19_1', 'attn_finetuned_new_head_18_0', 'attn_norm_rembert_new_head_24_16')
0.6982	0.5945	0.4681	0.7108	0.5679	0.4356
('attn_finetuned_new_head_20_3', 'attn_finetuned_new_head_19_12', 'attn_norm_rembert_new_head_13_3', 'attn_finetuned_new_head_19_1', 'attn_finetuned_new_head_18_0', 'attn_norm_rembert_new_head_24_16')
0.7000	0.5936	0.4689	0.7121	0.5711	0.4384


In [40]:
# go to internal leaderboard and see .sh scripts to get the name of the explainers!
lp = 'de-zh'
# lp = 'en-zh'

if lp == 'de-zh':
#     start, end = 0, 1000  # if not word-level and want to select en-zh
    start, end = 0, 2000
elif lp == 'en-zh':
    start, end = 0, 1000
else:
    start, end = 0, 10000000

# XLMR-Unbabel - de-zh - layer 8 - head 12	de-zh
# XLMR-Unbabel + all-all - layer 21 - head 12	de-zh
# Rembert - layer 23 - head 9	de-zh
# Rembert - layer 24 - head 16	de-zh
# XLMR-Unbabel - de-zh - layer 10 - head 7	en-zh
# XLMR-Unbabel - de-zh - layer 10 - head 8	en-zh
# XLMR-Unbabel - de-zh - layer 10	en-zh
# Rembert - layer 23 - head 12	en-zh
# Rembert - layer 23 - head 9	en-zh

# attn_norm_new
# attn_norm_metrics_new
# attn_norm_xlmr_new
# attn_norm_unbabel_new
# attn_finetuned_new
# attn_norm_rembert_new

gold_scores, mt_tokens, src_tokens, mt_word_tags, src_word_tags = read_gold_data(lp)
explanations={
    'attn_norm_metrics_new_head_8_12': partial(get_explanations, reduction='sum', transform='none'),
    'attn_norm_unbabel_new_head_21_12': partial(get_explanations, reduction='sum', transform='none'),
    'attn_norm_rembert_new_head_24_16': partial(get_explanations, reduction='sum', transform='none', revert_mt_src=False),
    'attn_norm_metrics_new_head_10_7': partial(get_explanations, reduction='sum', transform='none'),
    'attn_norm_metrics_new_head_10_8': partial(get_explanations, reduction='sum', transform='none'),
    'attn_norm_metrics_new_layer_10': partial(get_explanations, reduction='sum', transform='none'),
    'attn_norm_rembert_new_head_23_12': partial(get_explanations, reduction='sum', transform='none', revert_mt_src=False),
    'attn_norm_rembert_new_head_23_9': partial(get_explanations, reduction='sum', transform='none', revert_mt_src=False),
#     'wordlevel_xlmr_allall_lbda10000': partial(get_explanations),
#     'wordlevel_unbabel_allall_lbda10000': partial(get_explanations),
#     'wordlevel_rembert_allall_lbda10000': partial(get_explanations, revert_mt_src=False),
#     'wordlevel_lbda100000': partial(get_explanations),
#     'wordlevel_unbabel_lbda100000': partial(get_explanations),
}

for expl in powerset(list(explanations.keys())):
    print(expl)
    explanations_subset = [explanations[ex](lp=lp, explainer=ex) for ex in expl]
    ensemble_explanations(
        all_mt_expls=[ex[0][start:end] for ex in explanations_subset],
        all_src_expls=[ex[1][start:end] for ex in explanations_subset],
        weights='uni',
        gold_mt_tags=mt_word_tags[start:end],
        gold_src_tags=src_word_tags[start:end]
    )
    print('==='*20)

('attn_norm_metrics_new_head_8_12',)
0.6240	0.5274	0.4137	0.6425	0.5089	0.3890
('attn_norm_unbabel_new_head_21_12',)
0.6112	0.4884	0.3607	0.5982	0.4408	0.3048
('attn_norm_rembert_new_head_24_16',)
0.6255	0.5245	0.4000	0.6389	0.4888	0.3573
('attn_norm_metrics_new_head_10_7',)
0.6024	0.4854	0.3702	0.5957	0.4603	0.3413
('attn_norm_metrics_new_head_10_8',)
0.5947	0.4755	0.3591	0.6156	0.4648	0.3441
('attn_norm_metrics_new_layer_10',)
0.5958	0.4797	0.3656	0.6162	0.4710	0.3474
('attn_norm_rembert_new_head_23_12',)
0.6247	0.4911	0.3766	0.6044	0.4523	0.3250
('attn_norm_rembert_new_head_23_9',)
0.6261	0.5005	0.3927	0.6381	0.4818	0.3529
('attn_norm_metrics_new_head_8_12', 'attn_norm_unbabel_new_head_21_12')
0.6421	0.5164	0.3873	0.6362	0.4804	0.3497
('attn_norm_metrics_new_head_8_12', 'attn_norm_rembert_new_head_24_16')
0.6465	0.5341	0.4120	0.6744	0.5239	0.3917
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_7')
0.6040	0.4940	0.3836	0.6184	0.4749	0.3557
('attn_norm_metrics_new_h

0.6524	0.5405	0.4194	0.6754	0.5244	0.3902
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8')
0.6030	0.4805	0.3668	0.6257	0.4733	0.3538
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_layer_10')
0.6025	0.4842	0.3736	0.6226	0.4718	0.3531
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_7', 'attn_norm_rembert_new_head_23_12')
0.6195	0.4976	0.3930	0.6276	0.4809	0.3657
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_7', 'attn_norm_rembert_new_head_23_9')
0.6135	0.5000	0.3919	0.6277	0.4805	0.3617
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10')
0.6019	0.4784	0.3662	0.6291	0.4745	0.3555
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_8', 'attn_norm_rembert_new_head_23_12')
0.6192	0.4919	0.3865	0.6338	0.4789	0.3595
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_8',

0.6212	0.4967	0.3948	0.6275	0.4771	0.3580
('attn_norm_metrics_new_head_10_7', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6353	0.5007	0.3892	0.6256	0.4694	0.3475
('attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12')
0.6264	0.4967	0.3904	0.6253	0.4685	0.3475
('attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_9')
0.6209	0.4959	0.3937	0.6305	0.4762	0.3567
('attn_norm_metrics_new_head_10_8', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6346	0.4974	0.3849	0.6294	0.4637	0.3365
('attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6341	0.4996	0.3875	0.6313	0.4693	0.3462
('attn_norm_metrics_new_head_8_12', 'attn_norm_unbabel_new_head_21_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_7')
0.6495	0.5251	0.3987	0.6561	0.4972	0.3664
('attn_norm_metrics_new_head_8_1

0.6102	0.4880	0.3788	0.6304	0.4767	0.3592
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_7', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6259	0.5027	0.4010	0.6324	0.4831	0.3665
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12')
0.6180	0.4885	0.3802	0.6351	0.4779	0.3602
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_9')
0.6097	0.4855	0.3769	0.6362	0.4784	0.3586
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_head_10_8', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6248	0.4963	0.3917	0.6369	0.4812	0.3610
('attn_norm_metrics_new_head_8_12', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6257	0.5007	0.3977	0.6352	0.4838	0.3667
('attn_norm_unbabel_new_head_21_12', 'attn

0.6381	0.4879	0.3633	0.6480	0.4693	0.3281
('attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6414	0.4961	0.3723	0.6520	0.4795	0.3379
('attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12')
0.6256	0.4976	0.3934	0.6243	0.4719	0.3562
('attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_9')
0.6191	0.4947	0.3904	0.6259	0.4746	0.3550
('attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6343	0.5029	0.3959	0.6290	0.4722	0.3511
('attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6337	0.5033	0.3964	0.6300	0.4745	0.3550
('attn_norm_metrics_new_head_10_8', 'attn_

0.6416	0.4983	0.3734	0.6550	0.4885	0.3602
('attn_norm_metrics_new_head_8_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_9')
0.6381	0.4947	0.3702	0.6563	0.4912	0.3672
('attn_norm_metrics_new_head_8_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_7', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6482	0.5218	0.4030	0.6603	0.5067	0.3716
('attn_norm_metrics_new_head_8_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12')
0.6386	0.4902	0.3648	0.6557	0.4833	0.3542
('attn_norm_metrics_new_head_8_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_9')
0.6350	0.4852	0.3605	0.6583	0.4867	0.3600
('attn_norm_metrics_new_head_8_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_

0.6433	0.4927	0.3659	0.6547	0.4746	0.3392
('attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6414	0.4857	0.3574	0.6536	0.4691	0.3286
('attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6320	0.5033	0.3997	0.6295	0.4756	0.3578
('attn_norm_metrics_new_head_8_12', 'attn_norm_unbabel_new_head_21_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10')
0.6416	0.4864	0.3569	0.6525	0.4771	0.3494
('attn_norm_metrics_new_head_8_12', 'attn_norm_unbabel_new_head_21_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_rembert_new_head_23_12')
0.6487	0.4991	0.3722	0.6529	0.4810	0.3522
('attn_norm

0.6529	0.4982	0.3700	0.6486	0.4717	0.3414
('attn_norm_unbabel_new_head_21_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6490	0.4911	0.3641	0.6485	0.4667	0.3343
('attn_norm_unbabel_new_head_21_12', 'attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6335	0.4936	0.3805	0.6274	0.4597	0.3329
('attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metrics_new_layer_10', 'attn_norm_rembert_new_head_23_12', 'attn_norm_rembert_new_head_23_9')
0.6442	0.4897	0.3641	0.6550	0.4759	0.3452
('attn_norm_metrics_new_head_8_12', 'attn_norm_unbabel_new_head_21_12', 'attn_norm_rembert_new_head_24_16', 'attn_norm_metrics_new_head_10_7', 'attn_norm_metrics_new_head_10_8', 'attn_norm_metric

In [None]:
# fix openkiwi explanations to do ensemble
# openkiwi_expl_dirs = !ls /home/mtreviso/OpenKiwi/predictions/
# for ok_dname in openkiwi_expl_dirs:
#     ok_lp = ok_dname[-5:].replace('_', '')
#     ok_expl = ok_dname[:-6]
#     if ok_lp == 'test':
#         continue
#     ok_dname = os.path.join('/home/mtreviso/OpenKiwi/predictions/', ok_dname)
#     new_dname = 'experiments/explanations/{}_{}'.format(ok_lp, ok_expl)
#     if not os.path.exists(new_dname):
#         os.mkdir(new_dname)
#     shutil.copy(os.path.join(ok_dname, 'aggregated_mt_scores.txt'), os.path.join(new_dname, 'mt_scores.txt'))
#     shutil.copy(os.path.join(ok_dname, 'aggregated_source_scores.txt'), os.path.join(new_dname, 'source_scores.txt'))
#     shutil.copy(os.path.join(ok_dname, 'sentence_scores'), os.path.join(new_dname, 'sentence_scores.txt'))
