In [None]:
using_wandb = False
using_colab = True

In [None]:
pmid_post_processing = False
data_augmentation = False
section_headers = False

In [None]:
if using_wandb:
  %pip install wandb -q
  import wandb
  wandb.login()

In [None]:
if using_colab:
    from google.colab import drive

    drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [None]:
model_checkpoint = "distilbert-base-uncased"
# model_checkpoint = 'michiyasunaga/BioLinkBERT-large'
# model_checkpoint = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'
model_name = "custom_model_3"

In [None]:
%pip install transformers -q
%pip install accelerate -U -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.8/294.8 kB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m56.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m58.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.1/258.1 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import f1_score
from transformers import AutoTokenizer, get_scheduler
import importlib
import json
import random
import os
from importlib import reload
from tqdm.auto import tqdm

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
if using_colab:
    # dir_path = (
    #     "drive/Othercomputers/my_computer/dl-nlp_project_named-entity-recognition/"
    # )
    dir_path = "drive/MyDrive/dl-nlp_project_named-entity-recognition/"
    module_path = dir_path.replace("/", ".")
    # imports
    data_module = importlib.import_module(module_path + "data")
    load_data = data_module.load_data
    extract_sentences_and_labels = data_module.extract_sentences_and_labels
    generate_label_vocab = data_module.generate_label_vocab
    split_data = data_module.split_data

else:
    dir_path = "./"
    from data import (
        load_data,
        extract_sentences_and_labels,
        generate_label_vocab,
        split_data,
    )

In [None]:
train_file_path = dir_path + "data/train.json"
test_file_path = dir_path + "data/test.json"

In [None]:
train_data, test_data = load_data(train_file_path, test_file_path)
train_sentences, train_raw_labels = extract_sentences_and_labels(train_data)
test_sentences, test_raw_labels = extract_sentences_and_labels(test_data)

# Generate label vocabulary
label_vocab = generate_label_vocab(train_raw_labels + test_raw_labels)

In [None]:
SPECIAL_TOKEN = "<SPC>"


class Labels:
    def __init__(self, num_classes, names):
        super().__init__()
        self.names = names
        print(self.names)
        self.num_classes = num_classes

    def __getitem__(self, label_vector):
        return [self.names[idx] for idx, value in enumerate(label_vector) if value == 1]

    def decode(self, label_vector):
        return self.__getitem__(label_vector)

    def encode(self, names):
        indexes = []
        for name in names:
            index = self.names.index(name)
            indexes.append(index)
        tensor = torch.zeros(self.num_classes)
        for index in indexes:
            tensor[index] = 1
        return tensor

    def tensor2sentence(self, tensor):
        return [self.decode(vector) for vector in tensor]


ner_labels = Labels(
    num_classes=len(label_vocab) + 1, names=label_vocab + [SPECIAL_TOKEN]
)
id2label = ner_labels.decode
label2id = ner_labels.encode
ner_labels.num_classes

['NumberAffected', 'PercentageAffected', 'NumberPatientsArm', 'ObservedResult', 'CTDesign', 'FinalNumPatientsArm', 'DoseDescription', 'ConfIntervalDiff', 'TimePoint', 'Journal', 'ConfIntervalChangeValue', 'Country', 'Precondition', 'Title', 'PMID', 'SdDevChangeValue', 'ConclusionComment', 'NumberPatientsCT', 'ResultMeasuredValue', 'Drug', 'RelativeChangeValue', 'SdDevBL', 'AllocationRatio', 'DiffGroupAbsValue', 'MinAge', 'AvgAge', 'PublicationYear', 'DoseValue', 'AggregationMethod', 'ObjectiveDescription', 'Author', 'PvalueDiff', 'Frequency', 'PValueChangeValue', 'SubGroupDescription', 'SdDevResValue', '<SPC>']


37

In [None]:
class NERDataset(Dataset):
    def __init__(self, abstracts):
        self.abstracts = abstracts
        self.count = len(abstracts)

    def __getitem__(self, idx):
        return self.abstracts[idx]

    def __len__(self):
        return self.count

    def tokenize(self):
        for abstract in self.abstracts:
            tokenized_inputs = tokenize_and_align_labels(
                [sentence["tokens"] for sentence in abstract["sentences"]],
                [sentence["labels_list"] for sentence in abstract["sentences"]],
            )
            for key in tokenized_inputs.keys():
                abstract[key] = torch.Tensor(tokenized_inputs[key]).to(device)

