In this notebook, we use the 🤗 transformers library to fine-tune the SpanBERT/spanbert-large-cased model on the dataset ade_corpus_v2. The goal is for the fine-tuned model to perform Named Entity Recognition by identifying Adverse Drug Reactions (ADRs) as well as Drug names.

This was originally run on an ml.p3.2xlarge instance on AWS SageMaker.

In [None]:
# !pip install torch==1.5.0

In [None]:
# ! pip install datasets transformers seqeval

In [None]:
# ! pip install spacy 

In [1]:
from datasets import Dataset, ClassLabel, Sequence, load_dataset, load_metric
import numpy as np
import pandas as pd
from spacy import displacy
import transformers
from transformers import (AutoModelForTokenClassification, 
                          AutoTokenizer, 
                          DataCollatorForTokenClassification,
                          pipeline,
                          TrainingArguments, 
                          Trainer)

# Dataset Exploration

We use the Ade_corpus_v2_drug_ade_relation subset of the ade_corpus_v2 dataset, which provides labeled spans for drug names and adverse effects.

See dataset page here: https://huggingface.co/datasets/ade_corpus_v2

In [2]:
datasets = load_dataset("ade_corpus_v2", "Ade_corpus_v2_drug_ade_relation")

Reusing dataset ade_corpus_v2 (/home/ec2-user/.cache/huggingface/datasets/ade_corpus_v2/Ade_corpus_v2_drug_ade_relation/1.0.0/940d61334dbfac6b01ac5d00286a2122608b8dc79706ee7e9206a1edb172c559)


  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
datasets

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'indexes'],
        num_rows: 6821
    })
})

In [4]:
datasets["train"][10]

{'text': 'BACKGROUND: How to best treat psychotic patients who have had past clozapine-induced agranulocytosis or granulocytopenia remains a problem.',
 'drug': 'clozapine',
 'effect': 'granulocytopenia',
 'indexes': {'drug': {'start_char': [67], 'end_char': [76]},
  'effect': {'start_char': [104], 'end_char': [120]}}}

# Dataset Consolidation

Upon further examination of the dataset, we can see that sentences are often repeated to identify different pairs of drugs and adverse reactions. For example, see this sentence from the dataset:

