In [1]:
import os,sys,inspect
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir) 
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, SequentialSampler, DataLoader
from transformers import XLNetTokenizer, XLNetForSequenceClassification, XLNetPreTrainedModel, XLNetModel
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from XLNet import (Dataset_Span_Detection,
                   XLNetForMultiSequenceClassification,
                   SpanDetectionResult, 
                   SquadExample,
                   SquadFeatures,
                   squad_convert_example_to_features)
from span_detection_metrics import compute_predictions_log_probs, span_evaluate
from utils import *


import pandas as pd
import numpy as np
import random
from tqdm.notebook import tqdm, trange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

dataset = Dataset_Span_Detection("RTE5_test_span", tokenizer=tokenizer)
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=1)
model = torch.load('../3multi_task/multi_0.6, 25, 15.pkl', map_location=torch.device('cpu'))
#model = torch.load('../3multi_task/multi_0.6, 25, 15.pkl')

In [3]:
all_results = []
all_examples = []
all_features = []

for data in tqdm(eval_dataloader, desc="Evaluating"):
    model.eval()
    
    with torch.no_grad():
        task = data[0]
        example_index = data[6]
        unique_id = data[7]
        input_ids, attention_mask, token_type_ids, cls_index, p_mask = [t.squeeze(0).to(device) for t in data[1:6]]
        
        question_text, context_text, answer_text, start_position_character = [t[0] for t in data[-4:]]

        example = SquadExample(
            question_text=question_text,
            context_text=context_text,
            answer_text=answer_text,
            start_position_character=start_position_character,
            unique_id=unique_id
        )

        feature = squad_convert_example_to_features(example,
                                                    max_seq_length=384,
                                                    doc_stride=128,
                                                    max_query_length=64,
                                                    is_training=False,
                                                    example_index=example_index,
                                                    unique_id=unique_id,
                                                    )
        
        output = model(input_ids=input_ids, 
                        token_type_ids=token_type_ids, 
                        attention_mask=attention_mask, 
                        cls_index=cls_index,
                        p_mask=p_mask,
                        task=task)
        
        #eval_feature = features
        
#         start_logits = output[0]
#         start_top_index = output[1]
#         end_logits = output[2]
#         end_top_index = output[3]
#         cls_logits = output[4]
        start_logits, start_top_index, end_logits, end_top_index, cls_logits = attention_weight_span(data, feature, output)
        
        result = SpanDetectionResult(
            unique_id,
            start_logits,
            end_logits,
            start_top_index=start_top_index,
            end_top_index=end_top_index,
            cls_logits=cls_logits,
        )
        
        all_results.append(result)
        all_examples.append(example)
        all_features.append(feature)

