In [3]:
from collections import defaultdict, namedtuple
from typing import *

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data.dataset import Subset
from openprompt.utils.logging import logger

from typing import Union

import os

import pandas as pd
import numpy as np
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from loguru import logger

from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from transformers import RobertaTokenizerFast as RobertaTokenizer
from transformers import AutoTokenizer, AutoModel
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

from bert_classifier import MimicBertModel, MimicDataset, MimicDataModule
import argparse
from datetime import datetime
import warnings



# adapt this to work for pytorch lightning models?
class FewShotSampler(object):
    '''
    Few-shot learning is an important scenario this is sampler that samples few examples over each class.
    Args:
        num_examples_total(:obj:`int`, optional): Sampling strategy ``I``: Use total number of examples for few-shot sampling.
        num_examples_per_label(:obj:`int`, optional): Sampling strategy ``II``: Use the number of examples for each label for few-shot sampling.
        also_sample_dev(:obj:`bool`, optional): Whether to apply the sampler to the dev data.
        num_examples_total_dev(:obj:`int`, optional): Sampling strategy ``I``: Use total number of examples for few-shot sampling.
        num_examples_per_label_dev(:obj:`int`, optional): Sampling strategy ``II``: Use the number of examples for each label for few-shot sampling.
    
    '''

    def __init__(self,
                 num_examples_total: Optional[int]=None,
                 num_examples_per_label: Optional[int]=None,
                 also_sample_dev: Optional[bool]=False,
                 num_examples_total_dev: Optional[int]=None,
                 num_examples_per_label_dev: Optional[int]=None,
                 label_col = "hospital_expire_flag"
                 ):
        if num_examples_total is None and num_examples_per_label is None:
            raise ValueError("num_examples_total and num_examples_per_label can't be both None.")
        elif num_examples_total is not None and num_examples_per_label is not None:
            raise ValueError("num_examples_total and num_examples_per_label can't be both set.")
        
        if also_sample_dev:
            if num_examples_total_dev is not None and num_examples_per_label_dev is not None:
                raise ValueError("num_examples_total and num_examples_per_label can't be both set.")
            elif num_examples_total_dev is None and num_examples_per_label_dev is None:
                logger.warning(r"specify neither num_examples_total_dev nor num_examples_per_label_dev,\
                                set to default (equal to train set setting).")
                self.num_examples_total_dev = num_examples_total
                self.num_examples_per_label_dev = num_examples_per_label
            else:
                self.num_examples_total_dev  = num_examples_total_dev
                self.num_examples_per_label_dev = num_examples_per_label_dev

        self.num_examples_total = num_examples_total
        self.num_examples_per_label = num_examples_per_label
        self.also_sample_dev = also_sample_dev
        self.label_col = label_col

    def __call__(self, 
                 dataset: Union[Dataset, List],
                 valid_dataset: Optional[Union[Dataset, List]] = None,
                 seed: Optional[int] = None
                ) -> Union[Dataset, List]:
        '''
        The ``__call__`` function of the few-shot sampler.
        Args:
            dataset (:obj:`Dictionary or dataframe`): The train dataset for the sampler.
            valid_dataset (:obj:`Union[Dataset, List]`, optional): The valid datset for the sampler. Default to None.
            seed (:obj:`int`, optional): The random seed for the sampling.
        
        Returns:
            :obj:`(Union[Dataset, List], Union[Dataset, List])`: The sampled dataset (dataset, valid_dataset), whose type is identical to the input.
        '''
        if valid_dataset is None:
            if self.also_sample_dev:
                return self._sample(dataset, seed, sample_twice=True)
            else:
                dataset = self._sample(dataset, seed, sample_twice=False)
                return pd.DataFrame(dataset)
        else:
            dataset = self._sample(dataset, seed)
            if self.also_sample_dev:
                valid_dataset = self._sample(valid_dataset, seed)
            return pd.DataFrame(dataset)
    
    def _sample(self, 
                data: Union[Dataset, List], 
                seed: Optional[int],
                sample_twice = False,
               ) -> Union[Dataset, List]:
        if seed is not None:
            self.rng = np.random.RandomState(seed)
        else:
            self.rng = np.random.RandomState()
        indices = [i for i in range(len(data))]

        if self.num_examples_per_label is not None:  
            assert self.label_col in data[0].keys(), "sample by label requires the data has the label_col attribute."
            labels = [x[self.label_col] for x in data]
            selected_ids = self.sample_per_label(indices, labels, self.num_examples_per_label) # TODO fix: use num_examples_per_label_dev for dev
        else:
            selected_ids = self.sample_total(indices, self.num_examples_total)
        
        if sample_twice:
            selected_set = set(selected_ids)
            remain_ids = [i for i in range(len(data)) if i not in selected_set]
            if self.num_examples_per_label_dev is not None:
                assert self.label_col in data[0].keys(), "sample by label requires the data has a 'label' attribute."
                remain_labels = [x[self.label_col] for idx, x in enumerate(data) if idx not in selected_set]
                selected_ids_dev = self.sample_per_label(remain_ids, remain_labels, self.num_examples_per_label_dev)
            else:
                selected_ids_dev = self.sample_total(remain_ids, self.num_examples_total_dev)
        
            if isinstance(data, Dataset):
                return Subset(data, selected_ids), Subset(data, selected_ids_dev)
            elif isinstance(data, List):
                return [data[i] for i in selected_ids], [data[i] for i in selected_ids_dev]
        
        else:
            if isinstance(data, Dataset):
                return Subset(data, selected_ids)
            elif isinstance(data, List):
                return [data[i] for i in selected_ids]
        
    
    def sample_total(self, indices: List, num_examples_total):
        '''
        Use the total number of examples for few-shot sampling (Strategy ``I``).
        
        Args:
            indices(:obj:`List`): The random indices of the whole datasets.
            num_examples_total(:obj:`int`): The total number of examples.
        
        Returns:
            :obj:`List`: The selected indices with the size of ``num_examples_total``.
            
        '''
        self.rng.shuffle(indices)
        selected_ids = indices[:num_examples_total]
        logger.info("Selected examples (mixed) {}".format(selected_ids))
        return selected_ids

    def sample_per_label(self, indices: List, labels, num_examples_per_label):
        '''
        Use the number of examples per class for few-shot sampling (Strategy ``II``). 
        If the number of examples is not enough, a warning will pop up.
        
        Args:
            indices(:obj:`List`): The random indices of the whole datasets.
            labels(:obj:`List`): The list of the labels.
            num_examples_per_label(:obj:`int`): The total number of examples for each class.
        
        Returns:
            :obj:`List`: The selected indices with the size of ``num_examples_total``.
        '''

        ids_per_label = defaultdict(list)
        selected_ids = []
        for idx, label in zip(indices, labels):
            ids_per_label[label].append(idx)
        for label, ids in ids_per_label.items():
            tmp = np.array(ids)
            self.rng.shuffle(tmp)
            if len(tmp) < num_examples_per_label:
                logger.info("Not enough examples of label {} can be sampled".format(label))
            selected_ids.extend(tmp[:num_examples_per_label].tolist())
        selected_ids = np.array(selected_ids)
        self.rng.shuffle(selected_ids)
        selected_ids = selected_ids.tolist()    
        logger.info("Selected examples {}".format(selected_ids))
        return selected_ids