{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'insulin', 'effect': 'increasing myalgia', 'indexes': {'drug': {'start_char': [37], 'end_char': [44]}, 'effect': {'start_char': [147], 'end_char': [165]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'cresol', 'effect': 'lost consciousness', 'indexes': {'drug': {'start_char': [74], 'end_char': [80]}, 'effect': {'start_char': [233], 'end_char': [251]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'cresol', 'effect': 'high fever', 'indexes': {'drug': {'start_char': [74], 'end_char': [80]}, 'effect': {'start_char': [179], 'end_char': [189]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'insulin', 'effect': 'high fever', 'indexes': {'drug': {'start_char': [37], 'end_char': [44]}, 'effect': {'start_char': [179], 'end_char': [189]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'insulin', 'effect': 'lost consciousness', 'indexes': {'drug': {'start_char': [37], 'end_char': [44]}, 'effect': {'start_char': [233], 'end_char': [251]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'insulin', 'effect': 'respiratory and metabolic acidosis', 'indexes': {'drug': {'start_char': [37], 'end_char': [44]}, 'effect': {'start_char': [194], 'end_char': [228]}}}
{'text': 'After therapy for diabetic coma with insulin (containing the preservative cresol) and electrolyte solutions was started, the patient complained of increasing myalgia, developed a high fever and respiratory and metabolic acidosis and lost consciousness.', 'drug': 'cresol', 'effect': 'respiratory and metabolic acidosis', 'indexes': {'drug': {'start_char': [74], 'end_char': [80]}, 'effect': {'start_char': [194], 'end_char': [228]}}}

In [5]:
consolidated_dataset = {}

for row in datasets["train"]:
    if row["text"] in consolidated_dataset:
        consolidated_dataset[row["text"]]["drug_indices_start"].update(row["indexes"]["drug"]["start_char"])
        consolidated_dataset[row["text"]]["drug_indices_end"].update(row["indexes"]["drug"]["end_char"])
        consolidated_dataset[row["text"]]["effect_indices_start"].update(row["indexes"]["effect"]["start_char"])
        consolidated_dataset[row["text"]]["effect_indices_end"].update(row["indexes"]["effect"]["end_char"])
        consolidated_dataset[row["text"]]["drug"].append(row["drug"])
        consolidated_dataset[row["text"]]["effect"].append(row["effect"])
        
    else:
        consolidated_dataset[row["text"]] = {
            "text": row["text"],
            "drug": [row["drug"]],
            "effect": [row["effect"]],
            # use sets because the indices can repeat for various reasons
            "drug_indices_start": set(row["indexes"]["drug"]["start_char"]),
            "drug_indices_end": set(row["indexes"]["drug"]["end_char"]),
            "effect_indices_start": set(row["indexes"]["effect"]["start_char"]),
            "effect_indices_end": set(row["indexes"]["effect"]["end_char"])
        }

With the dataset consolidated, we need to assign per-token labels to each sentence. First, we re-define our Python data structure as a Hugging Face Dataset object.

In [6]:
df = pd.DataFrame(list(consolidated_dataset.values()))

In [7]:
df.head()

Unnamed: 0,text,drug,effect,drug_indices_start,drug_indices_end,effect_indices_start,effect_indices_end
0,Intravenous azithromycin-induced ototoxicity.,[azithromycin],[ototoxicity],{12},{24},{33},{44}
1,"Immobilization, while Paget's bone disease was...",[dihydrotachysterol],[increased calcium-release],{91},{109},{143},{168}
2,Unaccountable severe hypercalcemia in a patien...,[dihydrotachysterol],[hypercalcemia],{84},{102},{21},{34}
3,METHODS: We report two cases of pseudoporphyri...,"[naproxen, oxaprozin]","[pseudoporphyria, pseudoporphyria]","{58, 71}","{80, 66}",{32},{47}
4,"Naproxen, the most common offender, has been a...",[Naproxen],[erythropoietic protoporphyria],{0},{8},{134},{163}


In [8]:
# since no spans overlap, we can sort to get 1:1 matched index spans
# note that sets don't preserve insertion order

df["drug_indices_start"] = df["drug_indices_start"].apply(list).apply(sorted)
df["drug_indices_end"] = df["drug_indices_end"].apply(list).apply(sorted)
df["effect_indices_start"] = df["effect_indices_start"].apply(list).apply(sorted)
df["effect_indices_end"] = df["effect_indices_end"].apply(list).apply(sorted)

In [9]:
df.head()

Unnamed: 0,text,drug,effect,drug_indices_start,drug_indices_end,effect_indices_start,effect_indices_end
0,Intravenous azithromycin-induced ototoxicity.,[azithromycin],[ototoxicity],[12],[24],[33],[44]
1,"Immobilization, while Paget's bone disease was...",[dihydrotachysterol],[increased calcium-release],[91],[109],[143],[168]
2,Unaccountable severe hypercalcemia in a patien...,[dihydrotachysterol],[hypercalcemia],[84],[102],[21],[34]
3,METHODS: We report two cases of pseudoporphyri...,"[naproxen, oxaprozin]","[pseudoporphyria, pseudoporphyria]","[58, 71]","[66, 80]",[32],[47]
4,"Naproxen, the most common offender, has been a...",[Naproxen],[erythropoietic protoporphyria],[0],[8],[134],[163]


In [10]:
# save to JSON to then import into Dataset object
df.to_json("dataset.jsonl", orient="records", lines=True)

In [11]:
cons_dataset = load_dataset("json", data_files="dataset.jsonl")

Using custom data configuration default-1cc54ed808739822


Downloading and preparing dataset json/default to /home/ec2-user/.cache/huggingface/datasets/json/default-1cc54ed808739822/0.0.0/c2d554c3377ea79c7664b93dc65d0803b45e3279000f993c7bfd18937fd7f426...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Dataset json downloaded and prepared to /home/ec2-user/.cache/huggingface/datasets/json/default-1cc54ed808739822/0.0.0/c2d554c3377ea79c7664b93dc65d0803b45e3279000f993c7bfd18937fd7f426. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [12]:
# no train-test provided, so we create our own
cons_dataset = cons_dataset["train"].train_test_split()

In [13]:
cons_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'drug_indices_start', 'drug_indices_end', 'effect_indices_start', 'effect_indices_end'],
        num_rows: 3203
    })
    test: Dataset({
        features: ['text', 'drug', 'effect', 'drug_indices_start', 'drug_indices_end', 'effect_indices_start', 'effect_indices_end'],
        num_rows: 1068
    })
})

# Token Labeling

In [14]:
label_list = ['O', 'B-DRUG', 'I-DRUG', 'B-EFFECT', 'I-EFFECT']

custom_seq = Sequence(feature=ClassLabel(num_classes=5, 
                                         names=label_list,
                                         names_file=None, id=None), length=-1, id=None)

cons_dataset["train"].features["ner_tags"] = custom_seq
cons_dataset["test"].features["ner_tags"] = custom_seq

In [15]:
model_checkpoint = "SpanBERT/spanbert-large-cased"

In [16]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [17]:
def generate_row_labels(row, verbose=False):
    """ Given a row from the consolidated `Ade_corpus_v2_drug_ade_relation` dataset, 
    generates BIO tags for drug and effect entities. 
    
    """

    text = row["text"]

    labels = []
    label = "O"
    prefix = ""
    
    # while iterating through tokens, increment to traverse all drug and effect spans
    drug_index = 0
    effect_index = 0
    
    tokens = tokenizer(text, return_offsets_mapping=True)

    for n in range(len(tokens["input_ids"])):
        offset_start, offset_end = tokens["offset_mapping"][n]

        # should only happen for [CLS] and [SEP]
        if offset_end - offset_start == 0:
            labels.append(-100)
            continue
        
        if drug_index < len(row["drug_indices_start"]) and offset_start == row["drug_indices_start"][drug_index]:
            label = "DRUG"
            prefix = "B-"

        elif effect_index < len(row["effect_indices_start"]) and offset_start == row["effect_indices_start"][effect_index]:
            label = "EFFECT"
            prefix = "B-"
        
        labels.append(label_list.index(f"{prefix}{label}"))
            
        if drug_index < len(row["drug_indices_end"]) and offset_end == row["drug_indices_end"][drug_index]:
            label = "O"
            prefix = ""
            drug_index += 1
            
        elif effect_index < len(row["effect_indices_end"]) and offset_end == row["effect_indices_end"][effect_index]:
            label = "O"
            prefix = ""
            effect_index += 1

        # need to transition "inside" if we just entered an entity
        if prefix == "B-":
            prefix = "I-"
    
    if verbose:
        print(f"{row}\n")
        orig = tokenizer.convert_ids_to_tokens(tokens["input_ids"])
        for n in range(len(labels)):
            print(orig[n], labels[n])
    tokens["labels"] = labels
    
    return tokens

In [18]:
# testing out...

generate_row_labels(cons_dataset["train"][2], verbose=True)

{'text': 'Ampicillin-associated seizures.', 'drug': ['Ampicillin'], 'effect': ['seizures'], 'drug_indices_start': [0], 'drug_indices_end': [10], 'effect_indices_start': [22], 'effect_indices_end': [30]}

[CLS] -100
am 1
##pic 2
##ill 2
##in 2
- 0
associated 0
seizure 3
##s 4
. 0
[SEP] -100


{'input_ids': [101, 1821, 20437, 7956, 1394, 118, 2628, 20752, 1116, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 2), (2, 5), (5, 8), (8, 10), (10, 11), (11, 21), (22, 29), (29, 30), (30, 31), (0, 0)], 'labels': [-100, 1, 2, 2, 2, 0, 0, 3, 4, 0, -100]}

In [19]:
labeled_dataset = cons_dataset.map(generate_row_labels)

  0%|          | 0/3203 [00:00<?, ?ex/s]

  0%|          | 0/1068 [00:00<?, ?ex/s]

In [20]:
labeled_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'drug', 'effect', 'drug_indices_start', 'drug_indices_end', 'effect_indices_start', 'effect_indices_end', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'labels'],
        num_rows: 3203
    })
    test: Dataset({
        features: ['text', 'drug', 'effect', 'drug_indices_start', 'drug_indices_end', 'effect_indices_start', 'effect_indices_end', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'labels'],
        num_rows: 1068
    })
})

