In [1]:
import sys
sys.path.insert(0, '../util')
sys.path.insert(1, '../experiments')

import os
# Disable weights and biases (if installed)
os.environ["WANDB_DISABLED"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from pathlib import Path
import transformers
import datasets
from transformers import AutoModelForTokenClassification, AutoTokenizer, Trainer, TrainingArguments, pipeline, DataCollatorForTokenClassification, EarlyStoppingCallback, trainer_utils
from huggingface_utils import load_custom_dataset, LabelAligner, compute_metrics, eval_on_test_set
from run_experiment import get_train_args
from convert_annotations import entity_values

In [3]:
datasets.logging.set_verbosity_error()
transformers.logging.disable_default_handler()

# Parameters

In [4]:
level = 'fine' # Change to 'coarse' to look at high-level entity classes only
spans = 'long' # Change to 'short' to consider short spans ignoring specifications

In [5]:
config_files = {
    ('coarse' , 'short') : '01_ggponc_coarse_short.yaml',
    ('fine', 'short') : '02_ggponc_fine_short.yaml',
    ('coarse' , 'long' ) : '03_ggponc_coarse_long.yaml',
    ('fine', 'long' ) : '04_ggponc_fine_long.yaml'
}

In [6]:
import hydra
from hydra import compose, initialize

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=Path('..') / 'experiments', job_name='foo')
config = compose(config_name=config_files[(level, spans)], overrides=['cuda=0', 'link=false'])

In [7]:
train_file = config['train_dataset']
dev_file = config['dev_dataset']
test_file = config['test_dataset']

# Setup IOB-encoded dataset with train / dev / test splits

In [8]:
dataset, tags = load_custom_dataset(train=train_file, dev=dev_file, test=test_file, tag_strings=config['task'])

  0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
tokenizer = AutoTokenizer.from_pretrained(config['base_model_checkpoint'])
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [10]:
label_aligner = LabelAligner(tokenizer)

