In [1]:
%pip install transformers -q
%pip install accelerate -U -q
%pip install wandb -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m48.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m55.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.2/251.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.8/188.8 kB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m218.8/218.8 kB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ..

In [2]:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import f1_score

from transformers import AutoTokenizer, BertForTokenClassification, get_scheduler
import importlib
from importlib import reload
from tqdm.auto import tqdm

In [3]:
model_checkpoint = 'michiyasunaga/BioLinkBERT-large'
model_name = 'custom_model'

In [4]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [5]:
using_wandb = False

In [6]:
import wandb
wandb.login()
using_wandb = True

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [7]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [8]:
# If you are using Colab
dir_path = "/content/drive/Othercomputers/my_computer/dl-nlp_project_named-entity-recognition/"
module_path = dir_path[9:].replace("/", ".")
# imports
data_module = importlib.import_module(module_path + "data")
load_data = data_module.load_data
extract_sentences_and_labels = data_module.extract_sentences_and_labels
generate_label_vocab = data_module.generate_label_vocab
encode_labels = data_module.encode_labels
build_label_to_idx = data_module.build_label_to_idx
build_idx_to_label = data_module.build_idx_to_label
build_word_to_idx = data_module.build_word_to_idx
build_idx_to_word = data_module.build_idx_to_word
split_data = data_module.split_data

In [9]:
reload(data_module)

<module 'drive.Othercomputers.my_computer.dl-nlp_project_named-entity-recognition.data' from '/content/drive/Othercomputers/my_computer/dl-nlp_project_named-entity-recognition/data.py'>

In [10]:
# If you are NOT using colab
# dir_path = ""
# from data_new import (
#     prepare_data_pipeline,
#     TRAIN_DATA_PATH,
#     TEST_DATA_PATH,
#     PAD,
#     tensor_to_sentences,
#     tensor_to_labels,
# )

In [11]:
train_file_path = dir_path + "data/train.json"
test_file_path = dir_path + "data/test.json"

In [12]:
train_data, test_data = load_data(train_file_path, test_file_path)
train_sentences, train_raw_labels = extract_sentences_and_labels(train_data)
test_sentences, test_raw_labels = extract_sentences_and_labels(test_data)

# Generate label vocabulary
label_vocab = generate_label_vocab(train_raw_labels + test_raw_labels)

# Encode labels pre-transformer
train_encoded_labels = encode_labels(train_raw_labels, label_vocab, train_sentences)
test_labels = encode_labels(test_raw_labels, label_vocab, test_sentences)

word_to_idx = build_word_to_idx(train_sentences + test_sentences)
idx_to_word = build_idx_to_word(word_to_idx)
label_to_idx = build_label_to_idx(label_vocab)
idx_to_label = build_idx_to_label(label_to_idx)

train_sentences, train_labels, val_sentences, val_labels = split_data(
    train_sentences, train_encoded_labels
)

In [13]:
SPECIAL_TOKEN = "<SPC>"

class Labels():
    def __init__(self, num_classes, names):
        super().__init__()
        names.append(SPECIAL_TOKEN)
        self.names = names
        self.num_classes = num_classes + 1

    def __getitem__(self, label_vector):
        return [
            self.names[idx]
            for idx, value in enumerate(label_vector)
            if value == 1
        ]

    def num_classes(self):
        return self.num_classes

    def decode(self, label_vector):
        return self.__getitem__(label_vector)

    def encode(self, names):
        indexes = []
        for name in names:
            index = self.names.index(name)
            indexes.append(index)
        tensor = torch.zeros(self.num_classes)
        for index in indexes:
            tensor[index] = 1
        return tensor

    def tensor2sentence(self, tensor):
        return [self.decode(vector) for vector in tensor]

ner_labels = Labels(num_classes=len(label_vocab), names=label_vocab)

In [14]:
id2label = ner_labels.decode
label2id = ner_labels.encode

