In [1]:
from collections import defaultdict, namedtuple
from typing import *

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.dataset import Subset
from openprompt.utils.logging import logger
from torchnlp.encoders import LabelEncoder
from typing import Union

import os

import pandas as pd
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from loguru import logger

from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from transformers import RobertaTokenizerFast as RobertaTokenizer
from transformers import AutoTokenizer, AutoModel
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

from openprompt.data_utils.data_processor import DataProcessor

from bert_classifier import MimicBertModel, MimicDataset, MimicDataModule
import argparse
from datetime import datetime
import warnings


# TODO adapt this to work for all datatasks - we need to return dataset with encoded labels and the original class labels

In [2]:
triage_data = pd.read_csv("../mimic3-icd9-data/intermediary-data/triage/train.csv")

In [3]:
triage_data.head()

Unnamed: 0,text,label,triage-category
0,: : : Sex: F Service: CARDIOTHORACIC Allergies...,4240,Cardiology
1,: : : Sex: F Service: NEONATOLOGY HISTORY: wee...,V3001,Obstetrics
2,: : : Sex: M Service: CARDIOTHORACIC Allergies...,41041,Cardiology
3,: : : Sex: F Service: MEDICINE Allergies: Peni...,51881,Respiratory
4,: : : Sex: M Service: ADMISSION DIAGNOSIS: . S...,41401,Cardiology


In [4]:
icd9_50_data = pd.read_csv("../mimic3-icd9-data/intermediary-data/top_50_icd9/train.csv")
icd9_50_data.head()

Unnamed: 0,text,label
0,: : : Sex: F Service: CARDIOTHORACIC Allergies...,4240
1,: : : Sex: F Service: NEONATOLOGY HISTORY: wee...,V3001
2,: : : Sex: M Service: CARDIOTHORACIC Allergies...,41041
3,: : : Sex: F Service: MEDICINE Allergies: Peni...,51881
4,: : : Sex: F Service: CARDIOTHORACIC Allergies...,3962


In [9]:
mortality_data = pd.read_csv("../clinical-outcomes-data/mimic3-clinical-outcomes/mp/train.csv")
mortality_data.head()

Unnamed: 0,id,text,hospital_expire_flag
0,107384,"CHIEF COMPLAINT: AMS, concern for toxic alcoho...",0
1,101061,CHIEF COMPLAINT: abdominal pain\n\nPRESENT ILL...,0
2,127180,CHIEF COMPLAINT: Bilateral Sub Dural Hematoma\...,0
3,168339,CHIEF COMPLAINT: Intracranial bleed\n\nPRESENT...,0
4,154044,CHIEF COMPLAINT: ischemic left foot\n\nPRESENT...,0


In [26]:
class Mimic_ICD9_Processor():


    '''
    Function to convert mimic icd9 dataset to a open prompt ready dataset. 
    
    We also instantiate a LabelEncoder() class which is fitted to the given dataset. Fortunately it appears
    to create the same mapping for each set, given each set contains all classes. 

    This is not ideal, and need to think of a better way to store the label encoder based on training data.
    

  
    
    '''
    # TODO Test needed
    def __init__(self):
        super().__init__()


    def balance_dataset(self,df, random_state = 42):
    
        '''
        Function to balance the training dataset - won't bother with the valid and test sets


        '''   

        # slightly clunky but works
        g = df.groupby('label')
        g = pd.DataFrame(g.apply(lambda x: x.sample(g.size().min(), random_state = random_state).reset_index(drop=True)))
        g.reset_index(drop=True, inplace = True)

        return g.sample(frac=1, random_state=random_state)

    def get_examples(self, data_dir, mode = "train", label_encoder = None,
                     generate_class_labels = True, class_labels_save_dir = "scripts/mimic_icd9_top50/", balance_data = False):

        path = f"{data_dir}/{mode}.csv"
        print(f"loading {mode} data")
        print(f"data path provided was: {path}")
        examples = []
        df = pd.read_csv(path)
        
        # balance data based on minority class if desired
        if balance_data:
            df = self.balance_dataset(df)

        # need to either initializer and fit the label encoder if not provided
        if label_encoder is None:
            self.label_encoder = LabelEncoder(np.unique(df.label).tolist(), reserved_labels = [])
        else: 
            print("we were given a label encoder")
            self.label_encoder = label_encoder

        # new df to fill in examples list
        new_dfs = []
        for idx, row in tqdm(df.iterrows()):
