In [1]:
import os,sys,inspect
current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir) 
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers.modeling_utils import (WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
                             SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
from transformers import XLNetTokenizer, XLNetForSequenceClassification, XLNetPreTrainedModel, XLNetModel
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from torch.utils.data.dataset import ConcatDataset
from XLNet import XLNetForMultiSequenceClassification, Dataset_multi, Dataset_3Way, get_predictions, Dataset_3Way_test

import pandas as pd
import numpy as np
import random
from IPython.display import clear_output
from tqdm.notebook import tqdm, trange

In [2]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
trainset = Dataset_3Way("RTE5_train", tokenizer=tokenizer, three_tasks=False)

In [3]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def create_mini_batch(samples):
    task = [s[0] for s in samples]
    tokens_tensors = [s[1].squeeze(0) for s in samples]
    segments_tensors = [s[2].squeeze(0) for s in samples]
    masks_tensors = [s[3].squeeze(0) for s in samples]
    if samples[0][4] is not None:
        label_ids = torch.stack([s[4] for s in samples])
    else:
        label_ids = None
    # zero pad 到同一序列長度
    tokens_tensors = pad_sequence(tokens_tensors, batch_first=True)
    segments_tensors = pad_sequence(segments_tensors, batch_first=True)
    masks_tensors = pad_sequence(masks_tensors, batch_first=True)

    return tokens_tensors.squeeze(1), segments_tensors.squeeze(1), masks_tensors.squeeze(1), label_ids


# 初始化回傳訓練樣本的 DataLoader
# 利用 `collate_fn` 將 list of samples 合併成一個 mini-batch 

trainloader = DataLoader(trainset, batch_size=1,collate_fn=create_mini_batch, shuffle=True)

In [4]:
def create_mini_batch_test(samples):
    tokens_tensors = [s[0].squeeze(0) for s in samples]
    segments_tensors = [s[1].squeeze(0) for s in samples]
    masks_tensors = [s[2].squeeze(0) for s in samples]
    if samples[0][3] is not None:
        label_ids = torch.stack([s[3] for s in samples])
    else:
        label_ids = None
    # zero pad 到同一序列長度
    tokens_tensors = pad_sequence(tokens_tensors, batch_first=True)
    segments_tensors = pad_sequence(segments_tensors, batch_first=True)
    masks_tensors = pad_sequence(masks_tensors, batch_first=True)

    return tokens_tensors.squeeze(1), segments_tensors.squeeze(1), masks_tensors.squeeze(1), label_ids

In [5]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

PRETRAINED_MODEL_NAME = "xlnet-base-cased"
model = XLNetForMultiSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME,
                                                            output_attentions=True,
                                                            dropout=0.1)

In [6]:
from torch.optim import AdamW

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)
model = model.to(device)

device: cuda:0


