set cuda id

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=4

Install dependencies for computing metrics and plots:

In [None]:
#!pip3 install numpy scipy pandas seaborn matplotlib sklearn

## Basic imports

In [None]:
import jax
import jax.numpy as jnp
import flax
from entmax_jax import sparsemax
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from IPython.display import display, HTML
from functools import partial
import json
from entmax_jax.activations import sparsemax, entmax15
from sklearn.metrics import roc_auc_score, average_precision_score

from meta_expl.explainers import load_explainer
from meta_expl.models import load_model
from meta_expl.data.mlqe import dataloader

from evaluate_explanations import evaluate_word_level, evaluate_sentence_level, aggregate_pieces

In [None]:
# data utils
def unroll(list_of_lists):
    return [e for ell in list_of_lists for e in ell]

def read_data(lp, split='dev'):
    def tags_to_ints(line):
        return list(map(int, line.strip().replace('OK', '0').replace('BAD', '1').split()))
    data = {
        'original': [line.strip() for line in open('data/mlqepe/{}/{}.src'.format(lp, split), 'r')],
        'translation': [line.strip() for line in open('data/mlqepe/{}/{}.mt'.format(lp, split), 'r')],
        'z_mean': [float(line.strip()) for line in open('data/mlqepe/{}/{}.da'.format(lp, split), 'r')],
        'src_tags': [tags_to_ints(line) for line in open('data/mlqepe/{}/{}.src-tags'.format(lp, split), 'r')],
        'mt_tags': [tags_to_ints(line) for line in open('data/mlqepe/{}/{}.tgt-tags'.format(lp, split), 'r')]
    }
    data['da'] = data['z_mean']
    data = [dict(zip(data.keys(), v)) for v in list(zip(*data.values()))]
    return data

def read_data_all(lps, split='dev'):
    data = {
        'original': [],
        'translation': [],
        'z_mean': [],
        'src_tags': [],
        'mt_tags': [],
    }
    data['da'] = data['z_mean']
    for lp in lps:
        ell = read_data(lp, split)
        for key in data.keys():
            data[key].extend([d[key] for d in ell])
    data = [dict(zip(data.keys(), v)) for v in list(zip(*data.values()))]
    return data

## Define args and load stuff

In [None]:
# arguments
arch = 'xlm-roberta-base'
arch_mtl = 'xlm-r'
setup = 'no_teacher'  # "no_teacher", "static_teacher", "learnable_teacher"
# langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en"]
lp = 'ro-en'
max_len = 256
batch_size = 16
seed = 1
sep_token = "</s>" if 'xlm' in arch else "[SEP]"
dataloader = partial(dataloader, sep_token=sep_token)
num_classes = 1
task_type = "regression"
teacher_dir = 'data/mlqe-xlmr-models/teacher_dir'
explainer_dir = 'data/mlqe-xlmr-models/teacher_expl_dir'

In [None]:
# create dummy inputs for model instantiation
input_ids = jnp.ones((batch_size, max_len), jnp.int32)
dummy_inputs = {
    "input_ids": input_ids,
    "attention_mask": jnp.ones_like(input_ids),
    "token_type_ids": jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
    "position_ids": jnp.ones_like(input_ids),
}
dummy_inputs['input_ids'].shape

### load data

In [None]:
# load data
train_data = read_data(lp, "train")
valid_data = read_data(lp, "dev")
test_data = read_data(lp, "test")

### load tokenizer

In [None]:
from transformers import XLMRobertaTokenizerFast
tokenizer = XLMRobertaTokenizerFast.from_pretrained(arch)
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
pad_id = tokenizer.pad_token_id

### load model and explainer

In [None]:
teacher, teacher_params, dummy_state = load_model(teacher_dir, dummy_inputs, batch_size, max_len)
teacher_explainer, teacher_explainer_params = load_explainer(explainer_dir, dummy_inputs, state=dummy_state)

In [None]:
from meta_expl.utils import PRNGSequence
from meta_expl.explainers import create_explainer
keyseq = PRNGSequence(11)
teacher_explainer_params_non_trained={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': None,
    'head_idx': None
}
teacher_explainer_non_trained, teacher_explainer_params_non_trained = create_explainer(
    key=next(keyseq),
    inputs=dummy_inputs,
    state=dummy_state,
    explainer_type='attention_explainer',
    explainer_args=teacher_explainer_params_non_trained,
)

