set cuda id

In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=4

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=4


Install dependencies for computing metrics and plots:

In [2]:
#!pip3 install numpy scipy pandas seaborn matplotlib sklearn

## Basic imports

In [3]:
import jax
import jax.numpy as jnp
import flax
from entmax_jax import sparsemax
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML
from functools import partial
import json
from entmax_jax.activations import sparsemax, entmax15
from sklearn.metrics import roc_auc_score, average_precision_score

from meta_expl.explainers import load_explainer
from meta_expl.models import load_model
from meta_expl.data.mlqe import dataloader

from evaluate_explanations import evaluate_word_level, evaluate_sentence_level, aggregate_pieces

In [4]:
# data utils
def unroll(list_of_lists):
    return [e for ell in list_of_lists for e in ell]

def read_data(lp, split='dev'):
    def tags_to_ints(line):
        return list(map(int, line.strip().replace('OK', '0').replace('BAD', '1').split()))
    data = {
        'original': [line.strip() for line in open('data/mlqepe/{}/{}.src'.format(lp, split), 'r')],
        'translation': [line.strip() for line in open('data/mlqepe/{}/{}.mt'.format(lp, split), 'r')],
        'z_mean': [float(line.strip()) for line in open('data/mlqepe/{}/{}.da'.format(lp, split), 'r')],
        'src_tags': [tags_to_ints(line) for line in open('data/mlqepe/{}/{}.src-tags'.format(lp, split), 'r')],
        'mt_tags': [tags_to_ints(line) for line in open('data/mlqepe/{}/{}.tgt-tags'.format(lp, split), 'r')]
    }
    data['da'] = data['z_mean']
    data = [dict(zip(data.keys(), v)) for v in list(zip(*data.values()))]
    return data

def read_data_all(lps, split='dev'):
    data = {
        'original': [],
        'translation': [],
        'z_mean': [],
        'src_tags': [],
        'mt_tags': [],
    }
    data['da'] = data['z_mean']
    for lp in lps:
        ell = read_data(lp, split)
        for key in data.keys():
            data[key].extend([d[key] for d in ell])
    data = [dict(zip(data.keys(), v)) for v in list(zip(*data.values()))]
    return data

## Define args and load stuff

In [5]:
# arguments
arch = 'xlm-roberta-base'
arch_mtl = 'xlm-r'
setup = 'no_teacher'  # "no_teacher", "static_teacher", "learnable_teacher"
# langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en"]
lp = 'ro-en'
max_len = 256
batch_size = 16
seed = 1
sep_token = "</s>" if 'xlm' in arch else "[SEP]"
dataloader = partial(dataloader, sep_token=sep_token)
num_classes = 1
task_type = "regression"
teacher_dir = 'data/mlqe-xlmr-models/teacher_dir'
explainer_dir = 'data/mlqe-xlmr-models/teacher_expl_dir'

In [6]:
# create dummy inputs for model instantiation
input_ids = jnp.ones((batch_size, max_len), jnp.int32)
dummy_inputs = {
    "input_ids": input_ids,
    "attention_mask": jnp.ones_like(input_ids),
    "token_type_ids": jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
    "position_ids": jnp.ones_like(input_ids),
}
dummy_inputs['input_ids'].shape

(16, 256)

### load data

In [7]:
# load data
train_data = read_data(lp, "train")
valid_data = read_data(lp, "dev")
test_data = read_data(lp, "test")

### load tokenizer

In [8]:
from transformers import XLMRobertaTokenizerFast
tokenizer = XLMRobertaTokenizerFast.from_pretrained(arch)
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
pad_id = tokenizer.pad_token_id

### load model and explainer

In [10]:
teacher, teacher_params, dummy_state = load_model(teacher_dir, dummy_inputs, batch_size, max_len)
teacher_explainer, teacher_explainer_params = load_explainer(explainer_dir, dummy_inputs, state=dummy_state)

In [27]:
from meta_expl.utils import PRNGSequence
from meta_expl.explainers import create_explainer
keyseq = PRNGSequence(11)
teacher_explainer_params_non_trained={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': None,
    'head_idx': None
}
teacher_explainer_non_trained, teacher_explainer_params_non_trained = create_explainer(
    key=next(keyseq),
    inputs=dummy_inputs,
    state=dummy_state,
    explainer_type='attention_explainer',
    explainer_args=teacher_explainer_params_non_trained,
)

