In [1]:
import torch
dir_train = 'dataset/vlsp21/train.pkl'
dir_dev = 'dataset/vlsp21/dev.pkl'
dir_test = 'dataset/vlsp21/test.pkl'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import time
from transformers import BertTokenizer, BertConfig, BertModel
from TorchCRF import CRF
import torch.nn as nn
import torch
import warnings
warnings.filterwarnings("ignore")
import pickle

from utils.bert_loader import BERT_DATALOADER
from utils.bert_evaluate import BERT_EVALUATE
from utils.bert_predict import BERT_PREDICT
from utils.bert_visualize import BERT_VISUALIZE
from utils.bert_model import BERT_4_CRF, BERT_4_SOFTMAX

In [3]:
start = time.time()
tokenizer = BertTokenizer.from_pretrained('mBert_Tokenizer', do_lower_case=False,use_fast=False)
config = BertConfig.from_pretrained('mBert_Config', output_hidden_states=True)
config.max_position_embeddings = 512

bert_model = BertModel.from_pretrained(
                        'mBert_Model',
                        config=config,
                        add_pooling_layer=False
)
print("Loading: ",time.time()-start, " S")

file mBert_Tokenizer\config.json not found


Loading:  3.612210988998413  S


In [4]:
def read_dataset(dir_train):
    with open(dir_train ,'rb') as f:
        _data = pickle.load(f)
    return _data
tag = ['ADDRESS','SKILL','EMAIL','PERSON','PHONENUMBER','MISCELLANEOUS','QUANTITY','PERSONTYPE',
              'ORGANIZATION','PRODUCT','IP','LOCATION','O','DATETIME','EVENT', 'URL']

In [5]:
data_train = read_dataset(dir_train)
TRAIN_SET = BERT_DATALOADER(data_train, tokenizer, tag, device)
train_dataloader = TRAIN_SET.create_dataloader()

data_dev = read_dataset(dir_dev)
DEV_SET = BERT_DATALOADER(data_dev, tokenizer, tag, device)
dev_dataloader = DEV_SET.create_dataloader()

data_test = read_dataset(dir_test)
TEST_SET = BERT_DATALOADER(data_test, tokenizer, tag, device)
test_dataloader = TEST_SET.create_dataloader()

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    '''Multi-class Focal loss implementation'''
    def __init__(self, gamma=2, weight=None,ignore_index=-100):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        self.ignore_index=ignore_index

    def forward(self, input, target):
        """
        input: [N, C]
        target: [N, ]
        """
        logpt = F.log_softmax(input, dim=1)
        pt = torch.exp(logpt)
        logpt = (1-pt)**self.gamma * logpt
        loss = F.nll_loss(logpt, target, self.weight,ignore_index=self.ignore_index)
        return loss


class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, eps=0.1, reduction='mean',ignore_index=-100):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.eps = eps
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, output, target):
        c = output.size()[-1]
        log_preds = F.log_softmax(output, dim=-1)
        if self.reduction=='sum':
            loss = -log_preds.sum()
        else:
            loss = -log_preds.sum(dim=-1)
            if self.reduction=='mean':
                loss = loss.mean()
        return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
                                                           ignore_index=self.ignore_index)

In [7]:
# Define bert 4 layer
class BERT_LSTM_SOFTMAX(nn.Module):
    def __init__(self, bert_model, num_labels):
        super(BERT_LSTM_SOFTMAX, self).__init__()
        self.num_labels = num_labels
        self.bert = bert_model
        self.classifier_1 = nn.Linear(768, num_labels)
        self.dropout_1 = nn.Dropout(0.2)
        self.classifier_2 = nn.Linear(768, num_labels)
        self.dropout_2 = nn.Dropout(0.2)
        self.classifier_3 = nn.Linear(768, num_labels)
        self.dropout_3 = nn.Dropout(0.2)
        self.classifier_4 = nn.Linear(768, num_labels)
        self.dropout_4 = nn.Dropout(0.2)
        self.classifier = nn.Linear(256 + 768, num_labels)
        self.dropout = nn.Dropout(0.25)
        
        self.lstm = nn.LSTM(input_size = 4*num_labels ,hidden_size = 256 , num_layers = 4*num_labels, dropout = 0.2)
        

        
        
    def forward_custom(self, input_ids, attention_mask=None, 
                       head_mask=None, labels=None):
        outputs = self.bert(input_ids = input_ids, attention_mask=attention_mask)
        out_bert = outputs[1][-1]
        out_1 = self.dropout_1(outputs[1][-1])
        out_2 = self.dropout_1(outputs[1][-2])
        out_3 = self.dropout_1(outputs[1][-3])
        out_4 = self.dropout_1(outputs[1][-4])
        sequence_1 = self.classifier_1(out_1)
        sequence_2 = self.classifier_1(out_2)
        sequence_3 = self.classifier_1(out_3)
        sequence_4 = self.classifier_1(out_4)
        
        sequence_output = torch.cat((sequence_1, sequence_2, sequence_3, sequence_4),-1)
        lstm_output = self.lstm(sequence_output)
        sequence_output = torch.cat((lstm_output[0], out_bert),-1)
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output) # bsz, seq_len, num_labels
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        
        if labels is not None:
            loss_fct = FocalLoss(ignore_index=0)
            if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                    active_logits = logits.view(-1, self.num_labels)[active_loss]
                    active_labels = labels.view(-1)[active_loss]
                    loss = loss_fct(active_logits, active_labels)
            else:
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs
        return outputs  

