In [None]:
using_wandb = False
using_colab = True

In [None]:
if using_wandb:
  %pip install wandb -q
  import wandb
  wandb.login()

In [None]:
if using_colab:
    from google.colab import drive

    drive.mount("/content/drive", force_remount=True)

In [None]:
model_checkpoint = "distilbert-base-uncased"
# model_checkpoint = 'michiyasunaga/BioLinkBERT-large'
# model_checkpoint = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'
model_name = "custom_model_3"

In [None]:
%pip install transformers -q
%pip install accelerate -U -q

In [None]:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import f1_score
from transformers import AutoTokenizer, AutoModelForTokenClassification, get_scheduler
import importlib
import json
import random
import os
from importlib import reload
from tqdm.auto import tqdm

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
if using_colab:
    dir_path = (
        "drive/Othercomputers/my_computer/dl-nlp_project_named-entity-recognition/"
    )
    # dir_path = "drive/MyDrive/dl-nlp_project_named-entity-recognition/"
    module_path = dir_path.replace("/", ".")
    # imports
    data_module = importlib.import_module(module_path + "data")
    load_data = data_module.load_data
    extract_sentences_and_labels = data_module.extract_sentences_and_labels
    generate_label_vocab = data_module.generate_label_vocab
    split_data = data_module.split_data

else:
    dir_path = "./"
    from data import (
        load_data,
        extract_sentences_and_labels,
        generate_label_vocab,
        split_data,
    )

In [None]:
train_file_path = dir_path + "data/train.json"
test_file_path = dir_path + "data/test.json"

In [None]:
train_data, test_data = load_data(train_file_path, test_file_path)
train_sentences, train_raw_labels = extract_sentences_and_labels(train_data)
test_sentences, test_raw_labels = extract_sentences_and_labels(test_data)

# Generate label vocabulary
label_vocab = generate_label_vocab(train_raw_labels + test_raw_labels)

In [None]:
SPECIAL_TOKEN = "<SPC>"


class Labels:
    def __init__(self, num_classes, names):
        super().__init__()
        self.names = names
        print(self.names)
        self.num_classes = num_classes

    def __getitem__(self, label_vector):
        return [self.names[idx] for idx, value in enumerate(label_vector) if value == 1]

    def decode(self, label_vector):
        return self.__getitem__(label_vector)

    def encode(self, names):
        indexes = []
        for name in names:
            index = self.names.index(name)
            indexes.append(index)
        tensor = torch.zeros(self.num_classes)
        for index in indexes:
            tensor[index] = 1
        return tensor

    def tensor2sentence(self, tensor):
        return [self.decode(vector) for vector in tensor]


ner_labels = Labels(
    num_classes=len(label_vocab) + 1, names=label_vocab + [SPECIAL_TOKEN]
)
id2label = ner_labels.decode
label2id = ner_labels.encode
ner_labels.num_classes

In [None]:
class NERDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.ner_labels = labels
        self.num_rows = len(sentences)
        self.input_ids = None
        self.attention_mask = None
        self.aligned_labels = None
        self.features = {
            "id": range(self.num_rows),
            "tokens": self.sentences,
            "ner_labels": self.ner_labels,
            "input_ids": self.input_ids,
            "attention_mask": self.attention_mask,
            "labels": self.aligned_labels,
        }
        self.tokenized = False

    def __getitem__(self, idx):
        if self.tokenized:
            item = {
                "id": idx,
                "tokens": self.sentences[idx],
                "ner_labels": self.ner_labels[idx],
                "input_ids": self.input_ids[idx],
                "attention_mask": self.attention_mask[idx],
                "labels": self.aligned_labels[idx],
            }
        else:
            item = {
                "id": idx,
                "tokens": self.sentences[idx],
                "ner_labels": self.ner_labels[idx],
            }
        return item

    def __len__(self):
        return self.num_rows

    def tokenize(self):
        tokenized_inputs = tokenize_and_align_labels(self[:])
        self.input_ids = torch.Tensor(tokenized_inputs["input_ids"]).to(device)
        self.attention_mask = torch.Tensor(tokenized_inputs["attention_mask"]).to(
            device
        )
        self.aligned_labels = torch.Tensor(tokenized_inputs["labels"]).to(device)
        self.tokenized = True

