In [1]:
import utils
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification
import os, json
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Reading all files .txt and .ann from MACCROBAT directory
maccrobat_dir = "../data/MACCROBAT/"
all_file_ids = set()
for filename in os.listdir(maccrobat_dir):
    if filename.endswith(".txt"):
        file_id = filename[:-4]
        all_file_ids.add(file_id)
print(f"Total number of files: {len(all_file_ids)}")

Total number of files: 200


In [3]:
# Find frequency of unique labels in the dataset
label_freq = {}
for file_id in tqdm(all_file_ids):
    ann_filepath = os.path.join(maccrobat_dir, file_id + ".ann")
    with open(ann_filepath, "r", encoding="utf-8") as f:
        for line in f:
            if line.startswith("T"):
                parts = line.strip().split("\t")
                label_info = parts[1]
                label = label_info.split(" ")[0]
                if label not in label_freq:
                    label_freq[label] = 0
                label_freq[label] += 1

# Sort labels by frequency in descending order
label_freq = dict(sorted(label_freq.items(), key=lambda item: item[1], reverse=True))

# Top 10 most frequent labels, and age and sex labels
most_frequent_labels = list(label_freq.items())[:10]
most_frequent_labels.append(("Age", label_freq.get("Age", 0)))
most_frequent_labels.append(("Sex", label_freq.get("Sex", 0)))
print("Top 10 most frequent labels:")
for label, freq in most_frequent_labels:
    print(f"Label: {label}, Frequency: {freq}")

100%|██████████| 200/200 [00:00<00:00, 2799.66it/s]

Top 10 most frequent labels:
Label: Diagnostic_procedure, Frequency: 4598
Label: Sign_symptom, Frequency: 3382
Label: Biological_structure, Frequency: 2953
Label: Detailed_description, Frequency: 2920
Label: Lab_value, Frequency: 2848
Label: Disease_disorder, Frequency: 1362
Label: Medication, Frequency: 1080
Label: Therapeutic_procedure, Frequency: 1036
Label: Date, Frequency: 735
Label: Clinical_event, Frequency: 626
Label: Age, Frequency: 206
Label: Sex, Frequency: 191





In [4]:
# Tokenize and convert to BIO tagging format
tokenizer = AutoTokenizer.from_pretrained("medicalai/ClinicalBERT")

examples = []
for file_id in tqdm(all_file_ids):
    text, entities = utils.read_and_extract_maccrobat_file(file_id)
    encoding, tags = utils.bio_tagging(text, entities, tokenizer)
    examples.append({
        "input_ids": encoding["input_ids"],
        "attention_mask": encoding["attention_mask"],
        "labels": tags
    })

# Only consider most frequent labels for training
frequent_labels_set = set([label for label, freq in most_frequent_labels])
BIO_labels = []
for label in frequent_labels_set:
    BIO_labels.append("B-" + label)
    BIO_labels.append("I-" + label)


# Only consider most frequent labels for training
for example in examples:
    example["labels"] = [label if label in BIO_labels else "O" for label in example["labels"]]
label_list = list(BIO_labels)
label_list.sort()
print(f"Labels considered for training: {label_list}")

100%|██████████| 200/200 [00:01<00:00, 131.39it/s]

Labels considered for training: ['B-Age', 'B-Biological_structure', 'B-Clinical_event', 'B-Date', 'B-Detailed_description', 'B-Diagnostic_procedure', 'B-Disease_disorder', 'B-Lab_value', 'B-Medication', 'B-Sex', 'B-Sign_symptom', 'B-Therapeutic_procedure', 'I-Age', 'I-Biological_structure', 'I-Clinical_event', 'I-Date', 'I-Detailed_description', 'I-Diagnostic_procedure', 'I-Disease_disorder', 'I-Lab_value', 'I-Medication', 'I-Sex', 'I-Sign_symptom', 'I-Therapeutic_procedure']





In [5]:
# Map BIO tags to IDs
unique_tags = sorted(set(tag for ex in examples for tag in ex["labels"]))
tag2id = {tag: i for i, tag in enumerate(unique_tags)}
for ex in examples:
    ex["labels"] = [tag2id[tag] for tag in ex["labels"]]