In [14]:
best_head_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 9,  #9, None
    'head_idx': 5,  #5, None
}
best_head_teacher_explainer, best_head_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_head_teacher_explainer_params
)

In [40]:
best_layer_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 10,  #9, None
    'head_idx': None,  #5, None
}
best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_layer_teacher_explainer_params
)

In [18]:
input_gradient_teacher_explainer, input_gradient_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='gradient_input_explainer', 
    model_extras={
        "grad_fn": teacher.apply(
            teacher_params, dummy_inputs, method=teacher.embeddings_grad_fn
        )
    }
)

In [19]:
int_gradient_teacher_explainer, int_gradient_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='integrated_gradients_explainer', 
    model_extras={
        "grad_fn": teacher.apply(
            teacher_params, dummy_inputs, method=teacher.embeddings_grad_fn
        )
    }
)

### look at the coefficients

In [20]:
sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12)

DeviceArray([[0.05825656, 0.05605122, 0.06579615, 0.06052823, 0.05019351,
              0.0858976 , 0.04817941, 0.05621345, 0.01868655, 0.02941416,
              0.04372283, 0.05819663],
             [0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.0433885 , 0.        ],
             [0.        , 0.        , 0.        , 0.        , 0. 

In [21]:
hc = sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12)
for a, b in zip(*hc.nonzero()):
    print(a+1, b+1)

1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
1 11
1 12
5 11
8 9
8 10
9 6
9 8
10 6
11 2
11 4


In [22]:
# check the layers with the highest coefficients
layer_coeffs = sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12).mean(-1).tolist()
sorted(list(zip(list(range(1, len(layer_coeffs)+1)), layer_coeffs)), key=lambda k: k[1])

[(2, 0.0),
 (3, 0.0),
 (4, 0.0),
 (6, 0.0),
 (7, 0.0),
 (12, 0.0),
 (11, 0.003190424060449004),
 (5, 0.003615708090364933),
 (10, 0.004049660172313452),
 (9, 0.005110216327011585),
 (8, 0.014772631227970123),
 (1, 0.052594687789678574)]

## Get explanations

In [50]:
def get_expls(data, t, t_p, t_e, t_e_p, s=None, s_p=None, s_e=None, s_e_p=None, is_grad_based=False):
    all_tokens = []
    all_masks = []
    all_explanations = []
    all_outputs = []
    for i, (x, y) in enumerate(dataloader(data, tokenizer, batch_size=batch_size, max_len=max_len, shuffle=False)):
        print('{} of {}'.format(i+1, len(data) // batch_size + 1), end='\r')
    
        y_teacher, teacher_attn = t.apply(t_p, **x, deterministic=True)
        y_teacher = jnp.argmax(y_teacher, axis=-1) if task_type == "classification" else y_teacher
        
        if is_grad_based:
            teacher_extras = {
                "grad_fn": t.apply(t_p, x, method=t.embeddings_grad_fn)
            }
            teacher_expl, _ = t_e.apply(t_e_p, x, teacher_attn, **teacher_extras)
        else:
            teacher_expl, _ = t_e.apply(t_e_p, x, teacher_attn)
        # teacher_rep = teacher_attn['hidden_states'][0][0]
        # teacher_attn = np.asarray(jnp.stack(teacher_attn['attentions']).transpose([1, 0, 2, 3, 4]))
        
        if s is not None:
            y_student, student_attn = s.apply(s_p, **x)
            y_student = jnp.argmax(y_student, axis=-1) if task_type == "classification" else y_student
            student_expl, _ = student_explainer.apply(s_e_p, x, student_attn)
            # student_attn = np.asarray(jnp.stack(student_attn['attentions']).transpose([1, 0, 2, 3, 4]))
        
        # convert everything to lists
        batch_ids = x['input_ids'].tolist()
        batch_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in batch_ids]
        batch_masks = [[tk.startswith('▁') for tk in tokens] for tokens in batch_tokens]
        batch_expls = teacher_expl.tolist()
        
        # filter out pad
        batch_valid_len = x['attention_mask'].sum(-1).tolist()
        for i in range(len(batch_valid_len)):
            n = batch_valid_len[i]
            batch_ids[i] = batch_ids[i][:n]
            batch_tokens[i] = batch_tokens[i][:n]
            batch_masks[i] = batch_masks[i][:n]
            batch_expls[i] = batch_expls[i][:n]
        
        all_tokens.extend(batch_tokens)
        all_masks.extend(batch_masks)
        all_explanations.extend(batch_expls)
        all_outputs.extend(y_teacher.tolist())
        
    return all_tokens, all_masks, all_explanations, all_outputs

In [26]:
valid_tokens, valid_masks, valid_explanations, valid_outputs = get_expls(
    valid_data, teacher, teacher_params, teacher_explainer, teacher_explainer_params 
)
list(map(len, [valid_tokens, valid_masks, valid_explanations, valid_outputs]))

63 of 63

[1000, 1000, 1000, 1000]

### Aggregate scores for word pieces in SRC and MT independently

In [28]:
import torch
from utils import aggregate_pieces

def get_src_and_mt_explanations(all_tokens, all_fp_masks, all_explanations, reduction):
    src_expls = []
    mt_expls = []
    src_pieces = []
    mt_pieces = []
    for tokens, expl, fp_mask in zip(all_tokens, all_explanations, all_fp_masks):
        # split data into src and mt (assuming "<s> src </s> mt </s>" format without CLS for mt) 
        src_len = tokens.index(tokenizer.sep_token) + 1
        src_tokens, mt_tokens = tokens[:src_len], tokens[src_len:]
        src_expl, mt_expl = expl[:src_len], expl[src_len:]
        src_fp_mask, mt_fp_mask = fp_mask[:src_len], fp_mask[src_len:]
        
        # aggregate word pieces scores (use my old good torch function)
        agg_src_expl = aggregate_pieces(torch.tensor(src_expl), torch.tensor(src_fp_mask), reduction)
        agg_mt_expl = aggregate_pieces(torch.tensor(mt_expl), torch.tensor(mt_fp_mask), reduction)
        
        # remove <s> and </s> from src
        agg_src_expl = agg_src_expl.tolist()[1:-1]
        # remove </s> from mt
        agg_mt_expl = agg_mt_expl.tolist()[:-1]
        
        src_pieces.append(src_tokens)
        mt_pieces.append(mt_tokens)
        src_expls.append(agg_src_expl)
        mt_expls.append(agg_mt_expl)
    return src_expls, mt_expls, src_pieces, mt_pieces

In [29]:
reduction = 'sum'  # first, sum, mean, max
src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
    valid_tokens, valid_masks, valid_explanations, reduction=reduction
)

## Evaluating explanations by comparing explanations with word-level QE tags

In [30]:
gold_src_tokens = [inp['original'].split() for inp in valid_data]
gold_mt_tokens = [inp['translation'].split() for inp in valid_data]
gold_expls_src = [inp['src_tags'] for inp in valid_data]
gold_expls_mt = [inp['mt_tags'] for inp in valid_data]
gold_scores = [inp['z_mean'] for inp in valid_data]

pred_expls_src = src_expls
pred_expls_mt = mt_expls
pred_scores = unroll(valid_outputs)

In [31]:
_ = evaluate_sentence_level(gold_scores, pred_scores)

Pearson: 0.7941
Spearman: 0.7420
MAE: 0.5030
RMSE: 0.6408


In [32]:
_ = evaluate_word_level(gold_expls_src, pred_expls_src)

AUC score: 0.7050
AP score: 0.5724
Recall at top-K: 0.4449


In [33]:
_ = evaluate_word_level(gold_expls_mt, pred_expls_mt)

AUC score: 0.7028
AP score: 0.5823
Recall at top-K: 0.4555


## Evaluate all LPs

In [34]:
def filter_diff_seq_len(gold, pred):
    new_pred, new_gold = [], []
    t = 0
    for p, g in zip(pred, gold):
        if len(p) == len(g):
            new_pred.append(p)
            new_gold.append(g)
        else:
            t += 1
    print('filtered:', t)
    return new_gold, new_pred

In [49]:
def eval_plausibility_all_lps(t, t_p, t_e, t_e_p, split='dev', is_grad_based=False):
    langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en", "all"]
    for lp in langpairs:
        if lp == "all":
            data = read_data_all(langpairs[:-1], split)
        else:
            data = read_data(lp, split)
        valid_tokens, valid_masks, valid_explanations, valid_outputs = get_expls(
            data, t, t_p, t_e, t_e_p, is_grad_based=is_grad_based
        )
        print('')
        print(lp)
        print('----------')
        src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
            valid_tokens, valid_masks, valid_explanations, reduction='sum'
        )
        gold_src_tokens = [inp['original'].split() for inp in data]
        gold_mt_tokens = [inp['translation'].split() for inp in data]
        gold_expls_src = [inp['src_tags'] for inp in data]
        gold_expls_mt = [inp['mt_tags'] for inp in data]
        gold_scores = [inp['z_mean'] for inp in data]
        pred_expls_src = src_expls
        pred_expls_mt = mt_expls
        pred_scores = unroll(valid_outputs)
        gold_expls_src, pred_expls_src = filter_diff_seq_len(gold_expls_src, pred_expls_src)
        gold_expls_mt, pred_expls_mt = filter_diff_seq_len(gold_expls_mt, pred_expls_mt)
        evaluate_sentence_level(gold_scores, pred_scores)
        evaluate_word_level(gold_expls_src, pred_expls_src)
        evaluate_word_level(gold_expls_mt, pred_expls_mt)