In [None]:
def extract_sentences(json_file_path):
    with open(json_file_path, "r") as file:
        data = json.load(file)

    sentences = []

    for entry in data:
        for sentence in entry["sentences"]:
            tokens = sentence["words"]

            entities = sentence["entities"]
            labels_list = [torch.zeros(ner_labels.num_classes) for x in tokens]
            for label_entity in entities:
                start_pos = label_entity["start_pos"]
                end_pos = label_entity["end_pos"]
                label = label_entity["label"]
                for label_index in range(start_pos, end_pos + 1):
                    # this needs to be changed as well, see below
                    labels_list[label_index] = label2id([label])
            sentence["tokens"] = tokens
            sentence["labels_list"] = labels_list
            sentences.append(sentence)

    return [x["tokens"] for x in sentences], [x["labels_list"] for x in sentences]

In [None]:
def extract_sentences_with_section_headers(json_file_path):
    with open(json_file_path, "r") as file:
        data = json.load(file)

    sentences = []

    for entry in data:
        current_section_header = []
        for sentence in entry["sentences"]:
            tokens = sentence["words"]

            section_header = []
            if tokens[0].isupper() and tokens[0] != ":":
                for word in tokens:
                    if word.isupper():
                        section_header.append(word)
                    elif word == ":":
                        break
                    else:
                        # Reset if a non-uppercase word is encountered before ':'
                        section_header = []
            if len(section_header) > 0:
                current_section_header = section_header
            entities = sentence["entities"]
            labels_list = [torch.zeros(ner_labels.num_classes) for x in tokens]
            for label_entity in entities:
                start_pos = label_entity["start_pos"]
                end_pos = label_entity["end_pos"]
                label = label_entity["label"]
                label_id = label2id([label]).argmax().item()
                for label_index in range(start_pos, end_pos + 1):
                    # This is the critical point
                    # if there is already a label, we need to add the current label to the tensor
                    # if there is no label, we need to replace the zero tensor with the current label
                    # labels_list[label_index] = label2id([label])
                    labels_list[label_index][label_id] = 1
            section_headers_tokens = tokens
            if len(current_section_header) > 0:
                section_headers_tokens = ["["] + current_section_header + ["]"] + tokens
                labels_list = [
                    label2id([SPECIAL_TOKEN])
                    for x in range(len(current_section_header) + 2)
                ] + labels_list
            sentence["section_headers_tokens"] = section_headers_tokens
            sentence["tokens"] = tokens
            sentence["labels_list"] = labels_list
            sentences.append(sentence)

    return (
        [x["section_headers_tokens"] for x in sentences],
        [x["labels_list"] for x in sentences],
        [x["tokens"] for x in sentences],
    )

In [None]:
section_headers = True

if section_headers:
    (
        train_sentences,
        train_labels,
        train_original_tokens,
    ) = extract_sentences_with_section_headers(train_file_path)
    (
        test_sentences,
        test_labels,
        test_original_tokens,
    ) = extract_sentences_with_section_headers(test_file_path)
    train_sentences, train_labels, val_sentences, val_labels = split_data(
        train_sentences, train_labels
    )
else:
    train_sentences, train_labels = extract_sentences(train_file_path)
    test_sentences, test_labels = extract_sentences(test_file_path)
    train_sentences, train_labels, val_sentences, val_labels = split_data(
        train_sentences, train_labels
    )

    test_original_tokens = test_sentences

# train_original_tokens will be misaligned because of the validation split

print(len(train_sentences), len(train_labels))
print(len(val_sentences), len(val_labels))
print(len(test_sentences), len(test_labels))

In [None]:
data_augmentation = True
augmentations = []