In [None]:
def extract_abstracts(json_file_path):
    with open(json_file_path, "r") as file:
        data = json.load(file)

    abstracts = []

    for entry in data:
        abstract = {"id": entry["abstract_id"], "sentences": []}
        for sentence in entry["sentences"]:
            tokens = sentence.pop("words")

            entities = sentence.pop("entities")
            labels_list = [torch.zeros(ner_labels.num_classes) for x in tokens]
            for label_entity in entities:
                start_pos = label_entity["start_pos"]
                end_pos = label_entity["end_pos"]
                label = label_entity["label"]
                label_id = label2id([label]).argmax().item()
                for label_index in range(start_pos, end_pos + 1):
                    labels_list[label_index][label_id] = 1
            sentence["tokens"] = tokens
            sentence["labels_list"] = labels_list
            sentence["id"] = sentence.pop("sentence_id")
            abstract["sentences"].append(sentence)
        abstract["length"] = len(abstract["sentences"])
        abstract["sentence_length"] = np.sum(
            [len(sentence["tokens"]) for sentence in abstract["sentences"]]
        )
        abstracts.append(abstract)

    return abstracts

In [None]:
def split_abstracts(abstracts, val_split=0.2):
    random.shuffle(abstracts)
    num_val = int(len(abstracts) * val_split)
    val_abstracts = abstracts[:num_val]
    train_abstracts = abstracts[num_val:]
    return train_abstracts, val_abstracts

In [None]:
abstracts = extract_abstracts(train_file_path)
train_abstracts, val_abstracts = split_abstracts(abstracts)
test_abstracts = extract_abstracts(test_file_path)

print(
    f"Train abstracts: {len(train_abstracts)}, sentences: {sum([abstract['length'] for abstract in train_abstracts])}"
)
print(
    f"Val abstracts: {len(val_abstracts)}, sentences: {sum([abstract['length'] for abstract in val_abstracts])}"
)
print(
    f"Test abstracts: {len(test_abstracts)}, sentences: {sum([abstract['length'] for abstract in test_abstracts])}"
)

Train abstracts: 55, sentences: 1149
Val abstracts: 13, sentences: 296
Test abstracts: 20, sentences: 385


In [None]:
datasets = {
    "train": NERDataset(train_abstracts),
    "val": NERDataset(val_abstracts),
    "test": NERDataset(test_abstracts),
}

print(len(datasets["train"]))
print(len(datasets["val"]))
print(len(datasets["test"]))

55
13
20


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
def tokenize_and_align_labels(sentences, original_labels_lists):
    # print(f"Sentences: {sentences}")
    # print(f"OG Labels_lists: {original_labels_lists}")
    tokenized_inputs = tokenizer(
        sentences,
        is_split_into_words=True,
        padding=True
        # max_length=max_length,
        # padding="max_length",
    )
    # print(f"Tokenized_inputs: {tokenized_inputs}")
    # print(f"Length: {len(tokenized_inputs['input_ids'])}")
    # print(f"max_length: {max_length}")

    label_list = []
    for i, labels in enumerate(original_labels_lists):
        word_ids = tokenized_inputs.word_ids(
            batch_index=i
        )  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(label2id([SPECIAL_TOKEN]))
            elif word_idx != previous_word_idx:
                label_ids.append(torch.Tensor(labels[word_idx]))
            else:
                label_ids.append(label2id([SPECIAL_TOKEN]))
            previous_word_idx = word_idx

        padded_length = len(tokenized_inputs["input_ids"][i])
        for i in range(padded_length - len(label_ids)):
            label_ids.append(label2id([SPECIAL_TOKEN]))
        label_ids = torch.stack(label_ids)
        # print(f"Label_ids_shape: {label_ids.shape}")
        label_list.append(label_ids)

    # print(f"Label_list: {label_list}")
    tokenized_inputs["labels"] = torch.stack(label_list)
    return tokenized_inputs

In [None]:
datasets["train"].tokenize()
datasets["val"].tokenize()
datasets["test"].tokenize()

In [None]:
from transformers import BertForTokenClassification, DistilBertForTokenClassification


# class CustomTokenClassification(BertForTokenClassification):
class CustomTokenClassification(DistilBertForTokenClassification):
    def __init__(self, config):
        super(CustomTokenClassification, self).__init__(config)
        self.loss_fct = BCEWithLogitsLoss()

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        outputs = super().forward(
            input_ids=input_ids.int(), attention_mask=attention_mask.int(), **kwargs
        )

        # Check if the outputs is a dictionary and extract logits
        if isinstance(outputs, dict):
            logits = outputs.get("logits", None)
        else:  # Assume the outputs is a tuple and logits are the first element
            logits = outputs[0]

        return logits

In [None]:
def mask_and_flatten_logits_and_labels(logits, labels):
    mask = labels[:, :, -1] != 1
    logits = logits[mask]
    labels = labels[mask]

    flat_logits = logits.view(-1, logits.shape[-1])
    flat_labels = labels.view(-1, labels.shape[-1])
    return flat_logits, flat_labels

In [None]:
def flatten_list(list_of_lists):
    flattened_list = []
    for sublist in list_of_lists:
        for item in sublist:
            flattened_list.append(item)
    return flattened_list