In [11]:
dataset = dataset.map(lambda e: label_aligner.tokenize_and_align_labels(e, config['label_all_tokens']), batched=True)

  0%|          | 0/47 [00:00<?, ?ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/11 [00:00<?, ?ba/s]

In [12]:
id2label = dict(enumerate(tags))
id2label

{0: 'O',
 1: 'B-Other_Finding',
 2: 'I-Other_Finding',
 3: 'B-Diagnosis_or_Pathology',
 4: 'I-Diagnosis_or_Pathology',
 5: 'B-Therapeutic',
 6: 'I-Therapeutic',
 7: 'B-Diagnostic',
 8: 'I-Diagnostic',
 9: 'B-Nutrient_or_Body_Substance',
 10: 'I-Nutrient_or_Body_Substance',
 11: 'B-External_Substance',
 12: 'I-External_Substance',
 13: 'B-Clinical_Drug',
 14: 'I-Clinical_Drug'}

In [13]:
dataset

DatasetDict({
    train: Dataset({
        features: ['_tags', 'attention_mask', 'fname', 'input_ids', 'labels', 'offset_mapping', 'sentence_id', 'special_tokens_mask', 'tags', 'token_type_ids', 'tokens'],
        num_rows: 46291
    })
    dev: Dataset({
        features: ['_tags', 'attention_mask', 'fname', 'input_ids', 'labels', 'offset_mapping', 'sentence_id', 'special_tokens_mask', 'tags', 'token_type_ids', 'tokens'],
        num_rows: 9685
    })
    test: Dataset({
        features: ['_tags', 'attention_mask', 'fname', 'input_ids', 'labels', 'offset_mapping', 'sentence_id', 'special_tokens_mask', 'tags', 'token_type_ids', 'tokens'],
        num_rows: 10743
    })
})

# Configure and train 🤗 token classification model

In [14]:
from run_experiment import get_train_args

In [15]:
num_train_epochs = 10 # Remove this line to train for default value of 100 epochs

In [16]:
config['num_train_epochs'] = num_train_epochs

In [17]:
training_args = get_train_args(cp_path='../ner_results', run_name='ner_baseline', report_to=[], **config, resume_from_checkpoint=None)

INFO:run_experiment:ner_baseline


In [18]:
def model_init():
    return AutoModelForTokenClassification.from_pretrained(
        config['base_model_checkpoint'],
        num_labels=len(tags), 
        id2label=id2label,
    )

data_collator = DataCollatorForTokenClassification(tokenizer)
tr = Trainer(
    args=training_args,
    model_init=model_init,
    train_dataset=dataset["train"],
    eval_dataset=dataset["dev"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics(tags, True),
)

Some weights of the model checkpoint at deepset/gbert-base were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at deepset/gb

### Train the model

In [19]:
train_result = tr.train()

Some weights of the model checkpoint at deepset/gbert-base were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at deepset/gb

Epoch,Training Loss,Validation Loss,Clinical Drug Precision,Clinical Drug Recall,Clinical Drug F1,Clinical Drug Number,Diagnosis Or Pathology Precision,Diagnosis Or Pathology Recall,Diagnosis Or Pathology F1,Diagnosis Or Pathology Number,Diagnostic Precision,Diagnostic Recall,Diagnostic F1,Diagnostic Number,External Substance Precision,External Substance Recall,External Substance F1,External Substance Number,Nutrient Or Body Substance Precision,Nutrient Or Body Substance Recall,Nutrient Or Body Substance F1,Nutrient Or Body Substance Number,Other Finding Precision,Other Finding Recall,Other Finding F1,Other Finding Number,Therapeutic Precision,Therapeutic Recall,Therapeutic F1,Therapeutic Number,Overall Precision,Overall Recall,Overall F1,Overall Accuracy
1,0.314,0.312082,0.53517,0.646601,0.585632,1412,0.658379,0.734192,0.694222,7069,0.637019,0.687915,0.661489,3467,0.4,0.232558,0.294118,129,0.442396,0.595041,0.507489,484,0.497943,0.576479,0.534341,5459,0.654701,0.704014,0.678463,5855,0.602624,0.673298,0.636004,0.894012
2,0.226,0.296495,0.557651,0.733003,0.633415,1412,0.700743,0.74664,0.722964,7069,0.63671,0.723392,0.677289,3467,0.48913,0.348837,0.40724,129,0.560484,0.57438,0.567347,484,0.517188,0.614581,0.561694,5459,0.633065,0.750811,0.686929,5855,0.617823,0.707644,0.65969,0.900326
3,0.166,0.321142,0.612698,0.683428,0.646133,1412,0.712123,0.752016,0.731526,7069,0.649029,0.732333,0.688169,3467,0.445652,0.317829,0.371041,129,0.563043,0.535124,0.548729,484,0.53971,0.627404,0.580263,5459,0.66497,0.736294,0.698817,5855,0.640109,0.70601,0.671447,0.901267
4,0.1165,0.366398,0.594857,0.737252,0.658444,1412,0.721769,0.75046,0.735835,7069,0.633719,0.741563,0.683413,3467,0.463415,0.294574,0.36019,129,0.618557,0.495868,0.550459,484,0.539909,0.630702,0.581784,5459,0.660816,0.733049,0.695061,5855,0.638892,0.70911,0.672172,0.90085
5,0.0837,0.401041,0.614788,0.73017,0.66753,1412,0.727958,0.76503,0.746034,7069,0.667712,0.73666,0.700494,3467,0.366667,0.255814,0.30137,129,0.583333,0.549587,0.565957,484,0.563791,0.620077,0.590596,5459,0.685773,0.73766,0.710771,5855,0.659296,0.711874,0.684577,0.903114
6,0.0586,0.450769,0.613928,0.711756,0.659233,1412,0.719015,0.768143,0.742767,7069,0.653124,0.732622,0.690593,3467,0.389313,0.395349,0.392308,129,0.586957,0.557851,0.572034,484,0.574038,0.598644,0.586083,5459,0.683187,0.732195,0.706843,5855,0.657292,0.705801,0.680683,0.901862
7,0.0446,0.492896,0.613213,0.723088,0.663633,1412,0.736255,0.746357,0.741272,7069,0.691201,0.706951,0.698988,3467,0.5,0.325581,0.394366,129,0.610738,0.56405,0.586466,484,0.540801,0.648287,0.589686,5459,0.704346,0.741759,0.722569,5855,0.661449,0.709738,0.684743,0.90204
8,0.0324,0.521401,0.61094,0.703966,0.654163,1412,0.727773,0.767718,0.747212,7069,0.681005,0.735218,0.707074,3467,0.521277,0.379845,0.439462,129,0.630137,0.570248,0.598698,484,0.566301,0.6329,0.597751,5459,0.693469,0.745346,0.718472,5855,0.664544,0.716817,0.689691,0.903656
9,0.0251,0.554185,0.612394,0.713881,0.659254,1412,0.728125,0.768708,0.747867,7069,0.6722,0.735795,0.702561,3467,0.43956,0.310078,0.363636,129,0.62645,0.557851,0.590164,484,0.57053,0.631251,0.599356,5459,0.701301,0.745858,0.722894,5855,0.666174,0.716901,0.690607,0.903603
10,0.019,0.579063,0.615525,0.718839,0.663182,1412,0.732114,0.770123,0.750638,7069,0.677075,0.731757,0.703355,3467,0.472527,0.333333,0.390909,129,0.62037,0.553719,0.585153,484,0.568653,0.627404,0.596586,5459,0.694581,0.750811,0.721602,5855,0.666278,0.717403,0.690896,0.903887


# Evaluate Model

In [20]:
model = tr.model

In [21]:
from transformers.pipelines.token_classification import AggregationStrategy

In [22]:
test_metrics = eval_on_test_set(dataset["test"], tr, tokenizer, "test")

10743it [00:07, 1397.55it/s]


In [23]:
print(f"""
F1: {test_metrics["test/overall_f1"]:.2f}
 P: {test_metrics["test/overall_precision"]:.2f}
 R: {test_metrics["test/overall_recall"]:.2f}
""")


F1: 0.70
 P: 0.68
 R: 0.72



### Detailed analysis of model performance

See notebook: [03_NER_Analysis](03_NER_Analysis.ipynb)