### meta-learned explainer

In [38]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    teacher_explainer, 
    teacher_explainer_params, 
    split='dev'
)

63 of 63
en-de
----------
filtered: 0
filtered: 0
Pearson: 0.4250
Spearman: 0.4282
MAE: 0.5056
RMSE: 0.6900
AUC score: 0.6426
AP score: 0.4733
Recall at top-K: 0.3470
AUC score: 0.6534
AP score: 0.5085
Recall at top-K: 0.3781
63 of 63
en-zh
----------
filtered: 0
filtered: 0
Pearson: 0.4491
Spearman: 0.4538
MAE: 0.5167
RMSE: 0.6402
AUC score: 0.6808
AP score: 0.5521
Recall at top-K: 0.4279
AUC score: 0.5191
AP score: 0.4727
Recall at top-K: 0.3612
63 of 63
et-en
----------
filtered: 0
filtered: 0
Pearson: 0.6332
Spearman: 0.6368
MAE: 0.5586
RMSE: 0.6877
AUC score: 0.6563
AP score: 0.5336
Recall at top-K: 0.4050
AUC score: 0.6423
AP score: 0.5209
Recall at top-K: 0.4038
63 of 63
ne-en
----------
filtered: 0
filtered: 0
Pearson: 0.6564
Spearman: 0.6537
MAE: 0.5005
RMSE: 0.6361
AUC score: 0.6564
AP score: 0.7447
Recall at top-K: 0.6555
AUC score: 0.5409
AP score: 0.7359
Recall at top-K: 0.6748
63 of 63
ro-en
----------
filtered: 0
filtered: 0
Pearson: 0.7941
Spearman: 0.7420
MAE: 0.5030
R

