In [1]:
!pip install seqeval

Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16162 sha256=e8b4e92d3e3cf38b64a8e8b3f13799e0fa5788438487515991718135c342f9bd
  Stored in directory: /root/.cache/pip/wheels/5f/b8/73/0b2c1a76b701a677653dd79ece07cfabd7457989dbfbdcd8d7
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


In [2]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.5


In [3]:
from google.colab import drive
drive.mount('/content/drive')
base_path = '/content/drive/MyDrive/NLP/'

Mounted at /content/drive


In [4]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [5]:
import json
from collections import defaultdict
import numpy as np
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import TrainingArguments, Trainer
from transformers import EarlyStoppingCallback
from sklearn.metrics import classification_report as sk_classification_report
from sklearn.metrics import accuracy_score
from datasets import Dataset
import torch
import random
from collections import Counter
from transformers import set_seed

SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

set_seed(SEED)

def convert_to_iob(entries):
    samples = []
    label_map = defaultdict(list)

    for entry in entries:
        tokens = entry["tokens"]
        text = entry["text"]
        spans = entry.get("spans", [])

        labels = ["O"] * len(tokens)

        for span in spans:
            start_token = span["token_start"]
            end_token = span["token_end"]
            label = span["label"]

            labels[start_token] = f"B-{label}"
            for i in range(start_token + 1, end_token + 1):
                labels[i] = f"I-{label}"

            label_map[label].append(span)

        token_label_pairs = [
            (token["text"], label)
            for token, label in zip(tokens, labels)
        ]
        samples.append(token_label_pairs)

    return samples, label_map

entries_silver = []
entries_golden = []
with open(base_path+"dataset/synthesis/silver.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        entries_silver.append(json.loads(line))

with open(base_path+"dataset/cleaned/NER/processed_merged.jsonl", "r", encoding="utf-8") as f:
    for line in f:
        entries_golden.append(json.loads(line))

silver_data, label_stats_silver = convert_to_iob(entries_silver)
golden_data, label_stats_golden = convert_to_iob(entries_golden)

# Printed label statistics
print(f"{len(label_stats_golden)} types of entities are discovered:")
for label, spans in label_stats_golden.items():
    print(f"- {label}: {len(spans)} samples")

7 types of entities are discovered:
- AGE_ONSET: 93 samples
- PATIENT: 246 samples
- HPO_TERM: 2525 samples
- GENE: 252 samples
- GENE_VARIANT: 404 samples
- AGE_FOLLOWUP: 76 samples
- AGE_DEATH: 29 samples


In [6]:
label_list = ["O"] + [
    f"{pre}-{label}"
    for label in label_stats_golden.keys()
    for pre in ["B", "I"]
]

# model_name = "bert-base-cased"
# model_name = "allenai/scibert_scivocab_uncased"
# model_name = "dmis-lab/biobert-base-cased-v1.1"
model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract"
# model_name = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(label_list),
    id2label={i: label for i, label in enumerate(label_list)},
    label2id={label: i for i, label in enumerate(label_list)}
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
def encode_data(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        padding="max_length",
        max_length=256
    )

    label = examples["labels"]
    word_ids = tokenized_inputs.word_ids()
    previous_word_idx = None
    label_ids = []

    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)
        elif word_idx != previous_word_idx:
            label_ids.append(label[word_idx])
        else:
            label_ids.append(-100)
        previous_word_idx = word_idx

    tokenized_inputs["labels"] = label_ids
    return tokenized_inputs

formatted_data_silver = [{"tokens": [t[0] for t in sample], "labels": [label_list.index(t[1]) for t in sample]}
                 for sample in silver_data]
formatted_data_golden = [{"tokens": [t[0] for t in sample], "labels": [label_list.index(t[1]) for t in sample]}
                 for sample in golden_data]

from sklearn.model_selection import train_test_split
temp_set1, temp_set2 = train_test_split(formatted_data_golden, test_size=0.3, random_state=SEED)
train_set = temp_set1 + formatted_data_silver
val_set, test_set = train_test_split(temp_set2, test_size=0.5, random_state=SEED)

encoded_train = [encode_data(d) for d in train_set]
encoded_val = [encode_data(d) for d in val_set]
encoded_test = [encode_data(d) for d in test_set]

