set cuda id

In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=6

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=6


Install dependencies for computing metrics and plots:

In [5]:
#!pip3 install numpy scipy pandas seaborn matplotlib sklearn

## Basic imports

In [2]:
import jax
import jax.numpy as jnp
import flax
from entmax_jax import sparsemax
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML
from functools import partial
import json
from entmax_jax.activations import sparsemax, entmax15
from sklearn.metrics import roc_auc_score, average_precision_score

from meta_expl.explainers import load_explainer
from meta_expl.models import load_model
from meta_expl.data.mlqe import dataloader

from evaluate_explanations import evaluate_word_level, evaluate_sentence_level, aggregate_pieces

In [3]:
# data utils
def unroll(list_of_lists):
    return [e for ell in list_of_lists for e in ell]

def read_data(lp, split='dev'):
    def tags_to_ints(line):
        return list(map(int, line.strip().replace('OK', '0').replace('BAD', '1').split()))
    data = {
        'original': [line.strip() for line in open('data/mlqepe/{}/{}.src'.format(lp, split), 'r')],
        'translation': [line.strip() for line in open('data/mlqepe/{}/{}.mt'.format(lp, split), 'r')],
        'z_mean': [float(line.strip()) for line in open('data/mlqepe/{}/{}.da'.format(lp, split), 'r')],
        'src_tags': [tags_to_ints(line) for line in open('data/mlqepe/{}/{}.src-tags'.format(lp, split), 'r')],
        'mt_tags': [tags_to_ints(line) for line in open('data/mlqepe/{}/{}.tgt-tags'.format(lp, split), 'r')]
    }
    data['da'] = data['z_mean']
    data = [dict(zip(data.keys(), v)) for v in list(zip(*data.values()))]
    return data

def read_data_all(lps, split='dev'):
    data = {
        'original': [],
        'translation': [],
        'z_mean': [],
        'src_tags': [],
        'mt_tags': [],
    }
    data['da'] = data['z_mean']
    for lp in lps:
        ell = read_data(lp, split)
        for key in data.keys():
            data[key].extend([d[key] for d in ell])
    data = [dict(zip(data.keys(), v)) for v in list(zip(*data.values()))]
    return data

## Define args and load stuff

In [4]:
# arguments
arch = 'xlm-roberta-base'
arch_mtl = 'xlm-r'
setup = 'no_teacher'  # "no_teacher", "static_teacher", "learnable_teacher"
# langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en"]
lp = 'ro-en'
max_len = 256
batch_size = 16
seed = 1
sep_token = "</s>" if 'xlm' in arch else "[SEP]"
dataloader = partial(dataloader, sep_token=sep_token)
num_classes = 1
task_type = "regression"
teacher_dir = 'data/mlqe-xlmr-models/teacher_dir'
explainer_dir = 'data/mlqe-xlmr-models/teacher_expl_dir'

In [5]:
# create dummy inputs for model instantiation
input_ids = jnp.ones((batch_size, max_len), jnp.int32)
dummy_inputs = {
    "input_ids": input_ids,
    "attention_mask": jnp.ones_like(input_ids),
    "token_type_ids": jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
    "position_ids": jnp.ones_like(input_ids),
}
dummy_inputs['input_ids'].shape

INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


(16, 256)

### load data

In [6]:
# load data
train_data = read_data(lp, "train")
valid_data = read_data(lp, "dev")
test_data = read_data(lp, "test")

### load tokenizer

In [7]:
from transformers import XLMRobertaTokenizerFast
tokenizer = XLMRobertaTokenizerFast.from_pretrained(arch)
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
pad_id = tokenizer.pad_token_id

### load model and explainer

In [8]:
teacher, teacher_params, dummy_state = load_model(teacher_dir, batch_size, max_len)
teacher_explainer, teacher_explainer_params = load_explainer(explainer_dir, dummy_inputs, state=dummy_state)

In [9]:
# from meta_expl.utils import PRNGSequence
# from meta_expl.explainers import create_explainer
# keyseq = PRNGSequence(11)
# teacher_explainer_params={
#     'normalize_head_coeffs': 'sparsemax',
#     'normalizer_fn': 'softmax',
#     'aggregator_idx': 'mean',
#     'aggregator_dim': 'row',
#     'init_fn': 'uniform',
#     'layer_idx': 9,  #9, None
#     'head_idx': 5,  #5, None
# }
# explainer_type='attention_explainer'
# teacher_explainer, teacher_explainer_params = create_explainer(next(keyseq), dummy_inputs, dummy_state, 
#                                      explainer_type, explainer_args=teacher_explainer_params)