In [6]:
# Split train/eval
from sklearn.model_selection import train_test_split
train_data, eval_data = train_test_split(examples, test_size=0.2, random_state=42)

In [7]:
# Create Datasets
class NERDataset:
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return {
        "input_ids": self.data[idx]["input_ids"],
        "attention_mask": self.data[idx]["attention_mask"],
        "labels": self.data[idx]["labels"]
    }

train_dataset = NERDataset(train_data)
eval_dataset = NERDataset(eval_data)

data_collator = DataCollatorForTokenClassification(tokenizer)

In [8]:
from seqeval.metrics import classification_report, f1_score, accuracy_score, precision_score, recall_score

def compute_metrics(p):
    # Convert predictions and labels from IDs to tag strings
    predictions = p.predictions.argmax(-1)
    true_labels = p.label_ids
    # Remove ignored index (usually -100) and convert to tag names
    pred_tags = [
        [unique_tags[pred] for (pred, label) in zip(pred_seq, label_seq) if label != -100]
        for pred_seq, label_seq in zip(predictions, true_labels)
    ]
    true_tags = [
        [unique_tags[label] for (pred, label) in zip(pred_seq, label_seq) if label != -100]
        for pred_seq, label_seq in zip(predictions, true_labels)
    ]
    return {
        "precision": precision_score(true_tags, pred_tags),
        "recall": recall_score(true_tags, pred_tags),
        "f1": f1_score(true_tags, pred_tags),
        "accuracy": accuracy_score(true_tags, pred_tags)
    }

In [9]:
model = AutoModelForTokenClassification.from_pretrained("medicalai/ClinicalBERT", num_labels=len(unique_tags))
model_dir = "../models/rebalanced_finetuned_ClinicalBERT"

training_args = TrainingArguments(
    output_dir=model_dir,
    learning_rate=5e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    logging_steps=5,                # Log every 50 steps
    logging_dir=f"{model_dir}/logs", # Directory for logs
    weight_decay=0.001,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model(model_dir)
tokenizer.save_pretrained(model_dir)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at medicalai/ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss
5,2.686
10,1.3917
15,1.104
20,0.8908
25,0.7102
30,0.6553
35,0.5916
40,0.5622
45,0.3987
50,0.3874


('../models/rebalanced_finetuned_ClinicalBERT\\tokenizer_config.json',
 '../models/rebalanced_finetuned_ClinicalBERT\\special_tokens_map.json',
 '../models/rebalanced_finetuned_ClinicalBERT\\vocab.txt',
 '../models/rebalanced_finetuned_ClinicalBERT\\added_tokens.json',
 '../models/rebalanced_finetuned_ClinicalBERT\\tokenizer.json')

In [10]:
eval_results = trainer.evaluate()
print(eval_results)

{'eval_loss': 1.0644633769989014, 'eval_precision': 0.5305275637225845, 'eval_recall': 0.6349769421780773, 'eval_f1': 0.5780720167931537, 'eval_accuracy': 0.8072265625, 'eval_runtime': 0.8357, 'eval_samples_per_second': 47.864, 'eval_steps_per_second': 5.983, 'epoch': 10.0}


In [13]:
import torch
from seqeval.metrics import classification_report

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Get predictions and true labels for the whole eval set
predictions, true_labels = [], []
for example in eval_dataset:
    input_ids = torch.tensor([example["input_ids"]]).to(device)
    attention_mask = torch.tensor([example["attention_mask"]]).to(device)
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        pred_ids = outputs.logits.argmax(-1).squeeze().tolist()
    true_label_ids = example["labels"]
    # Remove ignored index (-100) if present
    pred_tags = [unique_tags[pid] for pid, lid in zip(pred_ids, true_label_ids) if lid != -100]
    true_tags = [unique_tags[lid] for lid in true_label_ids if lid != -100]
    predictions.append(pred_tags)
    true_labels.append(true_tags)

# Calculate accuracy
correct = sum(p == t for pred_seq, true_seq in zip(predictions, true_labels) for p, t in zip(pred_seq, true_seq))
total = sum(len(true_seq) for true_seq in true_labels)
accuracy = correct / total
print(f"Overall Accuracy: {accuracy:.4f}")

# Print per-label metrics
print("Evaluation Set Metrics:")
print(classification_report(true_labels, predictions))

Overall Accuracy: 0.8072
Evaluation Set Metrics:
                       precision    recall  f1-score   support

                  Age       1.00      0.95      0.98        42
 Biological_structure       0.57      0.72      0.64       420
       Clinical_event       0.71      0.71      0.71        78
                 Date       0.63      0.72      0.67        65
 Detailed_description       0.30      0.44      0.36       396
 Diagnostic_procedure       0.65      0.71      0.68       594
     Disease_disorder       0.29      0.35      0.32       147
            Lab_value       0.54      0.61      0.58       331
           Medication       0.78      0.82      0.80        84
                  Sex       1.00      0.97      0.99        40
         Sign_symptom       0.58      0.67      0.62       497
Therapeutic_procedure       0.33      0.46      0.38       125

            micro avg       0.53      0.63      0.58      2819
            macro avg       0.61      0.68      0.64      2819
    

In [14]:
example = eval_dataset[0]
input_ids = torch.tensor([example["input_ids"]]).to(device)
attention_mask = torch.tensor([example["attention_mask"]]).to(device)

model.eval()
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)
    pred_ids = outputs.logits.argmax(-1).squeeze().tolist()
    pred_tags = [unique_tags[pid] for pid in pred_ids]