In [15]:
class NERDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.ner_labels = labels
        self.num_rows = len(sentences)
        self.input_ids = None
        self.attention_mask = None
        self.aligned_labels = None
        self.features = {'id': range(self.num_rows),
                         'tokens': self.sentences,
                         'ner_labels': self.ner_labels,
                         'input_ids': self.input_ids,
                         'attention_mask': self.attention_mask,
                         'labels': self.aligned_labels}
        self.tokenized = False

    def __getitem__(self, idx):
        if self.tokenized:
            item = {
                'id': idx,
                'tokens': self.sentences[idx],
                'ner_labels': self.ner_labels[idx],
                'input_ids': self.input_ids[idx],
                'attention_mask': self.attention_mask[idx],
                'labels': self.aligned_labels[idx]
            }
        else:
            item = {
                'id': idx,
                'tokens': self.sentences[idx],
                'ner_labels': self.ner_labels[idx],
            }
        return item

    def __len__(self):
        return self.num_rows

    def tokenize(self):
        tokenized_inputs = tokenize_and_align_labels(self[:])
        self.input_ids = torch.Tensor(tokenized_inputs['input_ids']).to(device)
        self.attention_mask = torch.Tensor(tokenized_inputs['attention_mask']).to(device)
        self.aligned_labels = torch.Tensor(tokenized_inputs['labels']).to(device)
        self.tokenized = True


In [16]:
datasets = {
    'train': NERDataset(train_sentences, train_labels),
    'val': NERDataset(val_sentences, val_labels),
    'test': NERDataset(test_sentences, test_labels)
}

In [17]:
ner_labels.names

['NumberAffected',
 'NumberPatientsArm',
 'PMID',
 'MinAge',
 'Title',
 'Author',
 'Precondition',
 'Frequency',
 'SubGroupDescription',
 'TimePoint',
 'Journal',
 'PValueChangeValue',
 'SdDevBL',
 'ConclusionComment',
 'DoseValue',
 'CTDesign',
 'DiffGroupAbsValue',
 'ResultMeasuredValue',
 'PublicationYear',
 'AggregationMethod',
 'Drug',
 'FinalNumPatientsArm',
 'AllocationRatio',
 'PvalueDiff',
 'ConfIntervalDiff',
 'Country',
 'SdDevResValue',
 'ObservedResult',
 'ConfIntervalChangeValue',
 'PercentageAffected',
 'ObjectiveDescription',
 'DoseDescription',
 'RelativeChangeValue',
 'SdDevChangeValue',
 'AvgAge',
 'NumberPatientsCT',
 '<SPC>']

