In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers.modeling_utils import (WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
                             SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
from transformers import XLNetTokenizer, XLNetForSequenceClassification, XLNetPreTrainedModel, XLNetModel
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from XLNet import XLNetForMultiSequenceClassification


import pandas as pd
import numpy as np
import random
from IPython.display import clear_output
import re
from utils import *
from tqdm.notebook import tqdm
device = torch.device("cpu")

In [14]:
class Dataset_MRR(Dataset):
    def __init__(self, mode, tokenizer):
        assert mode in ["data/RTE5_test"]
        self.mode = mode
        self.df = pd.read_csv(mode + ".tsv", sep="\t").fillna("")
        self.len = len(self.df)
        self.tokenizer = tokenizer
        
    def __getitem__(self, idx):
        text_a, text_b, text_eval, label, t3, t5, t7 = self.df.iloc[idx, :].values
        label_tensor = torch.tensor(label)
            
        inputs = tokenizer.encode_plus(text_a, text_b, return_tensors='pt', add_special_tokens=True)
        tokens_tensor = inputs['input_ids']
        segments_tensor = inputs['token_type_ids']
        masks_tensor = inputs['attention_mask']
        
        return (tokens_tensor, segments_tensor, masks_tensor, label_tensor, text_a, text_b, text_eval)
    
    def __len__(self):
        return self.len
    
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

In [15]:
dataset = Dataset_MRR("data/RTE5_test", tokenizer=tokenizer)

In [16]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def create_mini_batch(samples):
    tokens_tensors = [s[0] for s in samples]
    segments_tensors = [s[1] for s in samples]
    masks_tensors = [s[2] for s in samples]
    text_a = [s[4] for s in samples]
    text_b = [s[5] for s in samples]
    text_eval = [s[6] for s in samples]
    t3 = [s[7] for s in samples]
    t5 = [s[8] for s in samples]
    t7 = [s[9] for s in samples]
    
    
    if samples[0][3] is not None:
        label_ids = torch.stack([s[3] for s in samples])
    else:
        label_ids = None
    
    tokens_tensors = pad_sequence(tokens_tensors, 
                                  batch_first=True)
    segments_tensors = pad_sequence(segments_tensors, 
                                    batch_first=True)
    masks_tensors = pad_sequence(masks_tensors, 
                                    batch_first=True)
    
    
    return tokens_tensors.squeeze(1), segments_tensors.squeeze(1), masks_tensors.squeeze(1), label_ids, text_a, text_b, text_eval, t3, t5, t7


# 初始化回傳訓練樣本的 DataLoader
# 利用 `collate_fn` 將 list of samples 合併成一個 mini-batch 
BATCH_SIZE = 1
testloader = DataLoader(dataset, batch_size=1, collate_fn=create_mini_batch)

In [5]:
def calculate(model, dataloader, tokenizer, unique=False, in_un=False, top_k=None):
    total = len(dataloader)
    entail_total = 0
    entail_total_len = 0
    neutral_total = 0
    neutral_total_len = 0
    contradict_total = 0
    contradict_total_len = 0
    
    entail_correct = 0
    entail_correct_len = 0
    neutral_correct = 0
    neutral_correct_len = 0
    contradict_correct = 0
    contradict_correct_len = 0
    
    entail_MRR_c = 0.
    neutral_MRR_c = 0.
    contradict_MRR_c = 0.
    
    entail_MRR_inc = 0.
    neutral_MRR_inc = 0.
    contradict_MRR_inc = 0.
    
    model.eval()
    with torch.no_grad():
        data_iterator = tqdm(dataloader, desc='Iteration')
        for data in data_iterator:
            if next(model.parameters()).is_cuda:
                data = [t.to("cuda:0") for t in data if t is not None]
            # predict
            tokens_tensors, segments_tensors, masks_tensors = data[:3]
            sentence_a = data[4][0]
            sentence_b = data[5][0]
            if top_k = 3:
                eval_sentence = data[7][0]
            elif top_k = 5:
                eval_sentence = data[8][0]
            elif top_k = 7:
                eval_sentence = data[9][0]
            else:
                eval_sentence = data[5][0]
            outputs = model(input_ids=tokens_tensors, 
                            token_type_ids=segments_tensors, 
                            attention_mask=masks_tensors)
            logits = outputs[0]
            _, pred = torch.max(logits.data, 1)
            
            # divide 3 class
            label = data[3]
            MRR, length = explainability_compare(model,
                                                 tokenizer,
                                                 sentence_a,
                                                 sentence_b,
                                                 eval_sentence,
                                                 unique=unique,
                                                 in_un=in_un,
                                                 top_k=top_k,
                                                )

            if label == torch.tensor([0]):
                entail_total += 1
                entail_total_len += length
                if pred == label:
                    entail_correct += 1
                    entail_correct_len += length
                    entail_MRR_c += MRR
                else:
                    entail_MRR_inc += MRR
            elif label == torch.tensor([1]):
                neutral_total += 1
                neutral_total_len += length
                if pred == label:
                    neutral_correct += 1
                    neutral_correct_len += length
                    neutral_MRR_c += MRR
                else:
                    neutral_MRR_inc += MRR
            else:
                contradict_total += 1
                contradict_total_len += length
                if pred == label:
                    contradict_correct += 1
                    contradict_correct_len += length
                    contradict_MRR_c += MRR
                else:
                    contradict_MRR_inc += MRR
    if contradict_correct_len == 0:
        contradict_correct += 1
        
                    
    
    return {
        'total':total,
        'total_MRR':round((entail_MRR_c+entail_MRR_inc+
                           neutral_MRR_c+neutral_MRR_inc+
                           contradict_MRR_c+contradict_MRR_inc)/total, 4),
        'total_acc':round((entail_correct+neutral_correct+contradict_correct)/total, 2),
        'total_mean_len':round((entail_total_len+neutral_total_len+contradict_total_len)/total, 1),
        'entail_total':entail_total,
        'entail_acc':round(entail_correct/entail_total, 2),
        'entail_mean_len':round(entail_total_len/entail_total, 1),
        'entail_MRR':round((entail_MRR_c+entail_MRR_inc)/entail_total, 4),
        'entail_correct':entail_correct,
        'entail_correct_mean_len':round(entail_correct_len/entail_correct, 1),
        'entail_MRR_c':round(entail_MRR_c/entail_correct, 4),
        'entail_incorrect':entail_total-entail_correct,
        'entail_incorrect_mean_len':round((entail_total_len-entail_correct_len)/(entail_total-entail_correct), 2),
        'entail_MRR_inc':round(entail_MRR_inc/(entail_total-entail_correct), 4),
        'neutral_total':neutral_total,
        'neutral_acc':round(neutral_correct/neutral_total, 2),
        'neutral_mean_len':round(neutral_total_len/neutral_total, 1),
        'neutral_MRR':round((neutral_MRR_c+neutral_MRR_inc)/neutral_total, 4),
        'neutral_correct':neutral_correct,
        'neutral_correct_mean_len':round(neutral_correct_len/neutral_correct, 1),
        'neutral_MRR_c':round(neutral_MRR_c/neutral_correct, 4),
        'neutral_incorrect':neutral_total-neutral_correct,
        'neutral_incorrect_mean_len':round((neutral_total_len-neutral_correct_len)/(neutral_total-neutral_correct), 2),
        'neutral_MRR_inc':round(neutral_MRR_inc/(neutral_total-neutral_correct), 4),
        'contradict_total':contradict_total,
        'contradict_acc':round(contradict_correct/contradict_total, 2),
        'contradict_mean_len':round(contradict_total_len/contradict_total, 1),
        'contradict_MRR':round((contradict_MRR_c+contradict_MRR_inc)/contradict_total, 4),
        'contradict_correct':contradict_correct,
        'contradict_correct_mean_len':round(contradict_correct_len/contradict_correct, 1),
        'contradict_MRR_c':round(contradict_MRR_c/contradict_correct, 4),
        'contradict_incorrect':contradict_total-contradict_correct,
        'contradict_incorrect_mean_len':round((contradict_total_len-contradict_correct_len)/(contradict_total-contradict_correct), 2),
        'contradict_MRR_inc':round(contradict_MRR_inc/(contradict_total-contradict_correct), 4),
    }
        

# Multi

In [6]:
%%time
model = torch.load('contra_63_24.pkl', map_location=torch.device('cpu'))
model_multi_result = calculate(model, testloader, tokenizer, unique=True)

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




CPU times: user 45min 58s, sys: 32.1 s, total: 46min 30s
Wall time: 25min 54s


{'total': 600,
 'total_MRR': 0.1842,
 'total_acc': 0.56,
 'total_mean_len': 15.0,
 'entail_total': 300,
 'entail_acc': 0.62,
 'entail_mean_len': 15.5,
 'entail_MRR': 0.2048,
 'entail_correct': 185,
 'entail_correct_mean_len': 14.9,
 'entail_MRR_c': 0.2122,
 'entail_incorrect': 115,
 'entail_incorrect_mean_len': 16.65,
 'entail_MRR_inc': 0.1928,
 'neutral_total': 210,
 'neutral_acc': 0.62,
 'neutral_mean_len': 13.7,
 'neutral_MRR': 0.1621,
 'neutral_correct': 130,
 'neutral_correct_mean_len': 13.7,
 'neutral_MRR_c': 0.1619,
 'neutral_incorrect': 80,
 'neutral_incorrect_mean_len': 13.62,
 'neutral_MRR_inc': 0.1624,
 'contradict_total': 90,
 'contradict_acc': 0.21,
 'contradict_mean_len': 16.4,
 'contradict_MRR': 0.1671,
 'contradict_correct': 19,
 'contradict_correct_mean_len': 13.7,
 'contradict_MRR_c': 0.1834,
 'contradict_incorrect': 71,
 'contradict_incorrect_mean_len': 17.13,
 'contradict_MRR_inc': 0.1627}

In [64]:
%%time
model = torch.load('contra_63_24.pkl', map_location=torch.device('cpu'))
model_multi_result_in_un = calculate(model, testloader, tokenizer, unique=True, in_un=True)
model_multi_result_in_un

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…


CPU times: user 47min 13s, sys: 34.1 s, total: 47min 47s
Wall time: 27min 2s


{'total': 600,
 'total_MRR': 0.4666,
 'total_acc': 0.56,
 'total_mean_len': 15.0,
 'entail_total': 300,
 'entail_acc': 0.62,
 'entail_mean_len': 15.5,
 'entail_MRR': 0.5621,
 'entail_correct': 185,
 'entail_correct_mean_len': 14.9,
 'entail_MRR_c': 0.5755,
 'entail_incorrect': 115,
 'entail_incorrect_mean_len': 16.65,
 'entail_MRR_inc': 0.5405,
 'neutral_total': 210,
 'neutral_acc': 0.62,
 'neutral_mean_len': 13.7,
 'neutral_MRR': 0.3305,
 'neutral_correct': 130,
 'neutral_correct_mean_len': 13.7,
 'neutral_MRR_c': 0.3187,
 'neutral_incorrect': 80,
 'neutral_incorrect_mean_len': 13.62,
 'neutral_MRR_inc': 0.3497,
 'contradict_total': 90,
 'contradict_acc': 0.21,
 'contradict_mean_len': 16.4,
 'contradict_MRR': 0.4655,
 'contradict_correct': 19,
 'contradict_correct_mean_len': 13.7,
 'contradict_MRR_c': 0.4719,
 'contradict_incorrect': 71,
 'contradict_incorrect_mean_len': 17.13,
 'contradict_MRR_inc': 0.4638}

In [70]:
%%time
model_single = torch.load('single_056.pkl',map_location=torch.device('cpu'))
model_single_result = calculate(model_single, testloader, tokenizer, unique=True)
model_single_result

model = torch.load('contra_63_24.pkl', map_location=torch.device('cpu'))
model_multi_result_non_unique = calculate(model, testloader, tokenizer)
model_multi_result_non_unique

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.0248,
 'total_acc': 0.56,
 'total_mean_len': 245.8,
 'entail_total': 300,
 'entail_acc': 0.62,
 'entail_mean_len': 262.5,
 'entail_MRR': 0.0282,
 'entail_correct': 185,
 'entail_correct_mean_len': 253.2,
 'entail_MRR_c': 0.0302,
 'entail_incorrect': 115,
 'entail_incorrect_mean_len': 277.53,
 'entail_MRR_inc': 0.0249,
 'neutral_total': 210,
 'neutral_acc': 0.62,
 'neutral_mean_len': 209.0,
 'neutral_MRR': 0.0217,
 'neutral_correct': 130,
 'neutral_correct_mean_len': 211.1,
 'neutral_MRR_c': 0.0213,
 'neutral_incorrect': 80,
 'neutral_incorrect_mean_len': 205.65,
 'neutral_MRR_inc': 0.0223,
 'contradict_total': 90,
 'contradict_acc': 0.21,
 'contradict_mean_len': 275.6,
 'contradict_MRR': 0.0213,
 'contradict_correct': 19,
 'contradict_correct_mean_len': 230.4,
 'contradict_MRR_c': 0.0249,
 'contradict_incorrect': 71,
 'contradict_incorrect_mean_len': 287.69,
 'contradict_MRR_inc': 0.0203}