### look at the coefficients

In [10]:
sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12)

DeviceArray([[0.05825656, 0.05605122, 0.06579615, 0.06052823, 0.05019351,
              0.0858976 , 0.04817941, 0.05621345, 0.01868655, 0.02941416,
              0.04372283, 0.05819663],
             [0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        ],
             [0.        , 0.        , 0.        , 0.        , 0.        ,
              0.        , 0.        , 0.        , 0.        , 0.        ,
              0.0433885 , 0.        ],
             [0.        , 0.        , 0.        , 0.        , 0. 

In [11]:
hc = sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12)
for a, b in zip(*hc.nonzero()):
    print(a+1, b+1)

1 1
1 2
1 3
1 4
1 5
1 6
1 7
1 8
1 9
1 10
1 11
1 12
5 11
8 9
8 10
9 6
9 8
10 6
11 2
11 4


In [12]:
# check the layers with the highest coefficients
layer_coeffs = sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12).mean(-1).tolist()
sorted(list(zip(list(range(1, len(layer_coeffs)+1)), layer_coeffs)), key=lambda k: k[1])

[(2, 0.0),
 (3, 0.0),
 (4, 0.0),
 (6, 0.0),
 (7, 0.0),
 (12, 0.0),
 (11, 0.003190424060449004),
 (5, 0.003615708090364933),
 (10, 0.004049660172313452),
 (9, 0.005110216327011585),
 (8, 0.014772631227970123),
 (1, 0.052594687789678574)]

## Get explanations

In [13]:
def get_explanations(data, strategy='mtl', layer_id=0, head_id=0):
    all_tokens = []
    all_masks = []
    all_explanations = []
    all_outputs = []
    c = 1
    for x, y in dataloader(data, tokenizer, batch_size=batch_size, max_len=max_len, shuffle=False):
        print('{} of {}'.format(c, len(data) // batch_size + 1), end='\r')
        c += 1
        
        # get teacher output
        # y_teacher, teacher_attn = jax.jit(lambda x: teacher.apply(teacher_params, **x, deterministic=True))(x)
        y_teacher, teacher_attn = teacher.apply(teacher_params, **x, deterministic=True)
        if task_type == "classification":
            y_teacher = jnp.argmax(y_teacher, axis=-1)
        
        # use the explanation given to the student by the teacher explainer
        if strategy == 'mtl':
            # get explanation from the teacher explainer
            teacher_expl, _ = teacher_explainer.apply(teacher_explainer_params, x, teacher_attn)
        
        # use the explanation from the best head at the best layer (according to the coefficients)
        elif strategy == 'top_layer_head':
            # batch x layers x heads x seqlen x seqlen
            all_attentions = jnp.stack(teacher_attn['attentions']).transpose([1, 0, 2, 3, 4])
            num_layers = all_attentions.shape[1] 
            num_heads = all_attentions.shape[2]

            # get the attention from the teacher associated with the top head coeff
            head_coeffs = sparsemax(teacher_explainer_params['params']['head_coeffs'])
            top_joint_id = jnp.argmax(head_coeffs).item()
            top_layer_id = top_joint_id // num_heads
            top_head_id = top_joint_id % num_heads
            attn = all_attentions[:, top_layer_id, top_head_id]
            mask = x['attention_mask']
            teacher_expl = (attn * mask[:, :, None]).sum(-2) / mask.sum(-1)[:, None]
        
        # average a specific layer
        elif strategy == 'layer_average':
            all_attentions = jnp.stack(teacher_attn['attentions']).transpose([1, 0, 2, 3, 4])
            attn = all_attentions[:, layer_id].mean(1)
            mask = x['attention_mask']
            teacher_expl = (attn * mask[:, :, None]).sum(-2) / mask.sum(-1)[:, None]
        
        # use a specific head at a specific layer
        elif strategy == 'layer_head':
            all_attentions = jnp.stack(teacher_attn['attentions']).transpose([1, 0, 2, 3, 4])
            attn = all_attentions[:, layer_id, head_id]
            mask = x['attention_mask']
            teacher_expl = (attn * mask[:, :, None]).sum(-2) / mask.sum(-1)[:, None]
        
        # return all nonzero attention explanations (not tested)
        else:
            # get all attentions from the teacher associated with nonzero head coeff
            head_coeffs = head_coeffs.reshape(num_layers, num_heads)
            nonzero_rows, nonzero_cols = head_coeffs.nonzero()
            num_nonzero = len(nonzero_rows)
            attn = jnp.stack([
                all_attentions[:, r, c] for r, c in zip(nonzero_rows.tolist(), nonzero_cols.tolist())
            ]).transpose([1, 0, 2, 3])  # batch, num_nonzero, seqlen, seqlen
            mask = x['attention_mask']
            teacher_expl = (attn * mask[..., None]).sum(-2) / mask.sum(-1)[..., None]

        # convert everything to lists
        batch_ids = x['input_ids'].tolist()
        batch_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in batch_ids]
        batch_masks = [[tk.startswith('▁') for tk in tokens] for tokens in batch_tokens]
        batch_expls = teacher_expl.tolist()
        
        # filter out pad
        batch_valid_len = x['attention_mask'].sum(-1).tolist()
        for i in range(len(batch_valid_len)):
            n = batch_valid_len[i]
            batch_ids[i] = batch_ids[i][:n]
            batch_tokens[i] = batch_tokens[i][:n]
            batch_masks[i] = batch_masks[i][:n]
            batch_expls[i] = batch_expls[i][:n]
        
        all_tokens.extend(batch_tokens)
        all_masks.extend(batch_masks)
        all_explanations.extend(batch_expls)
        all_outputs.extend(y_teacher.tolist())

    return all_tokens, all_masks, all_explanations, all_outputs

In [None]:
valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
    valid_data, strategy='mtl'
)
list(map(len, [valid_tokens, valid_masks, valid_explanations, valid_outputs]))

### Aggregate scores for word pieces in SRC and MT independently

In [14]:
import torch
from utils import aggregate_pieces

def get_src_and_mt_explanations(all_tokens, all_fp_masks, all_explanations, reduction):
    src_expls = []
    mt_expls = []
    src_pieces = []
    mt_pieces = []
    for tokens, expl, fp_mask in zip(all_tokens, all_explanations, all_fp_masks):
        # split data into src and mt (assuming "<s> src </s> mt </s>" format without CLS for mt) 
        src_len = tokens.index(tokenizer.sep_token) + 1
        src_tokens, mt_tokens = tokens[:src_len], tokens[src_len:]
        src_expl, mt_expl = expl[:src_len], expl[src_len:]
        src_fp_mask, mt_fp_mask = fp_mask[:src_len], fp_mask[src_len:]
        
        # aggregate word pieces scores (use my old good torch function)
        agg_src_expl = aggregate_pieces(torch.tensor(src_expl), torch.tensor(src_fp_mask), reduction)
        agg_mt_expl = aggregate_pieces(torch.tensor(mt_expl), torch.tensor(mt_fp_mask), reduction)
        
        # remove <s> and </s> from src
        agg_src_expl = agg_src_expl.tolist()[1:-1]
        # remove </s> from mt
        agg_mt_expl = agg_mt_expl.tolist()[:-1]
        
        src_pieces.append(src_tokens)
        mt_pieces.append(mt_tokens)
        src_expls.append(agg_src_expl)
        mt_expls.append(agg_mt_expl)
    return src_expls, mt_expls, src_pieces, mt_pieces

In [None]:
reduction = 'sum'  # first, sum, mean, max
src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
    valid_tokens, valid_masks, valid_explanations, reduction=reduction
)

