In [1]:
import comet_ml

import os
import collections

from transformers import BertTokenizer
import torch
import numpy as np
import random

import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from ner.utils import create_dataset_and_document_dataloader
from ner.trainer import Trainer
from ner.model import DocumentContextBertBaseNER, BertNERBiLSTM

from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

SEED = 693

"""torch.manual_seed(SEED)
random.seed(SEED)
numpy.random.seed(SEED)"""

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(SEED)

comet_ml is installed but `COMET_API_KEY` is not set.


In [2]:
torch.cuda.get_device_name(device=0)

'TITAN V'

In [3]:
# TOKENIZER = RobertaTokenizer.from_pretrained("roberta-base", do_lower_case=False)
TOKENIZER = BertTokenizer.from_pretrained("bert-base-cased", do_lower_case=False)
torch.cuda.set_device(1)
DEVICE = 'cuda' if torch.cuda.is_available else 'cpu'
EPOCHS = 4
BATCH_SIZE = 32

In [4]:
train_dataset, train_documents, train_dataloader = create_dataset_and_document_dataloader('conll', "data/conll2003/train.txt", batch_size=BATCH_SIZE, shuffle=True, tokenizer=TOKENIZER)
eval_dataset, eval_documents, eval_dataloader = create_dataset_and_document_dataloader('conll', "data/conll2003/valid.txt", batch_size=BATCH_SIZE, shuffle=True, tokenizer=TOKENIZER)
test_dataset, test_documents, test_dataloader = create_dataset_and_document_dataloader('conll', "data/conll2003/test.txt", batch_size=BATCH_SIZE, shuffle=False, tokenizer=TOKENIZER)

In [5]:
eval_dataset.idx2tag = train_dataset.idx2tag
eval_dataset.tag2idx = train_dataset.tag2idx
test_dataset.idx2tag = train_dataset.idx2tag
test_dataset.tag2idx = train_dataset.tag2idx

### SEED 693

In [6]:
classes = len(train_dataset.ner_tags)

#model = DocumentContextRobertaLargeNER(classes, TOKENIZER, DEVICE).to(DEVICE)
model = BertNERBiLSTM(classes).to(DEVICE)
# model = DocumentContextBertBaseNER(classes, DEVICE).to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0).to(DEVICE)

params = {
    'model': 'BERT-Base + BiLSTM',
    'corpus': 'CoNLL',
    'document_context': False, 
    'optimize_document_calculation': False,
    'detach_additional_context': False, 
    'hidden_size': 768,
    'batch_size': 32,
    'shuffle_batch': True,
    'optimizer': 'AdamW',
    'learning_rate': 1e-4,
    'epochs': 4,
    'last_epoch_lstm': False,
    'seed': SEED
}

experiment = comet_ml.Experiment(api_key='fxEY7T7JQW6R5I9DkDazSYRpp', project_name='ner-with-nonlocal-features', workspace='ryzhtus')

trainer = Trainer(experiment, model, TOKENIZER, params, optimizer, criterion, None, False, EPOCHS, False, train_dataloader, eval_dataloader, test_dataloader,
                  None, None, None, train_dataset.tag2idx, train_dataset.idx2tag, DEVICE)

COMET INFO: Experiment is live on comet.ml https://www.comet.ml/ryzhtus/ner-with-nonlocal-features/b9731005202948aa965743ac52068550



In [46]:
def merge_tokens_into_words(tokens, tokenized_tags):
        """
        Returns a dict, where keys are words and values are lists with corresponding positions of tokens forming the word
        {0: [0], 1: [1, 2], 2: [3, 4, 5], 3: [6], ...}
        """
        words_ids = {}
        current_word_ids = []
        current_word_bpe = []
        word_id = 0

        for idx in range(len(tokens) - 1):
            if ('##' not in tokens[idx]) and ('##' not in tokens[idx + 1]):
                current_word_ids.append(idx)
                current_word_bpe.append(tokens[idx])
                words_ids[word_id] = {'bpe': current_word_bpe, 'positions': current_word_ids}
                word_id += 1
                current_word_bpe = []
                current_word_ids = []
            elif ('##' in tokens[idx]) and ('##' not in tokens[idx + 1]):
                current_word_ids.append(idx)
                current_word_bpe.append(tokens[idx])
                words_ids[word_id] = {'bpe': current_word_bpe, 'positions': current_word_ids}
                word_id += 1
                current_word_bpe = []
                current_word_ids = []
            else:
                current_word_ids.append(idx)
                current_word_bpe.append(tokens[idx])

        words = []
        tags = []
        for word in words_ids:
            bpe = words_ids[word]['bpe']
            pos = words_ids[word]['positions']
            word = "".join(bpe).replace('##', '')
            print(pos[0])
            tag = tokenized_tags[pos[0]]

            words.append(word)
            tags.append(tag)

        return words, tags