#             print(row)
            body, label = row['text'],row['label']
            label = self.label_encoder.encode(label).numpy() 
            # add to new df
            new_df = pd.DataFrame({'text':[body],'label':[label]})
            
            new_dfs.append(new_df)
          
        
#         concat all examples
        all_dfs = pd.concat(new_dfs)
        logger.info(f"Returning {len(all_dfs)} samples!") 

#         now we want to return a list of the non-encoded labels based on the fitted label encoder
        if generate_class_labels:
            logger.info(f"generating class labels!")
            class_labels = self.generate_class_labels()

        return all_dfs, class_labels

    def generate_class_labels(self):
        # now we want to return a list of the non-encoded labels based on the fitted label encoder
        try:
            return list(self.label_encoder.tokens.keys())
        except:
            print("No class labels as haven't fitted any data yet. Run get_examples first!")
            raise NotImplementedError

    
    def load_class_labels(self, file_path = "./scripts/mimic_icd9_top50/labels.txt"):
        # function to load pre-generated class labels
        # returns list of class labels

        text_file = open(f"{file_path}", "r")

        class_labels = text_file.read().split("\n")

        return class_labels

In [15]:
class Mimic_ICD9_Triage_Processor():


    '''
    Function to convert mimic icd9 dataset to a open prompt ready dataset. 
    
    We also instantiate a LabelEncoder() class which is fitted to the given dataset. Fortunately it appears
    to create the same mapping for each set, given each set contains all classes. 

    This is not ideal, and need to think of a better way to store the label encoder based on training data.
    

  
    
    '''
    # TODO Test needed
    def __init__(self):
        super().__init__()


    def balance_dataset(self,df, random_state = 42):
    
        '''
        Function to balance the training dataset - won't bother with the valid and test sets


        '''   

        # slightly clunky but works
        g = df.groupby('label')
        g = pd.DataFrame(g.apply(lambda x: x.sample(g.size().min(), random_state = random_state).reset_index(drop=True)))
        g.reset_index(drop=True, inplace = True)

        return g.sample(frac=1, random_state=random_state)

    def get_examples(self, data_dir, mode = "train", label_encoder = None,
                     generate_class_labels = True, class_labels_save_dir = "scripts/mimic_icd9_top50/", balance_data = False):

        path = f"{data_dir}/{mode}.csv"
        print(f"loading {mode} data")
        print(f"data path provided was: {path}")
        examples = []
        df = pd.read_csv(path)
        
        # balance data based on minority class if desired
        if balance_data:
            df = self.balance_dataset(df)

        # need to either initializer and fit the label encoder if not provided
        if label_encoder is None:
            self.label_encoder = LabelEncoder(np.unique(df["triage-category"]).tolist(), reserved_labels = [])
        else: 
            print("we were given a label encoder")
            self.label_encoder = label_encoder

        # new df to fill in examples list
        new_dfs = []
        for idx, row in tqdm(df.iterrows()):
#             print(row)
            body, label = row['text'],row['triage-category']
            label = self.label_encoder.encode(label).numpy() 
            # add to new df
            new_df = pd.DataFrame({'text':[body],'label':[label]})
            
            new_dfs.append(new_df)
          
        
#         concat all examples
        all_dfs = pd.concat(new_dfs)
        logger.info(f"Returning {len(all_dfs)} samples!") 

#         now we want to return a list of the non-encoded labels based on the fitted label encoder
        if generate_class_labels:
            logger.info(f"generating class labels!")
            class_labels = self.generate_class_labels()

        return all_dfs, class_labels

    def generate_class_labels(self):
        # now we want to return a list of the non-encoded labels based on the fitted label encoder
        try:
            return list(self.label_encoder.tokens.keys())
        except:
            print("No class labels as haven't fitted any data yet. Run get_examples first!")
            raise NotImplementedError

    
    def load_class_labels(self, file_path = "./scripts/mimic_icd9_top50/labels.txt"):
        # function to load pre-generated class labels
        # returns list of class labels

        text_file = open(f"{file_path}", "r")

        class_labels = text_file.read().split("\n")

        return class_labels