In [2]:
def read_csv(data_dir,filename):
    return pd.read_csv(f"{data_dir}{filename}", index_col=None)

In [4]:
# first we just need to read in the train/val/dev files and do the few shot sampling on that before passing on
# data_dir = "../clinical-outcomes-data/mimic3-clinical-outcomes/mp/"
data_dir = "../mimic3-icd9-data/intermediary-data/top_50_icd9/"

# read in training and validation data
train_df = read_csv(data_dir, "train.csv")
val_df = read_csv(data_dir, "valid.csv")
test_df = read_csv(data_dir, "test.csv")

logger.warning(f"train_df shape: {train_df.shape} and train_df cols:{train_df.columns}")





In [5]:
train_df.head()

Unnamed: 0,text,label
0,: : : Sex: F Service: CARDIOTHORACIC Allergies...,4240
1,: : : Sex: F Service: NEONATOLOGY HISTORY: wee...,V3001
2,: : : Sex: M Service: CARDIOTHORACIC Allergies...,41041
3,: : : Sex: F Service: MEDICINE Allergies: Peni...,51881
4,: : : Sex: F Service: CARDIOTHORACIC Allergies...,3962


In [6]:
# convert to dictionary in records style
train_dict = train_df.to_dict(orient = "records")

In [7]:
train_dict[0]

