In [None]:
# !pip install --upgrade "ray[tune]"
# !pip install "lightning-bolts"
# !pip install "torchvision"

# !pip install "pytorch-lightning==1.4" # need 1.4 for raytune - but 1.5.10 was used for all experiments!

In [1]:
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_bolts.datamodules import MNISTDataModule
import os
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray import tune

import tempfile

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchnlp.encoders import LabelEncoder

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from loguru import logger

from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup, get_constant_schedule_with_warmup 
from transformers.optimization import Adafactor, AdafactorSchedule 
from transformers import RobertaTokenizerFast as RobertaTokenizer
from transformers import AutoTokenizer, AutoModel

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix
from sklearn.metrics import balanced_accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix, roc_auc_score 


from data_utils import FewShotSampler, Mimic_ICD9_Processor, Mimic_ICD9_Triage_Processor, Mimic_Mortality_Processor

from bert_classifier import MimicBertModel, MimicDataset, MimicDataModule

import numpy as np
import pandas as pd

In [2]:
# redefine the mimic bert model to avoid certain logging etc

class hyperMimicBertModel(pl.LightningModule):
    def __init__(self,
                 bert_model,
                 num_labels,
                 class_labels = [],
                 bert_hidden_dim=768,
                 classifier_hidden_dim=768,
                 n_training_steps=None,
                 n_warmup_steps=5000,
                 dropout = 0.1,
                 nr_frozen_epochs = 0,
                 config = None):

        super().__init__()
        logger.warning(f"Building model based on following architecture. {bert_model}")

        # set all the relevant parameters
        self.num_labels = num_labels
        self.class_labels = class_labels

        
        self.n_training_steps = n_training_steps
        self.n_warmup_steps = n_warmup_steps 
        self.nr_frozen_epochs = nr_frozen_epochs

        
        # get parameters from config
        self.lr = config['lr']

        self.save_hyperparameters()

        self.bert = AutoModel.from_pretrained(f"{bert_model}", return_dict=True)
        # nn.Identity does nothing if the dropout is set to None
        self.classification_head = nn.Sequential(nn.Linear(bert_hidden_dim, classifier_hidden_dim),
                                        nn.ReLU(),
                                        nn.Dropout(dropout) if dropout is not None else nn.Identity(),
                                        nn.Linear(classifier_hidden_dim, num_labels))
        

        self.criterion = nn.CrossEntropyLoss()


        self._frozen = False
        

        self.n_training_steps = n_training_steps
        self.n_warmup_steps = n_warmup_steps



    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            
            for param in self.bert.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.bert.parameters():
            param.requires_grad = False
        self._frozen = True

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.bert(input_ids, attention_mask=attention_mask)
        # obtaining the last layer hidden states of the Transformer
        last_hidden_state = output.last_hidden_state  # shape: (batch_size, seq_length, bert_hidden_dim)

        #         or can use the output pooler : output = self.classifier(output.pooler_output)
        # As I said, the CLS token is in the beginning of the sequence. So, we grab its representation
        # by indexing the tensor containing the hidden representations
        CLS_token_state = last_hidden_state[:, 0, :]
        # passing this representation through our custom head
        logits = self.classification_head(CLS_token_state)
        loss = 0
        if labels is not None:
            loss = self.criterion(logits, labels)
        return loss, logits

    def training_step(self, batch, batch_idx):
        
        print("inside training!")
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("train/loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": outputs.detach(), "labels": labels.detach()}

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("valid/loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": outputs.detach(), "labels": labels.detach()}

    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("test/loss", loss, prog_bar=True, logger=True)
        return loss 

    def validation_epoch_end(self, outputs):
        logger.warning("on validation epoch end")

        # get class labels
        class_labels = self.class_labels


        labels = []
        predictions = []
        scores = []
        for output in outputs:
            
            for out_labels in output["labels"].to('cpu').detach().numpy():                                
                labels.append(out_labels)
            for out_predictions in output["predictions"]:
                
                # the handling of roc_auc score differs for binary and multi class
                if len(class_labels) > 2:
                    scores.append(torch.nn.functional.softmax(out_predictions).cpu().tolist())
                # append probas
                else:
                    scores.append(torch.nn.functional.softmax(out_predictions)[1].cpu().tolist())

                # get predictied labels                               
                predictions.append(np.argmax(out_predictions.to('cpu').detach().numpy(), axis = -1))

            #use softmax to normalize, as the sum of probs should be 1
        # get sklearn based metrics
        acc = balanced_accuracy_score(labels, predictions)
        f1_weighted = f1_score(labels, predictions, average = 'weighted')   

        # log to tensorboard

        self.logger.experiment.add_scalar('valid/balanced_accuracy',acc, self.current_epoch)

        self.logger.experiment.add_scalar('valid/f1_weighted',f1_weighted, self.current_epoch)



    def test_epoch_end(self, outputs):
        # get class labels
        class_labels = self.class_labels


        labels = []
        predictions = []
        scores = []
        for output in outputs:
            
            for out_labels in output["labels"].to('cpu').detach().numpy():                                
                labels.append(out_labels)
            for out_predictions in output["predictions"]:
                
                # the handling of roc_auc score differs for binary and multi class
                if len(class_labels) > 2:
                    scores.append(torch.nn.functional.softmax(out_predictions).cpu().tolist())
                # append probas
                else:
                    scores.append(torch.nn.functional.softmax(out_predictions)[1].cpu().tolist())

                # get predictied labels                               
                predictions.append(np.argmax(out_predictions.to('cpu').detach().numpy(), axis = -1))

            #use softmax to normalize, as the sum of probs should be 1
        # get sklearn based metrics
        acc = balanced_accuracy_score(labels, predictions)
        f1_weighted = f1_score(labels, predictions, average = 'weighted')

        # log to tensorboard

        self.logger.experiment.add_scalar('test/balanced_accuracy',acc, self.current_epoch)
        self.logger.experiment.add_scalar('test/f1_weighted',f1_weighted, self.current_epoch)



    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def on_epoch_end(self):
        """ Pytorch lightning hook """        
        logger.warning(f"On epoch {self.current_epoch}. Number of frozen epochs is: {self.nr_frozen_epochs}")
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            logger.warning("unfreezing PLM(encoder)")
            self.unfreeze_encoder()

