In [None]:
import torch
import datasets
import pandas as pd
import numpy as np

from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_fscore_support
from scipy.stats import pearsonr, spearmanr
from transformers import AutoTokenizer

In [None]:
t5_tokenizer = AutoTokenizer.from_pretrained('t5-small')

In [None]:
total_ok = 0
total_er = 0

def find_index_sublist(v, u_pad):
    global total_ok, total_er
    m = u_pad.index(t5_tokenizer.eos_token_id)  
    u = u_pad[1:m]
    n = len(u)
    for i in range(len(v)-n):
        if v[i:i+n] == u:
            total_ok += 1
            return i, i+n
    total_er += 1
    return None, None


def convert_evidences_to_mask(x, e):
    start_end_idxs = [find_index_sublist(x, e_) for e_ in e]
    mask = [0] * len(x)
    for a, b in start_end_idxs:
        if a is not None and b is not None:
            for j in range(a, b):
                mask[j] = 1
    return mask

def get_gold_explanations_from_movies_dataset(ds_movies):
    valid_idxs = []
    inputs_ids = []
    gold_explanations = []
    for i, sample in enumerate(ds_movies):
        if len(sample['evidences']) == 0:
            continue
        
        x = t5_tokenizer(sample['review'], truncation=True, max_length=512)
        e = t5_tokenizer(sample['evidences'], truncation=True, max_length=512)
        
        # z = convert_evidences_to_mask(sample['review'], sample['evidences'])
        z = convert_evidences_to_mask(x['input_ids'], e['input_ids'])
        
        valid_idxs.append(i)
        inputs_ids.append(x['input_ids'])
        gold_explanations.append(z)
        
    return valid_idxs, inputs_ids, gold_explanations

In [None]:
def validate_word_level_data(gold_explanations, model_explanations):
    """
    Ignore word level samples with full OK or full BAD tags
    """
    valid_gold, valid_model = [], []
    for gold_expl, model_expl in zip(gold_explanations, model_explanations):
        if sum(gold_expl) == 0 or sum(gold_expl) == len(gold_expl):
            continue
        else:
            valid_gold.append(gold_expl)
            valid_model.append(model_expl)
    return valid_gold, valid_model


def compute_auc_score(gold_explanations, model_explanations):
    res = 0
    for i in range(len(gold_explanations)):
        res += roc_auc_score(gold_explanations[i], model_explanations[i])
    return res / len(gold_explanations)


def compute_ap_score(gold_explanations, model_explanations):
    res = 0
    for i in range(len(gold_explanations)):
        res += average_precision_score(gold_explanations[i], model_explanations[i])
    return res / len(gold_explanations)


def compute_rec_topk(gold_explanations, model_explanations):
    res = 0
    for i in range(len(gold_explanations)):
        idxs = np.argsort(model_explanations[i])[::-1][:int(sum(gold_explanations[i]))]
        res += len([idx for idx in idxs if gold_explanations[i][idx] == 1])/sum(gold_explanations[i])
    return res / len(gold_explanations)


def compute_precision_recall_fscore_support(gold_explanations, model_explanations, threshold=0.0):
    res_p, res_r, res_f = 0, 0, 0
    for i in range(len(gold_explanations)):
        expl = 1 * (np.array(model_explanations[i]) > threshold)
        p, r, f, _ = precision_recall_fscore_support(gold_explanations[i], expl, average='binary', zero_division=1)
        res_p += p
        res_r += r
        res_f += f
    return res_p / len(gold_explanations), res_r / len(gold_explanations), res_f / len(gold_explanations)
    

def evaluate_word_level(gold_explanations, model_explanations, verbose=True, threshold=0.0):
    gold_explanations, model_explanations = validate_word_level_data(gold_explanations, model_explanations)
    prec, rec, f1 = compute_precision_recall_fscore_support(gold_explanations, model_explanations, threshold=threshold)
    auc_score = compute_auc_score(gold_explanations, model_explanations)
    ap_score = compute_ap_score(gold_explanations, model_explanations)
    rec_topk = compute_rec_topk(gold_explanations, model_explanations)
    if verbose:
        print('Prec: {:.4f}'.format(prec))
        print('Rec: {:.4f}'.format(rec))
        print('F1: {:.4f}'.format(f1))
        print('AUC score: {:.4f}'.format(auc_score))
        print('AP score: {:.4f}'.format(ap_score))
        print('Recall at top-K: {:.4f}'.format(rec_topk))
    return prec, rec, f1, auc_score, ap_score, rec_topk


def aggregate_pieces(x, mask, reduction='first'):
    """
    :param x: tensor of shape (seq_len) or (seq_len, hdim)
    :param mask: bool tensor of shape (seq_len)
    :param reduction: aggregation strategy (first, max, sum, mean)
    :returns: <s> word_1 word_2 ... </s>
    where word_i = aggregate(piece_i_1, piece_i_2, ...)
    """
    # mark <s> and </s> as True
    special_mask = mask.clone()
    special_mask[0] = special_mask[-1] = True
    if reduction == 'first':
        return x[special_mask.bool()]
    elif reduction == 'sum' or reduction == 'mean' or reduction == 'max':
        idx = special_mask.long().cumsum(dim=-1) - 1
        idx_unique_count = torch.bincount(idx)
        num_unique = idx_unique_count.shape[-1]
        if reduction == 'sum':
            res = torch.zeros(num_unique, device=x.device, dtype=x.dtype).scatter_add(0, idx, x)
        elif reduction == 'mean':
            res = torch.zeros(num_unique, device=x.device, dtype=x.dtype).scatter_add(0, idx, x)
            res /= idx_unique_count.float()
        else:
            res = torch.stack([x[idx == i].max() for i in range(num_unique)]).to(x.device)
        return res.float()
    else:
        return x


def get_piece_explanations(all_explanations, all_fp_masks, reduction):
    all_pieces = []
    for expl, fp_mask in zip(all_explanations, all_fp_masks):
        # aggregate word pieces scores (use my old good torch function)
        agg_expl = aggregate_pieces(torch.tensor(expl), torch.tensor(fp_mask), reduction)
        # remove <s> and </s>
        all_pieces.append(agg_expl.tolist()[1:-1])
    return all_pieces

In [None]:
ds = datasets.load_dataset(
    'movie_rationales',
    download_mode=datasets.DownloadMode.REUSE_CACHE_IF_EXISTS,
)
_, x, z_gold = get_gold_explanations_from_movies_dataset(ds['test'])

In [None]:
df = pd.read_csv('../data/rationales/revised_imdb_test_sparsemap_30p.tsv', delimiter='\t')

z = df['z'].map(eval)
fp_masks = [[not tk.startswith('▁') for tk in t5_tokenizer.convert_ids_to_tokens(ids)] for ids in x]

fp_masks_l = [t[:len(z[i])] for i,t in enumerate(fp_masks)]
z_gold_l = [t[:len(z[i])] for i,t in enumerate(z_gold)]

pred_expls = get_piece_explanations(z, fp_masks_l, reduction='mean')
gold_expls = get_piece_explanations(z_gold_l, fp_masks_l, reduction='max')
_ = evaluate_word_level(gold_expls, pred_expls, verbose=True, threshold=0.0)