In [27]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers.modeling_utils import (WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
                             SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
from transformers import XLNetTokenizer, XLNetForSequenceClassification, XLNetPreTrainedModel, XLNetModel
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score


import pandas as pd
import numpy as np
import random
from IPython.display import clear_output
import re
from utils import *
from tqdm.notebook import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [38]:
first = model(input_ids, token_type_ids=token_type_ids)[-1]

In [35]:
model.eval()
with torch.no_grad():
    attention = model(input_ids, token_type_ids=token_type_ids)[-1]

In [39]:
attn = format_attention(attention, tokens)  
tokens = format_special_chars(tokens)
sentence_b_start = token_type_ids[0].tolist().index(1)
slice_a = slice(0, sentence_b_start)
slice_b = slice(sentence_b_start, len(tokens))
attn_data = attn[:, :, slice_a, slice_b]
sentence_a_tokens = tokens[slice_a]
sentence_b_tokens = tokens[slice_b]
pair = pair_match(sentence_a_tokens, sentence_b_tokens, attn_data=attn_data)
pair = sorted(pair, key=lambda pair: pair[2], reverse=True)
pair = pair_without_score(pair)
pair

[('season', 'season'),
 ('opening', 'opening'),
 ('moto', 'moto'),
 ('p', 'p'),
 ('g', 'g'),
 ('ino', 'ino'),
 ('tar', 'tar'),
 ('ossi', 'ossi'),
 ('r', 'r'),
 ('qa', 'qa'),
 ('valent', 'valent'),
 ('tar', 'tar'),
 ('qa', 'qa'),
 ('season', 'opening'),
 ('win', 'won'),
 ('p', 'p'),
 ('the', 'the'),
 ('g', 'moto'),
 ('p', '.'),
 ('.', '.'),
 ('opening', 'season'),
 ('defending', 'won'),
 ('a', 'won'),
 ('for', 'won'),
 ('p', '.'),
 ('af', '.'),
 ('defending', 'the'),
 ('bidding', 'won'),
 ('g', 'p'),
 ('.', 'won'),
 ('p', 'moto'),
 ('tar', 'qa'),
 ('champion', 'won'),
 ('defending', '.'),
 ('qa', 'tar'),
 ('p', 'g'),
 ('qa', 'opening'),
 ('world', '.'),
 ('up', 'opening'),
 ('torrential', '.'),
 ('tar', 'moto'),
 ("'", '.'),
 ('season', 'the'),
 ('case', '.'),
 ('.', '.'),
 ('champion', 'won'),
 ('world', 'the'),
 ('position', '.'),
 ('sun', '.'),
 ('champion', '.'),
 ('the', 'won'),
 ('riders', '.'),
 ('the', '.'),
 ('opening', 'the'),
 ('and', 'ossi'),
 ('it', '.'),
 ('world', 'won'),

In [5]:
class XLNetForMultiSequenceClassification(XLNetPreTrainedModel):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = 3
        self.num_labels_3way = 3
        self.num_labels_multi = 5
        
        self.transformer = XLNetModel(config)
        self.sequence_summary = SequenceSummary(config)
        self.logits_proj_3way = nn.Linear(config.d_model, self.num_labels_3way)
        self.logits_proj_multi = nn.Linear(config.d_model, self.num_labels_multi)
        
        self.weights_3way = [1, 1.5, 3]
        self.weights_multi = [4, 2, 4, 2, 2]
        self.class_weights_3way = torch.FloatTensor(self.weights_3way).to(device)
        self.class_weights_multi = torch.FloatTensor(self.weights_multi).to(device)
        
        self.init_weights()
        

    def forward(self, input_ids, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
                token_type_ids=None, input_mask=None, head_mask=None, labels=None, inputs_embeds=None):
        transformer_outputs = self.transformer(input_ids,
                                               attention_mask=attention_mask,
                                               mems=mems,
                                               perm_mask=perm_mask,
                                               target_mapping=target_mapping,
                                               token_type_ids=token_type_ids,
                                               input_mask=input_mask, 
                                               head_mask=head_mask,
                                               inputs_embeds=inputs_embeds)

        output = transformer_outputs[0]
        output = self.sequence_summary(output)
        
        if labels is None:
            logits = self.logits_proj_3way(output)
            outputs = (logits,) + transformer_outputs[1:]

        if labels is not None:
            task_check = 0
        
            if labels.size() == torch.Size([1]):
                logits_3way = self.logits_proj_3way(output)
                outputs = (logits_3way,) + transformer_outputs[1:]
                task_check = 1
            else:
                logits_multi = self.logits_proj_multi(output)
                outputs = (logits_multi,) + transformer_outputs[1:]

            if task_check:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits_3way.view(-1, self.num_labels_3way), labels.view(-1)).to(device)
            else:
                loss_fct = BCEWithLogitsLoss(pos_weight=self.class_weights_multi)
                loss = loss_fct(logits_multi.view(-1, self.num_labels_multi), labels).to(device)
            outputs = (loss,) + outputs
            
        return outputs

In [192]:
model2 = torch.load('test.pkl',map_location=torch.device('cpu'))

In [200]:
model_single = torch.load('acc_0.5_complete.pkl',map_location=torch.device('cpu'))

In [31]:
import torch
import random
from torch.utils.data import Dataset
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def format_special_chars(tokens):
    return [t.replace('Ġ', ' ').replace('▁', ' ').replace('</w>', '').replace(' ', '') for t in tokens]

def format_attention(attention, tokens):
    """ Set special token <sep>, <cls> attention to zero and format the attention """
    # set special token's attention to zero
    for i, t in enumerate(tokens):
        if t in ("<sep>", "<cls>"):
            for layer_attn in attention:
                layer_attn[0, :, i, :] = 0
                layer_attn[0, :, :, i] = 0
    squeezed = []
    for layer_attention in attention:
        # 1 x num_heads x seq_len x seq_len
        if len(layer_attention.shape) != 4:
            raise ValueError("Wrong attention length, attention length must be 4")
        squeezed.append(layer_attention.squeeze(0))
    # num_layers x num_heads x seq_len x seq_len
    return torch.stack(squeezed)

def look_score(attn_data, index_a, index_b):
    """ Look pair attention score in layers, head """
    score = 0.
    for layer in attn_data:
        for head in layer:
            score_individual = head[index_a][index_b].tolist()
            score += score_individual
    return round(score, 3)

def pair_match(sentence_a_tokens, sentence_b_tokens, attn_data=None):
    """ Matching each token in sentence_a and sentence_b and making pairs """
    pairs = []
    for index_a in range(len(sentence_a_tokens)):
        for index_b in range(len(sentence_b_tokens)):
            if attn_data is not None:
                score = look_score(attn_data, index_a, index_b)
                pair = (sentence_a_tokens[index_a], sentence_b_tokens[index_b], score)
                # filter the special token
                if score != 0:
                    pairs.append(pair)
            else:
                # for evaluation pairs
                pair = (sentence_a_tokens[index_a], sentence_b_tokens[index_b])
                pairs.append(pair)
    return pairs

def pair_without_score(pair):
    """ Return pairs without score """
    pairs = []
    for token_a, token_b, score in pair:
        if token_a != '' and token_b != '':
            pair = (token_a, token_b)
            pairs.append(pair)
    return pairs

def MRR_calculate(pair_truth, pair_all):
    final_score = 0.
    for query in pair_truth:
        for response in range(len(pair_all)):
            if pair_all[response] == query:
                score = 1/(response+1)
                final_score += score
    final_score = final_score/len(pair_truth)
    return final_score

def MRR_mean(pair_truth, pair_all, top_k, times):
    """ Choose k tokens from tokens list for calculating MRR"""
    filtered = random.choices(pair_truth, k=top_k)
    final = 0.
    for i in range(times):
        score = MRR_calculate(filtered, pair_all)
        final += score
    final = final/times
    return final

def explainability_compare(model, tokenizer, sentence_a, sentence_b, test_sentence_a):
    """ Evaluating MRR between model and attention span"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids'].to(device)
    input_ids.squeeze()
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist())
    token_type_ids = inputs['token_type_ids'].to(device)
    
    model.eval()
    with torch.no_grad():
        attention = model(input_ids, token_type_ids=token_type_ids)[-1]
    
    attn = format_attention(attention, tokens)  
    tokens = format_special_chars(tokens)
    sentence_b_start = token_type_ids[0].tolist().index(1)
    slice_a = slice(0, sentence_b_start)
    slice_b = slice(sentence_b_start, len(tokens))
    attn_data = attn[:, :, slice_a, slice_b]
    sentence_a_tokens = tokens[slice_a]
    sentence_b_tokens = tokens[slice_b]
    pair = pair_match(sentence_a_tokens, sentence_b_tokens, attn_data=attn_data)
    pair = sorted(pair, key=lambda pair: pair[2], reverse=True)
    pair = pair_without_score(pair)
    
    test_inputs = tokenizer.encode_plus(test_sentence_a, sentence_b, return_tensors='pt', add_special_tokens=False)
    test_input_ids = test_inputs['input_ids']
    test_input_ids.squeeze()
    test_tokens = tokenizer.convert_ids_to_tokens(test_input_ids.squeeze().tolist())
    test_token_type_ids = test_inputs['token_type_ids']
    test_tokens = format_special_chars(test_tokens)
    test_sentence_b_start = test_token_type_ids[0].tolist().index(1)
    test_slice_a = slice(0, test_sentence_b_start)
    test_slice_b = slice(test_sentence_b_start, len(test_tokens))
    test_sentence_a_tokens = test_tokens[test_slice_a]
    test_sentence_b_tokens = test_tokens[test_slice_b]
    test_pair = pair_match(test_sentence_a_tokens, test_sentence_b_tokens, attn_data=None)

    return MRR_calculate(test_pair, pair), len(test_pair)

In [120]:
import xml.etree.ElementTree as ET

In [210]:
root = ET.parse('RTE5_test_AttnSpan.xml').getroot()
text = []
hypothesis = []
entailment = []
attention = []

label_mapping = {'ENTAILMENT': 0, 'UNKNOWN': 1, 'CONTRADICTION': 2}

replacement = {"hasn't": 'has not', 
               "couldn't": 'could not', 
               "wasn't": 'was not', 
               "weren't": 'were not', 
               "doesn't": 'does not',
               "don't": 'do not',
               '"': '',
              }

for type_tag in root.findall('pair'):
    
    e = type_tag.get('entailment')
    t = type_tag.find('t').text
    t = t.lower()
    for word, rep in replacement.items():
        t = t.replace(word.lower(), rep)
    t = re.sub(r"([\(\)\[\]\{\}!-])", "", t)
    
    h = type_tag.find('h').text
    h = h.lower()
    for word, rep in replacement.items():
        h = h.replace(word.lower(), rep)
    h = re.sub(r"([\(\)\[\]\{\}!-])", "", h)
    
    a = type_tag.find('a').text
    a = a.lower()
    for word, rep in replacement.items():
        h = h.replace(word.lower(), rep)
    a = re.sub(r"([\(\)\[\]\{\}!-])", "", a)
    
    
    text.append(t)
    hypothesis.append(h)
    attention.append(a)
    entailment.append(label_mapping[e])
    
df_test = pd.DataFrame((zip(text, hypothesis, attention, entailment)), columns=['text_a', 'text_b', 'eval_text','label'])
df_test.to_csv("test_2.tsv", sep="\t", index=False)