### all attention layers and heads

In [41]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    teacher_explainer_non_trained, 
    teacher_explainer_params_non_trained,
    split='dev'
)

63 of 63
en-de
----------
filtered: 0
filtered: 0
Pearson: 0.4250
Spearman: 0.4282
MAE: 0.5056
RMSE: 0.6900
AUC score: 0.6032
AP score: 0.3703
Recall at top-K: 0.2211
AUC score: 0.6302
AP score: 0.4796
Recall at top-K: 0.3478
63 of 63
en-zh
----------
filtered: 0
filtered: 0
Pearson: 0.4491
Spearman: 0.4538
MAE: 0.5167
RMSE: 0.6402
AUC score: 0.6793
AP score: 0.5264
Recall at top-K: 0.4004
AUC score: 0.5152
AP score: 0.4522
Recall at top-K: 0.3411
63 of 63
et-en
----------
filtered: 0
filtered: 0
Pearson: 0.6332
Spearman: 0.6368
MAE: 0.5586
RMSE: 0.6877
AUC score: 0.5966
AP score: 0.4451
Recall at top-K: 0.3131
AUC score: 0.6055
AP score: 0.4696
Recall at top-K: 0.3580
63 of 63
ne-en
----------
filtered: 0
filtered: 0
Pearson: 0.6564
Spearman: 0.6537
MAE: 0.5005
RMSE: 0.6361
AUC score: 0.5823
AP score: 0.6946
Recall at top-K: 0.6150
AUC score: 0.5465
AP score: 0.7314
Recall at top-K: 0.6822
63 of 63
ro-en
----------
filtered: 0
filtered: 0
Pearson: 0.7941
Spearman: 0.7420
MAE: 0.5030
R

