# Biomedical Named Entity Recognition with BioBERT

**Ready-to-run Jupyter/Colab notebook** — train, evaluate, and run inference for biomedical NER using `dmis-lab/biobert-base-cased-v1.1` (HuggingFace).

**Notebook contents**
1. Install & setup
2. Load dataset (BC5CDR example via `datasets`)
3. Preprocessing & token-label alignment (BIO)
4. Model setup (`AutoModelForTokenClassification`)
5. Training with `Trainer`
6. Evaluation with `seqeval`
7. Inference helper & demo
8. Tips for improvements and extensions

**Notes**
- This notebook expects an environment with internet access (to download models/datasets). For Colab, select a GPU runtime.
- If you're behind a firewall, download datasets and models manually and adjust paths.

In [169]:
!pip install transformers seqeval accelerate evaluate



In [171]:
# Check versions
import transformers, datasets
print('transformers', transformers.__version__)
print('datasets', datasets.__version__)

transformers 4.57.1
datasets 3.6.0


In [172]:
# Load dataset (BC5CDR for chemicals/diseases) via HuggingFace datasets
from datasets import load_dataset
tner_dataset = load_dataset('tner/bc5cdr')
print(tner_dataset)

DatasetDict({
    train: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 5228
    })
    validation: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 5330
    })
    test: Dataset({
        features: ['tokens', 'tags'],
        num_rows: 5865
    })
})


In [173]:
# Inspect an example
example = tner_dataset['train'][0]
print('keys:', example.keys())
if 'tokens' in example:
    print('tokens sample:', example['tokens'][:40])
if 'tags' in example:
    print('tags sample:', example['tags'][:40])

keys: dict_keys(['tokens', 'tags'])
tokens sample: ['Naloxone', 'reverses', 'the', 'antihypertensive', 'effect', 'of', 'clonidine', '.']
tags sample: [1, 0, 0, 0, 0, 0, 1, 0]


## Preprocessing

In [174]:
# Preprocessing: tokenize and align labels (BIO scheme)
from transformers import BertTokenizerFast

MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

# Utility to align labels for tokenized inputs
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples['tokens'], is_split_into_words=True, truncation=True, padding='max_length', max_length=256)
    all_labels = examples['tags']
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(labels[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        new_labels.append(label_ids)
    tokenized_inputs['labels'] = new_labels
    return tokenized_inputs

# Map over dataset
tokenized_datasets = tner_dataset.map(tokenize_and_align_labels, batched=True, remove_columns=tner_dataset['train'].column_names)
tokenized_datasets.set_format(type='torch')
tokenized_datasets


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5228
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5330
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5865
    })
})

## Model

In [175]:
def infer_label_map(dataset):
    unique_ids = sorted(
        set(tag for ex in dataset["train"]["tags"] for tag in ex)
    )

    # If 0 is present → assume "O"
    label_names = ["O"] if 0 in unique_ids else []

    # Find entity spans
    spans = []
    for tokens, tags in zip(dataset["train"]["tokens"], dataset["train"]["tags"]):
        current = []
        for tok, tag in zip(tokens, tags):
            if tag != 0:
                current.append(tok)
            elif current:
                spans.append(tuple(current))
                current = []
        if current:
            spans.append(tuple(current))

    # Unique span samples → unique entity categories
    entity_types = ["Chemical", "Disease"]

    for ent in entity_types:
        label_names.append(f"B-{ent}")
        label_names.append(f"I-{ent}")

    return label_names


In [176]:
# Get label list from dataset feature (adjust if different)
label_names = infer_label_map(tner_dataset)

print(label_names)

id_to_label = {i: label_names[i] for i in range(len(label_names))}
label_to_id = {label_names[i]: i for i in range(len(label_names))}

print(id_to_label)
print(label_to_id)

['O', 'B-Chemical', 'I-Chemical', 'B-Disease', 'I-Disease']
{0: 'O', 1: 'B-Chemical', 2: 'I-Chemical', 3: 'B-Disease', 4: 'I-Disease'}
{'O': 0, 'B-Chemical': 1, 'I-Chemical': 2, 'B-Disease': 3, 'I-Disease': 4}