In [1]:
def custom_collate(batch):
    if len(batch) != 1:
        return None
    abstract = batch[0]
    if "id" in abstract.keys():
        abstract.pop("id")
    if "sentences" in abstract.keys():
        abstract.pop("sentences")
    if "length" in abstract.keys():
        abstract.pop("length")
    if "sentence_length" in abstract.keys():
        abstract.pop("sentence_length")
    data = []
    input_ids = []
    attention_mask = []
    labels = []
    for x, y, z in zip(
        abstract["input_ids"], abstract["attention_mask"], abstract["labels"]
    ):
        input_ids_len = np.sum([len(x) for x in input_ids])
        if input_ids_len > tokenizer.max_model_input_sizes[model_checkpoint] - 100:
            print(f"input cut off at {input_ids_len}")
            data.append(
                {
                    "input_ids": torch.cat(input_ids),
                    "attention_mask": torch.cat(attention_mask),
                    "labels": torch.cat(labels),
                }
            )
            input_ids = []
            attention_mask = []
            labels = []
        input_ids_list = x.tolist()
        if 0.0 in input_ids_list:
            padding_index = input_ids_list.index(0.0)
            input_ids.append(x[:padding_index])
            attention_mask.append(y[:padding_index])
            labels.append(z[:padding_index])
        else:
            input_ids.append(x[:])
            attention_mask.append(y[:])
            labels.append(z[:])
    data.append(
        {
            "input_ids": torch.cat(input_ids),
            "attention_mask": torch.cat(attention_mask),
            "labels": torch.cat(labels),
        }
    )
    for data_object in data:
        print(data_object["input_ids"].shape)
        padding_length = (
            tokenizer.max_model_input_sizes[model_checkpoint]
            - data_object["input_ids"].shape[0]
        )
        torch.cat([data_object["input_ids"], torch.zeros(padding_length)])
        torch.cat([data_object["attention_mask"], torch.zeros(padding_length)])
        print(data_object["labels"].shape)
        torch.cat(
            [data_object["labels"], label2id([SPECIAL_TOKEN] * padding_length)]
        )  # this is wrong
        print(data_object["input_ids"].shape)
    return data

In [None]:
model = CustomTokenClassification.from_pretrained(
    model_checkpoint, num_labels=ner_labels.num_classes
)
model.to(device)

Some weights of CustomTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
config = {}
config["num_epochs"] = 25
config["batch_size"] = 1
config["lr"] = 1e-4
config["num_warmup_steps"] = 0
config["model_checkpoint"] = model_checkpoint
config["section_headers"] = section_headers
config["data_augmentation"] = data_augmentation
# config["augmentations"] = augmentations
config["pmid_post_processing"] = pmid_post_processing

In [None]:
if using_wandb:
    wandb.init(project="DL-NLP-Clinical-Trial-NER", config=config)
    config = wandb.config

In [None]:
train_dataloader = DataLoader(
    datasets["train"],
    shuffle=True,
    batch_size=config["batch_size"],
    collate_fn=lambda x: custom_collate(x),
)
val_dataloader = DataLoader(
    datasets["val"],
    shuffle=True,
    batch_size=config["batch_size"],
    collate_fn=lambda x: custom_collate(x),
)

optimizer = AdamW(model.parameters(), lr=config["lr"])

num_training_steps = config["num_epochs"] * len(train_dataloader)

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=config["num_warmup_steps"],
    num_training_steps=num_training_steps,
)

loss_fct = BCEWithLogitsLoss()

progress_bar = tqdm(range(num_training_steps))

for epoch in range(config["num_epochs"]):
    model.train()
    epoch_loss = 0.0

    for batch in train_dataloader:
        labels_list = []
        logits_list = []
        for data_object in batch:
            labels_list.append(data_object["labels"])

            logits = model(**data_object)
            logits_list.append(logits)
        foo()

        flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

        loss = loss_fct(flat_logits, flat_labels)
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        epoch_loss += loss.item() * labels.size(0)

        progress_bar.update(1)

    epoch_loss = epoch_loss / len(train_dataloader)
    progress_bar.write(f"Epoch {epoch}, Loss: {epoch_loss:.3f}")

    model.eval()

    preds = []
    true_labels = []

    for batch in val_dataloader:
        labels = batch["labels"]

        with torch.no_grad():
            logits = model(**batch)

        flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

        pred = flat_logits.heaviside(torch.tensor([0.0], device=device)).int().tolist()
        true_label = flat_labels.int().tolist()

        preds.extend(pred)
        true_labels.extend(true_label)

    f1 = f1_score(true_labels, preds, average="micro")
    progress_bar.write(f"f1 micro: {f1:.3f}")
    if using_wandb:
        wandb.log({"train_loss": epoch_loss, "micro_f1": f1, "epoch": epoch})