### gradient x input

In [52]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    input_gradient_teacher_explainer, 
    input_gradient_teacher_explainer_params,
    split='dev',
    is_grad_based=True
)

63 of 63
en-de
----------
filtered: 0
filtered: 0
Pearson: 0.4250
Spearman: 0.4282
MAE: 0.5056
RMSE: 0.6900
AUC score: 0.5799
AP score: 0.4146
Recall at top-K: 0.2876
AUC score: 0.5952
AP score: 0.4574
Recall at top-K: 0.3381
63 of 63
en-zh
----------
filtered: 0
filtered: 0
Pearson: 0.4491
Spearman: 0.4538
MAE: 0.5167
RMSE: 0.6402
AUC score: 0.6132
AP score: 0.4871
Recall at top-K: 0.3530
AUC score: 0.5061
AP score: 0.4421
Recall at top-K: 0.3324
63 of 63
et-en
----------
filtered: 0
filtered: 0
Pearson: 0.6332
Spearman: 0.6368
MAE: 0.5586
RMSE: 0.6877
AUC score: 0.6024
AP score: 0.4689
Recall at top-K: 0.3401
AUC score: 0.5356
AP score: 0.4436
Recall at top-K: 0.3320
63 of 63
ne-en
----------
filtered: 0
filtered: 0
Pearson: 0.6564
Spearman: 0.6537
MAE: 0.5005
RMSE: 0.6361
AUC score: 0.6145
AP score: 0.7128
Recall at top-K: 0.6268
AUC score: 0.4922
AP score: 0.7057
Recall at top-K: 0.6428
63 of 63
ro-en
----------
filtered: 0
filtered: 0
Pearson: 0.7941
Spearman: 0.7420
MAE: 0.5030
R

### integrated gradients

In [53]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    int_gradient_teacher_explainer, 
    int_gradient_teacher_explainer_params,
    split='dev',
    is_grad_based=True
)

63 of 63
en-de
----------
filtered: 0
filtered: 0
Pearson: 0.4250
Spearman: 0.4282
MAE: 0.5056
RMSE: 0.6900
AUC score: 0.5855
AP score: 0.3960
Recall at top-K: 0.2602
AUC score: 0.6028
AP score: 0.4574
Recall at top-K: 0.3267
63 of 63
en-zh
----------
filtered: 0
filtered: 0
Pearson: 0.4491
Spearman: 0.4538
MAE: 0.5167
RMSE: 0.6402
AUC score: 0.6275
AP score: 0.4976
Recall at top-K: 0.3707
AUC score: 0.4932
AP score: 0.4519
Recall at top-K: 0.3442
63 of 63
et-en
----------
filtered: 0
filtered: 0
Pearson: 0.6332
Spearman: 0.6368
MAE: 0.5586
RMSE: 0.6877
AUC score: 0.6042
AP score: 0.4778
Recall at top-K: 0.3446
AUC score: 0.5212
AP score: 0.4157
Recall at top-K: 0.3084
63 of 63
ne-en
----------
filtered: 0
filtered: 0
Pearson: 0.6564
Spearman: 0.6537
MAE: 0.5005
RMSE: 0.6361
AUC score: 0.6355
AP score: 0.7299
Recall at top-K: 0.6406
AUC score: 0.4848
AP score: 0.6963
Recall at top-K: 0.6507
63 of 63
ro-en
----------
filtered: 0
filtered: 0
Pearson: 0.7941
Spearman: 0.7420
MAE: 0.5030
R