{'text': ": : : Sex: F Service: CARDIOTHORACIC Allergies: Patient recorded as having No Known Allergies to Drugs : Chief Complaint: SOB with exertion, heart murmur since y/o Major Surgical or Invasive Procedure: Mitral valve replacement(mm CE tissue History of Present Illness: y/o female with known MVP who was diagnosed with a heart murmur at age . She was evaluated with serial TTE's which showed worsening MR. Echo showed LVEF % with Mitral valve regurgitant fraction of %. She denies any symptoms. Past Medical History: Hyperlipidemia, MVP/MR, Depression, Obesity Social History: social Etoh, live with mother, deniesDA or tobacco use Family History: noncontributory Physical Exam: y/o F in bed NAD Neuro AA&Ox, nonfocal Chest CTAB resp unlab median sternotomy stable, c/d/i no d/c, RRR no m/r/g chest tubes and epicardial wires removed. Abd S/NT/ND/BS+ EXT warm with trace edema Pertinent Results: RADIOLOGY Preliminary Report CHEST (PA & LAT : AM CHEST (PA & LAT Reason: assess LLL atelectasis

In [13]:
# this is how it works with prompt data
dataset = {}
support_sampler = FewShotSampler(num_examples_per_label = 500, also_sample_dev=False, label_col = "label")
dataset['few_train'] = support_sampler(train_dict, seed=1)



2022-03-21 15:18:24.629 | INFO     | __main__:sample_per_label:190 - Not enough examples of label 4240 can be sampled
2022-03-21 15:18:24.631 | INFO     | __main__:sample_per_label:190 - Not enough examples of label 41041 can be sampled
2022-03-21 15:18:24.633 | INFO     | __main__:sample_per_label:190 - Not enough examples of label 3962 can be sampled
2022-03-21 15:18:24.634 | INFO     | __main__:sample_per_label:190 - Not enough examples of label 51884 can be sampled
2022-03-21 15:18:24.635 | INFO     | __main__:sample_per_label:190 - Not enough examples of label 430 can be sampled
2022-03-21 15:18:24.636 | INFO     | __main__:sample_per_label:190 - Not enough examples of label 4280 can be sampled
2022-03-21 15:18:24.638 | INFO     | __main__:sample_per_label:190 - Not enough examples of label 42823 can be sampled
2022-03-21 15:18:24.639 | INFO     | __main__:sample_per_label:190 - Not enough examples of label V3401 can be sampled
2022-03-21 15:18:24.639 | INFO     | __main__:sample_

In [14]:
dataset['few_train']['label'].value_counts()

41071    500
51881    500
0389     500
4241     500
V3001    500
41401    500
V3000    500
V3101    482
431      460
4240     323
486      312
5070     294
430      284
4280     283
41011    281
41041    277
5789     224
5849     213
1983     197
5770     193
99662    180
43411    177
99859    177
42731    165
03842    160
56212    157
43491    144
4373     137
51884    135
V3401    135
5712     131
4271     126
85221    123
42823    122
03811    120
03849    119
41519    117
4321     116
4414     115
0380     109
53240    108
99811    104
1623     103
3962     102
5715     101
042       99
43310     99
44101     93
29181     92
5761      92
Name: label, dtype: int64

In [None]:
# push data through pipeline
# instantiate datamodule
data_module = MimicDataModule(
    train_df,
    val_df,
    test_df,
    tokenizer,
    batch_size=batch_size,
    max_token_len=max_tokens,
    label_col = args.label_col,
    num_workers=args.loader_workers,
)