## 3. NER Finetuning

### Prepare data for NER finetuning

In [5]:
import utils
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
import os, json
from tqdm import tqdm

In [2]:
# Reading all files .txt and .ann from MACCROBAT directory
maccrobat_dir = "../data/MACCROBAT/"
all_file_ids = set()
for filename in os.listdir(maccrobat_dir):
    if filename.endswith(".txt"):
        file_id = filename[:-4]
        all_file_ids.add(file_id)
print(f"Total number of files: {len(all_file_ids)}")

Total number of files: 200


In [None]:
# Tokenize and convert to BIO tagging format
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")

examples = []
for file_id in tqdm(all_file_ids):
    text, entities = utils.read_and_extract_maccrobat_file(file_id)
    encoding, tags = utils.bio_tagging(text, entities, tokenizer)
    examples.append({
        "input_ids": encoding["input_ids"],
        "attention_mask": encoding["attention_mask"],
        "labels": tags
    })

100%|██████████| 200/200 [00:01<00:00, 115.06it/s]


In [7]:
# Map BIO tags to IDs
unique_tags = sorted(set(tag for ex in examples for tag in ex["labels"]))
tag2id = {tag: i for i, tag in enumerate(unique_tags)}
for ex in examples:
    ex["labels"] = [tag2id[tag] for tag in ex["labels"]]

In [8]:
# Split train/eval
from sklearn.model_selection import train_test_split
train_data, eval_data = train_test_split(examples, test_size=0.2, random_state=42)

In [14]:
# Create Datasets
class NERDataset:
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return {
        "input_ids": self.data[idx]["input_ids"],
        "attention_mask": self.data[idx]["attention_mask"],
        "labels": self.data[idx]["labels"]
    }

train_dataset = NERDataset(train_data)
eval_dataset = NERDataset(eval_data)

data_collator = DataCollatorForTokenClassification(tokenizer)

### Training

In [13]:
from seqeval.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score

def compute_metrics(p):
    # Convert predictions and labels from IDs to tag strings
    predictions = p.predictions.argmax(-1)
    true_labels = p.label_ids
    # Remove ignored index (usually -100) and convert to tag names
    pred_tags = [
        [unique_tags[pred] for (pred, label) in zip(pred_seq, label_seq) if label != -100]
        for pred_seq, label_seq in zip(predictions, true_labels)
    ]
    true_tags = [
        [unique_tags[label] for (pred, label) in zip(pred_seq, label_seq) if label != -100]
        for pred_seq, label_seq in zip(predictions, true_labels)
    ]
    return {
        "precision": precision_score(true_tags, pred_tags),
        "recall": recall_score(true_tags, pred_tags),
        "f1": f1_score(true_tags, pred_tags),
        "accuracy": accuracy_score(true_tags, pred_tags)
    }

In [38]:
model = AutoModelForTokenClassification.from_pretrained("medicalai/ClinicalBERT", num_labels=len(unique_tags))
model_dir = "../models/finetuned_ClinicalBERT"

training_args = TrainingArguments(
    output_dir=model_dir,
    learning_rate=5e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    logging_steps=5,                # Log every 50 steps
    logging_dir=f"{model_dir}/logs", # Directory for logs
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model(model_dir)
tokenizer.save_pretrained(model_dir)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at medicalai/ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss
5,2.9783
10,1.739
15,1.386
20,1.1157
25,0.9176
30,0.8676
35,0.7913
40,0.7585
45,0.5781
50,0.5425


('../models/finetuned_ClinicalBERT\\tokenizer_config.json',
 '../models/finetuned_ClinicalBERT\\special_tokens_map.json',
 '../models/finetuned_ClinicalBERT\\vocab.txt',
 '../models/finetuned_ClinicalBERT\\added_tokens.json',
 '../models/finetuned_ClinicalBERT\\tokenizer.json')

In [40]:
eval_results = trainer.evaluate()
print(eval_results)

{'eval_loss': 1.2086138725280762, 'eval_precision': 0.5343298707879403, 'eval_recall': 0.62322695035461, 'eval_f1': 0.5753648888282636, 'eval_accuracy': 0.776953125, 'eval_runtime': 2.361, 'eval_samples_per_second': 16.942, 'eval_steps_per_second': 2.118, 'epoch': 10.0}


In [37]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

example = eval_dataset[0]
input_ids = torch.tensor([example["input_ids"]]).to(device)
attention_mask = torch.tensor([example["attention_mask"]]).to(device)

model.eval()
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)
    pred_ids = outputs.logits.argmax(-1).squeeze().tolist()
    pred_tags = [unique_tags[pid] for pid in pred_ids]

