set cuda id

In [None]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=5

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from entmax_jax.activations import sparsemax, entmax15

from meta_expl.explainers import load_explainer
from meta_expl.models import load_model
from meta_expl.data.mlqe import dataloader
from meta_expl.utils import PRNGSequence, mse_loss, pearson

In [None]:
# data utils
def unroll(list_of_lists):
    return [e for ell in list_of_lists for e in ell]

def read_data(lp, split='dev'):
    def tags_to_ints(line):
        return list(map(int, line.strip().replace('OK', '0').replace('BAD', '1').split()))
    data = {
        'original': [line.strip() for line in open('data/mlqepe/{}/{}.src'.format(lp, split), 'r')],
        'translation': [line.strip() for line in open('data/mlqepe/{}/{}.mt'.format(lp, split), 'r')],
        'z_mean': [float(line.strip()) for line in open('data/mlqepe/{}/{}.da'.format(lp, split), 'r')],
        'src_tags': [tags_to_ints(line) for line in open('data/mlqepe/{}/{}.src-tags'.format(lp, split), 'r')],
        'mt_tags': [tags_to_ints(line) for line in open('data/mlqepe/{}/{}.tgt-tags'.format(lp, split), 'r')]
    }
    data = [dict(zip(data.keys(), v)) for v in list(zip(*data.values()))]
    return data

def read_data_all(lps, split='dev'):
    data = {
        'original': [],
        'translation': [],
        'z_mean': [],
        'src_tags': [],
        'mt_tags': [],
    }
    for lp in lps:
        ell = read_data(lp, split)
        for key in data.keys():
            data[key].extend([d[key] for d in ell])
    data = [dict(zip(data.keys(), v)) for v in list(zip(*data.values()))]
    return data

## Define args and load stuff

In [None]:
# arguments
arch = 'xlm-roberta-base'
setup = 'static_teacher'  # "no_teacher", "static_teacher", "learnable_teacher"

seed = 9
max_len = 256
batch_size = 16

sep_token = "</s>" if 'xlm' in arch else "[SEP]"
num_classes = 1
task_type = "regression"
criterion = mse_loss
dataloader = partial(dataloader, sep_token=sep_token)

teacher_dir = 'data/mlqe-xlmr-models/teacher_dir'
teacher_expl_dir = 'data/mlqe-xlmr-models/teacher_expl_dir'
student_dir = 'data/mlqe-xlmr-models/student_dir'
student_expl_dir = 'data/mlqe-xlmr-models/student_expl_dir'

In [None]:
# create dummy inputs for model instantiation
input_ids = jnp.ones((batch_size, max_len), jnp.int32)
dummy_inputs = {
    "input_ids": input_ids,
    "attention_mask": jnp.ones_like(input_ids),
    "token_type_ids": jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
    "position_ids": jnp.ones_like(input_ids),
}
dummy_inputs['input_ids'].shape

### load tokenizer

In [None]:
from transformers import XLMRobertaTokenizerFast
tokenizer = XLMRobertaTokenizerFast.from_pretrained(arch)
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
pad_id = tokenizer.pad_token_id

### load models and explainers

In [None]:
teacher, teacher_params, dummy_state = load_model(teacher_dir, batch_size, max_len)
teacher_expl, teacher_expl_params = load_explainer(teacher_expl_dir, dummy_inputs, state=dummy_state)
student, student_params, dummy_state = load_model(student_dir, batch_size, max_len)
student_expl, student_expl_params = load_explainer(student_expl_dir, dummy_inputs, state=dummy_state)

### create a fixed teacher explainer using a specific layer & head

In [None]:
# from meta_expl.explainers import create_explainer
# keyseq = PRNGSequence(11)
# teacher_explainer_params={
#     'normalize_head_coeffs': 'sparsemax',
#     'normalizer_fn': 'softmax',
#     'aggregator_idx': 'mean',
#     'aggregator_dim': 'row',
#     'init_fn': 'uniform',
#     'layer_idx': 9,  #9, None
#     'head_idx': 5,  #5, None
# }
# explainer_type='attention_explainer'
# teacher_explainer, teacher_explainer_params = create_explainer(next(keyseq), dummy_inputs, dummy_state, 
#                                      explainer_type, explainer_args=teacher_explainer_params)

### look at the coefficients

In [None]:
sparsemax(teacher_expl_params['params']['head_coeffs']).reshape(12, 12)

In [None]:
hc = sparsemax(teacher_expl_params['params']['head_coeffs']).reshape(12, 12)
for a, b in zip(*hc.nonzero()):
    print(a, b)

In [None]:
# check the layers with the highest coefficients
layer_coeffs = hc.mean(-1).tolist()
sorted(list(zip(list(range(1, len(layer_coeffs)+1)), layer_coeffs)), key=lambda k: k[1])

## Evaluate simulability and student performance

In [None]:
def evaluate(data, return_outputs=False):
    teacher_predict = None
    all_outputs, all_y_sim, all_y = [], [], []
    for i, (x, y) in enumerate(dataloader(data, tokenizer, batch_size=batch_size, max_len=max_len, shuffle=False)):
        print('{} of {}'.format(i, len(data)//batch_size), end='\r')
        y_sim = teacher.apply(teacher_params, **x)[0]
        outputs = student.apply(student_params, **x)[0]
        all_outputs.append(outputs)
        all_y_sim.append(y_sim)
        all_y.append(y)
    all_outputs = jnp.concatenate(all_outputs, axis=0)
    all_y_sim = jnp.concatenate(all_y_sim, axis=0)
    all_y = jnp.concatenate(all_y, axis=0)
    student_score = pearson(all_outputs, all_y)
    teacher_score = pearson(all_y_sim, all_y)
    sim_score = pearson(all_outputs, all_y_sim)
    if return_outputs:
        return final_score, sim_score, (all_outputs, all_y_sim, all_y)
    return student_score, teacher_score, sim_score

### Evaluate for each LP

In [None]:
# load data
langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en"]
split = 'dev'
for lp in langpairs:
    print(lp)
    student_score, teacher_score, sim_score = evaluate(read_data(lp, split))
    print('------------')
    print('Pearson (teacher): {:.4f}'.format(teacher_score))
    print('Pearson (student): {:.4f}'.format(student_score))
    print('Pearson (simulability): {:.4f}'.format(sim_score))
    print('')

In [None]:
print("Overall")
student_score, teacher_score, sim_score = evaluate(read_data_all(langpairs, split))
print('------------')
print('Pearson (teacher): {:.4f}'.format(teacher_score))
print('Pearson (student): {:.4f}'.format(student_score))
print('Pearson (simulability): {:.4f}'.format(sim_score))
print('')

In [None]:
# load data
langpairs = ["en-de", "en-zh", "et-en", "ne-en", "ro-en", "ru-en"]
split = 'test'
for lp in langpairs:
    print(lp)
    student_score, teacher_score, sim_score = evaluate(read_data(lp, split))
    print('------------')
    print('Pearson (teacher): {:.4f}'.format(teacher_score))
    print('Pearson (student): {:.4f}'.format(student_score))
    print('Pearson (simulability): {:.4f}'.format(sim_score))
    print('')

In [None]:
print("Overall")
student_score, teacher_score, sim_score = evaluate(read_data_all(langpairs, split))
print('------------')
print('Pearson (teacher): {:.4f}'.format(teacher_score))
print('Pearson (student): {:.4f}'.format(student_score))
print('Pearson (simulability): {:.4f}'.format(sim_score))
print('')