In [19]:
model = torch.load('multi_0.59, 31.pkl', map_location=torch.device('cpu'))
model_multi_result_non_unique = calculate(model, testloader, tokenizer)
model_multi_result_non_unique

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.0286,
 'total_acc': 0.58,
 'total_mean_len': 158.8,
 'entail_total': 300,
 'entail_acc': 0.78,
 'entail_mean_len': 171.2,
 'entail_MRR': 0.0322,
 'entail_correct': 233,
 'entail_correct_mean_len': 168.4,
 'entail_MRR_c': 0.0339,
 'entail_incorrect': 67,
 'entail_incorrect_mean_len': 181.19,
 'entail_MRR_inc': 0.0262,
 'neutral_total': 210,
 'neutral_acc': 0.44,
 'neutral_mean_len': 134.8,
 'neutral_MRR': 0.0247,
 'neutral_correct': 93,
 'neutral_correct_mean_len': 132.2,
 'neutral_MRR_c': 0.0243,
 'neutral_incorrect': 117,
 'neutral_incorrect_mean_len': 136.88,
 'neutral_MRR_inc': 0.0249,
 'contradict_total': 90,
 'contradict_acc': 0.27,
 'contradict_mean_len': 173.7,
 'contradict_MRR': 0.026,
 'contradict_correct': 24,
 'contradict_correct_mean_len': 157.3,
 'contradict_MRR_c': 0.0248,
 'contradict_incorrect': 66,
 'contradict_incorrect_mean_len': 179.62,
 'contradict_MRR_inc': 0.0264}