In [None]:
best_head_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 9,  #9, None
    'head_idx': 5,  #5, None
}
best_head_teacher_explainer, best_head_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_head_teacher_explainer_params
)

In [None]:
best_layer_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 10,  #9, None
    'head_idx': None,  #5, None
}
best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_layer_teacher_explainer_params
)

In [None]:
input_gradient_teacher_explainer, input_gradient_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='gradient_input_explainer', 
    model_extras={
        "grad_fn": teacher.apply(
            teacher_params, dummy_inputs, method=teacher.embeddings_grad_fn
        )
    }
)

In [None]:
int_gradient_teacher_explainer, int_gradient_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='integrated_gradients_explainer', 
    model_extras={
        "grad_fn": teacher.apply(
            teacher_params, dummy_inputs, method=teacher.embeddings_grad_fn
        )
    }
)

### look at the coefficients

In [None]:
sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12)

In [None]:
hc = sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12)
for a, b in zip(*hc.nonzero()):
    print(a+1, b+1)

In [None]:
# check the layers with the highest coefficients
layer_coeffs = sparsemax(teacher_explainer_params['params']['head_coeffs']).reshape(12, 12).mean(-1).tolist()
sorted(list(zip(list(range(1, len(layer_coeffs)+1)), layer_coeffs)), key=lambda k: k[1])

## Get explanations