### best attention layer

In [44]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    best_layer_teacher_explainer, 
    best_layer_teacher_explainer_params,
    split='dev'
)

63 of 63
en-de
----------
filtered: 0
filtered: 0
Pearson: 0.4250
Spearman: 0.4282
MAE: 0.5056
RMSE: 0.6900
AUC score: 0.6370
AP score: 0.4117
Recall at top-K: 0.2774
AUC score: 0.6468
AP score: 0.4944
Recall at top-K: 0.3659
63 of 63
en-zh
----------
filtered: 0
filtered: 0
Pearson: 0.4491
Spearman: 0.4538
MAE: 0.5167
RMSE: 0.6402
AUC score: 0.5755
AP score: 0.4636
Recall at top-K: 0.3471
AUC score: 0.5258
AP score: 0.4814
Recall at top-K: 0.3716
63 of 63
et-en
----------
filtered: 0
filtered: 0
Pearson: 0.6332
Spearman: 0.6368
MAE: 0.5586
RMSE: 0.6877
AUC score: 0.6420
AP score: 0.4908
Recall at top-K: 0.3636
AUC score: 0.6840
AP score: 0.5630
Recall at top-K: 0.4460
63 of 63
ne-en
----------
filtered: 0
filtered: 0
Pearson: 0.6564
Spearman: 0.6537
MAE: 0.5005
RMSE: 0.6361
AUC score: 0.6824
AP score: 0.7523
Recall at top-K: 0.6751
AUC score: 0.6812
AP score: 0.8055
Recall at top-K: 0.7226
63 of 63
ro-en
----------
filtered: 0
filtered: 0
Pearson: 0.7941
Spearman: 0.7420
MAE: 0.5030
R

### best attention head

In [45]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    best_head_teacher_explainer, 
    best_head_teacher_explainer_params,
    split='dev'
)

63 of 63
en-de
----------
filtered: 0
filtered: 0
Pearson: 0.4250
Spearman: 0.4282
MAE: 0.5056
RMSE: 0.6900
AUC score: 0.6727
AP score: 0.4776
Recall at top-K: 0.3459
AUC score: 0.6713
AP score: 0.5155
Recall at top-K: 0.3882
63 of 63
en-zh
----------
filtered: 0
filtered: 0
Pearson: 0.4491
Spearman: 0.4538
MAE: 0.5167
RMSE: 0.6402
AUC score: 0.5591
AP score: 0.4675
Recall at top-K: 0.3478
AUC score: 0.5385
AP score: 0.4850
Recall at top-K: 0.3742
63 of 63
et-en
----------
filtered: 0
filtered: 0
Pearson: 0.6332
Spearman: 0.6368
MAE: 0.5586
RMSE: 0.6877
AUC score: 0.6960
AP score: 0.5576
Recall at top-K: 0.4243
AUC score: 0.6961
AP score: 0.5714
Recall at top-K: 0.4589
63 of 63
ne-en
----------
filtered: 0
filtered: 0
Pearson: 0.6564
Spearman: 0.6537
MAE: 0.5005
RMSE: 0.6361
AUC score: 0.6956
AP score: 0.7635
Recall at top-K: 0.6817
AUC score: 0.6919
AP score: 0.8145
Recall at top-K: 0.7285
63 of 63
ro-en
----------
filtered: 0
filtered: 0
Pearson: 0.7941
Spearman: 0.7420
MAE: 0.5030
R

### last layer attention

In [68]:
flax.linen.softmax(teacher_params['params']['scalarmix']['coeffs'])  # first item is the embedding layer

DeviceArray([0.07685792, 0.07677209, 0.0766765 , 0.07660621, 0.07668299,
             0.0766904 , 0.07666271, 0.07671636, 0.07712432, 0.07718218,
             0.07728937, 0.07735462, 0.07738439], dtype=float32)

In [69]:
flax.linen.softmax(teacher_params['params']['scalarmix']['coeffs']).argmax()

DeviceArray(12, dtype=int32)

In [55]:
best_layer_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 11,  #9, None
    'head_idx': None,  #5, None
}
best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_layer_teacher_explainer_params
)
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    best_layer_teacher_explainer, 
    best_layer_teacher_explainer_params,
    split='dev'
)