In [62]:
tokens = TOKENIZER.convert_ids_to_tokens(test_dataset[4][0])
tags = [train_dataset.idx2tag[tag.item()] for tag in test_dataset[4][1]]

for word, tag in zip(tokens, tags):
    print(word, tag)

[CLS] O
But O
China B-LOC
saw O
their O
luck O
desert O
them O
in O
the O
second O
match O
of O
the O
group O
, O
crashing O
to O
a O
surprise O
2 O
- O
0 O
defeat O
to O
newcomer O
##s O
Uzbekistan B-LOC
. O
[SEP] O


In [65]:
trainer.test_labels[idx]

['O',
 'B-LOC',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O']

In [48]:
idx = 4

tokens = TOKENIZER.convert_ids_to_tokens(test_dataset[idx][0])[1: len(test_dataset[idx][0]) - 1]
token_tags = trainer.test_labels[idx]

words, tags = merge_tokens_into_words(tokens, token_tags)
for word, tag in zip(words, tags):
    print(word, tag)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
26


IndexError: list index out of range

In [7]:
trainer.fit()

[1 / 4] Train: Loss = 0.25690, F1-score = 80.71%, Repeated Entities Accuracy = 70.99%: 100%|██████████| 439/439 [01:31<00:00,  4.77it/s]
[1 / 4] Eval : Loss = 0.09301, F1-score = 94.03%, Repeated Entities Accuracy = 92.77%: 100%|██████████| 102/102 [00:07<00:00, 14.17it/s]
[2 / 4] Train: Loss = 0.06759, F1-score = 95.75%, Repeated Entities Accuracy = 95.79%: 100%|██████████| 439/439 [01:31<00:00,  4.82it/s]
[2 / 4] Eval : Loss = 0.07867, F1-score = 94.89%, Repeated Entities Accuracy = 92.48%: 100%|██████████| 102/102 [00:07<00:00, 14.20it/s]
[3 / 4] Train: Loss = 0.04168, F1-score = 97.36%, Repeated Entities Accuracy = 97.51%: 100%|██████████| 439/439 [01:33<00:00,  4.69it/s]
[3 / 4] Eval : Loss = 0.08459, F1-score = 95.09%, Repeated Entities Accuracy = 93.58%: 100%|██████████| 102/102 [00:07<00:00, 14.30it/s]
[4 / 4] Train: Loss = 0.03454, F1-score = 97.80%, Repeated Entities Accuracy = 98.02%: 100%|██████████| 439/439 [01:30<00:00,  4.86it/s]
[4 / 4] Eval : Loss = 0.09169, F1-score =

In [8]:
trainer.test()

Test : Loss = 0.17691, F1-score = 91.28%, Repeated Entities Accuracy = 93.15%: 100%|██████████| 108/108 [00:06<00:00, 16.93it/s]


              precision    recall  f1-score   support

         LOC     0.9241    0.8899    0.9067      2887
        MISC     0.5344    0.7675    0.6301      1226
         ORG     0.8819    0.8658    0.8738      3354
         PER     0.9315    0.9216    0.9265      2907

   micro avg     0.8490    0.8765    0.8625     10374
   macro avg     0.8180    0.8612    0.8343     10374
weighted avg     0.8665    0.8765    0.8689     10374



In [9]:
experiment.end()

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/ryzhtus/ner-with-nonlocal-features/b9731005202948aa965743ac52068550
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     test_Test F1                      : 0.9128327645051195
COMET INFO:     test_Test Precision               : 0.8541227565944944
COMET INFO:     test_Test RE Accuracy             : 0.9315126050420168
COMET INFO:     test_Test Recall                  : 0.9802096313127612
COMET INFO:     train_Train F1 [4]                : (0.8071459598340125, 0.9779594479744222)
COMET INFO:     train_Train Precision [4]         : (0.78093926604192, 0.966777089634644)
COMET INFO:     train_Train RE Accuracy [4]       : (0.7099428112772148, 0.9802138503047867)
COMET INFO:     train_Train Recall [4]            : (0.8351726116477001, 0.989403517966

In [7]:
trainer.fit()

[1 / 5] Train: Loss = 0.61121, F1-score = 47.22%, Repeated Entities Accuracy = 36.38%: 100%|██████████| 439/439 [03:13<00:00,  2.27it/s]
[1 / 5] Eval : Loss = 0.25971, F1-score = 81.63%, Repeated Entities Accuracy = 76.71%: 100%|██████████| 102/102 [00:39<00:00,  2.60it/s]
[2 / 5] Train: Loss = 0.18943, F1-score = 87.36%, Repeated Entities Accuracy = 81.91%: 100%|██████████| 439/439 [03:13<00:00,  2.27it/s]
[2 / 5] Eval : Loss = 0.12220, F1-score = 92.38%, Repeated Entities Accuracy = 92.92%: 100%|██████████| 102/102 [00:39<00:00,  2.60it/s]
[3 / 5] Train: Loss = 0.09398, F1-score = 94.25%, Repeated Entities Accuracy = 92.51%: 100%|██████████| 439/439 [03:15<00:00,  2.24it/s]
[3 / 5] Eval : Loss = 0.07423, F1-score = 95.30%, Repeated Entities Accuracy = 95.96%: 100%|██████████| 102/102 [00:39<00:00,  2.59it/s]
[4 / 5] Train: Loss = 0.05878, F1-score = 96.47%, Repeated Entities Accuracy = 96.36%: 100%|██████████| 439/439 [03:15<00:00,  2.25it/s]
[4 / 5] Eval : Loss = 0.06626, F1-score =

In [8]:
trainer.test()

Test : Loss = 0.13074, F1-score = 93.36%, Repeated Entities Accuracy = 94.37%: 100%|██████████| 108/108 [00:33<00:00,  3.18it/s]


              precision    recall  f1-score   support

         LOC       0.92      0.90      0.91      2890
        MISC       0.69      0.71      0.70      1228
         ORG       0.87      0.88      0.88      3372
         PER       0.94      0.93      0.94      2948

   micro avg       0.88      0.88      0.88     10438
   macro avg       0.86      0.86      0.86     10438
weighted avg       0.88      0.88      0.88     10438



In [9]:
experiment.end()

COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/ryzhtus/ner-with-nonlocal-features/8377cda2cbee42ceb856e53b3dc99500
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     test_Test F1                      : 0.9335873895721463
COMET INFO:     test_Test Precision               : 0.8969511383304487
COMET INFO:     test_Test RE Accuracy             : 0.9436974789915966
COMET INFO:     test_Test Recall                  : 0.9733439283392328
COMET INFO:     train_Train F1 [5]                : (0.47223302802041706, 0.9740197080590424)
COMET INFO:     train_Train Precision [5]         : (0.5067422268322919, 0.9610909384329863)
COMET INFO:     train_Train RE Accuracy [5]       : (0.3638466955034705, 0.9761593401066291)
COMET INFO:     train_Train Recall [5]            : (0.44212431573371, 0.9873010610

In [10]:
from torchviz import make_dot

In [11]:
model.train()

for param in model.bert.parameters():
    param.requires_grad = False

mean_embeddings_for_batch_documents = {}
sentences_from_documents = {}

document_id = 0

mean_embeddings_for_batch_documents[document_id] = model.get_document_context(train_documents[document_id].to(DEVICE), train_documents.collect_all_positions_for_each_word(document_id))
sentences_from_documents[document_id] = train_documents.get_document_words_by_sentences(document_id)

for param in model.bert.parameters():
    param.requires_grad = True

tokens, tags, mask, sentence_id, document_id = train_dataset[0]

dot = make_dot(model(tokens.unsqueeze(0).to(DEVICE), [document_id], [sentence_id], mean_embeddings_for_batch_documents, sentences_from_documents), params=dict(model.named_parameters()), show_attrs=True, show_saved=True)

In [13]:
dot.format = 'png'
dot.render('NERonBERTwithDocumentContextFullInfoWithoutGradForAdd')

ExecutableNotFound: failed to execute ['dot', '-Kdot', '-Tpng', '-O', 'NERonBERTwithDocumentContextFullInfoWithoutGradForAdd'], make sure the Graphviz executables are on your systems' PATH

In [14]:
dot -V

NameError: name 'V' is not defined