In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

In [3]:
from IPython.core.display import HTML

In [4]:
import os
import numpy as np

import torch
import torch.nn.functional as F

from xbert.engine import Engine, weight_of_evidence, difference_of_log_probabilities, calculate_correlation
from xbert import InputInstance, Config
from xbert.visualization import visualize_relevances
from xbert.occlusion.explainer import GradxInputExplainer, IntegrateGradExplainer

In [5]:
#import spacy
from tqdm import tqdm
from collections import defaultdict

from segtok.tokenizer import web_tokenizer, space_tokenizer
from transformers import RobertaTokenizer, RobertaForSequenceClassification  #, glue_convert_examples_to_features

In [6]:
CUDA_DEVICE = 0 # or -1 if no GPU is available

MODEL_NAME = "../models/CoLA/"

In [8]:
tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME)
model = RobertaForSequenceClassification.from_pretrained(MODEL_NAME).to(CUDA_DEVICE)

In [9]:
COLA_DATASET_PATH = "../data/glue_data/CoLA/"

In [10]:
def rindex(alist, value):
    return len(alist) - alist[-1::-1].index(value) - 1


def byte_pair_offsets(input_ids, tokenizer):
    def get_offsets(tokens, start_offset):
        offsets = [start_offset]
        for t_idx, token in enumerate(tokens, start_offset):
            if not token.startswith(" "):
                continue
            offsets.append(t_idx)
        offsets.append(start_offset + len(tokens))
        return offsets
        
    tokens = [tokenizer.convert_tokens_to_string(t)
              for t in tokenizer.convert_ids_to_tokens(input_ids, skip_special_tokens=False)]
    tokens = [token for token in tokens if token != "<pad>"]
    tokens = tokens[1:-1]
    
    offsets = get_offsets(tokens, start_offset=1)
    
    return offsets

In [21]:
from typing import List, Tuple


def read_cola_dataset(path: str) -> List[Tuple[List[str], str]]:
    dataset = []
    with open(path) as fin:
        fin.readline()
        for index, line in enumerate(fin):
            tokens = line.strip().split('\t')
            sent, target = tokens[3], tokens[1]
            dataset.append((sent, target))
            
    return dataset


def dataset_to_input_instances(dataset: List[Tuple[List[str], str]]) -> List[InputInstance]:
    input_instances = []
    for idx, (sent, _) in enumerate(dataset):
        instance = InputInstance(id_=idx, sent=web_tokenizer(sent))
        input_instances.append(instance)
        
    return input_instances


def get_labels(dataset: List[Tuple[List[str], List[str], str]]) -> List[str]:
    return [int(label) for _, label in dataset]

In [22]:
def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values)
    res = values[0].new(len(values), size).fill_(pad_idx)
    
    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        if move_eos_to_beginning:
            assert src[-1] == eos_idx
            dst[0] = eos_idx
            dst[1:] = src[:-1]
        else:
            dst.copy_(src)

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
    return res

In [23]:
def encode_instance(input_instance):
    return tokenizer.encode(text=" ".join(input_instance.sent.tokens),
                            add_special_tokens=True,
                            return_tensors="pt")[0]

In [24]:
#def predict(input_instance, model, tokenizer, cuda_device):
#    input_ids = tokenizer.encode(text=input_instance.sent1.tokens,
#                                 text_pair=input_instance.sent2.tokens,
#                                 add_special_tokens=True,
#                                 return_tensors="pt").to(cuda_device)
#    
#    logits = model(input_ids)[0]
#    return F.softmax(logits, dim=-1)

def predict(input_instances, model, tokenizer, cuda_device):
    if isinstance(input_instances, InputInstance):
        input_instances = [input_instances]
    
    input_ids = [encode_instance(instance) for instance in input_instances]
    attention_mask = [torch.ones_like(t) for t in input_ids]
    
    input_ids = collate_tokens(input_ids, pad_idx=1).to(cuda_device)
    attention_mask = collate_tokens(attention_mask, pad_idx=0).to(cuda_device)
    
    logits = model(input_ids=input_ids, attention_mask=attention_mask)[0]
    return F.softmax(logits, dim=-1)

In [25]:
dataset = read_cola_dataset(os.path.join(COLA_DATASET_PATH, "dev.tsv"))
input_instances = dataset_to_input_instances(dataset)
labels = get_labels(dataset)

In [26]:
batch_size = 100

ncorrect, nsamples = 0, 0
for i in tqdm(range(0, len(input_instances), batch_size), total=len(input_instances) // batch_size):
    batch_instances = input_instances[i: i + batch_size]
    with torch.no_grad():
        probs = predict(batch_instances, model, tokenizer, CUDA_DEVICE)
        #print(probs)
        predictions = probs.argmax(dim=-1).cpu().numpy().tolist()
        #print(predictions)
        for batch_idx, instance in enumerate(batch_instances):
            # the instance id is also the position in the list of labels
            idx = instance.id
            true_label = labels[idx]
            pred_label = predictions[batch_idx]
            ncorrect += int(true_label == pred_label)
            nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))

