In [1]:
import pickle
import numpy as np
from pytorch_lightning import (
    LightningDataModule,
    LightningModule,
    Trainer,
    seed_everything,
)
import torch

from torch.utils.data import DataLoader
from transformers import (
    AdamW,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

import datasets
from typing import Optional
from datetime import datetime

# import evaluate
from torchmetrics.classification import BinaryAccuracy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 0
rng = np.random.default_rng(SEED)
GEN_SEED = torch.Generator().manual_seed(SEED)
seed_everything(SEED)
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

Global seed set to 0


In [3]:
with open("../../dataset.pkl", "rb") as f:
    tuple_dataset = pickle.load(f)

print("loaded dataset")
print("dataset: ", len(tuple_dataset))
print(tuple_dataset[:1])

loaded dataset
dataset:  77998
[('patientregistry: idcenter=100, idpatient=3080, sex=M, yeardiagnosisdiabetes=1988-01-01, levelofeducation=[UNK], maritalstatus=[UNK], profession=[UNK], yearofbirth=1936-01-01 00:00:00, yearfirstaccess=1991-01-01 00:00:00, yearofdeath=[UNK], diagnosis: idcenter=100, idpatient=3080, date=1991-10-29 00:00:00, amdcode=AMD097, meaning=Cigarette smoke, value=N, idcenter=100, idpatient=3080, date=2004-02-04 00:00:00, amdcode=AMD044, meaning=Ischemic heart disease, value=414, idcenter=100, idpatient=3080, date=2004-02-04 00:00:00, amdcode=AMD247, meaning=Other comorbidities, value=414.9, idcenter=100, idpatient=3080, date=2004-07-21 00:00:00, amdcode=AMD130, meaning=Non diabetic retinopathy, value=[UNK], idcenter=100, idpatient=3080, date=2005-02-21 00:00:00, amdcode=AMD049, meaning=Coronary bypass, value=S, idcenter=100, idpatient=3080, date=2005-02-21 00:00:00, amdcode=AMD247, meaning=Other comorbidities, value=36.10, idcenter=100, idpatient=3080, date=2005-0

In [None]:
# class PubMedBERTDataset(Dataset):
#     def __init__(self, data):
#         # here data is a list of tuples,
#         # each containing the patient history string and their label
#         self.data = data

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         patient_history = self.data[idx][0]
#         label = self.data[idx][1]
#         return patient_history, label

In [4]:
def convert_to_huggingfaceDataset(tuple_dataset):
    # here data is a list of tuples,
    # each containing the patient history string and their label
    # we need to convert it to a hugginface dataset
    dict_list = [{"label": data[1], "text": data[0]} for data in tuple_dataset]
    dataset = datasets.Dataset.from_list(dict_list)
    return dataset

In [5]:
# max_len = [0] * 100
# index = [0] * 100

# # Save the 100 most length sentence and their indexes
# for i, (ph, l) in enumerate(dataset):
#     # Count the number of words in the ph variable
#     ph_len = len(ph.replace(' ', '=').split('='))
#     for j, x in enumerate(max_len):
#         if ph_len > x:
#             max_len.insert(j, ph_len)
#             index.insert(j, i)
#             max_len.pop()
#             index.pop()
#             break

# print(max_len)
# print(index)

In [6]:
class PubMedBERTDataModule(LightningDataModule):
    def __init__(
        self,
        tuple_dataset,
        model_name_with_path: str,
        max_seq_length: int = 512,  # 512 is the max length of BERT and PubMedBERT but I need 32768
        train_batch_size: int = 4,
        eval_batch_size: int = 4,
        **kwargs,
    ):
        super().__init__()
        self.model_name_with_path = model_name_with_path
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name_with_path, use_fast=True
        )

    def setup(self, stage=None):
        dataset = convert_to_huggingfaceDataset(tuple_dataset)
        tokenized_dataset = dataset.map(
            self.convert_to_features,
            batched=True,
            remove_columns=["text", "label"],
        )
        tokenized_dataset.set_format(type="torch")

        # split dataset into train and validation sampling randomly
        # use 20% of training data for validation
        train_set_size = int(len(tokenized_dataset) * 0.8)
        valid_set_size = len(tokenized_dataset) - train_set_size

        # split the dataset randomly into two
        self.train_data, self.valid_data = torch.utils.data.random_split(
            tokenized_dataset, [train_set_size, valid_set_size], generator=GEN_SEED
        )

    def prepare_data(self):
        AutoTokenizer.from_pretrained(
            self.model_name_with_path,
            use_fast=True,
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_data, batch_size=self.train_batch_size, shuffle=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_data, batch_size=self.eval_batch_size, shuffle=False
        )

    def test_dataloader(self):
        # placeholder
        return DataLoader(
            self.valid_data, batch_size=self.eval_batch_size, shuffle=False
        )

    def convert_to_features(self, example_batch, indices=None):
        # Tokenize the patient history
        features = self.tokenizer(
            text=example_batch["text"],
            max_length=self.max_seq_length,
            padding="longest",
            truncation=True,
            return_tensors="pt",
        )
        # Rename label to labels to make it easier to pass to model forward
        features["labels"] = example_batch["label"]

        return features

In [7]:
# dm = PubMedBERTDataModule(tuple_dataset, MODEL_NAME)
# dm.prepare_data()
# dm.setup("fit")
# next(iter(dm.train_dataloader()))

In [8]:
class PubMedBERTTransformer(LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        num_labels: int = 2,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 16,
        eval_batch_size: int = 16,
        eval_splits: Optional[list] = None,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.config = AutoConfig.from_pretrained(
            model_name_or_path, num_labels=num_labels
        )
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path, config=self.config
        )
        self.metric = BinaryAccuracy()
        self.validation_step_outputs = []

    def forward(self, **inputs):
        return self.model(
            **inputs
        )
    
    def step(self, batch):
        outputs = self(**batch)
        loss, logits = outputs[:2]
        if self.hparams.num_labels > 1:
            preds = logits.argmax(axis=1)
        elif self.hparams.num_labels == 1:
            preds = logits.squeeze()
        labels = batch["labels"]
        return {"loss": loss, "logits": logits, "preds": preds, "labels": labels}

    def training_step(self, batch, batch_idx):
        outputs = self.step(batch)
        value = self.metric(outputs["preds"], outputs["labels"])
        self.log("train_acc_step", value, on_epoch=True)
        self.log("train_loss", outputs["loss"], prog_bar=True)
        return outputs["loss"]

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self.step(batch)
        # self.validation_step_outputs.append(preds)
        value = self.metric(outputs["preds"], outputs["labels"])
        self.log("train_acc_step", value, on_epoch=True)
        return {"loss": outputs["loss"], "preds": outputs["preds"], "labels": outputs["labels"]}

    # def on_validation_epoch_end(self):
    # print("on_validation_epoch_end")
    # print(self.validation_step_outputs)
    # preds = (
    #     torch.cat([x["preds"] for x in self.validation_step_outputs])
    #     .detach()
    #     .cpu()
    #     .numpy()
    # )
    # labels = (
    #     torch.cat([x["labels"] for x in self.validation_step_outputs])
    #     .detach()
    #     .cpu()
    #     .numpy()
    # )
    # loss = torch.stack([x["loss"] for x in self.validation_step_outputs]).mean()
    # self.log("val_loss", loss, prog_bar=True)
    # self.log_dict(
    #     self.metric.compute(predictions=preds, references=labels), prog_bar=True
    # )
    # self.validation_step_outputs.clear()  # free memory

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = torch.optim.Adam(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            eps=self.hparams.adam_epsilon,
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches,
        )
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
        return [optimizer], [scheduler]

In [9]:
dm = PubMedBERTDataModule(tuple_dataset, MODEL_NAME)
dm.setup("fit")
# print(next(iter(dm.train_dataloader())))

model = PubMedBERTTransformer(
    model_name_or_path=MODEL_NAME,
)

trainer = Trainer(
    max_epochs=2,
    accelerator="auto",
    devices="auto",
)
trainer.fit(model=model, datamodule=dm)

Map: 100%|██████████| 150/150 [00:00<00:00, 187.21 examples/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Map: 100%|██████████| 150/150 [00:00<00:00, 223.64 examples/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
  rank_zero_warn(
  rank_zero_warn(

  | Name   | Type                          | Params
---------------------------------------------------------
0 | model  | BertForSequenceClassification | 109 M 
1 | metric | BinaryAccuracy                | 0     
-------------------------

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Epoch 1: 100%|██████████| 30/30 [00:23<00:00,  1.29it/s, v_num=0, train_loss=0.0398]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 30/30 [00:25<00:00,  1.18it/s, v_num=0, train_loss=0.0398]
