In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

In [3]:
from IPython.core.display import HTML

In [4]:
import os
import numpy as np

import torch
import torch.nn.functional as F

from xbert.engine import Engine, weight_of_evidence, difference_of_log_probabilities, calculate_correlation
from xbert import InputInstance, Config
from xbert.visualization import visualize_relevances
from xbert.occlusion.explainer import GradxInputExplainer

In [5]:
#import spacy
from tqdm import tqdm
from collections import defaultdict

from segtok.tokenizer import web_tokenizer, space_tokenizer
from transformers import RobertaTokenizer, RobertaForSequenceClassification  #, glue_convert_examples_to_features

In [6]:
CUDA_DEVICE = 0 # or -1 if no GPU is available

MODEL_NAME = "roberta-large-mnli"

In [7]:
tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME)
model = RobertaForSequenceClassification.from_pretrained(MODEL_NAME).to(CUDA_DEVICE)

In [8]:
MNLI_DATASET_PATH = "../data/glue_data/MNLI/"
MNLI_IDX2LABEL = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
MNLI_LABEL2IDX = {v: k for k, v in MNLI_IDX2LABEL.items()}

In [9]:
def rindex(alist, value):
    return len(alist) - alist[-1::-1].index(value) - 1


def byte_pair_offsets(input_ids, tokenizer):
    def get_offsets(tokens, start_offset):
        offsets = [start_offset]
        for t_idx, token in enumerate(tokens, start_offset):
            if not token.startswith(" "):
                continue
            offsets.append(t_idx)
        offsets.append(start_offset + len(tokens))
        return offsets
        
    tokens = [tokenizer.convert_tokens_to_string(t)
              for t in tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=False)]
    tokens = [token for token in tokens if token != "<pad>"]
    tokens = tokens[1:-1]
    
    sent_1_end = tokens.index("</s>")
    sent_2_start = rindex(tokens, "</s>") + 1
    
    sent_1_offsets = get_offsets(tokens[:sent_1_end], start_offset=1)
    sent_2_offsets = get_offsets(tokens[sent_2_start:], start_offset=sent_2_start+1)
    
    return sent_1_offsets, sent_2_offsets

In [10]:
from typing import List, Tuple


def read_mnli_dataset(path: str) -> List[Tuple[List[str], List[str], str]]:
    dataset = []
    with open(path) as fin:
        fin.readline()
        for index, line in enumerate(fin):
            tokens = line.strip().split('\t')
            sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
            dataset.append((sent1, sent2, target))
            
    return dataset


def dataset_to_input_instances(dataset: List[Tuple[List[str], List[str], str]]) -> List[InputInstance]:
    input_instances = []
    for idx, (sent1, sent2, _) in enumerate(dataset):
        instance = InputInstance(id_=idx, sent1=web_tokenizer(sent1), sent2=web_tokenizer(sent2))
        input_instances.append(instance)
        
    return input_instances


def get_labels(dataset: List[Tuple[List[str], List[str], str]]) -> List[str]:
    return [label for _, _, label in dataset]

In [11]:
def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values)
    res = values[0].new(len(values), size).fill_(pad_idx)
    
    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        if move_eos_to_beginning:
            assert src[-1] == eos_idx
            dst[0] = eos_idx
            dst[1:] = src[:-1]
        else:
            dst.copy_(src)

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
    return res

In [12]:
def encode_instance(input_instance):
    return tokenizer.encode(text=" ".join(input_instance.sent1.tokens),
                            text_pair=" ".join(input_instance.sent2.tokens),
                            add_special_tokens=True,
                            return_tensors="pt")[0]

In [13]:
#def predict(input_instance, model, tokenizer, cuda_device):
#    input_ids = tokenizer.encode(text=input_instance.sent1.tokens,
#                                 text_pair=input_instance.sent2.tokens,
#                                 add_special_tokens=True,
#                                 return_tensors="pt").to(cuda_device)
#    
#    logits = model(input_ids)[0]
#    return F.softmax(logits, dim=-1)

def predict(input_instances, model, tokenizer, cuda_device):
    if isinstance(input_instances, InputInstance):
        input_instances = [input_instances]
    
    input_ids = [encode_instance(instance) for instance in input_instances]
    attention_mask = [torch.ones_like(t) for t in input_ids]
    
    input_ids = collate_tokens(input_ids, pad_idx=1).to(cuda_device)
    attention_mask = collate_tokens(attention_mask, pad_idx=0).to(cuda_device)
    
    logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
    return F.softmax(logits, dim=-1)

In [14]:
dataset = read_mnli_dataset(os.path.join(MNLI_DATASET_PATH, "dev_matched.tsv"))
input_instances = dataset_to_input_instances(dataset)
labels = get_labels(dataset)