if data_augmentation:
    labels_data_folder = f"{dir_path}data/labels/"
    files = os.listdir(labels_data_folder)
    new_sentences = []
    new_labels = []
    for file_name in files:
        print(file_name)
        with open(f"{labels_data_folder}{file_name}", "r") as file:
            data = json.load(file)
            new_sentences += data["sentences"]
            new_labels += [
                [label2id(labels) for labels in labels_list]
                for labels_list in data["labels_lists"]
            ]
        augmentations.append((file_name[:-5], len(data["sentences"])))

    combined = list(zip(train_sentences + new_sentences, train_labels + new_labels))
    random.shuffle(combined)
    train_sentences, train_labels = zip(*combined)
    print(len(train_sentences), len(train_labels))
    print(f"Added {len(new_sentences)} sentences and labels")

In [None]:
datasets = {
    "train": NERDataset(train_sentences, train_labels),
    "val": NERDataset(val_sentences, val_labels),
    "test": NERDataset(test_sentences, test_labels),
}

print(len(datasets["train"]))
print(len(datasets["val"]))
print(len(datasets["test"]))

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True, padding=True
    )

    label_list = []
    for i, labels in enumerate(examples["ner_labels"]):
        word_ids = tokenized_inputs.word_ids(
            batch_index=i
        )  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(label2id([SPECIAL_TOKEN]))
            elif word_idx != previous_word_idx:
                label_ids.append(torch.Tensor(labels[word_idx]))
            else:
                label_ids.append(label2id([SPECIAL_TOKEN]))
            previous_word_idx = word_idx

        padded_length = len(tokenized_inputs["input_ids"][i])
        for i in range(padded_length - len(label_ids)):
            label_ids.append(label2id([SPECIAL_TOKEN]))
        label_ids = torch.stack(label_ids)
        label_list.append(label_ids)

    tokenized_inputs["labels"] = torch.stack(label_list)
    return tokenized_inputs

In [None]:
datasets["train"].tokenize()
datasets["val"].tokenize()
datasets["test"].tokenize()

In [None]:
from transformers import BertForTokenClassification, DistilBertForTokenClassification


# class CustomTokenClassification(BertForTokenClassification):
class CustomTokenClassification(DistilBertForTokenClassification):
    def __init__(self, config):
        super(CustomTokenClassification, self).__init__(config)
        self.loss_fct = BCEWithLogitsLoss()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        id=None,
        labels=None,
        tokens=None,
        ner_labels=None,
        **kwargs
    ):
        outputs = super().forward(
            input_ids=input_ids.int(), attention_mask=attention_mask.int(), **kwargs
        )

        # Check if the outputs is a dictionary and extract logits
        if isinstance(outputs, dict):
            logits = outputs.get("logits", None)
        else:  # Assume the outputs is a tuple and logits are the first element
            logits = outputs[0]

        return logits

In [None]:
def mask_and_flatten_logits_and_labels(logits, labels):
    mask = labels[:, :, -1] != 1
    logits = logits[mask]
    labels = labels[mask]

    flat_logits = logits.view(-1, logits.shape[-1])
    flat_labels = labels.view(-1, labels.shape[-1])
    return flat_logits, flat_labels

In [None]:
def flatten_list(list_of_lists):
    flattened_list = []
    for sublist in list_of_lists:
        for item in sublist:
            flattened_list.append(item)
    return flattened_list

In [None]:
def custom_collate(batch):
    input_ids_list = []
    attention_masks = []
    labels_list = []
    for entity in batch:
        input_ids_list.append(entity["input_ids"])
        attention_masks.append(entity["attention_mask"])
        labels_list.append(entity["labels"])
    new_batch = {}
    new_batch["input_ids"] = torch.stack(input_ids_list)
    new_batch["attention_mask"] = torch.stack(attention_masks)
    new_batch["labels"] = torch.stack(labels_list)
    return new_batch

In [None]:
model = CustomTokenClassification.from_pretrained(
    model_checkpoint, num_labels=ner_labels.num_classes
)
model.to(device)

In [None]:
config = {}
config["num_epochs"] = 25
config["batch_size"] = 16
config["lr"] = 2e-5
config["num_warmup_steps"] = 0
config["model_checkpoint"] = model_checkpoint
config["section_headers"] = section_headers
config["data_augmentation"] = data_augmentation
config["augmentations"] = augmentations