In [20]:
model = torch.load('multi_0.58, 21.pkl', map_location=torch.device('cpu'))
model_multi_result_non_unique2 = calculate(model, testloader, tokenizer)
model_multi_result_non_unique2

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.0292,
 'total_acc': 0.61,
 'total_mean_len': 158.8,
 'entail_total': 300,
 'entail_acc': 0.85,
 'entail_mean_len': 171.2,
 'entail_MRR': 0.0329,
 'entail_correct': 255,
 'entail_correct_mean_len': 170.3,
 'entail_MRR_c': 0.0341,
 'entail_incorrect': 45,
 'entail_incorrect_mean_len': 176.42,
 'entail_MRR_inc': 0.0258,
 'neutral_total': 210,
 'neutral_acc': 0.45,
 'neutral_mean_len': 134.8,
 'neutral_MRR': 0.0251,
 'neutral_correct': 94,
 'neutral_correct_mean_len': 132.2,
 'neutral_MRR_c': 0.0249,
 'neutral_incorrect': 116,
 'neutral_incorrect_mean_len': 136.93,
 'neutral_MRR_inc': 0.0252,
 'contradict_total': 90,
 'contradict_acc': 0.18,
 'contradict_mean_len': 173.7,
 'contradict_MRR': 0.0266,
 'contradict_correct': 16,
 'contradict_correct_mean_len': 148.9,
 'contradict_MRR_c': 0.0261,
 'contradict_incorrect': 74,
 'contradict_incorrect_mean_len': 179.03,
 'contradict_MRR_inc': 0.0267}