HBox(children=(IntProgress(value=0, description='Evaluating', max=600, style=ProgressStyle(description_width='…




In [11]:
a = torch.tensor([0.4354, 0.3759, 0.0914, 0.0453, 0.0204])

In [29]:
a.unsqueeze(1)[2]

tensor([0.0914])

In [19]:
a

tensor([0.4354, 0.3759, 0.0914, 0.0453, 0.0204])

In [78]:
for index, result in enumerate(all_results):
    result.start_logits = result.start_logits.unsqueeze(0)
    result.start_top_index = result.start_top_index.unsqueeze(0)
    #print(result.end_top_index.size())
    #result.end_logits = result.end_logits.unsqueeze(0)
    #result.end_top_index = result.end_top_index.unsqueeze(0)

In [79]:
all_results[0].start_logits.size()

torch.Size([1, 5])

In [86]:
start_n_top = 5
end_n_top = 5
n_best_size = 20
max_answer_length = 60
min_answer_length = 1
do_lower_case=False

output_dir = "../evaluation/"
prefix = ''
output_prediction_file = os.path.join(output_dir, "prediction_{}.json".format(prefix))
output_nbest_file = os.path.join(output_dir, "nbest_predictions_{}.json".format(prefix))

predictions = compute_predictions_log_probs(
    all_examples,
    all_features,
    all_results,
    n_best_size,
    max_answer_length,
    min_answer_length,
    output_prediction_file,
    output_nbest_file,
    start_n_top,
    end_n_top,
    tokenizer,
    verbose_logging=True,
)

result = span_evaluate(all_examples, predictions)

In [87]:
result

OrderedDict([('exact', 0.0),
             ('f1', 0.9756985980364193),
             ('total', 600),
             ('HasAns_exact', 0.0),
             ('HasAns_f1', 0.9756985980364193),
             ('HasAns_total', 600),
             ('best_exact', 0.0),
             ('best_exact_thresh', 0.0),
             ('best_f1', 0.9756985980364193),
             ('best_f1_thresh', 0.0)])

In [15]:
input_ids = all_features[0][0].input_ids
tokens = all_features[0][0].tokens
token_type_ids = all_features[0][0].token_type_ids
attention = output[5]
attn = format_attention(attention, tokens)

In [16]:
sentence_b_start = token_type_ids.index(1)

In [17]:
slice_a = slice(0, sentence_b_start)
slice_b = slice(sentence_b_start, len(tokens))

In [18]:
attn_data = attn[:, :, slice_a, slice_b]
sentence_a_tokens = tokens[slice_a]
sentence_b_tokens = tokens[slice_b]
pair = pair_match_accumulation(sentence_a_tokens, sentence_b_tokens, attn_data)


In [9]:
def attention_weight_span(data, feature, output):
    
    sentence_a = data[9]
    sentence_a = data[8]
    
    input_ids = feature[0].input_ids
    tokens = feature[0].tokens
    token_type_ids = feature[0].token_type_ids
    
    attention = output[5]
    attn = format_attention(attention, tokens)
    
    sentence_b_start = token_type_ids.index(1)
    slice_a = slice(0, sentence_b_start)
    slice_b = slice(sentence_b_start, len(tokens))
    
    attn_data = attn[:, :, slice_a, slice_b]
    sentence_a_tokens = tokens[slice_a]
    sentence_b_tokens = tokens[slice_b]
    attn_score = pair_match_accumulation(sentence_a_tokens, sentence_b_tokens, attn_data)
    
    attn_score = torch.tensor(attn_score)
    start_log_probs = F.softmax(attn_score, dim=-1)
    start_top_log_probs, start_top_index = torch.topk(start_log_probs, 5, dim=-1)
    end_top_log_probs, end_top_index = torch.topk(start_log_probs, 25, dim=-1)
    cls_logits = 0
    
    return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)

In [19]:
pair = torch.tensor(pair)

In [20]:
pair_log_probs = F.softmax(pair, dim=-1)

In [32]:
pair_log_probs[:5]

tensor([1.7767e-10, 1.0510e-10, 3.7801e-10, 6.9056e-11, 5.4780e-10])

In [40]:
pair_log_probs[start_top_index] = 0

In [41]:
pair_log_probs

tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.2611e-11,
        5.4890e-10, 1.2520e-10, 5.1879e-11, 2.0303e-09, 4.1818e-10, 7.0310e-11,
        1.2945e-09, 1.0122e-09, 4.8054e-10, 7.8518e-10, 1.7207e-10, 9.3816e-10,
        2.0141e-09, 1.5369e-10, 6.0906e-10, 1.1986e-09, 1.6682e-10, 2.9706e-10,
        4.2920e-10, 5.1539e-10, 3.1197e-10, 1.5980e-10, 6.3291e-09, 4.0786e-10,
        2.9646e-10, 5.1870e-09, 1.6672e-09, 1.3115e-09, 2.7920e-10, 3.9995e-09,
        1.5514e-09, 1.0042e-09, 8.7124e-10, 1.2626e-09, 4.4290e-09, 9.6287e-10,
        5.0086e-09, 2.2250e-10, 6.0759e-11, 1.5261e-10, 4.7576e-10, 1.2646e-10,
        4.5102e-11, 1.6375e-09, 8.6776e-10, 1.8611e-09, 4.7815e-10, 1.2971e-09,
        8.8706e-10, 7.2120e-10, 3.9344e-10, 3.8664e-11, 2.9351e-10, 1.3940e-09,
        1.6434e-10, 5.4671e-10, 4.6032e-10, 9.2937e-11, 1.1903e-09, 3.2373e-10,
        1.1408e-10, 8.5143e-10, 7.5289e-10, 1.3590e-10, 7.3915e-11, 1.6900e-10,
        2.9765e-10, 8.9598e-10, 1.9479e-

In [21]:
start_top_log_probs, start_top_index = torch.topk(pair_log_probs, 5, dim=-1)

In [42]:
end_top_log_probs, end_top_index = torch.topk(pair_log_probs, 25, dim=-1)

In [43]:
end_top_index

tensor([131, 129, 130, 123, 124, 132, 118, 133, 135, 134, 119, 117, 116, 114,
        115, 113, 120, 121, 122,  28,  31,  42,  40,  35,  99])

In [23]:
start_top_index

tensor([136, 127, 128, 125, 126])

In [15]:
result.__dict__

{'cls_logits': 0,
 'end_logits': tensor([3.2124e-01, 2.7649e-01, 1.2141e-01, 8.4030e-02, 5.3260e-02, 4.0092e-02,
         3.2989e-02, 2.4463e-02, 2.2336e-02, 1.3766e-02, 3.0440e-03, 1.8555e-03,
         1.7060e-03, 1.3181e-03, 6.5651e-04, 6.4933e-04, 2.0954e-04, 1.6319e-04,
         5.6483e-05, 5.0802e-05, 2.9516e-05, 2.2918e-05, 2.1284e-05, 1.0665e-05,
         9.6691e-06]),
 'end_top_index': tensor([ 16,  21,  18,  22,  23,  20,  19,  85,  14,  17,  87,  88, 136,  84,
          83,  86,  15,  80,   5,   6,  82,  13,  81,   4, 134]),
 'start_logits': tensor([0.3212, 0.2765, 0.1214, 0.0840, 0.0533]),
 'start_top_index': tensor([16, 21, 18, 22, 23]),
 'unique_id': tensor([1000001])}

In [16]:
result = SpanDetectionResult(
            100001,
            start_top_log_probs,
            end_top_log_probs,
            start_top_index=start_top_index,
            end_top_index=end_top_index,
            cls_logits=0,
)