tokens = tokenizer.convert_ids_to_tokens(example["input_ids"])
true_label_ids = example["labels"]
true_tags = [unique_tags[lid] for lid in true_label_ids]

for tok, true, pred in zip(tokens, true_tags, pred_tags):
    print(f"{tok:15} | True: {true:15} | Pred: {pred:15}")

[CLS]           | True: O               | Pred: O              
a               | True: O               | Pred: O              
38              | True: B-Age           | Pred: B-Age          
-               | True: I-Age           | Pred: I-Age          
year            | True: I-Age           | Pred: I-Age          
-               | True: I-Age           | Pred: I-Age          
old             | True: I-Age           | Pred: I-Age          
woman           | True: B-Sex           | Pred: B-Sex          
presented       | True: B-Clinical_event | Pred: B-Clinical_event
to              | True: O               | Pred: O              
our             | True: O               | Pred: O              
emergency       | True: B-Nonbiological_location | Pred: B-Nonbiological_location
department      | True: I-Nonbiological_location | Pred: I-Nonbiological_location
with            | True: O               | Pred: O              
severe          | True: B-Severity      | Pred: B-Severity     
ab

### Per Label Metrics

In [43]:
# Get predictions and true labels for the whole eval set
predictions, true_labels = [], []
for example in train_dataset:
    input_ids = torch.tensor([example["input_ids"]]).to(device)
    attention_mask = torch.tensor([example["attention_mask"]]).to(device)
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        pred_ids = outputs.logits.argmax(-1).squeeze().tolist()
    true_label_ids = example["labels"]
    # Remove ignored index (-100) if present
    pred_tags = [unique_tags[pid] for pid, lid in zip(pred_ids, true_label_ids) if lid != -100]
    true_tags = [unique_tags[lid] for lid in true_label_ids if lid != -100]
    predictions.append(pred_tags)
    true_labels.append(true_tags)

# Print per-label metrics
print("Taraining Set Metrics:")
print(classification_report(true_labels, predictions))

Taraining Set Metrics:
                        precision    recall  f1-score   support

              Activity       0.94      0.89      0.91        72
        Administration       1.00      1.00      1.00        65
                   Age       0.99      0.97      0.98       165
                  Area       0.96      0.96      0.96        28
  Biological_attribute       0.00      0.00      0.00         9
  Biological_structure       0.99      0.99      0.99      1719
        Clinical_event       0.96      0.99      0.98       315
                 Color       0.95      1.00      0.97        38
           Coreference       0.90      0.91      0.91       184
                  Date       0.98      0.97      0.98       358
  Detailed_description       0.99      0.99      0.99      1593
  Diagnostic_procedure       1.00      1.00      1.00      2474
      Disease_disorder       0.99      0.99      0.99       641
              Distance       0.97      1.00      0.99        78
                

In [44]:
from seqeval.metrics import classification_report

# Get predictions and true labels for the whole eval set
predictions, true_labels = [], []
for example in eval_dataset:
    input_ids = torch.tensor([example["input_ids"]]).to(device)
    attention_mask = torch.tensor([example["attention_mask"]]).to(device)
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        pred_ids = outputs.logits.argmax(-1).squeeze().tolist()
    true_label_ids = example["labels"]
    # Remove ignored index (-100) if present
    pred_tags = [unique_tags[pid] for pid, lid in zip(pred_ids, true_label_ids) if lid != -100]
    true_tags = [unique_tags[lid] for lid in true_label_ids if lid != -100]
    predictions.append(pred_tags)
    true_labels.append(true_tags)

# Print per-label metrics
print("Evaluation Set Metrics:")
print(classification_report(true_labels, predictions))

Evaluation Set Metrics:
                        precision    recall  f1-score   support

              Activity       0.00      0.00      0.00         4
        Administration       0.52      0.55      0.53        22
                   Age       1.00      0.95      0.97        41
                  Area       0.08      0.10      0.09        10
  Biological_attribute       0.00      0.00      0.00         1
  Biological_structure       0.58      0.73      0.65       404
        Clinical_event       0.81      0.68      0.74        84
                 Color       0.50      1.00      0.67         1
           Coreference       0.12      0.10      0.11        41
                  Date       0.65      0.80      0.71        64
  Detailed_description       0.37      0.44      0.40       424
  Diagnostic_procedure       0.64      0.70      0.67       670
      Disease_disorder       0.39      0.43      0.41       187
              Distance       0.29      0.50      0.37        18
               

  _warn_prf(average, modifier, msg_start, len(result))


Problem: There are a lot of labels in NER tasks, and some of them are very rare. When calculating per-label metrics (precision, recall, F1-score), the rare labels can lead to misleading results because they may not have enough examples to provide reliable statistics.

Solution: Find out the current label frequency distribution and reblance/remove the rare labels.