In [21]:
model = torch.load('multi_0.58, 26.pkl', map_location=torch.device('cpu'))
model_multi_result_non_unique3 = calculate(model, testloader, tokenizer)
model_multi_result_non_unique3

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.029,
 'total_acc': 0.6,
 'total_mean_len': 158.8,
 'entail_total': 300,
 'entail_acc': 0.87,
 'entail_mean_len': 171.2,
 'entail_MRR': 0.0327,
 'entail_correct': 260,
 'entail_correct_mean_len': 169.4,
 'entail_MRR_c': 0.0342,
 'entail_incorrect': 40,
 'entail_incorrect_mean_len': 183.25,
 'entail_MRR_inc': 0.023,
 'neutral_total': 210,
 'neutral_acc': 0.4,
 'neutral_mean_len': 134.8,
 'neutral_MRR': 0.0248,
 'neutral_correct': 83,
 'neutral_correct_mean_len': 131.4,
 'neutral_MRR_c': 0.0247,
 'neutral_incorrect': 127,
 'neutral_incorrect_mean_len': 137.06,
 'neutral_MRR_inc': 0.0249,
 'contradict_total': 90,
 'contradict_acc': 0.17,
 'contradict_mean_len': 173.7,
 'contradict_MRR': 0.0263,
 'contradict_correct': 15,
 'contradict_correct_mean_len': 132.8,
 'contradict_MRR_c': 0.0267,
 'contradict_incorrect': 75,
 'contradict_incorrect_mean_len': 181.84,
 'contradict_MRR_inc': 0.0263}

In [17]:
model = torch.load('multi_0.6, 22, 15.pkl', map_location=torch.device('cpu'))
model_multi_result_non_unique4 = calculate(model, testloader, tokenizer)
model_multi_result_non_unique4

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.0223,
 'total_acc': 0.6,
 'total_mean_len': 158.8,
 'entail_total': 300,
 'entail_acc': 0.85,
 'entail_mean_len': 171.2,
 'entail_MRR': 0.0247,
 'entail_correct': 254,
 'entail_correct_mean_len': 169.3,
 'entail_MRR_c': 0.0252,
 'entail_incorrect': 46,
 'entail_incorrect_mean_len': 181.85,
 'entail_MRR_inc': 0.0223,
 'neutral_total': 210,
 'neutral_acc': 0.42,
 'neutral_mean_len': 134.8,
 'neutral_MRR': 0.0196,
 'neutral_correct': 88,
 'neutral_correct_mean_len': 132.4,
 'neutral_MRR_c': 0.0198,
 'neutral_incorrect': 122,
 'neutral_incorrect_mean_len': 136.59,
 'neutral_MRR_inc': 0.0194,
 'contradict_total': 90,
 'contradict_acc': 0.21,
 'contradict_mean_len': 173.7,
 'contradict_MRR': 0.0203,
 'contradict_correct': 19,
 'contradict_correct_mean_len': 106.8,
 'contradict_MRR_c': 0.0226,
 'contradict_incorrect': 71,
 'contradict_incorrect_mean_len': 191.56,
 'contradict_MRR_inc': 0.0197}