In [8]:
training_args = TrainingArguments(
    output_dir=base_path + "./model/NER",
    eval_strategy="epoch",
    num_train_epochs=15,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=2e-5,
    weight_decay=0.03,
    warmup_ratio=0.15,
    save_strategy="epoch",
    load_best_model_at_end=True,
    seed=SEED,
    metric_for_best_model="strict_f1",
    save_total_limit=2,
    logging_strategy="epoch",
    report_to="none",
)

from evaluate import load
seqeval = load("seqeval")

from itertools import chain

def is_loose_match(true_tag, pred_tag):
    if true_tag != "O" and pred_tag != "O":
        return true_tag.split("-")[-1] == pred_tag.split("-")[-1]
    return False

def strip_prefix(tag):
    return tag.split("-")[-1] if tag != "O" else "O"

def compute_metrics(p):
    predictions, labels = p
    pred_ids = np.argmax(predictions, axis=2)

    true_labels = [[label_list[l] for l in lab if l != -100] for lab in labels]
    pred_labels = [[label_list[p] for p, l in zip(pred, lab) if l != -100]
            for pred, lab in zip(pred_ids, labels)]

     # Strict matching
    strict = seqeval.compute(predictions=pred_labels, references=true_labels)
    strict_precision = strict["overall_precision"]
    strict_recall = strict["overall_recall"]
    strict_f1 = strict["overall_f1"]

    # Loose matching
    flat_true = list(chain.from_iterable(true_labels))
    flat_pred = list(chain.from_iterable(pred_labels))

    adjusted_pred = [t if is_loose_match(t, p) else p for t, p in zip(flat_true, flat_pred)]

    flat_true_no_prefix = [strip_prefix(t) for t in flat_true]
    adjusted_pred_no_prefix = [strip_prefix(p) for p in adjusted_pred]

    labels_without_O = sorted((set(flat_true_no_prefix) | set(adjusted_pred_no_prefix)) - {"O"})

    report = sk_classification_report(
        flat_true_no_prefix,
        adjusted_pred_no_prefix,
        labels=labels_without_O,
        output_dict=True,
        zero_division=0
    )

    loose_precision = report["weighted avg"]["precision"]
    loose_recall = report["weighted avg"]["recall"]
    loose_f1 = report["weighted avg"]["f1-score"]

    return {
        "precision": loose_precision,
        "recall": loose_recall,
        "f1": loose_f1,

        "strict_precision": strict_precision,
        "strict_recall": strict_recall,
        "strict_f1": strict_f1
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_train,
    eval_dataset=encoded_val,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

trainer.train()

Downloading builder script: 0.00B [00:00, ?B/s]

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Strict Precision,Strict Recall,Strict F1
1,1.5579,0.349107,0.603009,0.643621,0.621882,0.392537,0.476449,0.430442
2,0.2359,0.25403,0.663134,0.721311,0.689274,0.484115,0.57971,0.527617
3,0.1384,0.210803,0.78043,0.676408,0.718859,0.556436,0.509058,0.531693
4,0.0876,0.211813,0.76487,0.770492,0.76569,0.590832,0.630435,0.609991
5,0.0601,0.249211,0.801633,0.707056,0.748891,0.625954,0.594203,0.609665
6,0.0434,0.299097,0.704897,0.856023,0.767746,0.579937,0.67029,0.621849
7,0.0328,0.274417,0.770941,0.806842,0.785616,0.604502,0.681159,0.640545
8,0.0239,0.297382,0.786479,0.779045,0.78107,0.607973,0.663043,0.634315
9,0.0195,0.324728,0.777334,0.796151,0.784053,0.627551,0.668478,0.647368
10,0.0134,0.33132,0.789228,0.779758,0.783017,0.62735,0.664855,0.645558


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=2548, training_loss=0.1604805708960798, metrics={'train_runtime': 1073.7729, 'train_samples_per_second': 20.256, 'train_steps_per_second': 2.542, 'total_flos': 2652473795328000.0, 'train_loss': 0.1604805708960798, 'epoch': 14.0})

In [9]:
dataset_test = encoded_test
results = trainer.predict(dataset_test)

logits = results.predictions
label_ids = results.label_ids

pred_indices = np.argmax(logits, axis=-1)

true_labels = []
pred_labels = []

