# Biomedical Named Entity Recognition with BioBERT

**Ready-to-run Jupyter/Colab notebook** â€” train, evaluate, and run inference for biomedical NER using `dmis-lab/biobert-base-cased-v1.1` (HuggingFace).

**Notebook contents**
1. Install & setup
2. Load dataset (BC5CDR example via `datasets`)
3. Preprocessing & token-label alignment (BIO)
4. Model setup (`PubMedBERT`)
5. Training with `Trainer`
6. Evaluation
7. Inference helper & demo

**Notes**
- This notebook expects an environment with internet access (to download models/datasets). For Colab, select a GPU runtime.
- If you're behind a firewall, download datasets and models manually and adjust paths.

In [1]:
!pip install "transformers>=4.44" "datasets>=2.21" "seqeval" "torch"



## Imports

In [2]:
import os
from typing import List, Dict, Any
import transformers, datasets
import numpy as np
import torch
import torch.nn as nn

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    BertForTokenClassification
)

from seqeval.metrics import precision_score, recall_score, f1_score, classification_report
from transformers.modeling_outputs import TokenClassifierOutput


In [3]:
# Check versions
print('transformers', transformers.__version__)
print('datasets', datasets.__version__)

transformers 4.57.1
datasets 3.6.0


## Configuration

In [4]:
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
DATASET_NAME = "tner/bc5cdr"  # pre-split, tokenized BC5CDR with tags:contentReference[oaicite:1]{index=1}

label_list = ["O", "B-Chemical", "B-Disease", "I-Disease", "I-Chemical"]
id2label = {i: l for i, l in enumerate(label_list)}
label2id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)

## Load dataset and tokenizer

In [5]:
# Load dataset (BC5CDR for chemicals/diseases) via HuggingFace datasets
print("Loading dataset:", DATASET_NAME)
tner_dataset = load_dataset(DATASET_NAME)

print("Splits:", tner_dataset)

Loading dataset: tner/bc5cdr
Splits: DatasetDict({
    train: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 5228
    })
    validation: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 5330
    })
    test: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 5865
    })
})


In [6]:
# Inspect an example
example = tner_dataset['train'][0]
print('keys:', example.keys())
if 'tokens' in example:
    print('tokens sample:', example['tokens'][:40])
if 'tags' in example:
    print('tags sample:', example['tags'][:40])

keys: dict_keys(['tokens', 'tags'])
tokens sample: ['Naloxone', 'reverses', 'the', 'antihypertensive', 'effect', 'of', 'clonidine', '.']
tags sample: [1, 0, 0, 0, 0, 0, 1, 0]


In [7]:
# Preprocessing: tokenize and align labels (BIO scheme)
print("Loading tokenizer:", MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

Loading tokenizer: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext


## Tokenization + label alignment

In [8]:
# Utility to align labels for tokenized inputs
def tokenize_and_align_labels(examples):
    """
    Tokenize the list of token sequences and align the BIO labels
    to the resulting wordpieces.

    examples["tokens"]: List[List[str]]
    examples["tags"]:   List[List[int]]  (indices into label_list)
    """
        
    tokenized = tokenizer(
        examples["tokens"],
        is_split_into_words=True,  # because we already have word tokens
        truncation=True,
        padding=False,
    )
    
    all_labels = examples['tags']
    new_labels = []

    for i, labels in enumerate(all_labels):
        # word_ids maps each subtoken position to its originating word index
        word_ids = tokenized.word_ids(batch_index=i)

        previous_word_id = None
        label_ids = []

        for word_id in word_ids:
            if word_id is None:
                # Special tokens (CLS, SEP, padding later)
                label_ids.append(-100)
            else:
                original_label_id = labels[word_id]

                if word_id != previous_word_id:
                    # First subtoken of the word: use original label
                    label_ids.append(original_label_id)
                else:
                    # Subsequent subtokens of the same word:
                    # convert B-* to I-* to respect BIO scheme
                    if original_label_id == label2id["B-Chemical"]:
                        label_ids.append(label2id["I-Chemical"])
                    elif original_label_id == label2id["B-Disease"]:
                        label_ids.append(label2id["I-Disease"])
                    else:
                        # For I-* or O, keep same
                        label_ids.append(original_label_id)

                previous_word_id = word_id

        new_labels.append(label_ids)

    tokenized["labels"] = new_labels
    return tokenized

In [9]:
print("Tokenizing and aligning labels...")
remove_columns = tner_dataset["train"].column_names  # ["tokens", "tags"]
tokenized_datasets = tner_dataset.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=remove_columns,
)