# SpanBERT Model Fine-Tuning

We are now ready to fine-tune the SpanBERT model on our dataset. This section is modified from the following 🤗 notebook provided here: https://github.com/huggingface/notebooks/blob/master/examples/token_classification.ipynb

In [21]:
task = "ner" # Should be one of "ner", "pos" or "chunk"
batch_size = 16

In [22]:
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

Some weights of BertForTokenClassification were not initialized from the model checkpoint at SpanBERT/spanbert-large-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-{task}",
    evaluation_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.05,
    logging_steps=1
)

In [24]:
data_collator = DataCollatorForTokenClassification(tokenizer)

In [25]:
metric = load_metric("seqeval")

In [26]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [27]:
trainer = Trainer(
    model,
    args,
    train_dataset=labeled_dataset["train"],
    eval_dataset=labeled_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics, 

)

In [28]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: effect_indices_start, effect_indices_end, drug_indices_start, drug, drug_indices_end, text, effect, offset_mapping.
***** Running training *****
  Num examples = 3203
  Num Epochs = 5
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 1005


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.079,0.201275,0.803863,0.890775,0.84509,0.94769
2,0.15,0.176926,0.815992,0.9,0.85594,0.950132
3,0.1274,0.162108,0.849773,0.899631,0.873992,0.955414
4,0.0631,0.16765,0.849689,0.90738,0.877587,0.955386
5,0.0303,0.169118,0.849723,0.905535,0.876742,0.956039