tokens = tokenizer.convert_ids_to_tokens(example["input_ids"])
true_label_ids = example["labels"]
true_tags = [unique_tags[lid] for lid in true_label_ids]

for tok, true, pred in zip(tokens, true_tags, pred_tags):
    print(f"{tok:15} | True: {true:15} | Pred: {pred:15}")

[CLS]           | True: O               | Pred: O              
a               | True: O               | Pred: O              
76              | True: B-Age           | Pred: B-Age          
-               | True: I-Age           | Pred: I-Age          
year            | True: I-Age           | Pred: I-Age          
-               | True: I-Age           | Pred: I-Age          
old             | True: I-Age           | Pred: I-Age          
woman           | True: B-Sex           | Pred: B-Sex          
presented       | True: B-Clinical_event | Pred: B-Clinical_event
to              | True: O               | Pred: O              
our             | True: O               | Pred: O              
hospital        | True: O               | Pred: O              
with            | True: O               | Pred: O              
com             | True: O               | Pred: O              
##plaints       | True: O               | Pred: O              
of              | True: O             

Problem: Some labels don't make sense in the context of BIO tagging scheme, leading to lower accuracy.

Solution: These cases can be corrected by simple heuristics. For example, "I-" cannot exist without a preceding "B-" or "B-" following another "I-" can be corrected to "I-". This can help improve accuracy.

Cases to consider:
- A "B-" tag that is immediately followed by an "I-" of a different type should be changed to "B-" of the same type as the following "I-".
- "O" tags between "B-" and "I-" of the same type should be changed to "I-".

### Label Annealing Heuristics Implementation

In [15]:
# Let's Implement Label Annealing Heuristics to correct some BIO tagging errors
# Cases to consider:
# - A "B-" tag that is immediately followed by an "I-" of a different type should be changed to "B-" of the same type as the following "I-".
# - "O" tags between "B-" and "I-" of the same type should be changed to "I-".

def apply_label_annealing(tags):
    corrected_tags = tags.copy()
    # - A "B-" tag that is immediately followed by an "I-" of a different type should be changed to "B-" of the same type as the following "I-".
    for i in range(len(corrected_tags) - 1):
        if corrected_tags[i].startswith("B-") and corrected_tags[i+1].startswith("I-"):
            curr_type = corrected_tags[i][2:]
            next_type = corrected_tags[i+1][2:]
            if curr_type != next_type:
                corrected_tags[i] = "B-" + next_type
    # - All "O" tags between "B-"/"I-" and "I-" of the same type should be changed to "I-".
    i = 0
    while i < len(corrected_tags) - 2:
        if corrected_tags[i].startswith(("B-", "I-")):
            start_type = corrected_tags[i][2:]
            j = i + 1
            while j < len(corrected_tags) and corrected_tags[j] == "O":
                j += 1
            if j < len(corrected_tags) and corrected_tags[j].startswith("I-"):
                end_type = corrected_tags[j][2:]
                if start_type == end_type:
                    for k in range(i + 1, j):
                        corrected_tags[k] = "I-" + start_type
            i = j
        else:
            i += 1
    return corrected_tags