In [None]:
if using_wandb:
    wandb.init(project="DL-NLP-Clinical-Trial-NER", config=config)
    config = wandb.config

In [None]:
train_dataloader = DataLoader(
    datasets["train"],
    shuffle=True,
    batch_size=config["batch_size"],
    collate_fn=lambda x: custom_collate(x),
)
val_dataloader = DataLoader(
    datasets["val"],
    shuffle=True,
    batch_size=config["batch_size"],
    collate_fn=lambda x: custom_collate(x),
)

optimizer = AdamW(model.parameters(), lr=config["lr"])

num_training_steps = config["num_epochs"] * len(train_dataloader)

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=config["num_warmup_steps"],
    num_training_steps=num_training_steps,
)

loss_fct = BCEWithLogitsLoss()

progress_bar = tqdm(range(num_training_steps))

for epoch in range(config["num_epochs"]):
    model.train()
    epoch_loss = 0.0

    for batch in train_dataloader:
        labels = batch.pop("labels")

        logits = model(**batch)

        flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

        loss = loss_fct(flat_logits, flat_labels)
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        epoch_loss += loss.item() * labels.size(0)

        progress_bar.update(1)

    epoch_loss = epoch_loss / len(train_dataloader)
    progress_bar.write(f"Epoch {epoch}, Loss: {epoch_loss:.3f}")

    model.eval()

    preds = []
    true_labels = []

    for batch in val_dataloader:
        labels = batch.pop("labels")

        with torch.no_grad():
            logits = model(**batch)

        flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

        pred = flat_logits.heaviside(torch.tensor([0.0], device=device)).int().tolist()
        true_label = flat_labels.int().tolist()

        preds.extend(pred)
        true_labels.extend(true_label)

    f1 = f1_score(true_labels, preds, average="micro")
    progress_bar.write(f"f1 micro: {f1:.3f}")
    if using_wandb:
        wandb.log({"train_loss": epoch_loss, "micro_f1": f1, "epoch": epoch})
progress_bar.close()
model.save_pretrained(model_name)

In [None]:
label_list = ner_labels.names

test_dataloader = DataLoader(
    datasets["test"],
    shuffle=False,
    batch_size=1,
    collate_fn=lambda x: custom_collate(x),
)
model.eval()

preds = []
true_labels = []

dataloader = test_dataloader
progress_bar = tqdm(range(len(dataloader)))
for batch in dataloader:
    labels = batch.pop("labels")

    with torch.no_grad():
        logits = model(**batch)

    flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

    pred = flat_logits.heaviside(torch.tensor([0.0], device=device)).int().tolist()
    true_label = flat_labels.int().tolist()

    preds.extend(pred)
    true_labels.extend(true_label)

    progress_bar.update(1)

f1 = f1_score(true_labels, preds, average="micro")
f1_per_class = f1_score(true_labels, preds, average=None, zero_division=1)
for label, score in zip(label_list, f1_per_class):
    print(f"{label}: {score:.4f}")
progress_bar.write(f"f1 micro: {f1}")
if using_wandb:
    wandb.log({"test_micro_f1": f1})
    per_class_table = wandb.Table(columns=label_list, data=[f1_per_class])
    wandb.log({"F1_per_label": per_class_table})
    wandb.finish()
progress_bar.close()

In [None]:
from random import randint

example_count = 30
index = randint(0, len(datasets["test"]) - example_count - 1)
print(f"Index: {index}")
examples = datasets["test"][index : index + example_count]
labels = examples["labels"]
with torch.no_grad():
    logits = model(**examples)

flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)


preds = flat_logits.heaviside(torch.tensor([0.0], device=device)).int().tolist()
true_labels = flat_labels.int().tolist()
print(len(true_label))
print(len(pred))
f1 = f1_score(true_labels, preds, average="micro")

tokens_lists = test_original_tokens[index : index + example_count]
tokens = flatten_list(tokens_lists)

for true_label, pred, token in zip(true_labels, preds, tokens):
    print(f"true:{id2label(true_label)}, pred:{id2label(pred)}, token: {token}")
print(f"Micro-F1 Score: {f1:.3f}")