63 of 63
en-de
----------
filtered: 0
filtered: 0
Pearson: 0.4250
Spearman: 0.4282
MAE: 0.5056
RMSE: 0.6900
AUC score: 0.5082
AP score: 0.3179
Recall at top-K: 0.1803
AUC score: 0.4891
AP score: 0.3287
Recall at top-K: 0.1978
63 of 63
en-zh
----------
filtered: 0
filtered: 0
Pearson: 0.4491
Spearman: 0.4538
MAE: 0.5167
RMSE: 0.6402
AUC score: 0.6086
AP score: 0.4767
Recall at top-K: 0.3456
AUC score: 0.4945
AP score: 0.4421
Recall at top-K: 0.3293
63 of 63
et-en
----------
filtered: 0
filtered: 0
Pearson: 0.6332
Spearman: 0.6368
MAE: 0.5586
RMSE: 0.6877
AUC score: 0.5070
AP score: 0.4023
Recall at top-K: 0.2681
AUC score: 0.5028
AP score: 0.3712
Recall at top-K: 0.2703
63 of 63
ne-en
----------
filtered: 0
filtered: 0
Pearson: 0.6564
Spearman: 0.6537
MAE: 0.5005
RMSE: 0.6361
AUC score: 0.5478
AP score: 0.6808
Recall at top-K: 0.5897
AUC score: 0.4768
AP score: 0.6785
Recall at top-K: 0.6477
63 of 63
ro-en
----------
filtered: 0
filtered: 0
Pearson: 0.7941
Spearman: 0.7420
MAE: 0.5030
R

## Plotting the distribution of predictions and AUC scores

In [None]:
# define options for seaborn
custom_params = {
    'axes.spines.right': False,
    'axes.spines.top': False,
    'grid.color': '.85',
    'grid.linestyle': ':'
}
_ = sns.set_theme(style='whitegrid', rc=custom_params),

def plot_da_vs_expl_metric(metric_fn, das, e_golds, e_preds):
    x = []
    y = []
    for da, gold, pred in zip(das, e_golds, e_preds):
        if sum(gold) == 0 or sum(gold) == len(gold):
            continue
        y.append(metric_fn(gold, pred))
        x.append(da)
    x = np.array(x)
    y = np.array(y)
    fig, axs = plt.subplots(1, 3, figsize=(16, 4))
    sns.histplot(x=x, y=y, ax=axs[0])
    axs[0].set_xlabel('da')
    axs[0].set_ylabel(str(metric_fn).split()[1])
    sns.histplot(x, bins=20, ax=axs[1])
    axs[1].set_xlabel('da')
    sns.histplot(y, bins=20, ax=axs[2])
    axs[2].set_xlabel(str(metric_fn).split()[1])

In [None]:
# plot predicted DA vs AUC for src and mt
plot_da_vs_expl_metric(roc_auc_score, pred_scores, gold_expls_src, pred_expls_src)

In [None]:
plot_da_vs_expl_metric(roc_auc_score, pred_scores, gold_expls_mt, pred_expls_mt)

## Check results for all layers (slooow -> very inefficient)

In [None]:
for layer_id in range(12):
    valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
        valid_data, strategy='layer_average', layer_id=layer_id
    )
    src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
        valid_tokens, valid_masks, valid_explanations, reduction='sum'
    )
    print('LAYER: {}'.format(layer_id))
    _ = evaluate_word_level(gold_expls_src, src_expls)
    _ = evaluate_word_level(gold_expls_mt, mt_expls)
    print('---')

## Check results for all heads in all layers ((very slow)^2)

In [None]:
for layer_id in range(12):
    for head_id in range(12):
        valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
            valid_data, strategy='layer_head', layer_id=layer_id, head_id=head_id
        )
        src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
            valid_tokens, valid_masks, valid_explanations, reduction='sum'
        )
        print('LAYER: {} | HEAD: {}'.format(layer_id, head_id))
        _ = evaluate_word_level(gold_expls_src, src_expls)
        _ = evaluate_word_level(gold_expls_mt, mt_expls)
        print('---')