In [8]:
%%time
EPOCHS = 20
batch_size = 8
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=EPOCHS * ((len(trainloader)//batch_size)))
epochs_trained = 0

model.zero_grad()
train_iterator = trange(epochs_trained, EPOCHS, desc="Epoch")
set_seed(42)

for _ in train_iterator:
    epoch_iterator = tqdm(trainloader, desc="Iteration")
    
    model.train()
    running_loss = 0.0
    batch_cnt = 1
    loss = torch.zeros(1).to(device)
    
    for step, data in enumerate(epoch_iterator):
        task = 0
        tokens_tensors, segments_tensors, masks_tensors, labels = [t.to(device) for t in data]
        # forward pass
        outputs = model(input_ids=tokens_tensors, 
                        token_type_ids=segments_tensors, 
                        attention_mask=masks_tensors, 
                        labels=labels,
                        task=task,
                       )
        batch_cnt += 1
        loss = outputs[0]/batch_size
        loss.backward()
        
        if batch_cnt >= batch_size:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            batch_cnt = 0

        # 紀錄當前 batch loss
        running_loss += loss.item()
    epochs_trained += 1
        
    testset = Dataset_3Way_test("RTE5_test", tokenizer=tokenizer)
    testloader = DataLoader(testset, batch_size=1, 
                     collate_fn=create_mini_batch_test)
    predictions = get_predictions(model, testloader)

    df_pred = pd.DataFrame({"label": predictions.tolist()})
        
    pred_Y = df_pred['label'].values
    test_Y = pd.read_csv("../data/RTE5_test.tsv", sep='\t').fillna("")['label'].values

    accuracy = accuracy_score(test_Y, np.array(pred_Y))
    precision = precision_score(test_Y, pred_Y, average='macro')
    recall = recall_score(test_Y, pred_Y, average='macro')
    fscore = f1_score(test_Y, pred_Y, average='macro')
    
    CNT = 0
    TOTAL = 0
    for i in range(len(test_Y)):
        if test_Y[i] == 2:
            TOTAL += 1
        else:
            pass
        if test_Y[i] == 2 and predictions[i] == 2:
            CNT += 1
    contra = round((CNT/TOTAL)*100,1)
    if contra > 20 and accuracy > 0.55:
        torch.save(model, "single_task_%g, %g.pkl" % (round(accuracy, 2), contra))
    print("Accuracy: %g\tPrecision: %g\tRecall: %g\tF-score: %g Loss: %g" % (accuracy, precision, recall, fscore, running_loss))
    print(contra)
    print("------------------------------------------")
        


HBox(children=(IntProgress(value=0, description='Epoch', max=20, style=ProgressStyle(description_width='initia…

HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…




Accuracy: 0.503333	Precision: 0.667504	Recall: 0.338624	F-score: 0.233436 Loss: 65.5674
1.1
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…




  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)


Accuracy: 0.505	Precision: 0.390011	Recall: 0.338571	F-score: 0.235314 Loss: 60.944
0.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.501667	Precision: 0.351224	Recall: 0.40873	F-score: 0.361736 Loss: 56.8543
0.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.481667	Precision: 0.453751	Recall: 0.444709	F-score: 0.414771 Loss: 44.695
18.9
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…




  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Accuracy: 0.553333	Precision: 0.493669	Recall: 0.492434	F-score: 0.481744 Loss: 27.0861
21.1
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.585	Precision: 0.506329	Recall: 0.48619	F-score: 0.489724 Loss: 12.721
20.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.565	Precision: 0.468424	Recall: 0.460899	F-score: 0.458997 Loss: 6.10726
12.2
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.555	Precision: 0.487032	Recall: 0.484444	F-score: 0.48426 Loss: 2.1431
23.3
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.531667	Precision: 0.463452	Recall: 0.467354	F-score: 0.455438 Loss: 1.05001
17.8
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.546667	Precision: 0.48344	Recall: 0.475714	F-score: 0.477418 Loss: 2.37527
26.7
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.561667	Precision: 0.452897	Recall: 0.452804	F-score: 0.445315 Loss: 0.789703
8.9
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.588333	Precision: 0.510505	Recall: 0.500053	F-score: 0.502415 Loss: 0.383393
21.1
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.573333	Precision: 0.501387	Recall: 0.487672	F-score: 0.490783 Loss: 0.283634
21.1
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.543333	Precision: 0.475712	Recall: 0.475556	F-score: 0.470004 Loss: 0.307446
20.0
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.571667	Precision: 0.498618	Recall: 0.487513	F-score: 0.490365 Loss: 0.237164
21.1
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.56	Precision: 0.478269	Recall: 0.475291	F-score: 0.474112 Loss: 0.197562
17.8
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.57	Precision: 0.476429	Recall: 0.466349	F-score: 0.465496 Loss: 0.184482
13.3
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.586667	Precision: 0.524394	Recall: 0.506984	F-score: 0.510659 Loss: 0.186831
23.3
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.573333	Precision: 0.493103	Recall: 0.476825	F-score: 0.478406 Loss: 0.167453
16.7
------------------------------------------


HBox(children=(IntProgress(value=0, description='Iteration', max=500, style=ProgressStyle(description_width='i…


Accuracy: 0.558333	Precision: 0.486064	Recall: 0.475291	F-score: 0.478216 Loss: 0.180338
21.1
------------------------------------------

Wall time: 48min 32s