In [18]:
model = torch.load('multi_0.6, 26, 13.pkl', map_location=torch.device('cpu'))
model_multi_result_non_unique5 = calculate(model, testloader, tokenizer)
model_multi_result_non_unique5

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.0236,
 'total_acc': 0.6,
 'total_mean_len': 158.8,
 'entail_total': 300,
 'entail_acc': 0.85,
 'entail_mean_len': 171.2,
 'entail_MRR': 0.0262,
 'entail_correct': 255,
 'entail_correct_mean_len': 170.2,
 'entail_MRR_c': 0.0271,
 'entail_incorrect': 45,
 'entail_incorrect_mean_len': 177.31,
 'entail_MRR_inc': 0.0212,
 'neutral_total': 210,
 'neutral_acc': 0.42,
 'neutral_mean_len': 134.8,
 'neutral_MRR': 0.0206,
 'neutral_correct': 88,
 'neutral_correct_mean_len': 135.6,
 'neutral_MRR_c': 0.0207,
 'neutral_incorrect': 122,
 'neutral_incorrect_mean_len': 134.28,
 'neutral_MRR_inc': 0.0205,
 'contradict_total': 90,
 'contradict_acc': 0.21,
 'contradict_mean_len': 173.7,
 'contradict_MRR': 0.0217,
 'contradict_correct': 19,
 'contradict_correct_mean_len': 171.4,
 'contradict_MRR_c': 0.0211,
 'contradict_incorrect': 71,
 'contradict_incorrect_mean_len': 174.27,
 'contradict_MRR_inc': 0.0219}

In [22]:
model = torch.load('multi_0.59, 43, 18.pkl', map_location=torch.device('cpu'))
model_multi_result_non_unique59_43 = calculate(model, testloader, tokenizer)
model_multi_result_non_unique59_43

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.0347,
 'total_acc': 0.61,
 'total_mean_len': 158.8,
 'entail_total': 300,
 'entail_acc': 0.81,
 'entail_mean_len': 171.2,
 'entail_MRR': 0.0386,
 'entail_correct': 242,
 'entail_correct_mean_len': 170.5,
 'entail_MRR_c': 0.0408,
 'entail_incorrect': 58,
 'entail_incorrect_mean_len': 174.21,
 'entail_MRR_inc': 0.0294,
 'neutral_total': 210,
 'neutral_acc': 0.43,
 'neutral_mean_len': 134.8,
 'neutral_MRR': 0.0305,
 'neutral_correct': 90,
 'neutral_correct_mean_len': 135.2,
 'neutral_MRR_c': 0.03,
 'neutral_incorrect': 120,
 'neutral_incorrect_mean_len': 134.53,
 'neutral_MRR_inc': 0.0308,
 'contradict_total': 90,
 'contradict_acc': 0.36,
 'contradict_mean_len': 173.7,
 'contradict_MRR': 0.0313,
 'contradict_correct': 32,
 'contradict_correct_mean_len': 173.1,
 'contradict_MRR_c': 0.0297,
 'contradict_incorrect': 58,
 'contradict_incorrect_mean_len': 173.97,
 'contradict_MRR_inc': 0.0321}

In [13]:
model = torch.load('contra_63_24.pkl', map_location=torch.device('cpu'))
model_2multi_result_non_unique = calculate(model, testloader, tokenizer)
model_2multi_result_non_unique

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.025,
 'total_acc': 0.56,
 'total_mean_len': 244.5,
 'entail_total': 300,
 'entail_acc': 0.62,
 'entail_mean_len': 261.6,
 'entail_MRR': 0.0283,
 'entail_correct': 186,
 'entail_correct_mean_len': 252.7,
 'entail_MRR_c': 0.0303,
 'entail_incorrect': 114,
 'entail_incorrect_mean_len': 275.93,
 'entail_MRR_inc': 0.0251,
 'neutral_total': 210,
 'neutral_acc': 0.62,
 'neutral_mean_len': 208.1,
 'neutral_MRR': 0.0217,
 'neutral_correct': 131,
 'neutral_correct_mean_len': 209.6,
 'neutral_MRR_c': 0.0214,
 'neutral_incorrect': 79,
 'neutral_incorrect_mean_len': 205.58,
 'neutral_MRR_inc': 0.0222,
 'contradict_total': 90,
 'contradict_acc': 0.21,
 'contradict_mean_len': 272.8,
 'contradict_MRR': 0.0216,
 'contradict_correct': 19,
 'contradict_correct_mean_len': 230.4,
 'contradict_MRR_c': 0.025,
 'contradict_incorrect': 71,
 'contradict_incorrect_mean_len': 284.1,
 'contradict_MRR_inc': 0.0207}