In [18]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/379 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/225k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/447k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [19]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True, padding=True)

    label_list = []
    for i, labels in enumerate(examples['ner_labels']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(label2id([SPECIAL_TOKEN]))
            elif word_idx != previous_word_idx:
                label_ids.append(torch.Tensor(labels[word_idx] + [0]))
            previous_word_idx = word_idx

        padded_length = len(tokenized_inputs['input_ids'][i])
        for i in range(padded_length - len(label_ids)):
            label_ids.append(label2id([SPECIAL_TOKEN]))
        label_ids = torch.stack(label_ids)
        label_list.append(label_ids)

    tokenized_inputs["labels"] = torch.stack(label_list)
    return tokenized_inputs

In [20]:
datasets['train'].tokenize()
datasets['val'].tokenize()
datasets['test'].tokenize()

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [21]:
from transformers import DistilBertForTokenClassification
class CustomDistilBertForTokenClassification(DistilBertForTokenClassification):
    def __init__(self, config):
        super(CustomDistilBertForTokenClassification, self).__init__(config)
        self.loss_fct = BCEWithLogitsLoss()

    def forward(self, input_ids=None, attention_mask=None, labels=None, id=None, tokens=None, ner_labels=None, **kwargs):
        outputs = super().forward(input_ids=input_ids.int(), attention_mask=attention_mask.int(), **kwargs)
        return outputs['logits']


In [22]:
from transformers import BertForTokenClassification
class CustomBertForTokenClassification(BertForTokenClassification):
    def __init__(self, config):
        super(CustomBertForTokenClassification, self).__init__(config)
        self.loss_fct = BCEWithLogitsLoss()

    def forward(self, input_ids=None, attention_mask=None, labels=None, id=None, tokens=None, ner_labels=None, **kwargs):
        outputs = super().forward(input_ids=input_ids.int(), attention_mask=attention_mask.int(), **kwargs)
        return outputs['logits']

In [23]:
def mask_and_flatten_logits_and_labels(logits, labels):
    mask = labels[:, :, -1] != 1
    logits = logits[mask]
    labels = labels[mask]

    flat_logits = logits.view(-1, logits.shape[-1])
    flat_labels = labels.view(-1, labels.shape[-1])
    return flat_logits, flat_labels

In [24]:
# model = CustomTokenClassification.from_pretrained(model_checkpoint, num_labels=ner_labels.num_classes)
model = CustomBertForTokenClassification.from_pretrained(model_checkpoint, num_labels=ner_labels.num_classes)
model.to(device);

Downloading (…)lve/main/config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.33G [00:00<?, ?B/s]

Some weights of CustomBertForTokenClassification were not initialized from the model checkpoint at michiyasunaga/BioLinkBERT-large and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
config = {}
config["num_epochs"] = 15
config["batch_size"] = 1 # only 1 works
config["lr"] = 2e-5
config["num_warmup_steps"] = 0
config["model_checkpoint"] = model_checkpoint

In [26]:
if using_wandb:
  wandb.init(project="DL-NLP-Clinical-Trial-NER", config=config)
  config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mreylord[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [27]:
label_list = ner_labels.names

train_dataloader = DataLoader(
    datasets['train'], shuffle=True, batch_size=config['batch_size'],
)
val_dataloader = DataLoader(
    datasets['val'], shuffle=True, batch_size=config['batch_size'],
)

optimizer = AdamW(model.parameters(), lr=config['lr'])

num_training_steps = config["num_epochs"] * len(train_dataloader)

lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=config["num_warmup_steps"], num_training_steps=num_training_steps
)

loss_fct = BCEWithLogitsLoss()

progress_bar = tqdm(range(num_training_steps))

for epoch in range(config["num_epochs"]):
    model.train()
    epoch_loss = 0

    for batch in train_dataloader:
        labels = batch.pop('labels')

        logits = model(**batch)

        flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

        loss = loss_fct(flat_logits, flat_labels)
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        epoch_loss += loss.item() * labels.size(0)

        progress_bar.update(1)

    epoch_loss = epoch_loss / len(train_dataloader)
    progress_bar.write(f"Epoch {epoch}, Loss: {epoch_loss}")

    model.eval()

    preds = []
    true_labels = []

    for batch in val_dataloader:
        labels = batch.pop("labels")

        with torch.no_grad():
            logits = model(**batch)

        flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

        pred = flat_logits.heaviside(torch.tensor(
            [0.0], device=device)).int().tolist()
        true_label = flat_labels.int().tolist()

        preds.extend(pred)
        true_labels.extend(true_label)

    f1 = f1_score(true_labels, preds, average='micro')
    progress_bar.write(f"f1 micro: {f1}")
    if using_wandb:
        wandb.log({"train_loss": epoch_loss, "micro_f1": f1, "epoch": epoch})
progress_bar.close()
model.save_pretrained(model_name)


  0%|          | 0/19500 [00:00<?, ?it/s]

Epoch 0, Loss: 0.06336923739765413
f1 micro: 0.2552083333333333
Epoch 1, Loss: 0.02923450159822376
f1 micro: 0.415041782729805
Epoch 2, Loss: 0.023584417085199115
f1 micro: 0.12630208333333334
Epoch 3, Loss: 0.023151620796505505
f1 micro: 0.4455598455598456
Epoch 4, Loss: 0.029753165467051215
f1 micro: 0.3097549699491447
Epoch 5, Loss: 0.018633879907339668
f1 micro: 0.46185737976782754
Epoch 6, Loss: 0.01627886500988657
f1 micro: 0.5843045843045843
Epoch 7, Loss: 0.013926863184091276
f1 micro: 0.6513070447496677
Epoch 8, Loss: 0.013536613511700685
f1 micro: 0.678585617798967
Epoch 9, Loss: 0.011488836632343008
f1 micro: 0.6515513126491647
Epoch 10, Loss: 0.010180796083027067
f1 micro: 0.6550770446463848
Epoch 11, Loss: 0.009196650322365503
f1 micro: 0.6932610379550737
Epoch 12, Loss: 0.008655384090111162
f1 micro: 0.6917057902973397
Epoch 13, Loss: 0.007758592219414333
f1 micro: 0.7014581734458941
Epoch 14, Loss: 0.007352741909943198
f1 micro: 0.6972126765941199


In [28]:
test_dataloader = DataLoader(
    datasets['test'], shuffle=True, batch_size=1
)
model.eval()

preds = []
true_labels = []

progress_bar = tqdm(range(len(test_dataloader)))
for batch in test_dataloader:
    labels = batch.pop("labels")

    with torch.no_grad():
        logits = model(**batch)

    flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

    pred = flat_logits.heaviside(torch.tensor(
        [0.0], device=device)).int().tolist()
    true_label = flat_labels.int().tolist()

    preds.extend(pred)
    true_labels.extend(true_label)

    progress_bar.update(1)

f1 = f1_score(true_labels, preds, average='micro', zero_division=1)
f1_per_class = f1_score(true_labels, preds, average=None, zero_division=1)
for label, score in zip(label_list, f1_per_class):
    print(f"{label}: {score:.4f}")
progress_bar.write(f"f1 micro: {f1}")
if using_wandb:
    wandb.log({"test_micro_f1": f1})
    wandb.finish()
progress_bar.close()

  0%|          | 0/385 [00:00<?, ?it/s]

NumberAffected: 0.0000
NumberPatientsArm: 0.0000
PMID: 0.0000
MinAge: 0.0000
Title: 0.9809
Author: 0.2801
Precondition: 0.5824
Frequency: 0.0000
SubGroupDescription: 0.0000
TimePoint: 0.2921
Journal: 0.9000
PValueChangeValue: 0.4038
SdDevBL: 0.0000
ConclusionComment: 0.9070
DoseValue: 0.0000
CTDesign: 0.0000
DiffGroupAbsValue: 0.4221
ResultMeasuredValue: 0.1806
PublicationYear: 0.0000
AggregationMethod: 1.0000
Drug: 1.0000
FinalNumPatientsArm: 1.0000
AllocationRatio: 0.0000
PvalueDiff: 0.6055
ConfIntervalDiff: 0.5985
Country: 0.0000
SdDevResValue: 0.0000
ObservedResult: 0.1293
ConfIntervalChangeValue: 0.4762
PercentageAffected: 0.6220
ObjectiveDescription: 0.9098
DoseDescription: 0.0000
RelativeChangeValue: 1.0000
SdDevChangeValue: 0.3143
AvgAge: 0.0000
NumberPatientsCT: 0.3571
<SPC>: 1.0000
f1 micro: 0.6913932477413219


VBox(children=(Label(value='0.001 MB of 0.012 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.097466…

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
micro_f1,▃▅▁▅▃▅▇▇█▇▇████
test_micro_f1,▁
train_loss,█▄▃▃▄▂▂▂▂▂▁▁▁▁▁

0,1
epoch,14.0
micro_f1,0.69721
test_micro_f1,0.69139
train_loss,0.00735


In [31]:
index = 17
examples = datasets['test'][index:index+1]
print(examples['tokens'][0])

labels = examples['labels']
with torch.no_grad():
    logits = model(**examples)

flat_logits, flat_labels = mask_and_flatten_logits_and_labels(logits, labels)

pred = flat_logits.heaviside(torch.tensor(
    [0.0], device=device)).int().tolist()
true_label = flat_labels.int().tolist()
f1 = f1_score(true_label, pred, average='micro', zero_division=1)

print(pred)
print(true_label)
print(ner_labels.tensor2sentence(true_label))
print(ner_labels.tensor2sentence(pred))
print(f"Micro-F1 Score: {f1:.3f}")

['Mean', 'number', 'of', 'overall', 'hypoglycaemic', 'events', 'with', 'basal', '-', 'bolus', 'and', 'premix', 'was', '13', '.', '99', 'and', '18', '.', '54', 'events', '/', 'patient', 'year', ',', 'respectively', '(', 'difference', ':', '-', '3', '.', '90', ';', '95', '%', 'CI', ':', '-', '10', '.', '40', ',', '2', '.', '60', ';', 'p', '=', '0', '.', '2385', ')', '.']
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [30]:
wandb.finish()