In [None]:
def get_expls(data, t, t_p, t_e, t_e_p, s=None, s_p=None, s_e=None, s_e_p=None, is_grad_based=False):
    all_tokens = []
    all_masks = []
    all_explanations = []
    all_outputs = []
    for i, (x, y) in enumerate(dataloader(data, tokenizer, batch_size=batch_size, max_len=max_len, shuffle=False)):
        print('{} of {}'.format(i+1, len(data) // batch_size + 1), end='\r')
    
        y_teacher, teacher_attn = t.apply(t_p, **x, deterministic=True)
        y_teacher = jnp.argmax(y_teacher, axis=-1) if task_type == "classification" else y_teacher
        
        if is_grad_based:
            teacher_extras = {
                "grad_fn": t.apply(t_p, x, method=t.embeddings_grad_fn)
            }
            teacher_expl, _ = t_e.apply(t_e_p, x, teacher_attn, **teacher_extras)
        else:
            teacher_expl, _ = t_e.apply(t_e_p, x, teacher_attn)
        # teacher_rep = teacher_attn['hidden_states'][0][0]
        # teacher_attn = np.asarray(jnp.stack(teacher_attn['attentions']).transpose([1, 0, 2, 3, 4]))
        
        if s is not None:
            y_student, student_attn = s.apply(s_p, **x)
            y_student = jnp.argmax(y_student, axis=-1) if task_type == "classification" else y_student
            student_expl, _ = student_explainer.apply(s_e_p, x, student_attn)
            # student_attn = np.asarray(jnp.stack(student_attn['attentions']).transpose([1, 0, 2, 3, 4]))
        
        # convert everything to lists
        batch_ids = x['input_ids'].tolist()
        batch_tokens = [tokenizer.convert_ids_to_tokens(ids) for ids in batch_ids]
        batch_masks = [[tk.startswith('▁') for tk in tokens] for tokens in batch_tokens]
        batch_expls = teacher_expl.tolist()
        
        # filter out pad
        batch_valid_len = x['attention_mask'].sum(-1).tolist()
        for i in range(len(batch_valid_len)):
            n = batch_valid_len[i]
            batch_ids[i] = batch_ids[i][:n]
            batch_tokens[i] = batch_tokens[i][:n]
            batch_masks[i] = batch_masks[i][:n]
            batch_expls[i] = batch_expls[i][:n]
        
        all_tokens.extend(batch_tokens)
        all_masks.extend(batch_masks)
        all_explanations.extend(batch_expls)
        all_outputs.extend(y_teacher.tolist())
        
    return all_tokens, all_masks, all_explanations, all_outputs

In [None]:
valid_tokens, valid_masks, valid_explanations, valid_outputs = get_expls(
    valid_data, teacher, teacher_params, teacher_explainer, teacher_explainer_params 
)
list(map(len, [valid_tokens, valid_masks, valid_explanations, valid_outputs]))

### Aggregate scores for word pieces in SRC and MT independently

In [None]:
import torch
from utils import aggregate_pieces

def get_src_and_mt_explanations(all_tokens, all_fp_masks, all_explanations, reduction):
    src_expls = []
    mt_expls = []
    src_pieces = []
    mt_pieces = []
    for tokens, expl, fp_mask in zip(all_tokens, all_explanations, all_fp_masks):
        # split data into src and mt (assuming "<s> src </s> mt </s>" format without CLS for mt) 
        src_len = tokens.index(tokenizer.sep_token) + 1
        src_tokens, mt_tokens = tokens[:src_len], tokens[src_len:]
        src_expl, mt_expl = expl[:src_len], expl[src_len:]
        src_fp_mask, mt_fp_mask = fp_mask[:src_len], fp_mask[src_len:]
        
        # aggregate word pieces scores (use my old good torch function)
        agg_src_expl = aggregate_pieces(torch.tensor(src_expl), torch.tensor(src_fp_mask), reduction)
        agg_mt_expl = aggregate_pieces(torch.tensor(mt_expl), torch.tensor(mt_fp_mask), reduction)
        
        # remove <s> and </s> from src
        agg_src_expl = agg_src_expl.tolist()[1:-1]
        # remove </s> from mt
        agg_mt_expl = agg_mt_expl.tolist()[:-1]
        
        src_pieces.append(src_tokens)
        mt_pieces.append(mt_tokens)
        src_expls.append(agg_src_expl)
        mt_expls.append(agg_mt_expl)
    return src_expls, mt_expls, src_pieces, mt_pieces

In [None]:
reduction = 'sum'  # first, sum, mean, max
src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
    valid_tokens, valid_masks, valid_explanations, reduction=reduction
)

## Evaluating explanations by comparing explanations with word-level QE tags

In [None]:
gold_src_tokens = [inp['original'].split() for inp in valid_data]
gold_mt_tokens = [inp['translation'].split() for inp in valid_data]
gold_expls_src = [inp['src_tags'] for inp in valid_data]
gold_expls_mt = [inp['mt_tags'] for inp in valid_data]
gold_scores = [inp['z_mean'] for inp in valid_data]

pred_expls_src = src_expls
pred_expls_mt = mt_expls
pred_scores = unroll(valid_outputs)

In [None]:
_ = evaluate_sentence_level(gold_scores, pred_scores)

In [None]:
_ = evaluate_word_level(gold_expls_src, pred_expls_src)

In [None]:
_ = evaluate_word_level(gold_expls_mt, pred_expls_mt)

## Evaluate all LPs

In [None]:
def filter_diff_seq_len(gold, pred):
    new_pred, new_gold = [], []
    t = 0
    for p, g in zip(pred, gold):
        if len(p) == len(g):
            new_pred.append(p)
            new_gold.append(g)
        else:
            t += 1
    print('filtered:', t)
    return new_gold, new_pred

In [None]:
def eval_plausibility_all_lps(t, t_p, t_e, t_e_p, split='dev', is_grad_based=False):
    langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en", "all"]
    for lp in langpairs:
        if lp == "all":
            data = read_data_all(langpairs[:-1], split)
        else:
            data = read_data(lp, split)
        valid_tokens, valid_masks, valid_explanations, valid_outputs = get_expls(
            data, t, t_p, t_e, t_e_p, is_grad_based=is_grad_based
        )
        print('')
        print(lp)
        print('----------')
        src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
            valid_tokens, valid_masks, valid_explanations, reduction='sum'
        )
        gold_src_tokens = [inp['original'].split() for inp in data]
        gold_mt_tokens = [inp['translation'].split() for inp in data]
        gold_expls_src = [inp['src_tags'] for inp in data]
        gold_expls_mt = [inp['mt_tags'] for inp in data]
        gold_scores = [inp['z_mean'] for inp in data]
        pred_expls_src = src_expls
        pred_expls_mt = mt_expls
        pred_scores = unroll(valid_outputs)
        gold_expls_src, pred_expls_src = filter_diff_seq_len(gold_expls_src, pred_expls_src)
        gold_expls_mt, pred_expls_mt = filter_diff_seq_len(gold_expls_mt, pred_expls_mt)
        evaluate_sentence_level(gold_scores, pred_scores)
        evaluate_word_level(gold_expls_src, pred_expls_src)
        evaluate_word_level(gold_expls_mt, pred_expls_mt)

### meta-learned explainer

In [None]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    teacher_explainer, 
    teacher_explainer_params, 
    split='dev'
)

### all attention layers and heads

