In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import sys

In [2]:
import logging


logger = logging.getLogger('sequence_tagger_bert')

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

logger.handlers = []

fhandler = logging.handlers.TimedRotatingFileHandler(filename='logs.txt', when='midnight')
fhandler.setFormatter(formatter)
logger.addHandler(fhandler)

handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger.addHandler(handler)

logger.setLevel(logging.DEBUG)

In [3]:
import torch

device = torch.device('cuda')
n_gpu = torch.cuda.device_count()

for i in range(n_gpu):
    print(torch.cuda.get_device_name(i))

Tesla V100-DGXS-16GB


In [4]:
CACHE_DIR = '../workdir/cache'
#BATCH_SIZE = 16
BATCH_SIZE = 8
PRED_BATCH_SIZE = 1000
MAX_LEN = 128
#MAX_N_EPOCHS = 4
MAX_N_EPOCHS = 100
REDUCE_ON_PLATEAU = False
WEIGHT_DECAY = 0.01
#LEARNING_RATE = 3e-5
LEARNING_RATE = 1e-5

In [5]:
from flair.datasets import ColumnCorpus


data_folder = '../workdir/conll2003/eng'
corpus = ColumnCorpus(data_folder, 
                      {0 : 'text', 3 : 'ner'},
                      train_file='train.txt',
                      test_file='test.txt',
                      dev_file='dev.txt')

print(corpus.obtain_statistics())

2019-09-26 21:45:56,524 Reading data from ../workdir/conll2003/eng
2019-09-26 21:45:56,524 Train: ../workdir/conll2003/eng/train.txt
2019-09-26 21:45:56,525 Dev: ../workdir/conll2003/eng/dev.txt
2019-09-26 21:45:56,525 Test: ../workdir/conll2003/eng/test.txt
{
    "TRAIN": {
        "dataset": "TRAIN",
        "total_number_of_documents": 14987,
        "number_of_documents_per_class": {},
        "number_of_tokens_per_tag": {},
        "number_of_tokens": {
            "total": 204567,
            "min": 1,
            "max": 113,
            "avg": 13.649629679055181
        }
    },
    "TEST": {
        "dataset": "TEST",
        "total_number_of_documents": 3684,
        "number_of_documents_per_class": {},
        "number_of_tokens_per_tag": {},
        "number_of_tokens": {
            "total": 46666,
            "min": 1,
            "max": 124,
            "avg": 12.667209554831704
        }
    },
    "DEV": {
        "dataset": "DEV",
        "total_number_of_documents": 346

In [None]:
from bert_sequence_tagger import SequenceTaggerBert, BertForTokenClassificationCustom, create_optimizer
from pytorch_transformers import BertTokenizer, BertForTokenClassification

from bert_sequence_tagger.bert_utils import make_bert_tag_dict_from_flair_corpus


bpe_tokenizer = BertTokenizer.from_pretrained('bert-base-cased', cache_dir=CACHE_DIR, do_lower_case=False)

idx2tag, tag2idx = make_bert_tag_dict_from_flair_corpus(corpus)

model = BertForTokenClassificationCustom.from_pretrained('bert-base-cased', cache_dir=CACHE_DIR, num_labels=len(tag2idx)).cuda()
#model = BertForTokenClassification.from_pretrained('bert-base-cased', cache_dir=CACHE_DIR, num_labels=len(tag2idx)).cuda()

seq_tagger = SequenceTaggerBert(bert_model=model, bpe_tokenizer=bpe_tokenizer, 
                                idx2tag=idx2tag, tag2idx=tag2idx, max_len=MAX_LEN)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [None]:
from torch.utils.data.dataset import Dataset
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from bert_sequence_tagger.bert_utils import prepare_flair_corpus
from bert_sequence_tagger.model_trainer_bert import ModelTrainerBert


collate_fn = lambda inpt: tuple(zip(*inpt))



train_dataset = prepare_flair_corpus(corpus.train)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, 
                              sampler=train_sampler, 
                              batch_size=BATCH_SIZE,
                              collate_fn=collate_fn)

val_dataset = prepare_flair_corpus(corpus.dev)
val_sampler = SequentialSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, 
                            sampler=val_sampler, 
                            batch_size=PRED_BATCH_SIZE,
                            collate_fn=collate_fn)

optimizer = create_optimizer(model, full_finetuning=True, 
                             weight_decay=WEIGHT_DECAY, 
                             lr_body=LEARNING_RATE, 
                             t_total=(len(corpus.train) / BATCH_SIZE)*MAX_N_EPOCHS)

trainer = ModelTrainerBert(model=seq_tagger, 
                           optimizer=optimizer, 
                           train_dataloader=train_dataloader, 
                           val_dataloader=val_dataloader,
                           patience=2,
                           reduce_on_plateau=False, 
                           number_of_steps=(len(corpus.train) / BATCH_SIZE)*MAX_N_EPOCHS)

