In [39]:
# !pip install seqeval

In [40]:
# !pip install evaluate

In [41]:
from google.colab import drive
drive.mount('/content/drive')
base_path = '/content/drive/MyDrive/NLP/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [42]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [43]:
import json
from collections import defaultdict
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
from sklearn.metrics import classification_report as sk_classification_report
from sklearn.metrics import accuracy_score
from datasets import Dataset
import torch
import random
from collections import Counter
from transformers import set_seed

SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

set_seed(SEED)

def convert_to_iob(entries):
    samples = []
    label_map = defaultdict(list)

    for entry in entries:
        tokens = entry["tokens"]
        text = entry["text"]
        spans = entry.get("spans", [])

        labels = ["O"] * len(tokens)

        for span in spans:
            start_token = span["token_start"]
            end_token = span["token_end"]
            label = span["label"]

            labels[start_token] = f"B-{label}"
            for i in range(start_token + 1, end_token + 1):
                labels[i] = f"I-{label}"

            label_map[label].append(span)

        token_label_pairs = [
            (token["text"], label)
            for token, label in zip(tokens, labels)
        ]
        samples.append(token_label_pairs)

    return samples, label_map

entries_silver = []
entries_golden = []
with open(base_path+"dataset/synthesis/silver.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        entries_silver.append(json.loads(line))

with open(base_path+"dataset/cleaned/NER/processed_merged.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        entries_golden.append(json.loads(line))

silver_data, label_stats_silver = convert_to_iob(entries_silver)
golden_data, label_stats_golden = convert_to_iob(entries_golden)

# Printed label statistics
print(f"{len(label_stats_golden)} types of entities are discovered:")
for label, spans in label_stats_golden.items():
    print(f"- {label}: {len(spans)} samples")

7 types of entities are discovered:
- AGE_ONSET: 93 samples
- PATIENT: 246 samples
- HPO_TERM: 2525 samples
- GENE: 252 samples
- GENE_VARIANT: 404 samples
- AGE_FOLLOWUP: 76 samples
- AGE_DEATH: 29 samples


In [44]:
label_list = ["O"] + [
    f"{pre}-{label}"
    for label in label_stats_golden.keys()
    for pre in ["B", "I"]
]

# model_name = "bert-base-cased"
# model_name = "allenai/scibert_scivocab_uncased"
# model_name = "dmis-lab/biobert-base-cased-v1.1"
model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
# model_name = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(label_list),
    id2label={i: label for i, label in enumerate(label_list)},
    label2id={label: i for i, label in enumerate(label_list)}
)


Some weights of BertForTokenClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
def encode_data(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        padding="max_length",
        max_length=256
    )

    label = examples["labels"]
    word_ids = tokenized_inputs.word_ids()
    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)
        elif word_idx != previous_word_idx:
            label_ids.append(label[word_idx])
        else:
            label_ids.append(-100)
        previous_word_idx = word_idx

    tokenized_inputs["labels"] = label_ids
    return tokenized_inputs

formatted_data_silver = [{"tokens": [t[0] for t in sample], "labels": [label_list.index(t[1]) for t in sample]}
                 for sample in silver_data]
formatted_data_golden = [{"tokens": [t[0] for t in sample], "labels": [label_list.index(t[1]) for t in sample]}
                 for sample in golden_data]

from sklearn.model_selection import train_test_split
train_set, temp_set = train_test_split(formatted_data_golden, test_size=0.3, random_state=SEED)
val_set, test_set = train_test_split(temp_set, test_size=0.5, random_state=SEED)

encoded_train = [encode_data(d) for d in train_set]
encoded_val = [encode_data(d) for d in val_set]
encoded_test = [encode_data(d) for d in test_set]

In [46]:
training_args = TrainingArguments(
    output_dir=base_path + "./model/NER",
    eval_strategy="epoch",
    num_train_epochs=15,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=5e-5,
    weight_decay=0.05,
    warmup_ratio=0.05,
    save_strategy="epoch",
    load_best_model_at_end=True,
    seed=SEED,
    metric_for_best_model="strict_f1",
    save_total_limit=2,
    logging_strategy="epoch",
    report_to="none",
)