tokenized_datasets

Tokenizing and aligning labels...


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5228
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5330
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5865
    })
})

## Data collator

In [10]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

## Model

In [11]:
class WeightedBertForTokenClassification(BertForTokenClassification):
    def __init__(self, config, label_weights=None):
        super().__init__(config)
        if label_weights is not None:
            self.register_buffer("label_weights", torch.tensor(label_weights, dtype=torch.float))
        else:
            self.label_weights = None

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs,  # <-- Required to absorb Trainer extra args
    ):
        # Run BERT encoder
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )

        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(
                weight=self.label_weights,
                ignore_index=-100
            )
            loss = loss_fct(
                logits.view(-1, self.num_labels),
                labels.view(-1)
            )

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [12]:
print("Loading model:", MODEL_NAME)
from transformers import AutoConfig

config = AutoConfig.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

label_weights = [1.0, 1.0, 1.5, 1.5, 1.0]  # [O, B-chem, B-dis, I-dis, I-chem]

model = WeightedBertForTokenClassification.from_pretrained(
    MODEL_NAME,
    config=config,
    label_weights=label_weights,
)


Loading model: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext


Some weights of WeightedBertForTokenClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight', 'label_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Metrics

In [15]:
def compute_metrics(p):
    """
    p is an EvalPrediction with:
    - p.predictions: np.array (batch, seq_len, num_labels)
    - p.label_ids:   np.array (batch, seq_len)
    """
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_labels: List[List[str]] = []
    true_predictions: List[List[str]] = []

    for pred_seq, label_seq in zip(predictions, labels):
        # filter out positions where label == -100
        valid_indices = label_seq != -100
        pred_seq = pred_seq[valid_indices]
        label_seq = label_seq[valid_indices]

        true_labels.append([id2label[int(l)] for l in label_seq])
        true_predictions.append([id2label[int(p_i)] for p_i in pred_seq])

    precision = precision_score(true_labels, true_predictions)
    recall = recall_score(true_labels, true_predictions)
    f1 = f1_score(true_labels, true_predictions)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

## Training configuration

In [16]:
training_args = TrainingArguments(
    output_dir="./pubmedbert_bc5cdr_weighted",
    eval_strategy="epoch",
    save_strategy="epoch",
    metric_for_best_model="f1",
    greater_is_better=True,
    load_best_model_at_end=True,

    learning_rate=1e-5,
    num_train_epochs=4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.05,
    warmup_ratio=0.1,

    dataloader_pin_memory=False,  # for MPS
    label_smoothing_factor=0.0,
)

## Trainer

In [17]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    processing_class=tokenizer,          # or tokenizer=tokenizer on older HF
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


## Training + evaluation

In [18]:
def train_and_evaluate():
    # Train
    trainer.train()

    # Evaluate on validation and test
    print("Validation metrics:")
    val_metrics = trainer.evaluate(tokenized_datasets["validation"])
    print(val_metrics)

    print("Test metrics:")
    test_metrics = trainer.evaluate(tokenized_datasets["test"])
    print(test_metrics)

    # Optional: detailed report on test set
    print("Detailed seqeval report on test set:")
    predictions, labels, _ = trainer.predict(tokenized_datasets["test"])
    predictions = np.argmax(predictions, axis=2)

    true_labels = []
    true_predictions = []

    for pred_seq, label_seq in zip(predictions, labels):
        valid_indices = label_seq != -100
        pred_seq = pred_seq[valid_indices]
        label_seq = label_seq[valid_indices]

        true_labels.append([id2label[int(l)] for l in label_seq])
        true_predictions.append([id2label[int(p_i)] for p_i in pred_seq])

    print(classification_report(true_labels, true_predictions))