## Evaluating explanations by comparing explanations with word-level QE tags

In [None]:
gold_src_tokens = [inp['original'].split() for inp in valid_data]
gold_mt_tokens = [inp['translation'].split() for inp in valid_data]
gold_expls_src = [inp['src_tags'] for inp in valid_data]
gold_expls_mt = [inp['mt_tags'] for inp in valid_data]
gold_scores = [inp['z_mean'] for inp in valid_data]

pred_expls_src = src_expls
pred_expls_mt = mt_expls
pred_scores = unroll(valid_outputs)

In [None]:
_ = evaluate_sentence_level(gold_scores, pred_scores)

In [None]:
_ = evaluate_word_level(gold_expls_src, pred_expls_src)

In [None]:
_ = evaluate_word_level(gold_expls_mt, pred_expls_mt)

## Evaluate all LPs

In [17]:
def filter_diff_seq_len(gold, pred):
    new_pred, new_gold = [], []
    t = 0
    for p, g in zip(pred, gold):
        if len(p) == len(g):
            new_pred.append(p)
            new_gold.append(g)
        else:
            t += 1
    print('filtered:', t)
    return new_gold, new_pred

In [18]:
langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en", "all"]
split = "test"
for lp in langpairs:
    if lp == "all":
        data = read_data_all(langpairs[:-1], split)
    else:
        data = read_data(lp, split)
    valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(data, strategy='mtl')
    print('')
    print(lp)
    print('----------')
    src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
        valid_tokens, valid_masks, valid_explanations, reduction='sum'
    )
    gold_src_tokens = [inp['original'].split() for inp in data]
    gold_mt_tokens = [inp['translation'].split() for inp in data]
    gold_expls_src = [inp['src_tags'] for inp in data]
    gold_expls_mt = [inp['mt_tags'] for inp in data]
    gold_scores = [inp['z_mean'] for inp in data]
    pred_expls_src = src_expls
    pred_expls_mt = mt_expls
    pred_scores = unroll(valid_outputs)
    gold_expls_src, pred_expls_src = filter_diff_seq_len(gold_expls_src, pred_expls_src)
    gold_expls_mt, pred_expls_mt = filter_diff_seq_len(gold_expls_mt, pred_expls_mt)
    evaluate_sentence_level(gold_scores, pred_scores)
    evaluate_word_level(gold_expls_src, pred_expls_src)
    evaluate_word_level(gold_expls_mt, pred_expls_mt)