In [15]:
batch_size = 100

ncorrect, nsamples = 0, 0
for i in tqdm(range(0, len(input_instances[:1000]), batch_size), total=len(input_instances) // batch_size):
    batch_instances = input_instances[i: i + batch_size]
    with torch.no_grad():
        probs = predict(batch_instances, model, tokenizer, CUDA_DEVICE)
        #print(probs)
        predictions = probs.argmax(dim=-1).cpu().numpy().tolist()
        #print(predictions)
        for batch_idx, instance in enumerate(batch_instances):
            # the instance id is also the position in the list of labels
            idx = instance.id
            true_label = labels[idx]
            pred_label = MNLI_IDX2LABEL[predictions[batch_idx]]
            ncorrect += int(true_label == pred_label)
            nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))

 10%|█         | 10/98 [00:18<02:43,  1.86s/it]

| Accuracy:  0.909





In [16]:
def batcher(batch_instances):
    true_label_indices = []
    probabilities = []
    with torch.no_grad():
        probs = predict(batch_instances, model, tokenizer, CUDA_DEVICE).cpu().numpy().tolist()
        for batch_idx, instance in enumerate(batch_instances):
            # the instance id is also the position in the list of labels
            idx = instance.id
            true_label_idx = MNLI_LABEL2IDX[labels[idx]]
            true_label_indices.append(true_label_idx)
            probabilities.append(probs[batch_idx][true_label_idx])
    
    return probabilities
    
    
def batcher_gradient(batch_instances):
    input_ids = collate_tokens(
        [encode_instance(instance) for instance in batch_instances], pad_idx=1
    ).to(CUDA_DEVICE)

    device = input_ids.device
    input_shape = input_ids.size()

    token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    inp = model.roberta.embeddings(input_ids=input_ids).detach()
    
    explainer = GradxInputExplainer(model, output_getter=lambda x: x[0])
    inp.requires_grad = True
    expl = explainer.explain(inp)
    
    input_ids_np = input_ids.cpu().numpy()
    expl_np = expl.cpu().numpy()

    relevances = []
    for b_idx in range(input_ids_np.shape[0]):
        sent1_offsets, sent2_offsets = byte_pair_offsets(input_ids_np[b_idx].tolist(), tokenizer)
        
        relevance_dict = defaultdict(float)
        for offsets, sent_id in zip([sent1_offsets, sent2_offsets], ["sent1", "sent2"]):
            for token_idx, (token_start, token_end) in enumerate(zip(offsets, offsets[1:])):
                relevance = expl_np[b_idx][token_start: token_end].sum()
                relevance_dict[(sent_id, token_idx)] = relevance
        relevances.append(relevance_dict)

    return relevances
    

config_unk = Config.from_dict({
    "strategy": "unk_replacement",
    "batch_size": 128,
    "unk_token": "<unk>"
})

config_gradient = Config.from_dict({
    "strategy": "gradient",
    "batch_size": 128
})

config_resample = Config.from_dict({
    "strategy": "bert_lm_sampling",
    "cuda_device": 0,
    "bert_model": "bert-base-uncased",
    "batch_size": 256,
    "n_samples": 100,
    "verbose": False
})

unknown_engine = Engine(config_unk, batcher)
resample_engine = Engine(config_resample, batcher)
gradient_engine = Engine(config_gradient, batcher_gradient)

In [17]:
instance_idx = 0
n = 2

unk_candidate_instances, unk_candidate_results = unknown_engine.run(input_instances[instance_idx: instance_idx+n])
res_candidate_instances, res_candidate_results = resample_engine.run(input_instances[instance_idx: instance_idx+n])
grad_candidate_instances, grad_candidate_results = gradient_engine.run(input_instances[instance_idx: instance_idx+n])

In [18]:
unk_relevances = unknown_engine.relevances(unk_candidate_instances, unk_candidate_results)
res_relevances = resample_engine.relevances(res_candidate_instances, res_candidate_results)
grad_relevances = gradient_engine.relevances(grad_candidate_instances, grad_candidate_results)

In [19]:
labels_true = labels[instance_idx: instance_idx+n]
labels_pred = [MNLI_IDX2LABEL[predict(instance, model, tokenizer, CUDA_DEVICE)[0].argmax().item()] for instance in input_instances[instance_idx: instance_idx+n]]

In [20]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], unk_relevances, labels_true, labels_pred))

In [21]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], res_relevances, labels_true, labels_pred))

In [22]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], grad_relevances, labels_true, labels_pred))

In [24]:
print(calculate_correlation(unk_relevances, res_relevances))
print(calculate_correlation(unk_relevances, grad_relevances))
print(calculate_correlation(res_relevances, grad_relevances))

0.6490628562805311
0.366195462088272
0.0921416301964605
