In [1]:
import comet_ml

import os
import collections

from transformers import BertTokenizer, BertModel
import torch
import numpy as np
import random

import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from ner.utils import create_dataset_and_document_dataloader
from ner.trainer import Trainer
from ner.model import DocumentContextBertBaseNER

from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

SEED = 693

"""torch.manual_seed(SEED)
random.seed(SEED)
numpy.random.seed(SEED)"""

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)

comet_ml is installed but `COMET_API_KEY` is not set.


In [2]:
torch.cuda.get_device_name(device=2)

'TITAN RTX'

In [3]:
TOKENIZER = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
torch.cuda.set_device(2)
DEVICE = 'cuda' if torch.cuda.is_available else 'cpu'
EPOCHS = 5
BATCH_SIZE = 32

In [4]:
train_dataset, train_documents, train_dataloader = create_dataset_and_document_dataloader('ontonotes', "data/ontonotes/train", batch_size=BATCH_SIZE, shuffle=False, tokenizer=TOKENIZER)
eval_dataset, eval_documents, eval_dataloader = create_dataset_and_document_dataloader('ontonotes', "data/ontonotes/development", batch_size=BATCH_SIZE, shuffle=False, tokenizer=TOKENIZER)
test_dataset, test_documents, test_dataloader = create_dataset_and_document_dataloader('ontonotes', "data/ontonotes/test", batch_size=BATCH_SIZE, shuffle=False, tokenizer=TOKENIZER)

In [5]:
eval_dataset.idx2tag = train_dataset.idx2tag
eval_dataset.tag2idx = train_dataset.tag2idx
test_dataset.idx2tag = train_dataset.idx2tag
test_dataset.tag2idx = train_dataset.tag2idx

### SEED 693

In [6]:
classes = len(train_dataset.ner_tags)

model = DocumentContextBertBaseNER(classes, DEVICE).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss(ignore_index=0).to(DEVICE)

params = {
    'model': 'Bert-Base-Cased',
    'corpus': 'ontonotes',
    'hidden_size': 768,
    'batch_size': 32,
    'shuffle_batch': False,
    'optimizer': 'AdamW',
    'learning_rate': 1e-5,
    'epochs': 5,
    'last_epoch_lstm': False,
    'seed': SEED
}

experiment = comet_ml.Experiment(api_key='fxEY7T7JQW6R5I9DkDazSYRpp', project_name='ner-with-nonlocal-features', workspace='ryzhtus')

trainer = Trainer(experiment, model, TOKENIZER, params, optimizer, criterion, None, False, EPOCHS, False, train_dataloader, eval_dataloader, test_dataloader,
                  train_documents, eval_documents, test_documents, train_dataset.tag2idx, train_dataset.idx2tag, DEVICE)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/ryzhtus/ner-with-nonlocal-features/42da3a6d2ab542dcae791ea7dbf3c68e



In [7]:
trainer.fit()

[1 / 5] Train: Loss = 0.24871, F1-score = 70.79%, Repeated Entities Accuracy = 61.99%: 100%|██████████| 3174/3174 [2:12:33<00:00,  2.51s/it]
[1 / 5] Eval : Loss = 0.11768, F1-score = 87.83%, Repeated Entities Accuracy = 83.99%: 100%|██████████| 433/433 [21:02<00:00,  2.91s/it]
[2 / 5] Train: Loss = 0.08741, F1-score = 89.36%, Repeated Entities Accuracy = 87.44%: 100%|██████████| 3174/3174 [2:06:37<00:00,  2.39s/it]
[2 / 5] Eval : Loss = 0.09484, F1-score = 90.30%, Repeated Entities Accuracy = 88.34%: 100%|██████████| 433/433 [21:02<00:00,  2.92s/it]
[3 / 5] Train: Loss = 0.06235, F1-score = 92.11%, Repeated Entities Accuracy = 92.33%: 100%|██████████| 3174/3174 [2:06:37<00:00,  2.39s/it]
[3 / 5] Eval : Loss = 0.09157, F1-score = 90.61%, Repeated Entities Accuracy = 88.62%: 100%|██████████| 433/433 [20:55<00:00,  2.90s/it]
[4 / 5] Train: Loss = 0.01184, F1-score = 0.00%, Repeated Entities Accuracy = 0.00%:  71%|███████   | 2239/3174 [1:29:01<35:09,  2.26s/it]  IOPub message rate exceede

In [8]:
trainer.test()

Test : Loss = 0.08477, F1-score = 89.90%, Repeated Entities Accuracy = 90.08%: 100%|██████████| 335/335 [16:48<00:00,  3.01s/it]


              precision    recall  f1-score   support

    CARDINAL       0.80      0.80      0.80      1122
        DATE       0.77      0.84      0.80      1898
       EVENT       0.47      0.53      0.50       113
         FAC       0.60      0.67      0.63       232
         GPE       0.93      0.89      0.91      3324
    LANGUAGE       0.93      0.58      0.72        24
         LAW       0.57      0.50      0.53        78
         LOC       0.58      0.70      0.63       269
       MONEY       0.80      0.90      0.85       494
        NORP       0.86      0.89      0.87      1010
     ORDINAL       0.77      0.83      0.80       193
         ORG       0.84      0.87      0.85      3060
     PERCENT       0.92      0.92      0.92       668
      PERSON       0.91      0.90      0.91      2938
     PRODUCT       0.49      0.60      0.54       134
    QUANTITY       0.68      0.66      0.67       201
        TIME       0.60      0.58      0.59       296
 WORK_OF_ART       0.59    

In [9]:
experiment.end()

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/ryzhtus/ner-with-nonlocal-features/42da3a6d2ab542dcae791ea7dbf3c68e
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     test_Test F1                      : 0.8990269224328801
COMET INFO:     test_Test Precision               : 0.8585033046691778
COMET INFO:     test_Test RE Accuracy             : 0.9008446566287184
COMET INFO:     test_Test Recall                  : 0.9435657098223004
COMET INFO:     train_Train F1 [5]                : (0.7079271200992082, 0.9511713606518263)
COMET INFO:     train_Train Precision [5]         : (0.7405927332079623, 0.9319636474467319)
COMET INFO:     train_Train RE Accuracy [5]       : (0.6199491437879932, 0.9614207954229409)
COMET INFO:     train_Train Recall [5]            : (0.6780213667057835, 0.971187474