for i in range(len(label_ids)):
    true_seq = label_ids[i]
    pred_seq = pred_indices[i]

    filtered_true = [label_list[l] for l in true_seq if l != -100]
    filtered_pred = [label_list[p] for p, l in zip(pred_seq, true_seq) if l != -100]

    true_labels.append(filtered_true)
    pred_labels.append(filtered_pred)

In [10]:
# Generate classification report
from seqeval.metrics import classification_report as seqeval_classification_report
report = seqeval_classification_report(true_labels, pred_labels, output_dict=True)

from sklearn.metrics import accuracy_score
from itertools import chain

flat_true = list(chain.from_iterable(true_labels))
flat_pred = list(chain.from_iterable(pred_labels))

accuracy = accuracy_score(flat_true, flat_pred)
print(f"Accuracy: {accuracy:.4f}")

print(f"Precision: {report['weighted avg']['precision']:.4f}")
print(f"Recall: {report['weighted avg']['recall']:.4f}")
print(f"F1-Score: {report['weighted avg']['f1-score']:.4f}")

print(seqeval_classification_report(true_labels, pred_labels))

Accuracy: 0.9360
Precision: 0.6434
Recall: 0.7300
F1-Score: 0.6830
              precision    recall  f1-score   support

   AGE_DEATH       0.25      0.33      0.29         3
AGE_FOLLOWUP       0.60      0.55      0.57        11
   AGE_ONSET       0.29      0.45      0.36        11
        GENE       0.90      0.88      0.89        40
GENE_VARIANT       0.78      0.95      0.86        78
    HPO_TERM       0.60      0.69      0.64       392
     PATIENT       0.72      0.74      0.73        39

   micro avg       0.64      0.73      0.68       574
   macro avg       0.59      0.66      0.62       574
weighted avg       0.64      0.73      0.68       574



In [11]:
from sklearn.metrics import classification_report as sk_classification_report
from itertools import chain

# Loose matching
def is_loose_match(true_tag, pred_tag):
    if true_tag != "O" and pred_tag != "O":
        true_entity = true_tag.split("-")[-1]
        pred_entity = pred_tag.split("-")[-1]
        return true_entity == pred_entity
    return False

def strip_prefix(tag):
    return tag.split("-")[-1] if tag != "O" else "O"

flat_true = list(chain.from_iterable(true_labels))
flat_pred = list(chain.from_iterable(pred_labels))

adjusted_pred = []
for t, p in zip(flat_true, flat_pred):
    if is_loose_match(t, p):
        adjusted_pred.append(t)
    else:
        adjusted_pred.append(p)

flat_true_no_prefix = [strip_prefix(t) for t in flat_true]
adjusted_pred_no_prefix = [strip_prefix(p) for p in adjusted_pred]

labels_without_O = sorted((set(flat_true_no_prefix) | set(adjusted_pred_no_prefix)) - {"O"})

accuracy = accuracy_score(flat_true_no_prefix, adjusted_pred_no_prefix)
print(f"Accuracy: {accuracy:.4f}")

report = sk_classification_report(
    flat_true_no_prefix,
    adjusted_pred_no_prefix,
    labels=labels_without_O,
    output_dict=True,
    zero_division=0
)

print(f"Precision: {report['weighted avg']['precision']:.4f}")
print(f"Recall: {report['weighted avg']['recall']:.4f}")
print(f"F1-Score: {report['weighted avg']['f1-score']:.4f}")

print(sk_classification_report(
    flat_true_no_prefix,
    adjusted_pred_no_prefix,
    labels=labels_without_O,
    zero_division=0
))


Accuracy: 0.9406
Precision: 0.7589
Recall: 0.7993
F1-Score: 0.7764
              precision    recall  f1-score   support

   AGE_DEATH       0.78      0.78      0.78         9
AGE_FOLLOWUP       0.60      0.21      0.32        28
   AGE_ONSET       0.50      0.61      0.55        23
        GENE       0.85      0.88      0.86        40
GENE_VARIANT       0.82      0.94      0.88       180
    HPO_TERM       0.76      0.80      0.78      1161
     PATIENT       0.72      0.69      0.71        49

   micro avg       0.76      0.80      0.78      1490
   macro avg       0.72      0.70      0.70      1490
weighted avg       0.76      0.80      0.78      1490