In [None]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    teacher_explainer_non_trained, 
    teacher_explainer_params_non_trained,
    split='dev'
)

### gradient x input

In [None]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    input_gradient_teacher_explainer, 
    input_gradient_teacher_explainer_params,
    split='dev',
    is_grad_based=True
)

### integrated gradients

In [None]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    int_gradient_teacher_explainer, 
    int_gradient_teacher_explainer_params,
    split='dev',
    is_grad_based=True
)

### best attention layer

In [None]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    best_layer_teacher_explainer, 
    best_layer_teacher_explainer_params,
    split='dev'
)

### best attention head

In [None]:
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    best_head_teacher_explainer, 
    best_head_teacher_explainer_params,
    split='dev'
)

### last layer attention

In [None]:
flax.linen.softmax(teacher_params['params']['scalarmix']['coeffs'])  # first item is the embedding layer

In [None]:
flax.linen.softmax(teacher_params['params']['scalarmix']['coeffs']).argmax()

In [None]:
best_layer_teacher_explainer_params={
    'normalize_head_coeffs': 'sparsemax',
    'normalizer_fn': 'softmax',
    'aggregator_idx': 'mean',
    'aggregator_dim': 'row',
    'init_fn': 'uniform',
    'layer_idx': 11,  #9, None
    'head_idx': None,  #5, None
}
best_layer_teacher_explainer, best_layer_teacher_explainer_params = create_explainer(
    key=next(keyseq), 
    inputs=dummy_inputs, 
    state=dummy_state, 
    explainer_type='attention_explainer', 
    explainer_args=best_layer_teacher_explainer_params
)
eval_plausibility_all_lps(
    teacher, 
    teacher_params, 
    best_layer_teacher_explainer, 
    best_layer_teacher_explainer_params,
    split='dev'
)

## Plotting the distribution of predictions and AUC scores

In [None]:
# define options for seaborn
custom_params = {
    'axes.spines.right': False,
    'axes.spines.top': False,
    'grid.color': '.85',
    'grid.linestyle': ':'
}
_ = sns.set_theme(style='whitegrid', rc=custom_params),

def plot_da_vs_expl_metric(metric_fn, das, e_golds, e_preds):
    x = []
    y = []
    for da, gold, pred in zip(das, e_golds, e_preds):
        if sum(gold) == 0 or sum(gold) == len(gold):
            continue
        y.append(metric_fn(gold, pred))
        x.append(da)
    x = np.array(x)
    y = np.array(y)
    fig, axs = plt.subplots(1, 3, figsize=(16, 4))
    sns.histplot(x=x, y=y, ax=axs[0])
    axs[0].set_xlabel('da')
    axs[0].set_ylabel(str(metric_fn).split()[1])
    sns.histplot(x, bins=20, ax=axs[1])
    axs[1].set_xlabel('da')
    sns.histplot(y, bins=20, ax=axs[2])
    axs[2].set_xlabel(str(metric_fn).split()[1])

In [None]:
# plot predicted DA vs AUC for src and mt
plot_da_vs_expl_metric(roc_auc_score, pred_scores, gold_expls_src, pred_expls_src)

In [None]:
plot_da_vs_expl_metric(roc_auc_score, pred_scores, gold_expls_mt, pred_expls_mt)

## Check results for all layers (slooow -> very inefficient)

In [None]:
for layer_id in range(12):
    valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
        valid_data, strategy='layer_average', layer_id=layer_id
    )
    src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
        valid_tokens, valid_masks, valid_explanations, reduction='sum'
    )
    print('LAYER: {}'.format(layer_id))
    _ = evaluate_word_level(gold_expls_src, src_expls)
    _ = evaluate_word_level(gold_expls_mt, mt_expls)
    print('---')

## Check results for all heads in all layers ((very slow)^2)

In [None]:
for layer_id in range(12):
    for head_id in range(12):
        valid_tokens, valid_masks, valid_explanations, valid_outputs = get_explanations(
            valid_data, strategy='layer_head', layer_id=layer_id, head_id=head_id
        )
        src_expls, mt_expls, src_pieces, mt_pieces = get_src_and_mt_explanations(
            valid_tokens, valid_masks, valid_explanations, reduction='sum'
        )
        print('LAYER: {} | HEAD: {}'.format(layer_id, head_id))
        _ = evaluate_word_level(gold_expls_src, src_expls)
        _ = evaluate_word_level(gold_expls_mt, mt_expls)
        print('---')