# TODO - Adapt below to work with our mimic models and tasks

We will be running a hyperparameter search for just one of  the mimic tasks - lets go with icd9_50 as this has the 
highest variability in performance?

In [None]:
# mimic icd9_50 specific parameters

dataset = "icd9_50"
root_data_dir = "../"

n_labels = 50
warmup_steps = 50
total_training_steps = 25000
label_col = "label"
max_tokens = 512
pretrained_model_name = "emilyalsentzer/Bio_ClinicalBERT"

In [None]:
# Defining a search space!
config = {
 "batch_size": tune.choice([2, 4, 8]),
 "grad_accum_steps": tune.choice([2,10,30]),   
 "lr": tune.loguniform(1e-4, 1e-1),
}

In [None]:
tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_model_name}")

# TODO update the dataloading to use the custom dataprocessors from data_utils in this folder

if dataset == "icd9_50":
    logger.warning(f"Using the following dataset: {dataset} ")
    Processor = Mimic_ICD9_Processor
    # update data_dir
    data_dir = f"{root_data_dir}/mimic3-icd9-data/intermediary-data/top_50_icd9/"

    # are we doing any downsampling or balancing etc
    class_weights = False
    balance_data = False

    # get different splits - the processor will return a dataframe and class_labels for each, but we only need training class_labels
    train_df, class_labels = Processor().get_examples(data_dir = data_dir, mode = "train", class_weights = class_weights, balance_data = balance_data)
    val_df,_ = Processor().get_examples(data_dir = data_dir, mode = "valid", class_weights = class_weights, balance_data = balance_data)
    test_df,_ = Processor().get_examples(data_dir = data_dir, mode = "test", class_weights = class_weights, balance_data = balance_data)