trainer.train(epochs=MAX_N_EPOCHS)

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

2019-09-26 21:50:49,117 - sequence_tagger_bert - INFO - Train loss: 0.8786812842545654
2019-09-26 21:50:49,117 Train loss: 0.8786812842545654
2019-09-26 21:50:58,612 - sequence_tagger_bert - INFO - Validation loss: 0.16806642711162567
2019-09-26 21:50:58,612 Validation loss: 0.16806642711162567
2019-09-26 21:50:58,613 - sequence_tagger_bert - INFO - Validation F1-Score: (0.6880476393024244, 0.7446436115637636)
2019-09-26 21:50:58,613 Validation F1-Score: (0.6880476393024244, 0.7446436115637636)


Epoch:   1%|          | 1/100 [04:43<7:48:27, 283.91s/it]

2019-09-26 21:55:33,984 - sequence_tagger_bert - INFO - Train loss: 0.11844969308136374
2019-09-26 21:55:33,984 Train loss: 0.11844969308136374
2019-09-26 21:55:43,250 - sequence_tagger_bert - INFO - Validation loss: 0.06370066478848457
2019-09-26 21:55:43,250 Validation loss: 0.06370066478848457
2019-09-26 21:55:43,252 - sequence_tagger_bert - INFO - Validation F1-Score: (0.8778575004171534, 0.9081668419207851)
2019-09-26 21:55:43,252 Validation F1-Score: (0.8778575004171534, 0.9081668419207851)


Epoch:   2%|▏         | 2/100 [09:28<7:44:04, 284.13s/it]

2019-09-26 22:00:09,026 - sequence_tagger_bert - INFO - Train loss: 0.06332421275229641
2019-09-26 22:00:09,026 Train loss: 0.06332421275229641
2019-09-26 22:00:18,244 - sequence_tagger_bert - INFO - Validation loss: 0.04755143076181412
2019-09-26 22:00:18,244 Validation loss: 0.04755143076181412
2019-09-26 22:00:18,246 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9159066186929964, 0.9388895355604702)
2019-09-26 22:00:18,246 Validation F1-Score: (0.9159066186929964, 0.9388895355604702)


Epoch:   3%|▎         | 3/100 [14:03<7:34:54, 281.39s/it]

2019-09-26 22:04:41,957 - sequence_tagger_bert - INFO - Train loss: 0.043285651388034495
2019-09-26 22:04:41,957 Train loss: 0.043285651388034495
2019-09-26 22:04:51,073 - sequence_tagger_bert - INFO - Validation loss: 0.038951643742620945
2019-09-26 22:04:51,073 Validation loss: 0.038951643742620945
2019-09-26 22:04:51,074 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9315344004022459, 0.9501452643811737)
2019-09-26 22:04:51,074 Validation F1-Score: (0.9315344004022459, 0.9501452643811737)


Epoch:   4%|▍         | 4/100 [18:36<7:26:06, 278.82s/it]

2019-09-26 22:09:23,211 - sequence_tagger_bert - INFO - Train loss: 0.029549416832729023
2019-09-26 22:09:23,211 Train loss: 0.029549416832729023
2019-09-26 22:09:32,482 - sequence_tagger_bert - INFO - Validation loss: 0.03621016349643469
2019-09-26 22:09:32,482 Validation loss: 0.03621016349643469
2019-09-26 22:09:32,484 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9390560818174196, 0.9563752541388324)
2019-09-26 22:09:32,484 Validation F1-Score: (0.9390560818174196, 0.9563752541388324)


Epoch:   5%|▌         | 5/100 [23:17<7:22:41, 279.60s/it]

2019-09-26 22:14:04,638 - sequence_tagger_bert - INFO - Train loss: 0.021122847122169663
2019-09-26 22:14:04,638 Train loss: 0.021122847122169663
2019-09-26 22:14:13,775 - sequence_tagger_bert - INFO - Validation loss: 0.03441258333623409
2019-09-26 22:14:13,775 Validation loss: 0.03441258333623409
2019-09-26 22:14:13,777 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9458641560188299, 0.9590861405758246)
2019-09-26 22:14:13,777 Validation F1-Score: (0.9458641560188299, 0.9590861405758246)


Epoch:   6%|▌         | 6/100 [27:59<7:18:50, 280.11s/it]

2019-09-26 22:18:52,349 - sequence_tagger_bert - INFO - Train loss: 0.016115604575272863
2019-09-26 22:18:52,349 Train loss: 0.016115604575272863
2019-09-26 22:19:01,461 - sequence_tagger_bert - INFO - Validation loss: 0.03568050405010581
2019-09-26 22:19:01,461 Validation loss: 0.03568050405010581
2019-09-26 22:19:01,463 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9454514937898624, 0.9601209161725381)
2019-09-26 22:19:01,463 Validation F1-Score: (0.9454514937898624, 0.9601209161725381)