In [23]:
model = torch.load('2multi_0.628333, 25.pkl', map_location=torch.device('cpu'))
model_2multi_result_non_unique1 = calculate(model, testloader, tokenizer)
model_2multi_result_non_unique1

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…






{'total': 600,
 'total_MRR': 0.0185,
 'total_acc': 0.64,
 'total_mean_len': 158.8,
 'entail_total': 300,
 'entail_acc': 0.83,
 'entail_mean_len': 171.2,
 'entail_MRR': 0.0201,
 'entail_correct': 250,
 'entail_correct_mean_len': 172.3,
 'entail_MRR_c': 0.0201,
 'entail_incorrect': 50,
 'entail_incorrect_mean_len': 165.64,
 'entail_MRR_inc': 0.0201,
 'neutral_total': 210,
 'neutral_acc': 0.56,
 'neutral_mean_len': 134.8,
 'neutral_MRR': 0.0167,
 'neutral_correct': 118,
 'neutral_correct_mean_len': 130.9,
 'neutral_MRR_c': 0.0172,
 'neutral_incorrect': 92,
 'neutral_incorrect_mean_len': 139.85,
 'neutral_MRR_inc': 0.0161,
 'contradict_total': 90,
 'contradict_acc': 0.18,
 'contradict_mean_len': 173.7,
 'contradict_MRR': 0.0173,
 'contradict_correct': 16,
 'contradict_correct_mean_len': 137.4,
 'contradict_MRR_c': 0.0176,
 'contradict_incorrect': 74,
 'contradict_incorrect_mean_len': 181.51,
 'contradict_MRR_inc': 0.0172}

# Single

In [65]:
%%time
model_single = torch.load('single_056.pkl',map_location=torch.device('cpu'))
model_single_result_in_un = calculate(model_single, testloader, tokenizer, unique=True, in_un=True)
model_single_result_in_un

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…


CPU times: user 46min 55s, sys: 33.4 s, total: 47min 28s
Wall time: 26min 49s


{'total': 600,
 'total_MRR': 0.4116,
 'total_acc': 0.52,
 'total_mean_len': 15.0,
 'entail_total': 300,
 'entail_acc': 0.44,
 'entail_mean_len': 15.5,
 'entail_MRR': 0.4951,
 'entail_correct': 133,
 'entail_correct_mean_len': 13.9,
 'entail_MRR_c': 0.5375,
 'entail_incorrect': 167,
 'entail_incorrect_mean_len': 16.86,
 'entail_MRR_inc': 0.4613,
 'neutral_total': 210,
 'neutral_acc': 0.73,
 'neutral_mean_len': 13.7,
 'neutral_MRR': 0.3011,
 'neutral_correct': 154,
 'neutral_correct_mean_len': 13.6,
 'neutral_MRR_c': 0.297,
 'neutral_incorrect': 56,
 'neutral_incorrect_mean_len': 13.82,
 'neutral_MRR_inc': 0.3125,
 'contradict_total': 90,
 'contradict_acc': 0.29,
 'contradict_mean_len': 16.4,
 'contradict_MRR': 0.3909,
 'contradict_correct': 26,
 'contradict_correct_mean_len': 17.6,
 'contradict_MRR_c': 0.423,
 'contradict_incorrect': 64,
 'contradict_incorrect_mean_len': 15.92,
 'contradict_MRR_inc': 0.3779}

In [71]:
model_single = torch.load('single_056.pkl',map_location=torch.device('cpu'))
model_single_result_non_unique = calculate(model_single, testloader, tokenizer)
model_single_result_non_unique

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.0225,
 'total_acc': 0.52,
 'total_mean_len': 245.8,
 'entail_total': 300,
 'entail_acc': 0.44,
 'entail_mean_len': 262.5,
 'entail_MRR': 0.0256,
 'entail_correct': 133,
 'entail_correct_mean_len': 229.1,
 'entail_MRR_c': 0.0295,
 'entail_incorrect': 167,
 'entail_incorrect_mean_len': 289.12,
 'entail_MRR_inc': 0.0224,
 'neutral_total': 210,
 'neutral_acc': 0.73,
 'neutral_mean_len': 209.0,
 'neutral_MRR': 0.0198,
 'neutral_correct': 154,
 'neutral_correct_mean_len': 212.0,
 'neutral_MRR_c': 0.0197,
 'neutral_incorrect': 56,
 'neutral_incorrect_mean_len': 200.98,
 'neutral_MRR_inc': 0.0202,
 'contradict_total': 90,
 'contradict_acc': 0.29,
 'contradict_mean_len': 275.6,
 'contradict_MRR': 0.0185,
 'contradict_correct': 26,
 'contradict_correct_mean_len': 311.7,
 'contradict_MRR_c': 0.0194,
 'contradict_incorrect': 64,
 'contradict_incorrect_mean_len': 260.92,
 'contradict_MRR_inc': 0.0181}