In [16]:
from seqeval.metrics import classification_report

# Get predictions and true labels for the whole eval set
predictions, true_labels = [], []
for example in eval_dataset:
    input_ids = torch.tensor([example["input_ids"]]).to(device)
    attention_mask = torch.tensor([example["attention_mask"]]).to(device)
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        pred_ids = outputs.logits.argmax(-1).squeeze().tolist()
    true_label_ids = example["labels"]
    # Remove ignored index (-100) if present
    pred_tags = [unique_tags[pid] for pid, lid in zip(pred_ids, true_label_ids) if lid != -100]
    true_tags = [unique_tags[lid] for lid in true_label_ids if lid != -100]
    predictions.append(apply_label_annealing(pred_tags))
    true_labels.append(true_tags)

# Calculate accuracy
correct = sum(p == t for pred_seq, true_seq in zip(predictions, true_labels) for p, t in zip(pred_seq, true_seq))
total = sum(len(true_seq) for true_seq in true_labels)
accuracy = correct / total
print(f"Overall Accuracy: {accuracy:.4f}")

# Print per-label metrics
print("Evaluation Set Metrics:")
print(classification_report(true_labels, predictions))

Overall Accuracy: 0.8032
Evaluation Set Metrics:
                       precision    recall  f1-score   support

                  Age       1.00      0.95      0.98        42
 Biological_structure       0.60      0.72      0.65       420
       Clinical_event       0.71      0.71      0.71        78
                 Date       0.66      0.72      0.69        65
 Detailed_description       0.32      0.44      0.37       396
 Diagnostic_procedure       0.69      0.72      0.70       594
     Disease_disorder       0.32      0.35      0.33       147
            Lab_value       0.56      0.61      0.59       331
           Medication       0.80      0.83      0.81        84
                  Sex       1.00      0.97      0.99        40
         Sign_symptom       0.60      0.67      0.63       497
Therapeutic_procedure       0.35      0.46      0.39       125

            micro avg       0.56      0.64      0.60      2819
            macro avg       0.63      0.68      0.65      2819
    

In [17]:
# Check outputs with label annealing applied
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

example = eval_dataset[0]
input_ids = torch.tensor([example["input_ids"]]).to(device)
attention_mask = torch.tensor([example["attention_mask"]]).to(device)

model.eval()
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)
    pred_ids = outputs.logits.argmax(-1).squeeze().tolist()
    pred_tags = [unique_tags[pid] for pid in pred_ids]
    pred_tags = apply_label_annealing(pred_tags)

tokens = tokenizer.convert_ids_to_tokens(example["input_ids"])
true_label_ids = example["labels"]
true_tags = [unique_tags[lid] for lid in true_label_ids]

for tok, true, pred in zip(tokens, true_tags, pred_tags):
    print(f"{tok:15} | True: {true:15} | Pred: {pred:15}")

[CLS]           | True: O               | Pred: O              
a               | True: O               | Pred: O              
76              | True: B-Age           | Pred: B-Age          
-               | True: I-Age           | Pred: I-Age          
year            | True: I-Age           | Pred: I-Age          
-               | True: I-Age           | Pred: I-Age          
old             | True: I-Age           | Pred: I-Age          
woman           | True: B-Sex           | Pred: B-Sex          
presented       | True: B-Clinical_event | Pred: B-Clinical_event
to              | True: O               | Pred: O              
our             | True: O               | Pred: O              
hospital        | True: O               | Pred: O              
with            | True: O               | Pred: O              
com             | True: O               | Pred: O              
##plaints       | True: O               | Pred: O              
of              | True: O             