In [19]:
train_and_evaluate()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.3945,0.097584,0.86983,0.877177,0.873488
2,0.0953,0.092474,0.872252,0.901991,0.886873
3,0.0662,0.104711,0.868938,0.917318,0.892473
4,0.0427,0.106975,0.87954,0.916588,0.897682


Validation metrics:


{'eval_loss': 0.10697530210018158, 'eval_precision': 0.8795397698849424, 'eval_recall': 0.9165884683557501, 'eval_f1': 0.8976820177677933, 'eval_runtime': 24.237, 'eval_samples_per_second': 219.911, 'eval_steps_per_second': 27.52, 'epoch': 4.0}
Test metrics:
{'eval_loss': 0.11681059002876282, 'eval_precision': 0.8581758580794446, 'eval_recall': 0.9074319502497706, 'eval_f1': 0.8821168425746989, 'eval_runtime': 42.8298, 'eval_samples_per_second': 136.937, 'eval_steps_per_second': 17.138, 'epoch': 4.0}
Detailed seqeval report on test set:
              precision    recall  f1-score   support

    Chemical       0.91      0.93      0.92      5385
     Disease       0.80      0.87      0.84      4424

   micro avg       0.86      0.91      0.88      9809
   macro avg       0.85      0.90      0.88      9809
weighted avg       0.86      0.91      0.88      9809



## Inference helper

In [20]:
def ner_inference(text: str, max_length: int = 256):
    """
    Run NER on a new biomedical sentence/abstract.
    Returns entities with type and char spans.
    """
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Tokenize as raw text (not pre-split)
    encoded = tokenizer(
        text,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    ).to(device)

    with torch.no_grad():
        outputs = model(**encoded)
        logits = outputs.logits  # (1, seq_len, num_labels)
        pred_ids = torch.argmax(logits, dim=-1).cpu().numpy()[0]

    # Map subtokens back to words using tokenizer
    tokens = tokenizer.convert_ids_to_tokens(encoded["input_ids"][0])
    # We will reconstruct entity spans in a simple way: group consecutive non-"O"
    entities = []
    current_entity = None

    # skip [CLS] (0) and stop at [SEP]
    for i, (token, label_id) in enumerate(zip(tokens, pred_ids)):
        if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]:
            if current_entity is not None:
                entities.append(current_entity)
                current_entity = None
            if token == tokenizer.sep_token:
                break
            continue

        label = id2label[int(label_id)]
        if label == "O":
            if current_entity is not None:
                entities.append(current_entity)
                current_entity = None
            continue

        # label is B-* or I-*
        label_type = label.split("-", 1)[1]

        # Approximate char span via tokenizer offsets
        offsets = encoded.token_to_chars(i)
        if offsets is None:
            # This can happen rarely; we just skip char span
            start_char, end_char = None, None
        else:
            start_char, end_char = offsets.start, offsets.end

        if current_entity is None:
            current_entity = {
                "type": label_type,
                "text": text[start_char:end_char] if start_char is not None else token,
                "start": start_char,
                "end": end_char,
            }
        else:
            # Same type? continue span
            if current_entity["type"] == label_type:
                if start_char is not None and end_char is not None:
                    # extend span
                    current_entity["end"] = end_char
                    current_entity["text"] = text[current_entity["start"]:current_entity["end"]]
            else:
                # different type, close previous and start new
                entities.append(current_entity)
                current_entity = {
                    "type": label_type,
                    "text": text[start_char:end_char] if start_char is not None else token,
                    "start": start_char,
                    "end": end_char,
                }

    if current_entity is not None:
        entities.append(current_entity)

    return entities


In [21]:
# Example inference after training:
example = "Paracetamol can cause liver toxicity in high doses."
ents = ner_inference(example)
print("\nExample inference:")
print("Text:", example)
for e in ents:
    print(e)


Example inference:
Text: Paracetamol can cause liver toxicity in high doses.
{'type': 'Chemical', 'text': 'Paracetamol', 'start': 0, 'end': 11}
{'type': 'Disease', 'text': 'liver toxicity', 'start': 22, 'end': 36}