Epoch:   7%|▋         | 7/100 [32:46<7:17:40, 282.38s/it]

2019-09-26 22:23:35,784 - sequence_tagger_bert - INFO - Train loss: 0.012142518383372842
2019-09-26 22:23:35,784 Train loss: 0.012142518383372842
2019-09-26 22:23:44,954 - sequence_tagger_bert - INFO - Validation loss: 0.04033001232892275
2019-09-26 22:23:44,954 Validation loss: 0.04033001232892275
2019-09-26 22:23:44,956 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9447852760736196, 0.9585323238206174)
2019-09-26 22:23:44,956 Validation F1-Score: (0.9447852760736196, 0.9585323238206174)


Epoch:   8%|▊         | 8/100 [37:30<7:13:29, 282.71s/it]

2019-09-26 22:28:16,775 - sequence_tagger_bert - INFO - Train loss: 0.009664351197577093
2019-09-26 22:28:16,775 Train loss: 0.009664351197577093
2019-09-26 22:28:25,943 - sequence_tagger_bert - INFO - Validation loss: 0.042247312143445015
2019-09-26 22:28:25,943 Validation loss: 0.042247312143445015
2019-09-26 22:28:25,945 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9427275780987573, 0.9585846867749419)
2019-09-26 22:28:25,945 Validation F1-Score: (0.9427275780987573, 0.9585846867749419)


Epoch:   9%|▉         | 9/100 [42:11<7:07:59, 282.19s/it]

2019-09-26 22:32:56,639 - sequence_tagger_bert - INFO - Train loss: 0.008162430261471522
2019-09-26 22:32:56,639 Train loss: 0.008162430261471522
2019-09-26 22:33:06,393 - sequence_tagger_bert - INFO - Validation loss: 0.038676043041050434
2019-09-26 22:33:06,393 Validation loss: 0.038676043041050434
2019-09-26 22:33:06,395 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9482412060301508, 0.9627696590118302)
2019-09-26 22:33:06,395 Validation F1-Score: (0.9482412060301508, 0.9627696590118302)


Epoch:  10%|█         | 10/100 [46:51<7:02:30, 281.67s/it]

2019-09-26 22:37:39,139 - sequence_tagger_bert - INFO - Train loss: 0.006518807236741245
2019-09-26 22:37:39,139 Train loss: 0.006518807236741245
2019-09-26 22:37:48,234 - sequence_tagger_bert - INFO - Validation loss: 0.04137409431859851
2019-09-26 22:37:48,234 Validation loss: 0.04137409431859851
2019-09-26 22:37:48,236 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9498069498069499, 0.9616793626795371)
2019-09-26 22:37:48,236 Validation F1-Score: (0.9498069498069499, 0.9616793626795371)


Epoch:  11%|█         | 11/100 [51:33<6:57:53, 281.72s/it]

2019-09-26 22:42:14,409 - sequence_tagger_bert - INFO - Train loss: 0.004637632620961292
2019-09-26 22:42:14,409 Train loss: 0.004637632620961292
2019-09-26 22:42:23,606 - sequence_tagger_bert - INFO - Validation loss: 0.04289621999487281
2019-09-26 22:42:23,606 Validation loss: 0.04289621999487281
2019-09-26 22:42:23,608 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9460525211846631, 0.9608925562205822)
2019-09-26 22:42:23,608 Validation F1-Score: (0.9460525211846631, 0.9608925562205822)


Epoch:  12%|█▏        | 12/100 [56:08<6:50:23, 279.82s/it]

2019-09-26 22:46:54,742 - sequence_tagger_bert - INFO - Train loss: 0.00398789727350876
2019-09-26 22:46:54,742 Train loss: 0.00398789727350876
2019-09-26 22:47:03,953 - sequence_tagger_bert - INFO - Validation loss: 0.04656124021857977
2019-09-26 22:47:03,953 Validation loss: 0.04656124021857977
2019-09-26 22:47:03,955 - sequence_tagger_bert - INFO - Validation F1-Score: (0.9453446394089497, 0.9603724178062264)
2019-09-26 22:47:03,955 Validation F1-Score: (0.9453446394089497, 0.9603724178062264)


Epoch:  13%|█▎        | 13/100 [1:00:49<6:45:57, 279.98s/it]

In [11]:
test_dataset = prepare_flair_corpus(corpus.test)
test_sampler = SequentialSampler(test_dataset)
test_dataloader = DataLoader(test_dataset, 
                              sampler=test_sampler, 
                              batch_size=PRED_BATCH_SIZE,
                              collate_fn=collate_fn)
_, test_loss, test_f1 = seq_tagger.predict(test_dataloader, evaluate=True)
test_f1

(0.9143007822800387, 0.9306361914074436)

In [None]:
(0.9143007822800387, 0.9306361914074436)