from evaluate import load
seqeval = load("seqeval")

from itertools import chain

def is_loose_match(true_tag, pred_tag):
    if true_tag != "O" and pred_tag != "O":
        return true_tag.split("-")[-1] == pred_tag.split("-")[-1]
    return False

def strip_prefix(tag):
    return tag.split("-")[-1] if tag != "O" else "O"

def compute_metrics(p):
    predictions, labels = p
    pred_ids = np.argmax(predictions, axis=2)

    true_labels = [[label_list[l] for l in lab if l != -100] for lab in labels]
    pred_labels = [[label_list[p] for p, l in zip(pred, lab) if l != -100]
            for pred, lab in zip(pred_ids, labels)]

    # Strict matching
    strict = seqeval.compute(predictions=pred_labels, references=true_labels)
    strict_precision = strict["overall_precision"]
    strict_recall = strict["overall_recall"]
    strict_f1 = strict["overall_f1"]

    # Loose matching
    flat_true = list(chain.from_iterable(true_labels))
    flat_pred = list(chain.from_iterable(pred_labels))

    adjusted_pred = [t if is_loose_match(t, p) else p for t, p in zip(flat_true, flat_pred)]

    flat_true_no_prefix = [strip_prefix(t) for t in flat_true]
    adjusted_pred_no_prefix = [strip_prefix(p) for p in adjusted_pred]

    labels_without_O = sorted((set(flat_true_no_prefix) | set(adjusted_pred_no_prefix)) - {"O"})

    report = sk_classification_report(
        flat_true_no_prefix,
        adjusted_pred_no_prefix,
        labels=labels_without_O,
        output_dict=True,
        zero_division=0
    )

    loose_precision = report["weighted avg"]["precision"]
    loose_recall = report["weighted avg"]["recall"]
    loose_f1 = report["weighted avg"]["f1-score"]

    return {
        "precision": loose_precision,
        "recall": loose_recall,
        "f1": loose_f1,

        "strict_precision": strict_precision,
        "strict_recall": strict_recall,
        "strict_f1": strict_f1
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_train,
    eval_dataset=encoded_val,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Strict Precision,Strict Recall,Strict F1
1,1.0431,0.279032,0.642942,0.650748,0.645569,0.419732,0.45471,0.436522
2,0.2292,0.230018,0.666953,0.846044,0.739315,0.498695,0.692029,0.579666
3,0.1539,0.190838,0.811361,0.73913,0.769667,0.579119,0.643116,0.609442
4,0.1007,0.205735,0.766812,0.796151,0.779728,0.601246,0.699275,0.646566
5,0.0711,0.240439,0.768283,0.802566,0.783661,0.601246,0.699275,0.646566
6,0.0521,0.232355,0.797705,0.791162,0.791657,0.639033,0.67029,0.654288
7,0.0369,0.292072,0.806452,0.779758,0.79138,0.649396,0.681159,0.664898
8,0.0293,0.29612,0.763525,0.818247,0.788764,0.612245,0.706522,0.656013
9,0.0188,0.349296,0.839112,0.729152,0.778469,0.648598,0.628623,0.638454


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=513, training_loss=0.19278987597303782, metrics={'train_runtime': 362.6578, 'train_samples_per_second': 18.613, 'train_steps_per_second': 2.358, 'total_flos': 529188121728000.0, 'train_loss': 0.19278987597303782, 'epoch': 9.0})

In [47]:
dataset_test = encoded_test
results = trainer.predict(dataset_test)

logits = results.predictions
label_ids = results.label_ids

pred_indices = np.argmax(logits, axis=-1)

true_labels = []
pred_labels = []

for i in range(len(label_ids)):
    true_seq = label_ids[i]
    pred_seq = pred_indices[i]

    filtered_true = [label_list[l] for l in true_seq if l != -100]
    filtered_pred = [label_list[p] for p, l in zip(pred_seq, true_seq) if l != -100]

    true_labels.append(filtered_true)
    pred_labels.append(filtered_pred)