63 of 63
ru-en
----------
filtered: 1
filtered: 1
Pearson: 0.6226
Spearman: 0.5725
MAE: 0.5733
RMSE: 0.7866
AUC score: 0.5845
AP score: 0.5396
Recall at top-K: 0.3974
AUC score: 0.5769
AP score: 0.5154
Recall at top-K: 0.3845


In [26]:
layer_id = 9
head_id = 5
langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en", "all"]
split = "test"  # train, dev, test
for lp in langpairs:
    if lp == "all":
        data = read_data_all(langpairs[:-1], split)
    else:
        data = read_data(lp, split)
    valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
        data, strategy='layer_head', layer_id=layer_id, head_id=head_id
    )
    src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
        valid_tokens, valid_masks, valid_explanations, reduction='sum'
    )
    print('')
    print('{} | LAYER: {} | HEAD: {}'.format(lp, layer_id, head_id))
    print('---')
    gold_expls_src = [inp['src_tags'] for inp in data]
    gold_expls_mt = [inp['mt_tags'] for inp in data]
    gold_expls_src, src_expls = filter_diff_seq_len(gold_expls_src, src_expls)
    gold_expls_mt, mt_expls = filter_diff_seq_len(gold_expls_mt, mt_expls)
    _ = evaluate_word_level(gold_expls_src, src_expls)
    _ = evaluate_word_level(gold_expls_mt, mt_expls)

375 of 376
all | LAYER: 9 | HEAD: 5
---
filtered: 1
filtered: 1
AUC score: 0.6674
AP score: 0.5972
Recall at top-K: 0.4827
AUC score: 0.6401
AP score: 0.6007
Recall at top-K: 0.4911


In [27]:
layer_id = 9
head_id = 5
langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en", "all"]
split = "dev"  # train, dev, test
for lp in langpairs:
    if lp == "all":
        data = read_data_all(langpairs[:-1], split)
    else:
        data = read_data(lp, split)
    valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
        data, strategy='layer_head', layer_id=layer_id, head_id=head_id
    )
    src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
        valid_tokens, valid_masks, valid_explanations, reduction='sum'
    )
    print('')
    print('{} | LAYER: {} | HEAD: {}'.format(lp, layer_id, head_id))
    print('---')
    gold_expls_src = [inp['src_tags'] for inp in data]
    gold_expls_mt = [inp['mt_tags'] for inp in data]
    gold_expls_src, src_expls = filter_diff_seq_len(gold_expls_src, src_expls)
    gold_expls_mt, mt_expls = filter_diff_seq_len(gold_expls_mt, mt_expls)
    _ = evaluate_word_level(gold_expls_src, src_expls)
    _ = evaluate_word_level(gold_expls_mt, mt_expls)