In [177]:
# Model setup for token classification
from transformers import BertConfig, BertForTokenClassification

config = BertConfig.from_pretrained(
    MODEL_NAME,
    num_labels=len(label_names),
    id2label=id_to_label,
    label2id=label_to_id
)

model = BertForTokenClassification.from_pretrained(
    MODEL_NAME,
    config=config
)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Training

In [178]:
# Training with HuggingFace Trainer
from transformers import TrainingArguments, Trainer, DataCollatorForTokenClassification
import numpy as np
import evaluate

# Metrics - using seqeval for entity-level metrics
metric = evaluate.load("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=-1)

    true_predictions = []
    true_labels = []

    for pred, lab in zip(predictions, labels):
        filtered_preds = []
        filtered_labels = []
        for p_i, l_i in zip(pred, lab):
            if l_i != -100:               # <--- ignore padding tokens
                filtered_preds.append(id_to_label[int(p_i)])
                filtered_labels.append(id_to_label[int(l_i)])
        true_predictions.append(filtered_preds)
        true_labels.append(filtered_labels)

    results = metric.compute(predictions=true_predictions, references=true_labels)

    overall = {
        'overall_precision': results['overall_precision'],
        'overall_recall': results['overall_recall'],
        'overall_f1': results['overall_f1'],
        'overall_accuracy': results['overall_accuracy']
    }
    return overall

batch_size = 16
args = TrainingArguments(
    output_dir = 'biobert-ner-run',
    eval_strategy = 'epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_steps=50,
    push_to_hub=False,
    save_strategy='epoch'
)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics
)

# Uncomment to train (requires GPU for reasonable speed)
trainer.train()


Epoch,Training Loss,Validation Loss,Overall Precision,Overall Recall,Overall F1,Overall Accuracy
1,0.0917,0.089167,0.85334,0.846017,0.849663,0.970309
2,0.0537,0.092946,0.847368,0.8733,0.860139,0.971558
3,0.0279,0.10485,0.857166,0.874514,0.865753,0.97278
4,0.0149,0.117621,0.858666,0.880424,0.869409,0.973211
5,0.0094,0.124546,0.856929,0.880586,0.868597,0.972908




TrainOutput(global_step=1635, training_loss=0.052293134391854665, metrics={'train_runtime': 2577.2463, 'train_samples_per_second': 10.143, 'train_steps_per_second': 0.634, 'total_flos': 3415241238988800.0, 'train_loss': 0.052293134391854665, 'epoch': 5.0})

## Evaluate

In [187]:
# Evaluate (run after training or with a pre-trained checkpoint)
from transformers import pipeline

# Load pipeline for token-classification (fast option)
nlp = pipeline(
    "token-classification",
    model=model,
    tokenizer=tokenizer,
    aggregation_strategy="simple",
    ignore_subwords=True
)

# Demo (replace with an actual abstract)
demo_text = "Aspirin is frequently used for pain relief in patients suffering from rheumatoid arthritis."
raw = nlp(demo_text)

print(raw)

Device set to use mps:0


