In [1]:
import os,sys,inspect
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir) 
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, RandomSampler
from transformers.modeling_utils import (WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
                             SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
from transformers import XLNetTokenizer, XLNetForSequenceClassification, XLNetPreTrainedModel, XLNetModel
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from torch.utils.data.dataset import ConcatDataset
from XLNet import (Dataset_3Way,
                  Dataset_multi, 
                  Dataset_Span_Detection, 
                  XLNetForMultiSequenceClassification, 
                  get_predictions)

import pandas as pd
import numpy as np
import random
from tqdm.notebook import tqdm, trange

In [2]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

trainset_3way = Dataset_3Way("RTE5_train", tokenizer=tokenizer, three_tasks=True)
trainset_multi = Dataset_multi("train_multi_label", tokenizer=tokenizer, three_tasks=True)
trainset_span = Dataset_Span_Detection("train_span_detection", tokenizer=tokenizer)
trainset = ConcatDataset([trainset_span, trainset_multi, trainset_3way])
train_sampler = RandomSampler(trainset)
train_dataloader = DataLoader(trainset, sampler=train_sampler, batch_size=1)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

PRETRAINED_MODEL_NAME = "xlnet-base-cased"
model = XLNetForMultiSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME,
                                                            output_attentions=True,
                                                            dropout=0.1)
model = model.to(device)

device: cuda:0


In [4]:
    testset = Dataset_3Way("RTE5_test", tokenizer=tokenizer, three_tasks=True)
    testloader = DataLoader(testset, batch_size=1)
    predictions = get_predictions(model, testloader)

    df_pred = pd.DataFrame({"label": predictions.tolist()})
        
    pred_Y = df_pred['label'].values
    test_Y = pd.read_csv("../data/RTE5_test.tsv", sep='\t').fillna("")['label'].values



In [4]:
from torch.optim import AdamW

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8)

In [5]:
%%time
EPOCHS = 20
batch_size = 8
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=EPOCHS * ((len(train_dataloader)//batch_size)))
epochs_trained = 0

model.zero_grad()
train_iterator = trange(epochs_trained, EPOCHS, desc="Epoch")
set_seed(42)

for _ in train_iterator:
    epoch_iterator = tqdm(train_dataloader, desc="Iteration")
    
    model.train()
    running_loss = 0.0
    batch_cnt = 1
    loss = torch.zeros(1).to(device)
    
    for step, data in enumerate(epoch_iterator):
        if data[0] == torch.tensor([0]) or data[0] == torch.tensor([1]):
            task = data[0]
            input_ids, token_type_ids, attention_mask, labels = [t.squeeze(0).to(device) for t in data[1:]]
            outputs = model(input_ids=input_ids, 
                            token_type_ids=token_type_ids, 
                            attention_mask=attention_mask, 
                            labels=labels,
                            task=task
                           )
        else:
            task = data[0]
            input_ids, attention_mask, token_type_ids, start_positions, end_positions, cls_index, p_mask = [t.squeeze(0).to(device) for t in data[1:]]
            outputs = model(input_ids=input_ids, 
                            token_type_ids=token_type_ids, 
                            attention_mask=attention_mask, 
                            start_positions=start_positions,
                            end_positions=end_positions,
                            cls_index=cls_index,
                            p_mask=p_mask,
                            task=task)
        batch_cnt += 1
        loss = outputs[0]/batch_size
        loss.backward()
        
        if batch_cnt >= batch_size:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            batch_cnt = 0

        # 紀錄當前 batch loss
        running_loss += loss.item()
    epochs_trained += 1
    
    testset = Dataset_3Way("RTE5_test", tokenizer=tokenizer, three_tasks=True)
    testloader = DataLoader(testset, batch_size=1)
    predictions = get_predictions(model, testloader)

    df_pred = pd.DataFrame({"label": predictions.tolist()})
        
    pred_Y = df_pred['label'].values
    test_Y = pd.read_csv("../data/RTE5_test.tsv", sep='\t').fillna("")['label'].values

    accuracy = accuracy_score(test_Y, np.array(pred_Y))
    precision = precision_score(test_Y, pred_Y, average='macro')
    recall = recall_score(test_Y, pred_Y, average='macro')
    fscore = f1_score(test_Y, pred_Y, average='macro')
    
    CNT = 0
    TOTAL = 0
    for i in range(len(test_Y)):
        if test_Y[i] == 2:
            TOTAL += 1
        else:
            pass
        if test_Y[i] == 2 and predictions[i] == 2:
            CNT += 1
    contra = round((CNT/TOTAL)*100,1)
    if contra > 20 and accuracy > 0.58:
        torch.save(model, "3multi_%g, %g, %g.pkl" % (round(accuracy, 2), contra, epochs_trained))
    print("Accuracy: %g\tPrecision: %g\tRecall: %g\tF-score: %g Loss: %g" % (accuracy, precision, recall, fscore, running_loss))
    print(contra)
    print("------------------------------------------")

HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…






  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


Accuracy: 0.393333	Precision: 0.284098	Recall: 0.340794	F-score: 0.271781 Loss: 365.734
0.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.471667	Precision: 0.312589	Recall: 0.369206	F-score: 0.336956 Loss: 250.615
0.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.471667	Precision: 0.341256	Recall: 0.395397	F-score: 0.336981 Loss: 194.319
0.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.511667	Precision: 0.373959	Recall: 0.425873	F-score: 0.367374 Loss: 150.259
0.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.608333	Precision: 0.402134	Recall: 0.466984	F-score: 0.431249 Loss: 124.593
0.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.63	Precision: 0.532992	Recall: 0.497513	F-score: 0.476109 Loss: 102.008
4.4
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.568333	Precision: 0.502442	Recall: 0.459206	F-score: 0.462045 Loss: 80.9879
16.7
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.596667	Precision: 0.511815	Recall: 0.513016	F-score: 0.510326 Loss: 64.318
20.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.52	Precision: 0.529515	Recall: 0.54545	F-score: 0.489766 Loss: 55.3283
51.1
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…




  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Accuracy: 0.591667	Precision: 0.531832	Recall: 0.538889	F-score: 0.53443 Loss: 44.1462
33.3
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.625	Precision: 0.566964	Recall: 0.552698	F-score: 0.551986 Loss: 38.036
26.7
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.586667	Precision: 0.533432	Recall: 0.539524	F-score: 0.535963 Loss: 31.7507
36.7
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.605	Precision: 0.54135	Recall: 0.529312	F-score: 0.533221 Loss: 30.8632
28.9
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.58	Precision: 0.518899	Recall: 0.531217	F-score: 0.51793 Loss: 25.6166
28.9
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.613333	Precision: 0.55995	Recall: 0.561058	F-score: 0.558661 Loss: 24.4889
38.9
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.591667	Precision: 0.525376	Recall: 0.524233	F-score: 0.524775 Loss: 22.7235
28.9
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.568333	Precision: 0.510429	Recall: 0.505767	F-score: 0.505836 Loss: 21.0728
31.1
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.586667	Precision: 0.526543	Recall: 0.528889	F-score: 0.526602 Loss: 20.3371
33.3
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.586667	Precision: 0.530545	Recall: 0.536931	F-score: 0.532082 Loss: 19.4198
35.6
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=1318, style=ProgressStyle(description_width='…


Accuracy: 0.57	Precision: 0.509079	Recall: 0.516667	F-score: 0.511199 Loss: 19.5294
30.0
------------------------------------------

Wall time: 1h 43min 19s
