In [1]:
from pytorch_transformers import WEIGHTS_NAME, AdamW, BertConfig, BertTokenizer, WarmupLinearSchedule
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import numpy as np
import logging, os, sys, torch
import torch.nn.functional as F
from tqdm import tqdm, trange
from seqeval.metrics import classification_report
from processor import *
from model import *

In [2]:
raw_data_path = 'data/'
model_path = 'model'
max_seq_length = 128
do_train = False
do_eval = True
train_batch_size = 32
eval_batch_size = 8
num_train_epochs = 3
max_grad_norm = 1
gradient_accumulation_steps = 1

device = torch.device("cuda")
if not os.path.exists(model_path):
    os.makedirs(model_path)

processor = DataProcessor()
label_list = ["O", "B-MISC", "I-MISC",  "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "[CLS]", "[SEP]"]
label_map = {i : label for i, label in enumerate(label_list,1)}
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

config = BertConfig.from_pretrained('bert-base-cased', num_labels=len(label_list) + 1, finetuning_task='ner')
model = Ner.from_pretrained('bert-base-cased', from_tf = False, config = config)

model.to(device)

param_optimizer = list(model.named_parameters())
no_decay = ['bias','LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}        ]
optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=131, t_total=1314)

In [4]:
if do_train:
    train_features = from_raw_to_feature(processor.get_train_sample(raw_data_path), label_list, max_seq_length, tokenizer)
    train_data = TensorDataset(torch.tensor([f.token_id for f in train_features], dtype=torch.long), 
                                torch.tensor([f.token_mask for f in train_features], dtype=torch.long), 
                                torch.tensor([f.label_id for f in train_features], dtype=torch.long),
                                torch.tensor([f.valid_id for f in train_features], dtype=torch.long),
                                torch.tensor([f.label_mask for f in train_features], dtype=torch.long))
        
    train_dataloader = DataLoader(train_data, sampler=RandomSampler(train_data), batch_size=train_batch_size)

    model.train()
    for _ in trange(int(num_train_epochs), desc="Epoch"):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
            batch = tuple(t.to(device) for t in batch)
            token_id, token_mask, segment_id, label_id, valid_id,l_mask = batch
            loss = model(token_id, segment_id, token_mask, label_id,valid_id,l_mask)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()
            scheduler.step()
            model.zero_grad()

    model.save_pretrained(model_path)
    tokenizer.save_pretrained(model_path)
    label_map = {i : label for i, label in enumerate(label_list,1)}
else:
    model = Ner.from_pretrained(model_path)
    tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=False)

model.to(device)

Ner(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        

In [5]:
if do_eval:
    eval_features = from_raw_to_feature(processor.get_dev_sample(raw_data_path), label_list, max_seq_length, tokenizer)
    eval_data = TensorDataset(torch.tensor([f.token_id for f in eval_features], dtype=torch.long), 
                                torch.tensor([f.token_mask for f in eval_features], dtype=torch.long), 
                                torch.tensor([f.segment_id for f in eval_features], dtype=torch.long), 
                                torch.tensor([f.label_id for f in eval_features], dtype=torch.long), 
                                torch.tensor([f.valid_id for f in eval_features], dtype=torch.long), 
                                torch.tensor([f.label_mask for f in eval_features], dtype=torch.long))
        
    eval_dataloader = DataLoader(eval_data, sampler=SequentialSampler(eval_data), batch_size=eval_batch_size)
    model.eval()
    y_true = []
    y_pred = []
    label_map = {i : label for i, label in enumerate(label_list,1)}
    for token_id, token_mask, segment_id, label_id, valid_id, l_mask in tqdm(eval_dataloader, desc="Evaluating"):
        token_id = token_id.to(device)
        token_mask = token_mask.to(device)
        segment_id = segment_id.to(device)
        valid_id = valid_id.to(device)
        label_id = label_id.to(device)
        l_mask = l_mask.to(device)

        with torch.no_grad():
            network_output = model(token_id, segment_id, token_mask,valid_id=valid_id,attention_mask_label=l_mask)

        network_output = torch.argmax(F.log_softmax(network_output,dim=2),dim=2)
        network_output = network_output.detach().cpu().numpy()
        label_id = label_id.to('cpu').numpy()
        token_mask = token_mask.to('cpu').numpy()

        for i, label in enumerate(label_id):
            sentence_label_true = []
            sentence_label_pred = []
            for j,m in enumerate(label):
                if j == 0:
                    continue
                elif label_id[i][j] == len(label_map):
                    y_true.append(sentence_label_true)
                    y_pred.append(sentence_label_pred)
                    break
                else:
                    sentence_label_true.append(label_map[label_id[i][j]])
                    sentence_label_pred.append(label_map[network_output[i][j]])

    report = classification_report(y_true, y_pred,digits=4)
    phase_eval('PER', y_pred, y_true)
    phase_eval('ORG', y_pred, y_true)
    phase_eval('LOC', y_pred, y_true)
    phase_eval('MISC', y_pred, y_true)

    logging.basicConfig(level = logging.INFO)
    logger = logging.getLogger(__name__)
    logger.info("\n%s", report)

Evaluating: 100%|██████████| 407/407 [00:31<00:00, 12.93it/s]
  _warn_prf(average, modifier, msg_start, len(result))
INFO:__main__:
              precision    recall  f1-score   support

         LOC     0.9657    0.9657    0.9657      1837
        MISC     0.8937    0.9121    0.9028       922
         ORG     0.9238    0.9306    0.9272      1341
         PER     0.9651    0.9750    0.9700      1842
        SEP]     0.0000    0.0000    0.0000         0

   micro avg     0.9441    0.9524    0.9482      5942
   macro avg     0.7497    0.7567    0.7532      5942
weighted avg     0.9449    0.9524    0.9486      5942



PER by ture/total=1750/1805=0.9695290858725761
ORG by ture/total=1236/1267=0.9755327545382794
LOC by ture/total=1727/1777=0.971862689926843
MISC by ture/total=819/859=0.9534342258440046