# Pretrained

In [92]:
%%time
model_pretrained = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased',output_attentions=True,
                                                                  num_labels=3)
model_pretrained_result_unique = calculate(model_pretrained, testloader, tokenizer, unique=True)
model_pretrained_result_unique

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…


CPU times: user 46min 27s, sys: 32.4 s, total: 46min 59s
Wall time: 26min 27s


{'total': 600,
 'total_MRR': 0.1777,
 'total_acc': 0.46,
 'total_mean_len': 15.0,
 'entail_total': 300,
 'entail_acc': 0.8,
 'entail_mean_len': 15.5,
 'entail_MRR': 0.1961,
 'entail_correct': 240,
 'entail_correct_mean_len': 15.6,
 'entail_MRR_c': 0.1987,
 'entail_incorrect': 60,
 'entail_incorrect_mean_len': 15.53,
 'entail_MRR_inc': 0.1857,
 'neutral_total': 210,
 'neutral_acc': 0.15,
 'neutral_mean_len': 13.7,
 'neutral_MRR': 0.1584,
 'neutral_correct': 31,
 'neutral_correct_mean_len': 14.1,
 'neutral_MRR_c': 0.1548,
 'neutral_incorrect': 179,
 'neutral_incorrect_mean_len': 13.59,
 'neutral_MRR_inc': 0.159,
 'contradict_total': 90,
 'contradict_acc': 0.07,
 'contradict_mean_len': 16.4,
 'contradict_MRR': 0.1613,
 'contradict_correct': 6,
 'contradict_correct_mean_len': 19.3,
 'contradict_MRR_c': 0.1167,
 'contradict_incorrect': 84,
 'contradict_incorrect_mean_len': 16.19,
 'contradict_MRR_inc': 0.1645}

In [93]:
model_pretrained_result_in_un = calculate(model_pretrained, testloader, tokenizer, unique=True, in_un=True)
model_pretrained_result_in_un

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.4017,
 'total_acc': 0.46,
 'total_mean_len': 15.0,
 'entail_total': 300,
 'entail_acc': 0.8,
 'entail_mean_len': 15.5,
 'entail_MRR': 0.4776,
 'entail_correct': 240,
 'entail_correct_mean_len': 15.6,
 'entail_MRR_c': 0.4874,
 'entail_incorrect': 60,
 'entail_incorrect_mean_len': 15.53,
 'entail_MRR_inc': 0.4384,
 'neutral_total': 210,
 'neutral_acc': 0.15,
 'neutral_mean_len': 13.7,
 'neutral_MRR': 0.3015,
 'neutral_correct': 31,
 'neutral_correct_mean_len': 14.1,
 'neutral_MRR_c': 0.2993,
 'neutral_incorrect': 179,
 'neutral_incorrect_mean_len': 13.59,
 'neutral_MRR_inc': 0.3019,
 'contradict_total': 90,
 'contradict_acc': 0.07,
 'contradict_mean_len': 16.4,
 'contradict_MRR': 0.3822,
 'contradict_correct': 6,
 'contradict_correct_mean_len': 19.3,
 'contradict_MRR_c': 0.3117,
 'contradict_incorrect': 84,
 'contradict_incorrect_mean_len': 16.19,
 'contradict_MRR_inc': 0.3873}

In [94]:
model_pretrained_result_non_unique = calculate(model_pretrained, testloader, tokenizer)
model_pretrained_result_non_unique