In [None]:
model = hyperMimicBertModel(bert_model=pretrained_model_name,
                             num_labels=n_labels,
                             n_warmup_steps=warmup_steps,
                             n_training_steps=total_training_steps                           
                             )

In [3]:


def train_mimic(config, dataset = "icd9_50",
                pretrained_model_name ="emilyalsentzer/Bio_ClinicalBERT" ,
                root_data_dir = "../",max_tokens = 512, label_col = "label",
                n_labels = 50, num_epochs=5, warmup_steps = 50, total_training_steps = 20000,
                num_gpus=0, data_dir = "~/ray_tune_results/"):
    
    tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_model_name}")

# TODO update the dataloading to use the custom dataprocessors from data_utils in this folder

    if dataset == "icd9_50":
        logger.warning(f"Using the following dataset: {dataset} ")
        Processor = Mimic_ICD9_Processor
        # update data_dir
        root_data_dir = f"{root_data_dir}/mimic3-icd9-data/intermediary-data/top_50_icd9/"

        # are we doing any downsampling or balancing etc
        class_weights = False
        balance_data = False

        # get different splits - the processor will return a dataframe and class_labels for each, but we only need training class_labels
        train_df, class_labels = Processor().get_examples(data_dir = root_data_dir, mode = "train", class_weights = class_weights, balance_data = balance_data)
        val_df,_ = Processor().get_examples(data_dir = root_data_dir, mode = "valid", class_weights = class_weights, balance_data = balance_data)
        test_df,_ = Processor().get_examples(data_dir = root_data_dir, mode = "test", class_weights = class_weights, balance_data = balance_data)

        # load model
        model = hyperMimicBertModel(bert_model=pretrained_model_name,
                                 num_labels=n_labels,
                                 n_warmup_steps=warmup_steps,
                                 n_training_steps=total_training_steps,
                                    config = config

                                 )
    
    

    # push data through pipeline
    # instantiate datamodule
    data_module = MimicDataModule(
        train_df,
        val_df,
        test_df,
        tokenizer,
        batch_size=config["batch_size"],
        max_token_len=max_tokens,
        label_col = label_col,
        num_workers=1,
    )
    metrics = {"loss": "valid/loss", "acc": "valid/balanced_accuracy"}
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        progress_bar_refresh_rate=0,
        callbacks=[TuneReportCallback(metrics, on="validation_end")])
    
    trainer.fit(model, data_module)

In [None]:
config

In [None]:
num_samples = 10
num_epochs = 5
gpus_per_trial = 0 # set this to higher if using GPU

# Defining a search space!
config = {
 "batch_size": 2,
 "grad_accum_steps": 2,   
 "lr": 1e-3,
}

train_mimic(config)

# bug below may suggest we want to edit the bert classifier for this hyperparameter search to not log so much?

In [8]:
num_samples = 10
num_epochs = 5
gpus_per_trial = 0 # set this to higher if using GPU

# Defining a search space!
config = {
 "batch_size": tune.choice([2, 4, 8]),   
 "lr": tune.loguniform(1e-4, 1e-1),
}

trainable = tune.with_parameters(
    train_mimic,   
    num_epochs=num_epochs,
    num_gpus=gpus_per_trial,
    data_dir = "~/ray_tune_results")

analysis = tune.run(
    trainable,
    resources_per_trial={
        "cpu": 1,
        "gpu": gpus_per_trial
    },
    metric="loss",
    mode="min",
    config=config,
    num_samples=num_samples,
    name="tune_m"
    
    )



print(analysis.best_config)

TypeError: ray.cloudpickle.dumps(<class 'ray.tune.function_runner.wrap_function.<locals>.ImplicitFunc'>) failed.
To check which non-serializable variables are captured in scope, re-run the ray script with 'RAY_PICKLE_VERBOSE_DEBUG=1'. Other options: 
-Try reproducing the issue by calling `pickle.dumps(trainable)`. 
-If the error is typing-related, try removing the type annotations and try again.

In [None]:
import pickle
pickle.dumps(trainable)