11it [00:01,  6.95it/s]                       

| Accuracy:  0.8262955854126679





In [29]:
def batcher(batch_instances):
    true_label_indices = []
    probabilities = []
    with torch.no_grad():
        probs = predict(batch_instances, model, tokenizer, CUDA_DEVICE).cpu().numpy().tolist()
        for batch_idx, instance in enumerate(batch_instances):
            # the instance id is also the position in the list of labels
            idx = instance.id
            true_label_idx = labels[idx]
            true_label_indices.append(true_label_idx)
            probabilities.append(probs[batch_idx][true_label_idx])
    
    return probabilities
    
    
def batcher_gradient(batch_instances):
    input_ids = [encode_instance(instance) for instance in batch_instances]
    attention_mask = [torch.ones_like(t) for t in input_ids]
    
    input_ids = collate_tokens(input_ids, pad_idx=1).to(CUDA_DEVICE)
    attention_mask = collate_tokens(attention_mask, pad_idx=0).to(CUDA_DEVICE)
    
    inputs_embeds = model.roberta.embeddings(input_ids=input_ids).detach()

    true_label_idx_list = [labels[instance.id] for instance in batch_instances]
    true_label_idx_tensor = torch.tensor(true_label_idx_list, dtype=torch.long, device=CUDA_DEVICE)
    
    # output_getter extracts the first entry of the return tuple and also applies a softmax to the
    # log probabilities
    explainer = IntegrateGradExplainer(model=model,
                                       input_key="inputs_embeds",
                                       output_getter=lambda x: F.softmax(x[0], dim=-1))
    inputs_embeds.requires_grad = True
    expl = explainer.explain(inp={"inputs_embeds": inputs_embeds, "attention_mask": attention_mask},
                             ind=true_label_idx_tensor)
    
    input_ids_np = input_ids.cpu().numpy()
    expl_np = expl.cpu().numpy()

    relevances = []
    for b_idx in range(input_ids_np.shape[0]):
        offsets = byte_pair_offsets(input_ids_np[b_idx].tolist(), tokenizer)
        
        relevance_dict = defaultdict(float)
        for token_idx, (token_start, token_end) in enumerate(zip(offsets, offsets[1:])):
            relevance = expl_np[b_idx][token_start: token_end].sum()
            relevance_dict[("sent", token_idx)] = relevance
        relevances.append(relevance_dict)

    return relevances
    

config_unk = Config.from_dict({
    "strategy": "unk_replacement",
    "batch_size": 128,
    "unk_token": "<unk>"
})

config_gradient = Config.from_dict({
    "strategy": "gradient",
    "batch_size": 128
})

config_resample = Config.from_dict({
    "strategy": "bert_lm_sampling",
    "cuda_device": 0,
    "bert_model": "bert-base-uncased",
    "batch_size": 256,
    "n_samples": 100,
    "verbose": False
})

unknown_engine = Engine(config_unk, batcher)
resample_engine = Engine(config_resample, batcher)
gradient_engine = Engine(config_gradient, batcher_gradient)

In [39]:
instance_idx = 0
n = 20

unk_candidate_instances, unk_candidate_results = unknown_engine.run(input_instances[instance_idx: instance_idx+n])
res_candidate_instances, res_candidate_results = resample_engine.run(input_instances[instance_idx: instance_idx+n])
grad_candidate_instances, grad_candidate_results = gradient_engine.run(input_instances[instance_idx: instance_idx+n])

100%|██████████| 20/20 [00:00<00:00, 26108.33it/s]
2it [00:00,  6.45it/s]                       
100%|██████████| 20/20 [00:01<00:00, 10.14it/s]
14it [00:03,  4.17it/s]                        
100%|██████████| 20/20 [00:00<00:00, 204600.20it/s]
1it [00:07,  7.75s/it]


In [40]:
unk_relevances = unknown_engine.relevances(unk_candidate_instances, unk_candidate_results)
res_relevances = resample_engine.relevances(res_candidate_instances, res_candidate_results)
grad_relevances = gradient_engine.relevances(grad_candidate_instances, grad_candidate_results)

In [41]:
labels_true = labels[instance_idx: instance_idx+n]
labels_pred = [predict(instance, model, tokenizer, CUDA_DEVICE)[0].argmax().item() for instance in input_instances[instance_idx: instance_idx+n]]

In [42]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], unk_relevances, labels_true, labels_pred))

In [43]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], res_relevances, labels_true, labels_pred))

In [44]:
HTML(visualize_relevances(input_instances[instance_idx: instance_idx+n], grad_relevances, labels_true, labels_pred))

In [45]:
print(calculate_correlation(unk_relevances, res_relevances))
print(calculate_correlation(unk_relevances, grad_relevances))
print(calculate_correlation(res_relevances, grad_relevances))

0.1441721835195447
0.06491314281530187
0.10737025456902813