The following columns in the evaluation set  don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: effect_indices_start, effect_indices_end, drug_indices_start, drug, drug_indices_end, text, effect, offset_mapping.
***** Running Evaluation *****
  Num examples = 1068
  Batch size = 16
The following columns in the evaluation set  don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: effect_indices_start, effect_indices_end, drug_indices_start, drug, drug_indices_end, text, effect, offset_mapping.
***** Running Evaluation *****
  Num examples = 1068
  Batch size = 16
Saving model checkpoint to spanbert-large-cased-finetuned-ner/checkpoint-500
Configuration saved in spanbert-large-cased-finetuned-ner/checkpoint-500/config.json
Model weights saved in spanbert-large-cased-finetuned-ner/checkpoint-500/pytorch_model.bin
tokenizer config file saved in spanbert-large-cased-finetuned-ner/checkpoint-500/toke

TrainOutput(global_step=1005, training_loss=0.18620062940082147, metrics={'train_runtime': 300.6498, 'train_samples_per_second': 53.268, 'train_steps_per_second': 3.343, 'total_flos': 1932024043823868.0, 'train_loss': 0.18620062940082147, 'epoch': 5.0})

In [29]:
predictions, labels, _ = trainer.predict(labeled_dataset["test"])
predictions = np.argmax(predictions, axis=2)