In [37]:
class Mimic_Mortality_Processor():


    '''
    Function to convert mimic mortality prediction dataset from the clinical outcomes paper: https://aclanthology.org/2021.eacl-main.75/
    
    to a open prompt ready dataset. 
    
    We also instantiate a LabelEncoder() class which is fitted to the given dataset. Fortunately it appears
    to create the same mapping for each set, given each set contains all classes.    
    
    '''
    # TODO Test needed
    def __init__(self):
        super().__init__()   

    def get_ce_class_weights(self,df):
        
        '''
        Function to calculate class weights to pass to cross entropy loss in pytorch framework.
        
        Here we use the sklearn compute_class_weight function.
        
        Returns: un-normalized class weights inverse to sample size. i.e. lower number given to majority class
        '''

        # calculate class weights 
        ce_class_weights = compute_class_weight("balanced", classes = np.unique(df["hospital_expire_flag"]),
                                             y = df['hospital_expire_flag'] )

        return ce_class_weights
    
    def get_weighted_sampler_class_weights(self, df, normalized = True):
        
        '''
        Function to create array of per sample class weights to pass to the weighted random sampler.
        
        Purpose is to create batches which sample from the entire dataset based on class weights ->
        this attempts to create balanced batches during training.
        
        DO NOT SHUFFLE DATASET WHEN TRAINING - use weightedrandomsampler
        '''
        
        if normalized:
            nSamples = df["hospital_expire_flag"].value_counts()
            class_weights = [1 - (x / sum(nSamples)) for x in nSamples]
            
        # can use the class weights derived from the get_ce_class weights function
        else:
            class_weights = self.get_ce_class_weights(df)
        
        # creata dict for easy mapping
        class_weights_dict = {0:class_weights[0], 1:class_weights[1]}        
        
        # then need to assign these class specific weights to each sample based on their class
        class_weights_array = df["hospital_expire_flag"].map(class_weights_dict)
        
        return class_weights_array
        
        
        
        
    def balance_dataset(self,df, random_state = 42):
    
        '''
        Function to balance the training dataset - won't bother with the valid and test sets


        '''   


        # slightly clunky but works
        g = df.groupby('hospital_expire_flag')
        g = pd.DataFrame(g.apply(lambda x: x.sample(g.size().min(), random_state = random_state).reset_index(drop=True)))
        g.reset_index(drop=True, inplace = True)

        return g.sample(frac=1, random_state=random_state)

    def get_examples(self, data_dir, mode = "train", label_encoder = None,
                     generate_class_labels = True, class_labels_save_dir = "./scripts/mimic_mortality/",
                     balance_data = False, class_weights = False, sampler_weights = False):

        path = f"{data_dir}/{mode}.csv"
        print(f"loading {mode} data")
        print(f"data path provided was: {path}")
        examples = []
        df = pd.read_csv(path)
        
        # if balance data - balance based on minority class
        
        if balance_data:
            df = self.balance_dataset(df)

        # map the binary classification label to a new string class label
        df["label"] = df["hospital_expire_flag"].map({0:"alive",1:"deceased"})
        
        # need to either initializer and fit the label encoder if not provided
        if label_encoder is None:
            self.label_encoder = LabelEncoder(np.unique(df["label"]).tolist(),reserved_labels = [])
        else: 
            print("we were given a label encoder")
            self.label_encoder = label_encoder
        
        # calculate class_weights
        if class_weights:
            print("getting class weights")
            task_class_weights = self.get_ce_class_weights(df)
        
        # calculate all sample weights for weighted sampler
        if sampler_weights:
            print("getting weights for sampler!")
            sampler_class_weights = self.get_weighted_sampler_class_weights(df)
        
        # new df to fill in examples list
        new_dfs = []
        for idx, row in tqdm(df.iterrows()):
            body, label = row['text'],row['label']
            label = self.label_encoder.encode(label).numpy() 
            # add to new df
            new_df = pd.DataFrame({'text':[body],'label':[int(label)]})
            
            new_dfs.append(new_df)
        
#         concat all examples
        all_dfs = pd.concat(new_dfs)
        logger.info(f"Returning {len(all_dfs)} samples!") 
#         now we want to return a list of the non-encoded labels based on the fitted label encoder
        if generate_class_labels:
            class_labels = self.generate_class_labels()            
        if class_weights and sampler_weights:
            print("cannot return both class and sample weights. Just returning samples")
            return all
        if class_weights:
            return all_dfs, class_labels, task_class_weights
        elif sampler_weights:
            return all_dfs, class_labels, sampler_class_weights
        else:
            return all_dfs, class_labels

    def generate_class_labels(self):
        # now we want to return a list of the non-encoded labels based on the fitted label encoder
        try:
            return list(self.label_encoder.tokens.keys())
        except:
            print("No class labels as haven't fitted any data yet. Run get_examples first!")
            raise NotImplementedError

    
    def load_class_labels(self, file_path = "./scripts/mimic_mortality/labels.txt"):
        # function to load pre-generated class labels
        # returns list of class labels

        text_file = open(f"{file_path}", "r")

        class_labels = text_file.read().split("\n")

        return class_labels