63 of 63
en-de | LAYER: 9 | HEAD: 5
---
filtered: 0
filtered: 0
AUC score: 0.6711
AP score: 0.4827
Recall at top-K: 0.3530
AUC score: 0.6569
AP score: 0.5044
Recall at top-K: 0.3771
63 of 63
en-zh | LAYER: 9 | HEAD: 5
---
filtered: 0
filtered: 0
AUC score: 0.5652
AP score: 0.4812
Recall at top-K: 0.3638
AUC score: 0.5422
AP score: 0.4911
Recall at top-K: 0.3787
63 of 63
et-en | LAYER: 9 | HEAD: 5
---
filtered: 0
filtered: 0
AUC score: 0.6835
AP score: 0.5592
Recall at top-K: 0.4343
AUC score: 0.6724
AP score: 0.5500
Recall at top-K: 0.4374
63 of 63
ne-en | LAYER: 9 | HEAD: 5
---
filtered: 0
filtered: 0
AUC score: 0.6967
AP score: 0.7661
Recall at top-K: 0.6845
AUC score: 0.6348
AP score: 0.7920
Recall at top-K: 0.7121
63 of 63
ro-en | LAYER: 9 | HEAD: 5
---
filtered: 0
filtered: 0
AUC score: 0.7237
AP score: 0.6055
Recall at top-K: 0.4929
AUC score: 0.7159
AP score: 0.5837
Recall at top-K: 0.4584
63 of 63
ru-en | LAYER: 9 | HEAD: 5
---
filtered: 0
filtered: 1
AUC score: 0.6547
AP score

## Plotting the distribution of predictions and AUC scores

In [None]:
# define options for seaborn
custom_params = {
    'axes.spines.right': False,
    'axes.spines.top': False,
    'grid.color': '.85',
    'grid.linestyle': ':'
}
_ = sns.set_theme(style='whitegrid', rc=custom_params),

def plot_da_vs_expl_metric(metric_fn, das, e_golds, e_preds):
    x = []
    y = []
    for da, gold, pred in zip(das, e_golds, e_preds):
        if sum(gold) == 0 or sum(gold) == len(gold):
            continue
        y.append(metric_fn(gold, pred))
        x.append(da)
    x = np.array(x)
    y = np.array(y)
    fig, axs = plt.subplots(1, 3, figsize=(16, 4))
    sns.histplot(x=x, y=y, ax=axs[0])
    axs[0].set_xlabel('da')
    axs[0].set_ylabel(str(metric_fn).split()[1])
    sns.histplot(x, bins=20, ax=axs[1])
    axs[1].set_xlabel('da')
    sns.histplot(y, bins=20, ax=axs[2])
    axs[2].set_xlabel(str(metric_fn).split()[1])

In [None]:
# plot predicted DA vs AUC for src and mt
plot_da_vs_expl_metric(roc_auc_score, pred_scores, gold_expls_src, pred_expls_src)

In [None]:
plot_da_vs_expl_metric(roc_auc_score, pred_scores, gold_expls_mt, pred_expls_mt)

## Check results for all layers (slooow -> very inefficient)

In [None]:
for layer_id in range(12):
    valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
        valid_data, strategy='layer_average', layer_id=layer_id
    )
    src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
        valid_tokens, valid_masks, valid_explanations, reduction='sum'
    )
    print('LAYER: {}'.format(layer_id))
    _ = evaluate_word_level(gold_expls_src, src_expls)
    _ = evaluate_word_level(gold_expls_mt, mt_expls)
    print('---')

## Check results for all heads in all layers ((very slow)^2)

In [None]:
for layer_id in range(12):
    for head_id in range(12):
        valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
            valid_data, strategy='layer_head', layer_id=layer_id, head_id=head_id
        )
        src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
            valid_tokens, valid_masks, valid_explanations, reduction='sum'
        )
        print('LAYER: {} | HEAD: {}'.format(layer_id, head_id))
        _ = evaluate_word_level(gold_expls_src, src_expls)
        _ = evaluate_word_level(gold_expls_mt, mt_expls)
        print('---')