# Remove ignored index (special tokens)
true_predictions = [
    [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]
true_labels = [
    [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
    for prediction, label in zip(predictions, labels)
]

results = metric.compute(predictions=true_predictions, references=true_labels)
results

The following columns in the test set  don't have a corresponding argument in `BertForTokenClassification.forward` and have been ignored: effect_indices_start, effect_indices_end, drug_indices_start, drug, drug_indices_end, text, effect, offset_mapping.
***** Running Prediction *****
  Num examples = 1068
  Batch size = 16


{'DRUG': {'precision': 0.9330357142857143,
  'recall': 0.9661016949152542,
  'f1': 0.9492808478425435,
  'number': 1298},
 'EFFECT': {'precision': 0.7772020725388601,
  'recall': 0.8498583569405099,
  'f1': 0.8119079837618403,
  'number': 1412},
 'overall_precision': 0.8497229916897507,
 'overall_recall': 0.9055350553505535,
 'overall_f1': 0.8767416934619507,
 'overall_accuracy': 0.9560389628830261}

# See Model Outputs

We load our fine-tuned model into a pipeline object to run arbitrary input against it.

In [35]:
effect_ner_model = pipeline(task="ner", model=model, tokenizer=tokenizer, device=0,grouped_entities=True)

In [36]:
# something from our validation set
effect_ner_model(labeled_dataset["test"][4]["text"])

[{'entity_group': 'LABEL_0',
  'score': 0.97856647,
  'word': 'possible',
  'start': 0,
  'end': 8},
 {'entity_group': 'LABEL_3',
  'score': 0.96306676,
  'word': 'se',
  'start': 9,
  'end': 11},
 {'entity_group': 'LABEL_4',
  'score': 0.9704819,
  'word': '##rotonin syndrome',
  'start': 11,
  'end': 27},
 {'entity_group': 'LABEL_0',
  'score': 0.97992325,
  'word': 'associated with',
  'start': 28,
  'end': 43},
 {'entity_group': 'LABEL_1',
  'score': 0.97092986,
  'word': 'c',
  'start': 44,
  'end': 45},
 {'entity_group': 'LABEL_2',
  'score': 0.9740207,
  'word': '##lomipramine',
  'start': 45,
  'end': 56},
 {'entity_group': 'LABEL_0',
  'score': 0.9799778,
  'word': 'after withdrawal of',
  'start': 57,
  'end': 76},
 {'entity_group': 'LABEL_1',
  'score': 0.9302365,
  'word': 'c',
  'start': 77,
  'end': 78},
 {'entity_group': 'LABEL_2',
  'score': 0.9093829,
  'word': '##lozapine',
  'start': 78,
  'end': 86},
 {'entity_group': 'LABEL_0',
  'score': 0.98005694,
  'word': '.',


We try out the first few examples of adverse effects from the Wikipedia page on adverse effects and visualize with the displaCy library:

https://en.wikipedia.org/wiki/Adverse_effect#Medications

In [32]:
tokens = effect_ner_model("having fever after taking paracetamol")
tokens

[{'entity_group': 'LABEL_0',
  'score': 0.9698605,
  'word': 'having',
  'start': 0,
  'end': 6},
 {'entity_group': 'LABEL_3',
  'score': 0.9575445,
  'word': 'fever',
  'start': 7,
  'end': 12},
 {'entity_group': 'LABEL_0',
  'score': 0.9792889,
  'word': 'after taking',
  'start': 13,
  'end': 25},
 {'entity_group': 'LABEL_1',
  'score': 0.9705644,
  'word': 'para',
  'start': 26,
  'end': 30},
 {'entity_group': 'LABEL_2',
  'score': 0.97412354,
  'word': '##cetamol',
  'start': 30,
  'end': 37}]

In [33]:
def visualize_entities(sentence):
    tokens = effect_ner_model(sentence)
    entities = []
    
    for token in tokens:
        label = int(token["entity_group"][-1])
        if label != 0:
            token["label"] = label_list[label]
            entities.append(token)
    
    params = [{"text": sentence,
               "ents": entities,
               "title": None}]
    
    html = displacy.render(params, style="ent", manual=True, options={
        "colors": {
                   "B-DRUG": "#f08080",
                   "I-DRUG": "#f08080",
                   "B-EFFECT": "#9bddff",
                   "I-EFFECT": "#9bddff",
               },
    })
    

In [34]:
examples = [
    "Abortion, miscarriage or uterine hemorrhage associated with misoprostol (Cytotec), a labor-inducing drug.",
    "Addiction to many sedatives and analgesics, such as diazepam, morphine, etc.",
    "Birth defects associated with thalidomide",
    "Bleeding of the intestine associated with aspirin therapy",
    "Cardiovascular disease associated with COX-2 inhibitors (i.e. Vioxx)",
    "Deafness and kidney failure associated with gentamicin (an antibiotic)",
    "having fever after taking paracetamol"
]

for example in examples:
    visualize_entities(example)
    print(f"{'*' * 50}\n")

**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************



**************************************************