# test all the dataprocessors

In [31]:
# test the icd9 datasets
data_dir = "../mimic3-icd9-data/intermediary-data/top_50_icd9/"

In [32]:
# get different splits

dataset_train,train_class_labels = Mimic_ICD9_Processor().get_examples(data_dir = f"{data_dir}",
                                                                                 mode = "train")



loading train data
data path provided was: ../mimic3-icd9-data/intermediary-data/top_50_icd9//train.csv


0it [00:00, ?it/s]

2022-03-18 11:14:49.958 | INFO     | __main__:get_examples:70 - Returning 0 samples!
2022-03-18 11:14:49.958 | INFO     | __main__:get_examples:74 - generating class labels!


In [34]:
dataset_train

Unnamed: 0,text,label
0,: : : Sex: F Service: CARDIOTHORACIC Allergies...,15
0,: : : Sex: F Service: NEONATOLOGY HISTORY: wee...,47
0,: : : Sex: M Service: CARDIOTHORACIC Allergies...,11
0,: : : Sex: F Service: MEDICINE Allergies: Peni...,32
0,: : : Sex: F Service: CARDIOTHORACIC Allergies...,9
...,...,...
0,: : Service: HISTORY OF THE PRESENT ILLNESS: M...,35
0,: : : Sex: F Service: SURGERY Allergies: Patie...,43
0,: : Service: CARDIOTHORACIC Allergies: Penicil...,12
0,: : : Sex: M Service: Neonatology HISTORY OF P...,47


In [33]:
len(train_class_labels)

50

In [6]:
# icd9_triage
data_dir = "../mimic3-icd9-data/intermediary-data/triage/"

dataset_train,train_class_labels = Mimic_ICD9_Triage_Processor().get_examples(data_dir = f"{data_dir}",
                                                                                 mode = "train")

loading train data
data path provided was: ../mimic3-icd9-data/intermediary-data/triage//train.csv


0it [00:00, ?it/s]

2022-03-18 13:26:46.101 | INFO     | __main__:get_examples:70 - Returning 9559 samples!
2022-03-18 13:26:46.102 | INFO     | __main__:get_examples:74 - generating class labels!


In [7]:
dataset_train

Unnamed: 0,text,label
0,: : : Sex: F Service: CARDIOTHORACIC Allergies...,1
0,: : : Sex: F Service: NEONATOLOGY HISTORY: wee...,4
0,: : : Sex: M Service: CARDIOTHORACIC Allergies...,1
0,: : : Sex: F Service: MEDICINE Allergies: Peni...,6
0,: : : Sex: M Service: ADMISSION DIAGNOSIS: . S...,1
...,...,...
0,: : : Sex: F Service: MEDICINE Allergies: Pati...,0
0,: : : Sex: F Service: MEDICINE Allergies: Peni...,6
0,Unit No: : : : Sex: F Service: Neonatology was...,4
0,: : Service: CARDIOTHORACIC Allergies: Penicil...,1


In [8]:
train_class_labels

['AcuteMedicine',
 'Cardiology',
 'Gastroenterology',
 'Neurology',
 'Obstetrics',
 'Oncology',
 'Respiratory']

In [38]:
# mortality

data_dir = "../clinical-outcomes-data/mimic3-clinical-outcomes/mp/"


dataset_train,train_class_labels = Mimic_Mortality_Processor().get_examples(data_dir = f"{data_dir}",
                                                                                 mode = "train", balance_data = True)

loading train data
data path provided was: ../clinical-outcomes-data/mimic3-clinical-outcomes/mp//train.csv


0it [00:00, ?it/s]

2022-03-18 14:21:42.007 | INFO     | __main__:get_examples:126 - Returning 7068 samples!


In [33]:
dataset_train.head()

Unnamed: 0,text,label
0,CHIEF COMPLAINT: Abdominal Pain\n\nPRESENT ILL...,1
0,CHIEF COMPLAINT: Dyspnea\n\nPRESENT ILLNESS: 7...,0
0,CHIEF COMPLAINT: Septic shock\n\nPRESENT ILLNE...,1
0,CHIEF COMPLAINT: Jaw pain.\n\nPRESENT ILLNESS:...,0
0,CHIEF COMPLAINT: Syncope\n\nPRESENT ILLNESS: T...,0


In [39]:
dataset_train.dtypes

text     object
label     int64
dtype: object

In [24]:
train_class_labels

['alive', 'deceased']