[{'entity': 'B-Chemical', 'score': np.float32(0.99944955), 'index': 1, 'word': 'As', 'start': 0, 'end': 2}, {'entity': 'B-Chemical', 'score': np.float32(0.8595811), 'index': 2, 'word': '##pi', 'start': 2, 'end': 4}, {'entity': 'I-Disease', 'score': np.float32(0.8810123), 'index': 3, 'word': '##rin', 'start': 4, 'end': 7}, {'entity': 'I-Chemical', 'score': np.float32(0.9987203), 'index': 8, 'word': 'pain', 'start': 31, 'end': 35}, {'entity': 'I-Chemical', 'score': np.float32(0.99668103), 'index': 14, 'word': 'r', 'start': 70, 'end': 71}, {'entity': 'B-Disease', 'score': np.float32(0.8686689), 'index': 15, 'word': '##he', 'start': 71, 'end': 73}, {'entity': 'B-Disease', 'score': np.float32(0.9250321), 'index': 16, 'word': '##uma', 'start': 73, 'end': 76}, {'entity': 'B-Disease', 'score': np.float32(0.99834096), 'index': 17, 'word': '##to', 'start': 76, 'end': 78}, {'entity': 'B-Disease', 'score': np.float32(0.99886876), 'index': 18, 'word': '##id', 'start': 78, 'end': 80}, {'entity': 'B-

In [189]:
def merge_subwords(entities):
    merged = []
    for ent in entities:
        w = ent['word']
        if w.startswith('##') and merged:
            # attach to previous token (remove '##'), extend end, update score
            prev = merged[-1]
            prev['word'] += w[2:]
            prev['end'] = ent['end']
            prev['score'] = max(prev['score'], float(ent['score']))
        else:
            # normalize numpy floats to Python float
            ent['score'] = float(ent['score'])
            merged.append(dict(ent))
    return merged

def display_entities(entities):
    print(f"{'Entity Type':<12} | {'Text':<30} | {'Score':<6} | Start-End")
    print("-" * 65)
    for e in entities:
        entity_type = e['entity']
        word = e['word']
        score = round(float(e['score']), 3)  # round to 3 decimals
        start, end = e['start'], e['end']
        print(f"{entity_type:<12} | {word:<30} | {score:<6} | {start}-{end}")

ents_merged = merge_subwords(raw)  # from previous step
display_entities(ents_merged)


Entity Type  | Text                           | Score  | Start-End
-----------------------------------------------------------------
B-Chemical   | Aspirin                        | 0.999  | 0-7
I-Chemical   | pain                           | 0.999  | 31-35
I-Chemical   | rheumatoid                     | 0.999  | 70-80
B-Disease    | arthritis                      | 1.0    | 81-90


In [190]:
# Inference helper (function) for raw text -> BIO spans
import torch
from typing import List, Dict

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
model.eval()

def get_entities_from_text(text):
    tokens = tokenizer(text, return_tensors="pt")
    tokens = {k: v.to(device) for k, v in tokens.items()}

    with torch.no_grad():
        outputs = model(**tokens)
        preds = outputs.logits.argmax(-1).squeeze().tolist()

    input_ids = tokens["input_ids"].squeeze().tolist()
    id2label = model.config.id2label

    entities = []
    current_text = ""
    current_label = None

    for token_id, pred in zip(input_ids, preds):
        token = tokenizer.convert_ids_to_tokens(token_id)
        label = id2label[pred]

        # skip special tokens
        if token in ["[CLS]", "[SEP]"]:
            continue

        # remove subword markers
        clean_token = token.replace("##", "")

        if label == current_label:
            # continuation
            current_text += clean_token
        else:
            # start a new entity
            if current_label is not None:
                entities.append({"entity": current_label, "text": current_text})
            
            # start the next one
            current_label = label
            current_text = clean_token

    # append last entity
    if current_label is not None:
        entities.append({"entity": current_label, "text": current_text})

    return entities


print(get_entities_from_text(demo_text))



[{'entity': 'B-Chemical', 'text': 'Aspi'}, {'entity': 'I-Disease', 'text': 'rin'}, {'entity': 'O', 'text': 'isfrequentlyusedfor'}, {'entity': 'I-Chemical', 'text': 'pain'}, {'entity': 'O', 'text': 'reliefinpatientssufferingfrom'}, {'entity': 'I-Chemical', 'text': 'r'}, {'entity': 'B-Disease', 'text': 'heumatoidarthritis'}, {'entity': 'O', 'text': '.'}]


## 8) Tips, improvements and next steps

- **Longer training**: increase epochs to 4-6 and use LR warmup.
- **Domain adaptation**: further pretrain on in-domain corpus (PubMed subset) using MLM before fine-tuning.
- **Label scheme**: consider BIOES for improved boundary detection.
- **Data augmentation**: back-translation, mention replacement.
- **Evaluation**: produce per-entity confusion matrices and error analysis.
- **Deploy**: use FastAPI + Docker or a Gradio demo for quick sharing.

---

If you'd like, I can also produce:
- A runnable Colab link (I can adapt the notebook for immediate Colab paste),
- A minimal Streamlit/Gradio demo app, or
- A LaTeX-ready project report template summarizing experiments and results.
