In [33]:
import os
import torch
import tensorflow_datasets as tfds
import datasets
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification
from transformers import TrainingArguments
from transformers import Trainer
from transformers import DataCollatorForTokenClassification
from transformers import pipeline
import evaluate
import numpy as np

In [34]:
from utils_display import pc

In [35]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pc("Device", device)

[34mDevice[0m: cuda


In [36]:
seqeval = evaluate.load("seqeval")

# conll2003 dataset

In [37]:
path_to_conll2003_dataset = os.path.join("local_datasets", "conll2003")
dataset_train = datasets.load_from_disk(os.path.join(path_to_conll2003_dataset, "train.hf"))
dataset_test = datasets.load_from_disk(os.path.join(path_to_conll2003_dataset, "test.hf"))

pc("Number of samples in the train dataset", len(dataset_train))
pc("Number of samples in the test dataset", len(dataset_test))

[34mNumber of samples in the train dataset[0m: 14042
[34mNumber of samples in the test dataset[0m: 3454


In [38]:
sample_index = 2

In [39]:
pos_tags2indices = {
    '"': 0, "''": 1, '#': 2, '$': 3, '(': 4, ')': 5, ',': 6, '.': 7, ':': 8, '``': 9, 'CC': 10, 'CD': 11, 'DT': 12,
    'EX': 13, 'FW': 14, 'IN': 15, 'JJ': 16, 'JJR': 17, 'JJS': 18, 'LS': 19, 'MD': 20, 'NN': 21, 'NNP': 22, 'NNPS': 23,
    'NNS': 24, 'NN|SYM': 25, 'PDT': 26, 'POS': 27, 'PRP': 28, 'PRP$': 29, 'RB': 30, 'RBR': 31, 'RBS': 32, 'RP': 33,
    'SYM': 34, 'TO': 35, 'UH': 36, 'VB': 37, 'VBD': 38, 'VBG': 39, 'VBN': 40, 'VBP': 41, 'VBZ': 42, 'WDT': 43,
    'WP': 44, 'WP$': 45, 'WRB': 46
}

chunk_tags2indices = {
    'O': 0, 'B-ADJP': 1, 'I-ADJP': 2, 'B-ADVP': 3, 'I-ADVP': 4, 'B-CONJP': 5, 'I-CONJP': 6, 'B-INTJ': 7, 'I-INTJ': 8,
    'B-LST': 9, 'I-LST': 10, 'B-NP': 11, 'I-NP': 12, 'B-PP': 13, 'I-PP': 14, 'B-PRT': 15, 'I-PRT': 16, 'B-SBAR': 17,
    'I-SBAR': 18, 'B-UCP': 19, 'I-UCP': 20, 'B-VP': 21, 'I-VP': 22
}

ner_tags2indices = {
    'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8
}

ner_label_list = [k for k in ner_tags2indices.keys()]

In [40]:
number_of_pos_tags = len(pos_tags2indices)
number_of_chunks_tags = len(chunk_tags2indices)
number_of_ner_tags = len(ner_tags2indices)

pc("Number of POS tags", number_of_pos_tags)
pc("Number of CHUNK tags", number_of_chunks_tags)
pc("Number of NER tags", number_of_ner_tags)

[34mNumber of POS tags[0m: 47
[34mNumber of CHUNK tags[0m: 23
[34mNumber of NER tags[0m: 9


In [41]:
def create_dico_indices2tags(dico_tags2indices: dict) -> dict:
    dico_indices2tags = dict()
    for key in dico_tags2indices:
        dico_indices2tags[dico_tags2indices[key]] = key
    return dico_indices2tags

In [42]:
pos_indices2tags = create_dico_indices2tags(dico_tags2indices=pos_tags2indices)
chunk_indices2tags = create_dico_indices2tags(dico_tags2indices=chunk_tags2indices)
ner_indices2tags = create_dico_indices2tags(dico_tags2indices=ner_tags2indices)

In [43]:
def print_sample(sample) -> None:

    print("-"*74)
    print("{:<4} | {:<20} | {:<3} {:<10} | {:<3} {:<10} | {:<3} {:<10}".format(
        "INDEX", "TOKEN", "", "POS", "", "CHUNK", "", "NER"))
    print("-"*74)
    for index in range(len(sample["tokens"])):
        
        pos_index = sample["pos"][index]
        pos_tag = pos_indices2tags[pos_index]
    
        chunk_index = sample["chunks"][index]
        chunk_tag = chunk_indices2tags[chunk_index]
    
        ner_index = sample["ner"][index]
        ner_tag = ner_indices2tags[ner_index]    
        
        print("{:<5} | {:<20} | {:<3} {:<10} | {:<3} {:<10} | {:<3} {:<10}".format(
            index,
            sample["tokens"][index],
            pos_index,
            pos_tag,
            chunk_index,
            chunk_tag,
            ner_index,
            ner_tag
        ))    
    print("-"*74)

In [44]:
pc("Sample index", sample_index, break_line=True)

sample = dataset_train[sample_index]
for key in sample.keys():
    pc(key, sample[key])
print_sample(sample=sample)

[34mSample index[0m: 2

[34mindex[0m: 2
[34mpos[0m: [22, 27, 16, 21, 42, 12, 21, 42, 16, 10, 22, 22, 22, 22, 27, 21, 20, 37, 12, 21, 27, 16, 10, 16, 24, 6, 12, 21, 21, 38, 15, 22, 7]
[34mchunks[0m: [11, 11, 12, 12, 21, 11, 12, 21, 1, 0, 11, 12, 12, 12, 11, 12, 21, 22, 11, 12, 12, 1, 2, 2, 11, 0, 11, 12, 12, 21, 13, 11, 0]
[34mner[0m: [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[34mtokens[0m: ['Thailand', "'s", 'powerful', 'military', 'thinks', 'the', 'government', 'is', 'dishonest', 'and', 'Prime', 'Minister', 'Banharn', 'Silpa-archa', "'s", 'resignation', 'might', 'solve', 'the', 'nation', "'s", 'political', 'and', 'economic', 'woes', ',', 'an', 'opinion', 'poll', 'showed', 'on', 'Thursday', '.']
--------------------------------------------------------------------------
INDEX | TOKEN                |     POS        |     CHUNK      |     NER       
----------------------------------------------------------------------

# BERT model

In [45]:
model_checkpoint = 'bert-base-uncased'

In [46]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [47]:
def tokenize_and_align_labels(examples):
    label_all_tokens = True
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)    
    labels = []
    for i, label in enumerate(examples["ner"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [48]:
def tokenize_dataset(dataset):
    return dataset.map(
        tokenize_and_align_labels,        
        batched=True,
        #remove_columns=dataset.column_names
    ) 

In [49]:
tokenized_dataset_train = tokenize_dataset(dataset_train)
tokenized_dataset_test = tokenize_dataset(dataset_test)

In [50]:
sample = dataset_train[sample_index]
sample_tokens = sample["tokens"]
sample_input_ids = tokenized_dataset_train[sample_index]["input_ids"]
tokenized_sample_input = tokenizer(sample_tokens, is_split_into_words=True)
word_indices = tokenized_sample_input.word_ids()
sample_labels_aligned = tokenized_dataset_train[sample_index]["labels"]

pc("Sample tokens", sample_tokens)
pc("Number of sample tokens", len(sample_tokens), break_line=True)
pc("Input ids", sample_input_ids)
pc("Number of input ids", len(sample_input_ids), break_line=True)
pc("Word indices", word_indices)
pc("Number of word indices", len(word_indices), break_line=True)
pc("Aligned labels", sample_labels_aligned)
pc("Number of aligned labels", len(sample_labels_aligned))

[34mSample tokens[0m: ['Thailand', "'s", 'powerful', 'military', 'thinks', 'the', 'government', 'is', 'dishonest', 'and', 'Prime', 'Minister', 'Banharn', 'Silpa-archa', "'s", 'resignation', 'might', 'solve', 'the', 'nation', "'s", 'political', 'and', 'economic', 'woes', ',', 'an', 'opinion', 'poll', 'showed', 'on', 'Thursday', '.']
[34mNumber of sample tokens[0m: 33

[34mInput ids[0m: [101, 6504, 1005, 1055, 3928, 2510, 6732, 1996, 2231, 2003, 9841, 21821, 2102, 1998, 3539, 2704, 7221, 8167, 2078, 9033, 14277, 2050, 1011, 7905, 2050, 1005, 1055, 8172, 2453, 9611, 1996, 3842, 1005, 1055, 2576, 1998, 3171, 24185, 2229, 1010, 2019, 5448, 8554, 3662, 2006, 9432, 1012, 102]
[34mNumber of input ids[0m: 48

[34mWord indices[0m: [None, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 9, 10, 11, 12, 12, 12, 13, 13, 13, 13, 13, 13, 14, 14, 15, 16, 17, 18, 19, 20, 20, 21, 22, 23, 24, 24, 25, 26, 27, 28, 29, 30, 31, 32, None]
[34mNumber of word indices[0m: 48

[34mAligned labels[0m: [-100, 5, 0, 

In [51]:
print("{:<3} | {:<15} | {:<30} | {:<55} | {:<35}".format(
    "",
    "TOKEN",
    "WORD IDS",
    "INPUT IDS",
    "NER LABELS"))

for s in range(len(sample_tokens)):
    w = [i for i, j in enumerate(word_indices) if j == s]
    sample_token_s = sample_tokens[s]
    word_indices_k = [str(word_indices[k]) for k in w]
    sample_input_ids_k = [str(sample_input_ids[k]) for k in w]
    sample_labels_aligned_k = [str(sample_labels_aligned[k]) for k in w]
    print("{:<3} | {:<15} | {:<30} | {:<55} | {:<35}".format(
        s,
        sample_token_s,
        ", ".join(word_indices_k),
        ", ".join(sample_input_ids_k),
        ", ".join(sample_labels_aligned_k)))

    | TOKEN           | WORD IDS                       | INPUT IDS                                               | NER LABELS                         
0   | Thailand        | 0                              | 6504                                                    | 5                                  
1   | 's              | 1, 1                           | 1005, 1055                                              | 0, 0                               
2   | powerful        | 2                              | 3928                                                    | 0                                  
3   | military        | 3                              | 2510                                                    | 0                                  
4   | thinks          | 4                              | 6732                                                    | 0                                  
5   | the             | 5                              | 1996                                 

In [52]:
pc("Dataset, sample index", sample_index)
for key in dataset_train[sample_index].keys():
    pc(key, dataset_train[sample_index][key])

[34mDataset, sample index[0m: 2
[34mindex[0m: 2
[34mpos[0m: [22, 27, 16, 21, 42, 12, 21, 42, 16, 10, 22, 22, 22, 22, 27, 21, 20, 37, 12, 21, 27, 16, 10, 16, 24, 6, 12, 21, 21, 38, 15, 22, 7]
[34mchunks[0m: [11, 11, 12, 12, 21, 11, 12, 21, 1, 0, 11, 12, 12, 12, 11, 12, 21, 22, 11, 12, 12, 1, 2, 2, 11, 0, 11, 12, 12, 21, 13, 11, 0]
[34mner[0m: [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[34mtokens[0m: ['Thailand', "'s", 'powerful', 'military', 'thinks', 'the', 'government', 'is', 'dishonest', 'and', 'Prime', 'Minister', 'Banharn', 'Silpa-archa', "'s", 'resignation', 'might', 'solve', 'the', 'nation', "'s", 'political', 'and', 'economic', 'woes', ',', 'an', 'opinion', 'poll', 'showed', 'on', 'Thursday', '.']


In [53]:
pc("Tokenized ataset, sample index", sample_index)
for key in tokenized_dataset_train[sample_index].keys():
    pc(key, tokenized_dataset_train[sample_index][key])

[34mTokenized ataset, sample index[0m: 2
[34mindex[0m: 2
[34mpos[0m: [22, 27, 16, 21, 42, 12, 21, 42, 16, 10, 22, 22, 22, 22, 27, 21, 20, 37, 12, 21, 27, 16, 10, 16, 24, 6, 12, 21, 21, 38, 15, 22, 7]
[34mchunks[0m: [11, 11, 12, 12, 21, 11, 12, 21, 1, 0, 11, 12, 12, 12, 11, 12, 21, 22, 11, 12, 12, 1, 2, 2, 11, 0, 11, 12, 12, 21, 13, 11, 0]
[34mner[0m: [5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[34mtokens[0m: ['Thailand', "'s", 'powerful', 'military', 'thinks', 'the', 'government', 'is', 'dishonest', 'and', 'Prime', 'Minister', 'Banharn', 'Silpa-archa', "'s", 'resignation', 'might', 'solve', 'the', 'nation', "'s", 'political', 'and', 'economic', 'woes', ',', 'an', 'opinion', 'poll', 'showed', 'on', 'Thursday', '.']
[34minput_ids[0m: [101, 6504, 1005, 1055, 3928, 2510, 6732, 1996, 2231, 2003, 9841, 21821, 2102, 1998, 3539, 2704, 7221, 8167, 2078, 9033, 14277, 2050, 1011, 7905, 2050, 1005, 1055, 8172, 2453, 9611, 1996, 

In [54]:
print(tokenizer.decode(tokenized_dataset_train[sample_index]["input_ids"]))

[CLS] thailand ' s powerful military thinks the government is dishonest and prime minister banharn silpa - archa ' s resignation might solve the nation ' s political and economic woes, an opinion poll showed on thursday. [SEP]


In [55]:
model = AutoModelForTokenClassification.from_pretrained(
    pretrained_model_name_or_path=model_checkpoint,
    num_labels=number_of_ner_tags,
    id2label=ner_indices2tags,
    label2id=ner_tags2indices,
)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [56]:
args = TrainingArguments(
    "bert-finetuned-ner",
    eval_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-3,
    num_train_epochs=5,
    weight_decay=0.01,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    report_to="wandb"    
)

In [57]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [ner_label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    true_labels = [
        [ner_label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)

    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [58]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [59]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_dataset_train,
    eval_dataset=tokenized_dataset_test,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=tokenizer,
)

In [60]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mgzahnd[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

# Evaluation

In [None]:
nlp = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first")

In [None]:
text = "She then related that, by the permission of Elizabeth, she had passed the evening of the night on which the murder had been committed at the house of an aunt at Chene, a village situated at about a league from Geneva."

In [None]:
text_tokenized = tokenizer.tokenize(text)

In [None]:
text_ner = nlp(text_tokenized)

In [None]:
pc("Text", text)
pc("Tokenized text", text_tokenized)
pc("NER", text_ner)