HBox(children=(IntProgress(value=0, description='Iteration', max=600, style=ProgressStyle(description_width='i…




{'total': 600,
 'total_MRR': 0.0222,
 'total_acc': 0.46,
 'total_mean_len': 245.8,
 'entail_total': 300,
 'entail_acc': 0.8,
 'entail_mean_len': 262.5,
 'entail_MRR': 0.0248,
 'entail_correct': 240,
 'entail_correct_mean_len': 269.4,
 'entail_MRR_c': 0.0245,
 'entail_incorrect': 60,
 'entail_incorrect_mean_len': 235.18,
 'entail_MRR_inc': 0.0259,
 'neutral_total': 210,
 'neutral_acc': 0.15,
 'neutral_mean_len': 209.0,
 'neutral_MRR': 0.0201,
 'neutral_correct': 31,
 'neutral_correct_mean_len': 196.4,
 'neutral_MRR_c': 0.0203,
 'neutral_incorrect': 179,
 'neutral_incorrect_mean_len': 211.22,
 'neutral_MRR_inc': 0.02,
 'contradict_total': 90,
 'contradict_acc': 0.07,
 'contradict_mean_len': 275.6,
 'contradict_MRR': 0.0189,
 'contradict_correct': 6,
 'contradict_correct_mean_len': 296.3,
 'contradict_MRR_c': 0.017,
 'contradict_incorrect': 84,
 'contradict_incorrect_mean_len': 274.12,
 'contradict_MRR_inc': 0.0191}

In [96]:
import csv
#with open('output.csv', 'w', newline='') as csvfile:

with open('pretrained_non_unique.csv', 'w') as f:
    w = csv.DictWriter(f, model_pretrained_result_non_unique.keys())
    w.writeheader()
    w.writerow(model_pretrained_result_non_unique)

with open('pretrained_unique.csv', 'w') as f:
    w = csv.DictWriter(f, model_pretrained_result.keys())
    w.writeheader()
    w.writerow(model_pretrained_result_unique)
    
with open('pretrained_in_un.csv', 'w') as f:
    w = csv.DictWriter(f, model_pretrained_result_in_un.keys())
    w.writeheader()
    w.writerow(model_pretrained_result_in_un)

In [10]:
sentence_a = """losail qatar afp torrential rain caused the seasonopening qatar motogp to be cancelled on sunday leaving officials and teams in a frenzy before deciding to race on monday instead at this floodlit desert venue. monsoonlike conditions accompanied by swirling winds arrived just moments before australia's casey stoner on pole position was due to lead defending world champion valentino rossi and the other riders away on the warmup lap. it's just unlucky with the weather said australian ducati rider stoner the 2007 world champion who was bidding for a third successive win here."""
sentence_b = "valentino rossi won the seasonopening qatar motogp."
test_sentence_a = """torrential rain caused the seasonopening qatar motogp to be cancelled"""
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased', do_lower_case=True)

In [11]:
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
input_ids = inputs['input_ids'].to(device)
input_ids.squeeze()
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist())
token_type_ids = inputs['token_type_ids'].to(device)

In [84]:
model.eval()
with torch.no_grad():
    attention = model_pretrained(input_ids, token_type_ids=token_type_ids)[-1]
    logits = model_pretrained(input_ids, token_type_ids=token_type_ids)[0]
logits

tensor([[-0.0703, -0.0253, -0.0086]])

In [13]:
attn = format_attention(attention, tokens)  
tokens = format_special_chars(tokens)
sentence_b_start = token_type_ids[0].tolist().index(1)
slice_a = slice(0, sentence_b_start)
slice_b = slice(sentence_b_start, len(tokens))
attn_data = attn[:, :, slice_a, slice_b]
sentence_a_tokens = tokens[slice_a]
sentence_b_tokens = tokens[slice_b]
pair = pair_match(sentence_a_tokens, sentence_b_tokens, attn_data=attn_data)
pair = sorted(pair, key=lambda pair: pair[2], reverse=True)
pair = unique_pair_without_score(pair)
pair

['season',
 'p',
 'moto',
 'opening',
 'qa',
 'g',
 'tar',
 'ino',
 'ossi',
 'r',
 'valent',
 'win',
 'for',
 'defending',
 '.',
 'a',
 'the',
 'bidding',
 'world',
 'champion',
 'rider',
 'race',
 'pole',
 'af',
 'lead',
 'was',
 'riders',
 'and',
 'who',
 'australia',
 'up',
 'lap',
 'cancelled',
 'successive',
 'third',
 'frenzy',
 'warm',
 'torrential',
 'venue',
 'deciding',
 'ati',
 'il',
 'officials',
 'monsoon',
 'stone',
 '2007',
 'case',
 "'",
 'flood',
 'position',
 'desert',
 'sun',
 'teams',
 'mon',
 'before',
 'lucky',
 'instead',
 'just',
 'conditions',
 'n',
 'this',
 'un',
 'duc',
 'winds',
 'to',
 'weather',
 'on',
 'here',
 'it',
 'said',
 'at',
 'los',
 'day',
 'due',
 'rain',
 'by',
 'arrived',
 'y',
 'caused',
 'other',
 'be',
 's',
 'leaving',
 'with',
 'lit',
 'swirling',
 'away',
 'moments',
 'accompanied',
 'like',
 'in']

In [17]:
test_inputs = tokenizer.encode_plus(test_sentence_a, sentence_b, return_tensors='pt', add_special_tokens=False)
test_input_ids = test_inputs['input_ids']
test_input_ids.squeeze()
test_tokens = tokenizer.convert_ids_to_tokens(test_input_ids.squeeze().tolist())
test_token_type_ids = test_inputs['token_type_ids']
test_tokens = format_special_chars(test_tokens)
test_sentence_b_start = test_token_type_ids[0].tolist().index(1)
test_slice_a = slice(0, test_sentence_b_start)
test_slice_b = slice(test_sentence_b_start, len(test_tokens))
test_sentence_a_tokens = test_tokens[test_slice_a]
test_sentence_b_tokens = test_tokens[test_slice_b]
test_pair = pair_match(test_sentence_a_tokens, test_sentence_b_tokens, attn_data=None)
test_pair = unique_pair_without_score(test_pair)