In [8]:
PATH = 'model/bert_lstm_softmax_ner.pt'
model = BERT_LSTM_SOFTMAX(bert_model, num_labels=len(TRAIN_SET.tag2idx))
model.load_state_dict(torch.load(PATH), strict=False)
model.to(device)

BERT_LSTM_SOFTMAX(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

In [11]:
BERT_EVALUATE(model=model, tokenizer=tokenizer, dataloader=dev_dataloader, tag_values= ['PAD']+tag, device = device, type_dataset='dev', model_type='softmax')

Progress Bar: 100%|████████████████████████████████████████████████████████████████████| 81/81 [01:55<00:00,  1.43s/it]


Validation loss: 0.14295242145013662
Validation F1-Score: 0.6949559409847702
               precision    recall  f1-score   support

      ADDRESS     0.3226    0.4444    0.3738        45
       PERSON     0.8987    0.9580    0.9274      4428
  PHONENUMBER     0.6190    1.0000    0.7647        26
MISCELLANEOUS     0.4049    0.5038    0.4490       262
     QUANTITY     0.6110    0.8175    0.6993      2263
   PERSONTYPE     0.4655    0.8238    0.5949      1016
 ORGANIZATION     0.7960    0.8320    0.8136      5440
      PRODUCT     0.4539    0.5383    0.4925      1553
     LOCATION     0.7736    0.9077    0.8353      4313
            O     0.9840    0.9486    0.9659    124694
     DATETIME     0.5847    0.8720    0.7000      1852
        EVENT     0.4531    0.6259    0.5257       802
          URL     0.8056    1.0000    0.8923        87

    micro avg     0.9327    0.9325    0.9326    146781
    macro avg     0.6287    0.7902    0.6950    146781
 weighted avg     0.9439    0.9325    0.9

In [12]:
BERT_EVALUATE(model=model, tokenizer=tokenizer, dataloader=test_dataloader, tag_values= ['PAD']+tag, device = device, type_dataset='test', model_type='softmax')

Progress Bar: 100%|██████████████████████████████████████████████████████████████████| 133/133 [03:12<00:00,  1.44s/it]


Validation loss: 0.12528892082715393
Validation F1-Score: 0.6476554072847293
               precision    recall  f1-score   support

      ADDRESS     0.1296    0.3333    0.1867        21
        EMAIL     0.5882    1.0000    0.7407        10
       PERSON     0.9329    0.9482    0.9405      9790
  PHONENUMBER     0.5000    0.6250    0.5556         8
MISCELLANEOUS     0.1442    0.2997    0.1947       297
     QUANTITY     0.8109    0.7851    0.7978      7739
   PERSONTYPE     0.5709    0.4401    0.4970      4183
 ORGANIZATION     0.7690    0.7670    0.7680      8808
      PRODUCT     0.4732    0.4833    0.4782      4329
     LOCATION     0.7396    0.8900    0.8079      4590
            O     0.9755    0.9662    0.9708    257459
     DATETIME     0.8229    0.9512    0.8824      4689
        EVENT     0.2200    0.6901    0.3336       568
          URL     0.8646    0.9677    0.9133       495

    micro avg     0.9388    0.9385    0.9386    302986
    macro avg     0.6101    0.7248    0.6