progress_bar.close()
model.save_pretrained(model_name)

  0%|          | 0/1375 [00:00<?, ?it/s]

437
470
482
torch.Size([437])
torch.Size([470])
torch.Size([482])
torch.Size([246])


RuntimeError: ignored

In [None]:
label_list = ner_labels.names

test_dataloader = DataLoader(
    datasets["test"],
    shuffle=False,
    batch_size=1,
    collate_fn=lambda x: custom_collate(x),
)
model.eval()

preds = []
true_labels = []

dataloader = test_dataloader
progress_bar = tqdm(range(len(dataloader)))
for batch in dataloader:
    labels = batch["labels"]

    with torch.no_grad():
        logits = model(**batch)

    flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

    pred = flat_logits.heaviside(torch.tensor([0.0], device=device)).int().tolist()
    true_label = flat_labels.int().tolist()

    preds.extend(pred)
    true_labels.extend(true_label)

    progress_bar.update(1)

f1 = f1_score(true_labels, preds, average="micro")
f1_per_class = f1_score(true_labels, preds, average=None, zero_division=1)
for label, score in zip(label_list, f1_per_class):
    print(f"{label}: {score:.4f}")
progress_bar.write(f"f1 micro: {f1}")
if using_wandb:
    wandb.log({"test_micro_f1": f1})
    per_class_table = wandb.Table(columns=label_list, data=[f1_per_class])
    wandb.log({"F1_per_label": per_class_table})
    wandb.finish()
progress_bar.close()

  0%|          | 0/20 [00:00<?, ?it/s]

NumberAffected: 0.0000
PercentageAffected: 0.9465
NumberPatientsArm: 0.8000
ObservedResult: 0.1679
CTDesign: 0.0000
FinalNumPatientsArm: 1.0000
DoseDescription: 0.1282
ConfIntervalDiff: 0.7466
TimePoint: 0.5424
Journal: 0.9764
ConfIntervalChangeValue: 0.0000
Country: 0.0000
Precondition: 0.3851
Title: 0.9587
PMID: 1.0000
SdDevChangeValue: 0.0000
ConclusionComment: 0.9092
NumberPatientsCT: 0.7586
ResultMeasuredValue: 0.5882
Drug: 1.0000
RelativeChangeValue: 1.0000
SdDevBL: 0.0000
AllocationRatio: 0.9730
DiffGroupAbsValue: 0.6220
MinAge: 0.0000
AvgAge: 0.0000
PublicationYear: 0.8696
DoseValue: 0.6667
AggregationMethod: 1.0000
ObjectiveDescription: 0.8955
Author: 0.9429
PvalueDiff: 0.8859
Frequency: 0.5634
PValueChangeValue: 0.2718
SubGroupDescription: 0.0000
SdDevResValue: 0.3111
<SPC>: 1.0000
f1 micro: 0.7782091725309391


In [None]:
from random import randint

example_count = 30
index = randint(0, len(datasets["test"]) - example_count - 1)
print(f"Index: {index}")
examples = datasets["test"][index : index + example_count]
labels = examples["labels"]
with torch.no_grad():
    logits = model(**examples)

flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)


preds = flat_logits.heaviside(torch.tensor([0.0], device=device)).int().tolist()
true_labels = flat_labels.int().tolist()
print(len(true_label))
print(len(pred))
f1 = f1_score(true_labels, preds, average="micro")

tokens_lists = test_original_tokens[index : index + example_count]
tokens = flatten_list(tokens_lists)

for true_label, pred, token in zip(true_labels, preds, tokens):
    print(f"true:{id2label(true_label)}, pred:{id2label(pred)}, token: {token}")
print(f"Micro-F1 Score: {f1:.3f}")

Index: 49
33
33
true:[], pred:[], token: A
true:[], pred:[], token: significantly
true:[], pred:[], token: greater
true:[], pred:[], token: proportion
true:[], pred:[], token: of
true:[], pred:[], token: patients
true:[], pred:[], token: achieved
true:[], pred:[], token: an
true:[], pred:[], token: A1C
true:[], pred:[], token: <
true:[], pred:[], token: 7
true:[], pred:[], token: %
true:[], pred:[], token: with
true:[], pred:[], token: sitagliptin
true:[], pred:[], token: (
true:['PercentageAffected'], pred:['PercentageAffected'], token: 47
true:['PercentageAffected'], pred:['PercentageAffected'], token: .
true:['PercentageAffected'], pred:['PercentageAffected'], token: 0
true:[], pred:[], token: %
true:[], pred:[], token: )
true:[], pred:[], token: than
true:[], pred:[], token: with
true:[], pred:[], token: placebo
true:[], pred:[], token: (
true:['PercentageAffected'], pred:['PercentageAffected'], token: 18
true:['PercentageAffected'], pred:['PercentageAffected'], token: .
true:['Per