In [48]:
# Generate classification report
from seqeval.metrics import classification_report as seqeval_classification_report
report = seqeval_classification_report(true_labels, pred_labels, output_dict=True)

from sklearn.metrics import accuracy_score
from itertools import chain

flat_true = list(chain.from_iterable(true_labels))
flat_pred = list(chain.from_iterable(pred_labels))

accuracy = accuracy_score(flat_true, flat_pred)
print(f"Accuracy: {accuracy:.4f}")

print(f"Precision: {report['weighted avg']['precision']:.4f}")
print(f"Recall: {report['weighted avg']['recall']:.4f}")
print(f"F1-Score: {report['weighted avg']['f1-score']:.4f}")

print(seqeval_classification_report(true_labels, pred_labels))

Accuracy: 0.9365
Precision: 0.6660
Recall: 0.6934
F1-Score: 0.6783
              precision    recall  f1-score   support

   AGE_DEATH       0.00      0.00      0.00         3
AGE_FOLLOWUP       0.56      0.45      0.50        11
   AGE_ONSET       0.17      0.36      0.24        11
        GENE       0.89      0.85      0.87        40
GENE_VARIANT       0.83      0.83      0.83        78
    HPO_TERM       0.62      0.67      0.64       392
     PATIENT       0.76      0.72      0.74        39

   micro avg       0.65      0.69      0.67       574
   macro avg       0.55      0.56      0.55       574
weighted avg       0.67      0.69      0.68       574



In [49]:
from sklearn.metrics import classification_report as sk_classification_report
from itertools import chain

# Loose matching
def is_loose_match(true_tag, pred_tag):
    if true_tag != "O" and pred_tag != "O":
        true_entity = true_tag.split("-")[-1]
        pred_entity = pred_tag.split("-")[-1]
        return true_entity == pred_entity
    return False

def strip_prefix(tag):
    return tag.split("-")[-1] if tag != "O" else "O"

flat_true = list(chain.from_iterable(true_labels))
flat_pred = list(chain.from_iterable(pred_labels))

adjusted_pred = []
for t, p in zip(flat_true, flat_pred):
    if is_loose_match(t, p):
        adjusted_pred.append(t)
    else:
        adjusted_pred.append(p)

flat_true_no_prefix = [strip_prefix(t) for t in flat_true]
adjusted_pred_no_prefix = [strip_prefix(p) for p in adjusted_pred]

labels_without_O = sorted((set(flat_true_no_prefix) | set(adjusted_pred_no_prefix)) - {"O"})

accuracy = accuracy_score(flat_true_no_prefix, adjusted_pred_no_prefix)
print(f"Accuracy: {accuracy:.4f}")

report = sk_classification_report(
    flat_true_no_prefix,
    adjusted_pred_no_prefix,
    labels=labels_without_O,
    output_dict=True,
    zero_division=0
)

print(f"Precision: {report['weighted avg']['precision']:.4f}")
print(f"Recall: {report['weighted avg']['recall']:.4f}")
print(f"F1-Score: {report['weighted avg']['f1-score']:.4f}")

print(sk_classification_report(
    flat_true_no_prefix,
    adjusted_pred_no_prefix,
    labels=labels_without_O,
    zero_division=0
))


Accuracy: 0.9409
Precision: 0.7849
Recall: 0.7631
F1-Score: 0.7716
              precision    recall  f1-score   support

   AGE_DEATH       0.67      0.44      0.53         9
AGE_FOLLOWUP       0.56      0.18      0.27        28
   AGE_ONSET       0.27      0.35      0.30        23
        GENE       0.89      0.85      0.87        40
GENE_VARIANT       0.94      0.85      0.89       180
    HPO_TERM       0.77      0.78      0.77      1161
     PATIENT       0.75      0.67      0.71        49

   micro avg       0.78      0.76      0.77      1490
   macro avg       0.69      0.59      0.62      1490
weighted avg       0.78      0.76      0.77      1490

