# Imports

In [None]:
# General includes.
import os
import io
import re
import errno
import gc
import random
import threading
import math
import itertools
import functools
from copy import deepcopy
import logging
import pickle
import tqdm
import hashlib

#from termcolor import colored, cprint
import colored
from datetime import datetime, timedelta
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True)
from matplotlib import pyplot as plt

# Typing includes.
from typing import Dict, List, Optional, Any, Tuple, Callable, Iterable

# Numerical includes.
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, roc_curve, precision_recall_curve
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# pyHealth includes.
from pyhealth.datasets import BaseDataset, MIMIC3Dataset, eICUDataset, SampleDataset, split_by_patient
from pyhealth.datasets.utils import MODULE_CACHE_PATH, strptime, hash_str
from pyhealth.data import Patient, Visit, Event

In [None]:
# Model imports 
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertTokenizerFast
from transformers import TensorType
from transformers import AutoTokenizer, AutoConfig, AutoModel

In [None]:
# Local imports
from tasks.code_emb_funcs import *
from tasks.desc_emb_funcs import *
from tasks.eicu_funcs import *
from tasks.dataset_transforms import *
from tasks.collate_funcs import *
# from tasks import code_emb_funcs, desc_emb_funcs, eicu_funcs
# from tasks import dataset_transforms, collate_funcs
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

In [None]:
# https://stackoverflow.com/questions/5364050/reloading-submodules-in-ipython
%load_ext autoreload
%autoreload 2

# Globals

In [None]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
USE_GPU_ = False
BERT_USE_GPU_ = True  # BERT embeddings 
DEV_ = True  # Uses a small subset of MIMIC data: https://pyhealth.readthedocs.io/en/latest/api/datasets/pyhealth.datasets.MIMIC3Dataset.html#pyhealth.datasets.MIMIC3Dataset
GPU_STR_ = 'cuda'
# DATA_DIR_ = os.path.join(os.getcwd(), DATA_DIR_)
MIMIC_DATA_DIR_ = '~/sw/physionet.org/files/mimiciii/1.4'
EICU_DATA_DIR_ = '~/sw/eicu-collaborative-research-database-2.0/eicu-collaborative-research-database-2.0'
BATCH_SIZE_ = 32
EMBEDDING_DIM_ = 264  # BERT requires a multiple of 12
SHUFFLE_ = True
SAMPLE_MULTIPLIER_ = 1

# Set seed for reproducibility.
seed = 90210
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

def load_pickle(filename):
    with open(filename, "rb") as f:
        return pickle.load(f)


def save_pickle(data, filename):
    with open(filename, "wb") as f:
        pickle.dump(data, f)
        
# https://stackoverflow.com/questions/50888391/pickle-of-object-with-getattr-method-in-python-returns-typeerror-object-no
class DotArgs(dict):
    """
    Access dictionary attributes via dot notation
    """
    def __getstate__(self):
        return vars(self)

    def __setstate__(self, state):
        vars(self).update(state)
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


# Preprocessing

### Load MIMIC III Data

In [None]:
if False:
    from pyhealth.medcode import InnerMap, ICD9CM

    icd9cm = InnerMap.load("ICD9CM")
    icd9cm.lookup("428.0") # get detailed info
    icd9cm.get_ancestors("428.0") # get parents

    print(icd9cm.lookup("78951")) # get detailed info
    print(f'78951 ancestors {icd9cm.get_ancestors("78951")}') # get parents

    print(icd9cm.lookup("7895")) # get detailed info
    print(f'7895 ancestors {icd9cm.get_ancestors("7895")}') # get parents


    print(icd9cm.lookup("7894")) # get detailed info
    print(f'7894 ancestors {icd9cm.get_ancestors("7894")}') # get parents

    print(icd9cm.lookup("78942")) # get detailed info
    print(f'78941 ancestors {icd9cm.get_ancestors("78941")}') # get parents

    print(ICD9CM.standardize('78951'))
    print(ICD9CM.standardize('7895'))

In [None]:
def _compute_duration_minutes(start_datetime: str, end_datetime: str) -> float:
    '''Return duration in minutes as a float.
    '''
    # MIMIC-III uses the following format: 2146-07-22 00:00:00
    start = datetime.strptime(start_datetime, '%Y-%m-%d %H:%M:%S')
    end = datetime.strptime(end_datetime,   '%Y-%m-%d %H:%M:%S')
    return float((end - start).seconds)

class MIMIC3DatasetWrapper(MIMIC3Dataset):
    ''' Add extra tables to the MIMIC III dataset.
    
      Some of the tables we need like "D_ICD_DIAGNOSES", "D_ITEMS", "D_ICD_PROCEDURES"
      are not supported out of the box. 
      
      This class defines parsing methods to extract text data from these extra tables.
      The text data is generally joined on the PATIENTID, HADMID, ITEMID to match the
      pyHealth Vists class representation.
    '''
   
    # We need to add storage for text-based lookup tables here.
    def __init__(self, *args, **kwargs):
        self._valid_text_tables = ["D_ICD_DIAGNOSES", "D_ITEMS", "D_ICD_PROCEDURES", "D_LABITEMS"]
        self._text_descriptions = {x: {} for x in self._valid_text_tables}
        self._text_luts = {x: {} for x in self._valid_text_tables}
        self.refresh_cache = False
        # The pyHealth dataset cache doesn't know about this class's private members.
        if 'refresh_cache' in kwargs:
            self.refresh_cache = kwargs['refresh_cache']
        super().__init__(*args, **kwargs)
        self._do_cache()
    
    def _do_cache(self):
        '''The pyHealth dataset cache doesn't know about this classes private members.
        
          We need to wrap the caching function so it's aware of the additional luts
          to be saved/restored. The superclass is still responsible for saving/restoring
          `self.patients` from the DB.
        '''
        self.extra_filepath = ''.join([
            os.path.splitext(self.filepath)[0],
            '_dev' if self.dev else '',
            '_extras.pkl',
        ])
        
        # check if cache exists or refresh_cache is True
        if os.path.exists(self.extra_filepath) and (not self.refresh_cache):
            # load from cache
            logger.info(
                f"Loaded {self.dataset_name} base dataset from {self.extra_filepath}"
            )
            from_pickle = load_pickle(self.extra_filepath)
            self._valid_text_tables = from_pickle['_valid_text_tables']
            self._text_descriptions = from_pickle['_text_descriptions']
            self._text_luts = from_pickle['_text_luts']
        else:
            # load from raw data
            logger.info(f"Processing {self.dataset_name} base dataset...")
            to_cache = {
                '_valid_text_tables': self._valid_text_tables,
                '_text_descriptions': self._text_descriptions,
                '_text_luts': self._text_luts,
            }
            logger.info(f"Saved {self.dataset_name} base dataset to {self.extra_filepath}")
            save_pickle(to_cache, self.extra_filepath)
    
    def get_all_tables(self) -> List[str]: 
        return list(self._text_descriptions.keys())
        
    def get_text_dict(self, table_name: str) -> Dict[str, Dict[Any, Any]]:
        return self._text_descriptions.get(table_name)
    
    def set_text_lut(self, table_name: str, lut: Dict[Any, Any]) -> None:
        self._text_luts[table_name] = lut
    
    def get_text_lut(self, table_name: str) -> Dict[Any, Any]:
        return self._text_luts[table_name]
    
    def _add_events_to_patient_dict(
        self,
        patient_dict: Dict[str, Patient],
        group_df: pd.DataFrame,
    ) -> Dict[str, Patient]:
        #TODO(botelho3) Imported from PyHealth Base dataset githubf to
        #support parse_prescription
        """Helper function which adds the events column of a df.groupby object to the patient dict.
        
        Will be called at the end of each `self.parse_[table_name]()` function.
        Args:
            patient_dict: a dict mapping patient_id to `Patient` object.
            group_df: a df.groupby object, having two columns: patient_id and events.
                - the patient_id column is the index of the patient
                - the events column is a list of <Event> objects
        Returns:
            The updated patient dict.
        """
        for _, events in group_df.items():
            for event in events:
                patient_dict = self._add_event_to_patient_dict(patient_dict, event)
        return patient_dict

    
    def parse_prescriptions(self, patients: Dict[str, Patient]) -> Dict[str, Patient]:
        """Helper function which parses PRESCRIPTIONS table.
        
        TODO(botelho3) - we have to override this to include the text fields. The
        prescriptions table does not link to a separate D_ICD_* table in MIMIC-III
        thtat contains text descriptions of the prescription. The text descriptions
        are in the columns of this table. Regular pyHealth ignores these columns. We
        override this method to appent pyHealth Event objects containing the text
        columns to each patient.
        
        Will be called in `self.parse_tables()`
        Docs:
            - PRESCRIPTIONS: https://mimic.mit.edu/docs/iii/tables/prescriptions/
        Args:
            patients: a dict of `Patient` objects indexed by patient_id.
        Returns:
            The updated patients dict.
        """
        table = "PRESCRIPTIONS"
        # read table
        df = pd.read_csv(
            os.path.join(self.root, f"{table}.csv"),
            low_memory=False,
            dtype={"SUBJECT_ID": str, "HADM_ID": str, "NDC": str,
                   "DRUG_TYPE": str, "DRUG": str,
                   "PROD_STRENGTH": str, "ROUTE": str, "ENDDATE": str},
        )
        # drop records of the other patients
        df = df[df["SUBJECT_ID"].isin(patients.keys())]
        df = df.dropna()
        # sort by start date and end date
        df = df.sort_values(
            ["SUBJECT_ID", "HADM_ID", "STARTDATE", "ENDDATE"], ascending=True
        )
        # group by patient and visit
        group_df = df.groupby("SUBJECT_ID")
        
        # parallel unit for prescription (per patient)
        def prescription_unit(p_id, p_info):
            events = []
            for v_id, v_info in p_info.groupby("HADM_ID"):
                zipped = zip(v_info["STARTDATE"], v_info["NDC"], v_info["DRUG_TYPE"],
                             v_info["DRUG"], v_info["PROD_STRENGTH"], v_info["ROUTE"],
                             v_info["ENDDATE"])
                for startdate, code, dtype, dname, dose, route, enddate in zipped:
                    if not type(startdate) == str:
                        startdate = '2142-07-18 00:00:00'
                    if not type(enddate) == str:
                        enddate = '2142-07-18 00:00:00'
                    assert(type(dname) is str)
                    # if not type(enddate) is str:
                    #     print(f'Not matching enddate {enddate} startdate {startdate}')
                    #     print(f'dname {dname}, hadm_id {v_id}, p_id {p_id}')
                    assert(type(startdate) is str)
                    assert(type(enddate) is str)
                    event = Event(
                        code=code,
                        table=table,
                        vocabulary="NDC",
                        visit_id=v_id,
                        patient_id=p_id,
                        timestamp=strptime(startdate),
                        dtype=dtype,
                        dname=dname,
                        dose=dose,
                        route=route,
                        duration=_compute_duration_minutes(startdate, enddate),
                    )
                    events.append(event)
            return events

                # parallel apply
        group_df = group_df.parallel_apply(
            lambda x: prescription_unit(x.SUBJECT_ID.unique()[0], x)
        )

        patients = self._add_events_to_patient_dict(patients, group_df)
        return patients
    
    # Note the name has to match the table name exactly.
    # See https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/datasets/mimic3.py#L71.
    def parse_d_icd_diagnoses(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: 
        """Helper function which parses D_ICD_DIAGNOSIS table.
        Will be called in `self.parse_tables()`
        Docs:
            - D_ICD_DIAGNOSIS: https://mimic.mit.edu/docs/iii/tables/d_icd_diagnoses/
        Args:
            patients: a dict of `Patient` objects indexed by patient_id.
        Returns:
            The unchanged patients dict.
        Note:
            This function doesn't up date the patients dict like other part_*() functions.
            Here we read the D_ICD_DIAGNOSES.csv file containing ICD9_code -> text mappings
            and store them in a dict `self._text_descriptions[table]`.
            
            The dict is used as a ICD9_code -> text diagnosis description lookup
            for DescEmb.
        """
        table = "D_ICD_DIAGNOSES"
        print(f"Parsing {table}")
        assert(table in self._valid_text_tables)
        
        
        # read table
        df = pd.read_csv(
            os.path.join(self.root, f"{table}.csv"),
            usecols=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"],
            dtype={"ICD9_CODE": str, "SHORT_TITLE": str, "LONG_TITLE": str}
        )
        
        # drop rows with missing values
        df = df.dropna(subset=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"])
        # sort by sequence number (i.e., priority)
        df = df.sort_values(["ICD9_CODE"], ascending=True)
       
        # print(df.head())
        self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split')
        
        # We haven't altered the patients array, just return it.
        return patients
    
    def parse_d_labitems(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: 
        """Helper function which parses D_LABITEMS table.
        Will be called in `self.parse_tables()`
        Docs:
            - D_LABITEMS: https://mimic.mit.edu/docs/iii/tables/d_labitems/
        Args:
            patients: a dict of `Patient` objects indexed by patient_id.
        Returns:
            The unchanged patients dict.
        Note:
            This function doesn't up date the patients dict like other part_*() functions.
            Here we read the D_LABITEMS.csv file containing ICD9_code -> text mappings
            and store them in a dict `self._text_descriptions[table]`.
            
            The dict is used as a ICD9_code -> text lab measurement description lookup
            for DescEmb.
        """
        table = "D_LABITEMS"
        print(f"Parsing {table}")
        assert(table in self._valid_text_tables)
        
        # read table
        df = pd.read_csv(
            os.path.join(self.root, f"{table}.csv"),
            usecols=["ITEMID", "LABEL", "CATEGORY", "FLUID"],
            dtype={"ITEMID": str, "LABEL": str, "CATEGORY": str, "FLUID": str}
        )
        
        # drop rows with missing values
        df = df.dropna(subset=["ITEMID", "LABEL", "CATEGORY", "FLUID"])
        # sort by sequence number (i.e., priority)
        df = df.sort_values(["ITEMID"], ascending=True)
       
        self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split')
        
        # We haven't altered the patients array, just return it.
        return patients
    
    
    def parse_d_items(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: 
        # TODO(botelho3) - Note this may not be totally useable because the ITEMID
        # uinqiue key only links to these tables using ITEMID
        #   - INPUTEVENTS_MV 
        #   - OUTPUTEVENTS on ITEMID
        #   - PROCEDUREEVENTS_MV on ITEMID
        # 
        # Not to the tables we want e.g. 
        """Helper function which parses D_ITEMS table.
        Will be called in `self.parse_tables()`
        Docs:
            - D_ITEMS: https://mimic.mit.edu/docs/iii/tables/d_items/
        Args:
            patients: a dict of `Patient` objects indexed by patient_id.
        Returns:
            The unchanged patients dict.
        Note:
            This function doesn't up date the patients dict like other part_*() functions.
            Here we read the D_ITEMS.csv file containing ICD9_code -> text mappings
            and store them in a dict `self._text_descriptions[table]`.
            
            The dict is used as a ICD9_code -> text inputs/output/procedure events lookup
            for DescEmb.
        """
        table = "D_ITEMS"
        print(f"Parsing {table}")
        assert(table in self._valid_text_tables)
        
        # read table
        df = pd.read_csv(
            os.path.join(self.root, f"{table}.csv"),
            usecols=["ITEMID", "LABEL", "CATEGORY"],
            dtype={"ITEMID": str, "LABEL": str, "CATEGORY": str}
        )
        
        # drop rows with missing values
        df = df.dropna(subset=["ITEMID", "LABEL", "CATEGORY"])
        # sort by sequence number (i.e., priority)
        df = df.sort_values(["ITEMID"], ascending=True)
       
        self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split')
        
        # We haven't altered the patients array, just return it.
        return patients
    
    
    def parse_d_icd_procedures(self, patients: Dict[str, Patient]) -> Dict[str, Patient]: 
        """Helper function which parses D_ICD_PROCEDURES table.
        
        Will be called in `self.parse_tables()`
        Docs:
            - D_ICD_PROCEDURES: https://mimic.mit.edu/docs/iii/tables/d_icd_procedures/
        Args:
            patients: a dict of `Patient` objects indexed by patient_id.
        Returns:
            The unchanged patients dict.
        Note:
            This function doesn't up date the patients dict like other part_*() functions.
            Here we read the D_ICD_PROCEDURES.csv file containing ICD9_code -> text mappings
            and store them in a dict `self._text_descriptions[table]`.
            
            The dict is used as a ICD9_code -> text procedure description lookup for DescEmb.
        """
        table = "D_ICD_PROCEDURES"
        print(f"Parsing {table}")
        assert(table in self._valid_text_tables)
        
        # read table
        df = pd.read_csv(
            os.path.join(self.root, f"{table}.csv"),
            usecols=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"],
            dtype={"ICD9_CODE": str, "SHORT_TITLE": str, "LONG_TITLE": str}
        )
        
        # drop rows with missing values
        df = df.dropna(subset=["ICD9_CODE", "SHORT_TITLE", "LONG_TITLE"])
        # sort by sequence number (i.e., priority)
        df = df.sort_values(["ICD9_CODE"], ascending=True)
       
        # print(df.head())
        self._text_descriptions[table] = df.reset_index(drop=True).to_dict(orient='split')
        
        # We haven't altered the patients array, just return it.
        return patients
    

In [None]:
if False:
    print(f'Reading data from: `{MIMIC_DATA_DIR_}`')
    if 'mimic3base' in globals():
        del mimic3base 
    gc.collect()

    mimic3base = MIMIC3DatasetWrapper(
        # root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
        root=MIMIC_DATA_DIR_,
        dataset_name='mimic_3_dataset',
        tables=["D_ICD_DIAGNOSES", "D_ICD_PROCEDURES", "D_ITEMS", "D_LABITEMS",
                "DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], # "LABEVENTS"],
        # map all NDC codes to ATC 3-rd level codes in these tables
        # See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System.
        code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
        # Reads a subset of the data. Disable for full training run.
        dev = DEV_,
        # True = Slow, rebuilds the dataset instead of caching.
        refresh_cache=False,
    )

    mimic3base.stat()
    mimic3base.info()
    # table_names = mimic3base.get_all_tables()
    # print(table_names)

    # print('\033[92m' '====Tables====\n' '\033[0m')
    # # print(colored('====Tables====\n', 'green'))
    # # print(colored.fg('green') + '====Tables====\n')
    # for t in table_names:
    #     d = mimic3base.get_text_dict(t)
    #     print(f"Table: {t}")
    #     print(d['data'][:5])
    #     print('\n\n')

    # # Take the cached tables from the parse_tables function and build the {ICD9 -> (short_name, long_name)}
    # # lookup tables.
    # for t in table_names:
    #     d = mimic3base.get_text_dict(t)
    #     d = d['data']
    #     lut = {record[0]: record[1:] for record in d}
    #     mimic3base.set_text_lut(t,  lut)

    # print('\033[92m' '====Luts====\n' '\033[0m')
    # # print(f'{colored.fg("green")} ====Luts====\n')
    # for t in table_names:
    #     d = mimic3base.get_text_lut(t)
    #     print(f"Lut {t}:\n{dict(itertools.islice(d.items(), 2))}")


### Load eICU Data

In [None]:
if True:
    print(f'Reading data from: `{EICU_DATA_DIR_}`')
    if 'eicubase' in globals():
        del eicubase 
    gc.collect()

    eicubase = eICUDataset(
        # root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
        root=EICU_DATA_DIR_,
        dataset_name='eicu_dataset',
        tables=["diagnosis", "treatment", "medication"],
        # map all NDC codes to ATC 3-rd level codes in these tables
        # See https://en.wikipedia.org/wiki/Anatomical_Therapeutic_Chemical_Classification_System.
        code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
        # Reads a subset of the data. Disable for full training run.
        dev = False,
        # True = Slow, rebuilds the dataset instead of caching.
        refresh_cache = False,
    )

    eicubase.stat()
    eicubase.info()

### Plot Utils

In [None]:
embed_model_types = ['code_emb', 'desc_emb', 'desc_emb_ft']
predict_model_types = ['code_emb', 'desc_emb', 'desc_emb_ft']
task_types = ['mort', 'readm']

def PlotAucRecallResults(val_results, test_results, emb_type: str, task_type: str):
    ''' Plot AUC-ROC curve and P-R curve.
    
        val_results: validation set
        test_results: test set
    '''
    if val_results:
        p, r, f, roc_auc, rcurve, precision_curve, recall_curve, acc = val_results
        plt.plot(rcurve[0], rcurve[1])
        plt.title(' '.join(['ROC Curve / AUC Val Set', emb_type, 'for task:', task_type]))
        plt.xlabel('FPR')
        plt.ylabel('TPR')
        plt.show()

        plt.plot(recall_curve, precision_curve)
        plt.title('PR Curve')
        plt.title(' '.join(['PR Curve', emb_type, task_type]))
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.show() 
        print('Val AUPRC:  {:.2}'.format(auc(recall_curve, precision_curve)))
   
    if test_results:
        p, r, f, roc_auc, rcurve, precision_curve, recall_curve, acc = test_results
        plt.plot(rcurve[0], rcurve[1])
        plt.title(' '.join(['ROC Curve / AUC Test Set', emb_type, 'for task:', task_type]))
        plt.xlabel('FPR')
        plt.ylabel('TPR')
        plt.show()

        plt.plot(recall_curve, precision_curve)
        plt.title(' '.join(['PR Curve', emb_type, task_type]))
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.show() 
        print('Test AUPRC:  {:.2}'.format(auc(recall_curve, precision_curve)))


    
def PlotDiffResults(results_a, results_b, task_type: str, labels: List[str]):
    results_set_labels = ['validation', 'test']
    for i in range(len(results_set_labels)):
        # ROC Curves for val, test.
        p, r, f, roc_auc, rcurve, precision_curve, recall_curve, acc = results_a[i][-1]
        plt.plot(rcurve[0], rcurve[1], label=labels[0] + '_' + results_set_labels[i])
        p, r, f, roc_auc, rcurve, precision_curve, recall_curve, acc = results_b[i][-1]
        plt.plot(rcurve[0], rcurve[1], label=labels[1] + '_' + results_set_labels[i])
        plt.title(' '.join(['ROC Curve / AUC', 'for task:', task_type]))
        plt.xlabel('FPR')
        plt.ylabel('TPR')
        plt.legend()
        plt.show()

        # PR Curves for val, test.
        p, r, f, roc_auc, rcurve, precision_curve, recall_curve, acc = results_a[i][-1]
        plt.plot(recall_curve, precision_curve, label=labels[0] + '_' + results_set_labels[i])
        p, r, f, roc_auc, rcurve, precision_curve, recall_curve, acc = results_b[i][-1]
        plt.plot(recall_curve, precision_curve, label=labels[1] + '_' + results_set_labels[i])
        plt.title(' '.join(['PR Curve', 'for task:', task_type]))
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.legend()
        plt.show() 

        
def PlotPrecisionRecallAcrossEpoch(val_results, test_results):
    num_epochs = len(val_results)
    epochs = range(num_epochs)
    results_set_labels = ['validation', 'test']
    
    recall_val = []
    precision_val = []
    recall_test = []
    precision_test = []
    for i in epochs:
        if val_results:
            recall_val.append(val_results[i][1])
            precision_val.append(val_results[i][0])
        if test_results:
            recall_test.append(test_results[i][1])
            precision_test.append(test_results[i][0])
  
    if val_results:
        plt.plot(epochs, recall_val, label='_'.join(['recall', results_set_labels[0]]))
        plt.plot(epochs, precision_val, label='_'.join(['precision', results_set_labels[0]]))
    if test_results:
        plt.plot(epochs, recall_test, label='_'.join(['recall', results_set_labels[1]]))
        plt.plot(epochs, precision_test, label='_'.join(['precision', results_set_labels[1]]))
    plt.title(' '.join(['Recall/Precision vs epoch']))
    plt.xlabel('Epoch')
    plt.ylabel('Recall/Precision')
    plt.legend()
    plt.show()
    

def PlotPrecisionRecallDiffAcrossEpoch(results_a, results_b):
    assert(len(results_a) == len(results_b))
    num_epochs = len(results_a)
    epochs = range(num_epochs)
    labels = ['a', 'b']
    results_set_labels = ['validation', 'test']
    
    recall_val_a = []
    precision_val_a = []
    recall_test_a = []
    precision_test_a = []
    recall_val_b = []
    precision_val_b = []
    recall_test_b = []
    precision_test_b = []
    for i in epochs:
        recall_val_a.append(results_a[i][1])
        precision_val_a.append(results_a[i][0])
        recall_test_a.append(results_a[i][1])
        precision_test_a.append(results_a[i][0])
        recall_val_b.append(results_b[i][1])
        precision_val_b.append(results_b[i][0])
        recall_test_b.append(results_b[i][1])
        precision_test_b.append(results_b[i][0])
   
    plt.plot(epochs, recall_val_a, label='_'.join(['recall', labels[0], results_set_labels[0]]))
    plt.plot(epochs, recall_test_a, label='_'.join(['recall', labels[0], results_set_labels[1]]))
    plt.plot(epochs, recall_val_b, label='_'.join(['recall', labels[1], results_set_labels[0]]))
    plt.plot(epochs, recall_test_b, label='_'.join(['recall', labels[1], results_set_labels[1]]))
    plt.title(' '.join(['Recall vs Epoch']))
    plt.xlabel('Epoch')
    plt.ylabel('Recall')
    plt.legend()
    plt.show()
    
    
    plt.plot(epochs, precision_val_a, label='_'.join(['precision', labels[0], results_set_labels[0]]))
    plt.plot(epochs, precision_test_a, label='_'.join(['precision', labels[0], results_set_labels[1]]))
    plt.plot(epochs, precision_val_b, label='_'.join(['precision', labels[1], results_set_labels[0]]))
    plt.plot(epochs, precision_test_b, label='_'.join(['precision', labels[1], results_set_labels[1]]))
    plt.title(' '.join(['Precision vs Epoch']))
    plt.xlabel('Epoch')
    plt.ylabel('Recall')
    plt.legend()
    plt.show()


def PlotAccuracyAcrossEpoch(results_a, results_b=None):
    if results_b:
        assert(len(results_a[0]) == len(results_b[0]))
    num_epochs = len(results_a[0])
    epochs = list(range(num_epochs))
    task_labels = ['a', 'b']
    results_set_labels = ['validation', 'test']
    
    accuracy_a_val = []
    accuracy_a_test = []
    accuracy_b_val = []
    accuracy_b_test = []
    for i in epochs:
        if results_a is not None:
            accuracy_a_val.append(results_a[0][i][7])
            accuracy_a_test.append(results_a[1][i][7])
        if results_b is not None:
            accuracy_b_val.append(results_b[0][i][7])
            accuracy_b_test.append(results_b[1][i][7])
            
    plt.plot(epochs, accuracy_a_val, label='_'.join(['recall', task_labels[0], results_set_labels[0]]))
    plt.plot(epochs, accuracy_a_test, label='_'.join(['recall', task_labels[0], results_set_labels[1]]))
    if results_b is not None:
        plt.plot(epochs, accuracy_b_val, label='_'.join(['accuracy', task_labels[1], results_set_labels[0]]))
        plt.plot(epochs, accuracy_b_test, label='_'.join(['accuracy', task_labels[1], results_set_labels[1]]))
    plt.title(' '.join(['Accuracy vs Epoch']))
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

    
def PlotAccuracyAcrossEpoch(results_a, results_b=None):
    if results_b:
        assert(len(results_a[0]) == len(results_b[0]))
    num_epochs = len(results_a[0])
    epochs = list(range(num_epochs))
    task_labels = ['a', 'b']
    results_set_labels = ['validation', 'test']
    
    accuracy_a_val = []
    accuracy_a_test = []
    accuracy_b_val = []
    accuracy_b_test = []
    for i in epochs:
        if results_a is not None:
            accuracy_a_val.append(results_a[0][i][7])
            accuracy_a_test.append(results_a[1][i][7])
        if results_b is not None:
            accuracy_b_val.append(results_b[0][i][7])
            accuracy_b_test.append(results_b[1][i][7])
            
    plt.plot(epochs, accuracy_a_val, label='_'.join(['recall', task_labels[0], results_set_labels[0]]))
    plt.plot(epochs, accuracy_a_test, label='_'.join(['recall', task_labels[0], results_set_labels[1]]))
    if results_b is not None:
        plt.plot(epochs, accuracy_b_val, label='_'.join(['accuracy', task_labels[1], results_set_labels[0]]))
        plt.plot(epochs, accuracy_b_test, label='_'.join(['accuracy', task_labels[1], results_set_labels[1]]))
    plt.title(' '.join(['Accuracy vs Epoch']))
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

    
def PrintFinalAccuracy(results_a, results_b=None):
    if results_b:
        assert(len(results_a[0]) == len(results_b[0]))
    num_epochs = len(results_a[0])
    
    if results_a:
        if len(results_a) > 1:
            accuracy_a_val = results_a[0][-1][7]
            accuracy_a_test = results_a[1][-1][7]
            print('val accuracy:  {:.2}'.format(accuracy_a_val))
            print('test accuracy: {:.2}'.format(accuracy_a_test))
        else:
            accuracy_a_test = results_a[-1][7]
            print('test accuracy: {:.2}'.format(accuracy_a_test))
    if results_b:
        if len(results_b) > 1:
            accuracy_b_val = results_b[0][-1][7]
            accuracy_b_test = results_b[1][-1][7]
            print('val accuracy:  {:.2}'.format(accuracy_b_val))
            print('test accuracy: {:.2}'.format(accuracy_b_test))
        else:
            accuracy_b_test = results_b[-1][7]
            print('test accuracy: {:.2}'.format(accuracy_b_test))
    
    
def PrintTrainTime(times):
    print('Total train time: {:.2} s'.format(sum(times)))
    print('Per epoch times: ' + ', '.join(['{:.2}'.format(t) for t in times]))


### Tasks

Declare tasks for 2 of the 5 prediction tasks specified in the paper. We will create dataloaders for each task that contain the ICD codes and the raw text for each (patient, visit).

#### CodeEMB Pred tasks

#### DescEmb Pred Tasks

#### Test Load Readmission Dataset

In [None]:
if False:
    # set_task() returns a SampleEHRDataset object
    READMISSION_PER_PATIENT_ICD_9_CODE_COUNT_ = {}
    READMISSION_PER_PATIENT_ICD_9_CODE2IDX_ = {}
    task_fn = functools.partial(readmission_pred_task_demb, READMISSION_PER_PATIENT_ICD_9_CODE_COUNT_)
    readm_dataset = mimic3base.set_task(task_fn, task_name=readmission_pred_task_demb.__name__)
    READMISSION_PER_PATIENT_ICD_9_CODE2IDX_ = {
        code: idx for idx, code in enumerate(sorted(READMISSION_PER_PATIENT_ICD_9_CODE_COUNT_.keys()))
    }
    readm_dataset.stat()
    readm_dataset.samples[1]
    # TODO(botelho3) could try a freq codes limit on this.
    print(f"READMISSION_PER_PATIENT_ICD_9_CODE2IDX_ len: {len(READMISSION_PER_PATIENT_ICD_9_CODE2IDX_)}\n"
          f"{READMISSION_PER_PATIENT_ICD_9_CODE2IDX_}")
    del readm_dataset

#### Test Load Mortality Dataset

In [None]:
if False:
    MORTALITY_PER_PATIENT_ICD_9_CODE_COUNT_ = {}
    MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ = {}
    task_fn = functools.partial(mortality_pred_task_demb, MORTALITY_PER_PATIENT_ICD_9_CODE_COUNT_)
    mor_dataset = mimic3base.set_task(task_fn, task_name=mortality_pred_task_demb.__name__)
    MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ = {
        code: idx for idx, code in enumerate(sorted(MORTALITY_PER_PATIENT_ICD_9_CODE_COUNT_.keys()))
    }
    mor_dataset.stat()
    mor_dataset.samples[1]
    # TODO(botelho3) could try a freq codes limit on this.
    print(f"MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ len: {len(MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_)}\n"
          f"{MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_}")
    del mor_dataset

### DataLoaders and Collate

#### Bert Collate

#### Bert Load

In [None]:
gc.collect()

# Create the transform that will take each sample (visit) in the dataset
# and convert the text description of the visit into a single embedding.
bert_xform = BertTextEmbedTransform(None, EMBEDDING_DIM_, use_tokenizer_fast=True, use_gpu=BERT_USE_GPU_)
BERT_EMBEDDING_SIZE = bert_xform.bert_config.hidden_size

# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)
# Add a transform to convert the visit into an embedding.
MORT_DEMB_CODE2IDX = {}
MORT_DEMB_CODE_COUNT = {}
task_fn = functools.partial(mortality_pred_task_demb, MORT_DEMB_CODE_COUNT)
mortality_demb_dataset = mimic3base.set_task(task_fn, task_name=mortality_pred_task_demb.__name__)
MORT_DEMB_CODE2IDX = {
    code: idx for idx, code in enumerate(sorted(MORT_DEMB_CODE_COUNT.keys()))
}
print(f"MORT_DEMB_CODE2IDX {len(MORT_DEMB_CODE2IDX)}")
# Wrap the default pyHealth dataset class in our own wrapper. The wrapper takes each
# sample and applies BERT to xform text->pytorch.tensor.
mortality_demb_dataset = TextEmbedDataset(mortality_demb_dataset, transform=bert_xform)


# mort_demb_train_ds, mort_demb_val_ds, mort_demb_test_ds = split_by_patient(mortality_demb_dataset, [0.8, 0.1, 0.1])

# # create dataloaders (torch.data.DataLoader)
# # mort_train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True, collate_fn)
# # mort_val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)
# # mort_test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)

# mort_demb_train_loader = DataLoader(
#     mort_demb_train_ds,
#     batch_size=BATCH_SIZE_,
#     shuffle=SHUFFLE_,
#     collate_fn=bert_per_patient_collate_function,
# )
# mort_demb_val_loader = DataLoader(
#     mort_demb_val_ds,
#     batch_size=BATCH_SIZE_,
#     shuffle=SHUFFLE_,
#     collate_fn=bert_per_patient_collate_function,
# )
# mort_demb_test_loader = DataLoader(
#     mort_demb_test_ds,
#     batch_size=BATCH_SIZE_,
#     shuffle=SHUFFLE_,
#     collate_fn=bert_per_patient_collate_function,
# )


In [None]:
gc.collect()

# Create the transform that will take each sample (visit) in the dataset
# and convert the text description of the visit into a single embedding.
bert_xform = BertTextEmbedTransform(None, EMBEDDING_DIM_, use_tokenizer_fast=True, use_gpu=BERT_USE_GPU_)
BERT_EMBEDDING_SIZE = bert_xform.bert_config.hidden_size

# Readmission Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)
# Add a transform to convert the visit into an embedding.
READM_DEMB_CODE2IDX = {}
READM_DEMB_CODE_COUNT = {}
task_fn = functools.partial(readmission_pred_task_demb, READM_DEMB_CODE_COUNT)
readmission_demb_dataset = mimic3base.set_task(task_fn, task_name=readmission_pred_task_demb.__name__)
READM_DEMB_CODE2IDX = {
    code: idx for idx, code in enumerate(sorted(READM_DEMB_CODE_COUNT.keys()))
}
print(f"READM_DEMB_CODE2IDX {len(READM_DEMB_CODE2IDX)}")
# Wrap the default pyHealth dataset class in our own wrapper. The wrapper takes each
# sample and applies BERT to xform text->pytorch.tensor.
readmission_demb_dataset = TextEmbedDataset(readmission_demb_dataset, transform=bert_xform)

In [None]:
if False:
    # Verify DataLoader properties.
    # Quick test without running the whole RNN training process.

    # from torch.utils.data import DataLoader

    # loader = DataLoader(mort_, batch_size=10, collate_fn=collate_fn)
    loader_iter = iter(mort_demb_train_loader)
    # for _ in loader_iter:
    #     pass
    try:
        x, masks, rev_x, rev_masks, y = next(loader_iter)
    except StopIteration as e:
        print(e)

    assert x.dtype == torch.float
    assert rev_x.dtype == torch.float
    assert y.dtype == torch.float
    assert masks.dtype == torch.bool
    assert rev_masks.dtype == torch.bool

    assert x.shape == (BATCH_SIZE_, 3, 105)
    assert y.shape == (BATCH_SIZE_, 1)
    assert masks.shape == (BATCH_SIZE_, 10, 3)

    # assert x[0][0].sum() == 9
    # assert masks[0].sum() == 2

#### CodeEmb Collate 

#### CodeEmb Load

In [None]:
# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)
MORT_CEMB_CODE2IDX = {}
MORT_CEMB_CODE_COUNT = {}
mortality_cemb_ds = mimic3base.set_task(
    task_fn=functools.partial(mortality_pred_task_cemb, MORT_CEMB_CODE_COUNT),
    task_name=mortality_pred_task_cemb.__name__)
# The set_task(...) function iterates over all samples.
# Applying the task to each before returning a new dataset.
# Since all samples have been processed we have observed all codes and can
# build the code->index LUT.
MORT_CEMB_CODE2IDX = {
    code: idx for idx, code in enumerate(sorted(set(MORT_CEMB_CODE_COUNT.keys())))
}
print(f'MORT_CEMB_CODE2IDX {len(MORT_CEMB_CODE2IDX)}')

# # We need to provide the code->index LUT to the collate function.
wrap_code_emb_per_visit_collate_function = functools.partial(
    code_emb_per_visit_collate_function,
    MORT_CEMB_CODE2IDX)

# # Split the dataset into train, val, and test.
# mort_cemb_train_ds, mort_cemb_val_ds, mort_cemb_test_ds = split_by_patient(mortality_cemb_ds, [0.8, 0.1, 0.1])
# mort_cemb_train_loader = DataLoader(
#     mort_cemb_train_ds,
#     batch_size=BATCH_SIZE_,
#     shuffle=SHUFFLE_,
#     collate_fn=wrap_code_emb_per_visit_collate_function,
# )
# mort_cemb_val_loader = DataLoader(
#     mort_cemb_val_ds,
#     batch_size=BATCH_SIZE_,
#     shuffle=SHUFFLE_,
#     collate_fn=wrap_code_emb_per_visit_collate_function,
# )
# mort_cemb_test_loader = DataLoader(
#     mort_cemb_test_ds,
#     batch_size=BATCH_SIZE_,
#     shuffle=SHUFFLE_,
#     collate_fn=wrap_code_emb_per_visit_collate_function,
# )

In [None]:
# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)
READM_CEMB_CODE2IDX = {}
READM_CEMB_CODE_COUNT = {}
readm_cemb_ds = mimic3base.set_task(
    task_fn=functools.partial(readmission_pred_task_cemb, READM_CEMB_CODE_COUNT),
    task_name=readmission_pred_task_cemb.__name__)
# The set_task(...) function iterates over all samples.
# Applying the task to each before returning a new dataset.
# Since all samples have been processed we have observed all codes and can
# build the code->index LUT.
READM_CEMB_CODE2IDX = {
    code: idx for idx, code in enumerate(sorted(set(READM_CEMB_CODE_COUNT.keys())))
}
print(f'READM_CEMB_CODE2IDX {len(READM_CEMB_CODE2IDX)}')

In [None]:
if False:
    # Verify DataLoader properties.
    # Quick test without running the whole RNN training process.

    # from torch.utils.data import DataLoader

    # loader = DataLoader(mort_, batch_size=10, collate_fn=collate_fn)
    loader_iter = iter(mort_cemb_train_loader)
    for _ in loader_iter:
        pass

    try:
        x, masks, rev_x, rev_masks, y = next(loader_iter)
    except StopIteration as e:
        print(e)

    # assert x.dtype == torch.float
    # assert rev_x.dtype == torch.float
    # assert y.dtype == torch.float
    # assert masks.dtype == torch.bool
    # assert rev_masks.dtype == torch.bool

    # assert x.shape == (BATCH_SIZE_, 3, 105)
    # assert y.shape == (BATCH_SIZE_, 1)

#### Bert ICU Load

In [None]:
# Create the transform that will take each sample (visit) in the dataset
# and convert the text description of the visit into a single embedding.
bert_xform = BertTextEmbedTransform(None, EMBEDDING_DIM_,
                                    use_tokenizer_fast=True, use_gpu=BERT_USE_GPU_)
BERT_EMBEDDING_SIZE = bert_xform.bert_config.hidden_size

# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)
# Add a transform to convert the visit into an embedding.
CODE2IDX = {}
CODE_COUNT = {}
task_fn = functools.partial(eicu_mortality_pred_task_demb, CODE_COUNT)
eicu_mortality_demb_dataset = eicubase.set_task(task_fn, task_name=eicu_mortality_pred_task_demb.__name__)
CODE2IDX = {
    code: idx for idx, code in enumerate(sorted(CODE_COUNT.keys()))
}
# Wrap the default pyHealth dataset class in our own wrapper. The wrapper takes each
# sample and applies BERT to xform text->pytorch.tensor.
eicu_mortality_demb_dataset = TextEmbedDataset(eicu_mortality_demb_dataset,
                                               transform=bert_xform,
                                               should_cache=False)

In [None]:
# indices = range(0, min(50000, len(eicu_mortality_demb_dataset)) )
# eicu_mort_demb_loader = DataLoader(
#     torch.utils.data.Subset(eicu_mortality_demb_dataset, indices),
#     batch_size=BATCH_SIZE_,
#     shuffle=SHUFFLE_,
#     collate_fn=bert_per_patient_collate_function_new_trainer,
# )
# loader_iter = iter(eicu_mort_demb_loader)
# # for _ in loader_iter:
# #     pass
# try:
#     x, masks, rev_x, rev_masks, y = next(loader_iter)
# except StopIteration as e:
#     print(e)

#### Dataset Caching

Caches a dataset to disk via pickle after preprocessing/collate functions have been applied. Can be used to transform samples into more compact representation or perform preprocessing once instead of during each epoch.

##### Cemb

In [None]:
# In order to cut down the data size the collate fn is going to have to happen later during batch run.
cacher = DatasetCacher()

wrap_code_emb_per_visit_collate_function = functools.partial(
    code_emb_per_visit_collate_function,
    MORT_CEMB_CODE2IDX)
mort_cemb_loader = DataLoader(
    mortality_cemb_ds,
    batch_size=BATCH_SIZE_,
    shuffle=SHUFFLE_,
    collate_fn=wrap_code_emb_per_visit_collate_function,
)
extra_data = {
    'code2idx': MORT_CEMB_CODE2IDX,
    'embed_index_size': len(MORT_CEMB_CODE2IDX),
    'keywords': ['x', 'masks', 'rev_x', 'rev_masks', 'y'],
}
mort_cemb_cacher_metadata = cacher.DatasetToCacheFromLoader(mort_cemb_loader,
                      mortality_pred_task_cemb,
                      batch_size=0,  # we already batched using loader+collate.
                      overwrite=False,
                      extra_data=extra_data)
# i = iter(mort_cemb_loader)
# print(next(i))

In [None]:
print(f"len dataset {len(mort_cemb_loader.dataset)}")
cacher = DatasetCacher()
metadata_in = {'batch_size': 0, 'length': len(mort_cemb_loader.dataset)}
mort_cemb_loader, mort_cemb_metadata = (
    cacher.DataloaderFromCache(mortality_pred_task_cemb.__name__,
                               metadata_in['batch_size'],
                               metadata_in['length'])
)
i = iter(mort_cemb_loader)
print(next(i))

In [None]:
# In order to cut down the data size the collate fn is going to have to happen later during batch run.
cacher = DatasetCacher()

wrap_code_emb_per_visit_collate_function = functools.partial(
    code_emb_per_visit_collate_function,
    READM_CEMB_CODE2IDX)
readm_cemb_loader = DataLoader(
    readm_cemb_ds,
    batch_size=BATCH_SIZE_,
    shuffle=SHUFFLE_,
    collate_fn=wrap_code_emb_per_visit_collate_function,
)
extra_data = {
    'code2idx': READM_CEMB_CODE2IDX,
    'embed_index_size': len(READM_CEMB_CODE2IDX),
    'keywords': ['x', 'masks', 'rev_x', 'rev_masks', 'y'],
}
readm_cemb_cacher_metadata = cacher.DatasetToCacheFromLoader(readm_cemb_loader,
                      readmission_pred_task_cemb,
                      batch_size=0,  # we already batched using loader+collate.
                      overwrite=False,
                      extra_data=extra_data)
# i = iter(mort_cemb_loader)
# print(next(i))

In [None]:
print(f"len dataset {len(readm_cemb_loader.dataset)}")
cacher = DatasetCacher()
metadata_in = {'batch_size': 0, 'length': len(readm_cemb_loader.dataset)}
readm_cemb_loader, readm_cemb_metadata = (
    cacher.DataloaderFromCache(readmission_pred_task_cemb.__name__,
                               metadata_in['batch_size'],
                               metadata_in['length'])
)
i = iter(readm_cemb_loader)
print(next(i))

In [None]:
if False:
    del readm_cemb_ds
    del readm_cemb_loader
    del mortality_cemb_ds
    del mortality_cemb_loader

##### Demb 

In [None]:
# Might need packed_sequence for this.
cacher = DatasetCacher()

indices = list(range(0, min(10000, len(mortality_demb_dataset)) ))
print(len(mortality_demb_dataset))
print(len(torch.utils.data.Subset(mortality_demb_dataset, indices)))
extra_data = {
    'embed_index_size': 768,
    'keywords': ['x', 'masks', 'rev_x', 'rev_masks', 'y'],
}

def numpy_conversion_function(sample):
    # This is faster than pickling tensors
    # https://github.com/pytorch/pytorch/issues/9168
    tensor, label = sample[0]
    return (tensor.numpy(), label)

def bytes_conversion_function(sample):
    # This is faster than pickling tensors
    # https://github.com/pytorch/pytorch/issues/9168
    tensor, label = sample[0]
    bio = io.BytesIO()
    torch.save(tensor, bio)
    return (bio.getvalue(), label)
    
mort_demb_loader = DataLoader(
    torch.utils.data.Subset(mortality_demb_dataset, indices),
    batch_size=1,
    shuffle=SHUFFLE_,
    collate_fn=numpy_conversion_function,
    # collate_fn=bytes_conversion_function,
)
mort_demb_cacher_metadata = cacher.DatasetToCacheFromLoader(mort_demb_loader,
                      mortality_pred_task_demb,
                      batch_size=0,  # we already batched using loader+collate.
                      overwrite=False,
                      extra_data=extra_data)
# i = iter(mort_demb_loader)
# print(next(i))

In [None]:
print(f"len dataset {len(mort_demb_loader.dataset)}")
cacher = DatasetCacher()
metadata_in = {'batch_size': 0, 'length': len(mort_demb_loader.dataset)}
mort_demb_loader, mort_demb_metadata = (
    cacher.DataloaderFromCache(mortality_pred_task_demb.__name__,
                               metadata_in['batch_size'],
                               metadata_in['length'])
)
i = iter(mort_demb_loader)
print(next(i))

In [None]:
# Might need packed_sequence for this.
cacher = DatasetCacher()

indices = list(range(0, min(10000, len(readmission_demb_dataset)) ))
print(len(readmission_demb_dataset))
print(len(torch.utils.data.Subset(readmission_demb_dataset, indices)))
extra_data = {
    'embed_index_size': 768,
    'keywords': ['x', 'masks', 'rev_x', 'rev_masks', 'y'],
}

def numpy_conversion_function(sample):
    # This is faster than pickling tensors
    # https://github.com/pytorch/pytorch/issues/9168
    tensor, label = sample[0]
    return (tensor.numpy(), label)

def bytes_conversion_function(sample):
    # This is faster than pickling tensors
    # https://github.com/pytorch/pytorch/issues/9168
    tensor, label = sample[0]
    bio = io.BytesIO()
    torch.save(tensor, bio)
    return (bio.getvalue(), label)
    
readm_demb_loader = DataLoader(
    torch.utils.data.Subset(readmission_demb_dataset, indices),
    batch_size=1,
    shuffle=SHUFFLE_,
    collate_fn=numpy_conversion_function,
    # collate_fn=bytes_conversion_function,
)
readm_demb_cacher_metadata = cacher.DatasetToCacheFromLoader(readm_demb_loader,
                      readmission_pred_task_demb,
                      batch_size=0,  # we already batched using loader+collate.
                      overwrite=False,
                      extra_data=extra_data)
# i = iter(mort_demb_loader)
# print(next(i))

In [None]:
print(f"len dataset {len(readm_demb_loader.dataset)}")
cacher = DatasetCacher()
metadata_in = {'batch_size': 0, 'length': len(readm_demb_loader.dataset)}
readm_demb_loader, readm_demb_metadata = (
    cacher.DataloaderFromCache(readmission_pred_task_demb.__name__,
                               metadata_in['batch_size'],
                               metadata_in['length'])
)
i = iter(readm_demb_loader)
print(next(i))

In [None]:
if False:
    del bert_xform
    del mort_demb_loader
    del mortality_demb_dataset
    del readm_demb_loader
    del readmission_demb_dataset

##### Demb-eICU

In [None]:
# Might need packed_sequence for this.
cacher = DatasetCacher()

# indices = list(range(0, min(10000, len(mortality_demb_dataset)) ))
# print(len(mortality_demb_dataset))
# print(len(torch.utils.data.Subset(mortality_demb_dataset, indices)))
extra_data = {
    'embed_index_size': 768,
    'keywords': ['x', 'masks', 'rev_x', 'rev_masks', 'y'],
}


def numpy_conversion_function(sample):
    # This is faster than pickling tensors
    # https://github.com/pytorch/pytorch/issues/9168
    tensor, label = sample[0]
    return (tensor.numpy(), label)

def bytes_conversion_function(sample):
    # This is faster than pickling tensors
    # https://github.com/pytorch/pytorch/issues/9168
    tensor, label = sample[0]
    bio = io.BytesIO()
    torch.save(tensor, bio)
    return (bio.getvalue(), label)

indices = range(0, min(50000, len(eicu_mortality_demb_dataset)) )
eicu_mort_demb_loader = DataLoader(
    torch.utils.data.Subset(eicu_mortality_demb_dataset, indices),
    batch_size=1,
    shuffle=SHUFFLE_,
    collate_fn=numpy_conversion_function,
)

mort_demb_cacher_metadata = cacher.DatasetToCacheFromLoader(eicu_mort_demb_loader,
                      eicu_mortality_pred_task_demb,
                      batch_size=0,  # we already batched using loader+collate.
                      overwrite=False,
                      extra_data=extra_data)

In [None]:
print(f"len dataset {len(eicu_mort_demb_loader.dataset)}")
cacher = DatasetCacher()
metadata_in = {'batch_size': 0, 'length': len(eicu_mort_demb_loader.dataset)}
eicu_mort_demb_loader, eicu_mort_demb_metadata = (
    cacher.DataloaderFromCache(eicu_mortality_pred_task_demb.__name__,
                               metadata_in['batch_size'],
                               metadata_in['length'])
)
i = iter(eicu_mort_demb_loader)
print(next(i))

In [None]:
if False:
    del bert_xform
    del eicu_mort_demb_loader
    del eicu_mortality_demb_dataset

### Bert Fine Tune Collate

In [None]:
# Create the transform that will take each sample (visit) in the dataset
# and convert the text description of the visit into a single embedding.
bert_ft_xform = BertFineTuneTransform()
BERT_FT_EMBEDDING_SIZE = bert_ft_xform.emb_size

# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)
# Add a transform to convert the visit into an embedding.
MORT_DEMBFT_CODE2IDX = {}
MORT_DEMBFT_CODE_COUNT = {}
task_fn = functools.partial(mortality_pred_task_demb, MORT_DEMBFT_CODE_COUNT)
mortality_dembft_ds = mimic3base.set_task(task_fn, task_name=mortality_pred_task_demb.__name__)
MORT_DEMBFT_CODE2IDX = {
    code: idx for idx, code in enumerate(sorted(MORT_DEMBFT_CODE_COUNT.keys()))
}


indices = list(range(0, min(10000, len(mortality_dembft_ds)) ))
print(len(mortality_dembft_ds))
print(len(torch.utils.data.Subset(mortality_dembft_ds, indices)))

print(f"MORT_DEMB_CODE2IDX {len(MORT_DEMBFT_CODE2IDX)}")
# Wrap the default pyHealth dataset class in our own wrapper. The wrapper takes each
# sample and applies BERT to xform text->pytorch.tensor.
tmp = TextEmbedDataset(mortality_dembft_ds, transform=bert_ft_xform, should_cache=False)
mortality_dembft_dataset = torch.utils.data.Subset(tmp, indices)
print(len(mortality_dembft_dataset))

In [None]:
if True:
    # Verify DataLoader properties.
    # Quick test without running the whole RNN training process.
    mort_dembft_loader = DataLoader(
        mortality_dembft_dataset,
        batch_size=BATCH_SIZE_,
        shuffle=SHUFFLE_,
        collate_fn=bert_fine_tune_collate,
    )
    # loader = DataLoader(mort_, batch_size=10, collate_fn=collate_fn)
    loader_iter = iter(mort_dembft_loader)
    # for _ in loader_iter:
    #     pass
    try:
        x, masks, rev_x, rev_masks, y = next(loader_iter)
    except StopIteration as e:
        print(e)

In [None]:
# Free some memory
del bert_ft_xform
del mortality_dembft_ds
del mortality_dembft_dataset
gc.collect()
del bert_xform
del mortality_demb_dataset

# Condensed Training using Trainer

### CodeEmb - Mort - Dev

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher


embed_model_types = ['code_emb', 'desc_emb', 'desc_emb_ft']
predict_model_types = ['code_emb', 'desc_emb', 'desc_emb_ft']
task_types = ['mort', 'readm']
# Dev set.
# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456
# READM_DEMB_CODE2IDX 1523
# READM_CEMB_CODE2IDX 1456
# Full set.
# READMISSION_PER_PATIENT_ICD_9_CODE2IDX_ len: 6555
# MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ len: 6732
# Full set.
# MORT_DEMB_CODE2IDX 6732
# MORT_CEMB_CODE2IDX 6546
# READM_CEMB_CODE2IDX 6546
# READM_DEMB_CODE2IDX 6555

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'mort'
args.db_name = 'mimic'
args.is_dev = True  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'code_emb'
args.predict_model_type = 'code_emb'
args.collate_fn=None
args.embed_index_size = 1456 # mort_cemb_metadata['extra_data']['embed_index_size']
args.pred_embed_size = 128

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-3

In [None]:
trainer = Trainer(args)
cemb_mort_dev_results = trainer.train()

In [None]:
# One entry per epoch in cemb_mort_results.
val_set_results = cemb_mort_dev_results[0]
test_set_results = cemb_mort_dev_results[1]
PrintTrainTime(cemb_mort_dev_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Cemb', task_type='Mortality')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(cemb_mort_dev_results)

### DescEmb - Mort - Dev

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'mort'
args.db_name = 'mimic'
args.is_dev = True  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'desc_emb'
args.predict_model_type = 'desc_emb'
args.override_batch_size = BATCH_SIZE_
args.collate_fn = bert_per_patient_collate_function_new_trainer
args.embed_index_size = 0 # unused mort_demb_metadata['extra_data']['embed_index_size']=768
args.pred_embed_size = 0  # unused 

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-3

In [None]:
trainer = Trainer(args)
demb_mort_dev_results = trainer.train()
del trainer
gc.collect()

In [None]:
# One entry per epoch in demb_mort_results.
val_set_results = demb_mort_dev_results[0]
test_set_results = demb_mort_dev_results[1]
PrintTrainTime(demb_mort_dev_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Demb', task_type='Mortality')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(demb_mort_dev_results)
labels = ['cemb', 'demb']
PlotDiffResults(cemb_mort_dev_results, demb_mort_dev_results, task_type='Mortality', labels=labels)


### CodeEmb - Readm - Dev

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

# Dev set.
# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456
# READM_DEMB_CODE2IDX 1523
# READM_CEMB_CODE2IDX 1456
# Full set.
# READMISSION_PER_PATIENT_ICD_9_CODE2IDX_ len: 6555
# MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ len: 6732
# Full set.
# MORT_DEMB_CODE2IDX 6732
# MORT_CEMB_CODE2IDX 6546

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'readm'
args.db_name = 'mimic'
args.is_dev = True  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'code_emb'
args.predict_model_type = 'code_emb'
args.collate_fn=None
args.embed_index_size = 1456
args.pred_embed_size = 128

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-3

In [None]:
trainer = Trainer(args)
cemb_readm_dev_results = trainer.train()

In [None]:
# One entry per epoch in cemb_mort_results.
val_set_results = cemb_readm_dev_results[0]
test_set_results = cemb_readm_dev_results[1]
PrintTrainTime(cemb_readm_dev_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Cemb', task_type='Readmission')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(cemb_readm_dev_results)

### DescEmb - Readm - Dev

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'readm'
args.db_name = 'mimic'
args.is_dev = True  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'desc_emb'
args.predict_model_type = 'desc_emb'
args.override_batch_size = BATCH_SIZE_
args.collate_fn = bert_per_patient_collate_function_new_trainer
args.embed_index_size = 0 # unused mort_demb_metadata['extra_data']['embed_index_size']=768
args.pred_embed_size = 0  # unused 

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-3

In [None]:
trainer = Trainer(args)
demb_readm_dev_results = trainer.train()
del trainer
gc.collect()

In [None]:
# One entry per epoch in demb_mort_results.
val_set_results = demb_readm_dev_results[0]
test_set_results = demb_readm_dev_results[1]
PrintTrainTime(demb_readm_dev_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Demb', task_type='Mortality')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(demb_readm_dev_results)
labels = ['cemb', 'demb']
PlotDiffResults(cemb_readm_dev_results, demb_readm_dev_results, task_type='Mortality', labels=labels)

### CodeEmb - Mort - Full

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

# Dev set.
# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456
# Full set.
# READMISSION_PER_PATIENT_ICD_9_CODE2IDX_ len: 6555
# MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ len: 6732
# Full set.
# MORT_DEMB_CODE2IDX 6732
# MORT_CEMB_CODE2IDX 6546

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'mort'
args.db_name = 'mimic'
args.is_dev = False  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'code_emb'
args.predict_model_type = 'code_emb'
args.collate_fn=None
args.embed_index_size = 6546  # mort_cemb_metadata['extra_data']['embed_index_size']
args.pred_embed_size = 128

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-4

In [None]:
trainer = Trainer(args)
cemb_mort_full_results = trainer.train()
del trainer
gc.collect()

In [None]:
# One entry per epoch in cemb_mort_results.
val_set_results = cemb_mort_full_results[0]
test_set_results = cemb_mort_full_results[1]
PrintTrainTime(cemb_mort_full_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Cemb', task_type='Mortality')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(cemb_mort_full_results)

### DescEmb - Mort - Full

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'mort'
args.db_name = 'mimic'
args.is_dev = False  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'desc_emb'
args.predict_model_type = 'desc_emb'
args.override_batch_size = BATCH_SIZE_
args.collate_fn = bert_per_patient_collate_function_new_trainer 
args.embed_index_size = 0 # unused mort_demb_metadata['extra_data']['embed_index_size']=768
args.pred_embed_size = 0  # unused 

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-4

In [None]:
trainer = Trainer(args)
demb_mort_full_results = trainer.train()
del trainer
gc.collect()

In [None]:
# One entry per epoch in cemb_mort_results.
val_set_results = demb_mort_full_results[0]
test_set_results = demb_mort_full_results[1]
PrintTrainTime(demb_mort_full_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Demb', task_type='Mortality')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(demb_mort_full_results)
labels = ['cemb', 'demb']
PlotDiffResults(cemb_mort_full_results, demb_mort_full_results, task_type='Mortality', labels=labels)

### CodeEmb - Readm - Full 

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

# Dev set.
# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456
# READM_DEMB_CODE2IDX 1523
# READM_CEMB_CODE2IDX 1456
# Full set.
# READMISSION_PER_PATIENT_ICD_9_CODE2IDX_ len: 6555
# MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ len: 6732
# Full set.
# MORT_DEMB_CODE2IDX 6732
# MORT_CEMB_CODE2IDX 6546

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'readm'
args.db_name = 'mimic'
args.is_dev = False  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'code_emb'
args.predict_model_type = 'code_emb'
args.collate_fn=None
args.embed_index_size = 6546
args.pred_embed_size = 128

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-4

In [None]:
# 33426 / 32 * 0.8 = 837
trainer = Trainer(args)
cemb_readm_full_results = trainer.train()
del trainer
gc.collect()

In [None]:
# One entry per epoch in cemb_mort_results.
val_set_results = cemb_readm_full_results[0]
test_set_results = cemb_readm_full_results[1]
PrintTrainTime(cemb_readm_full_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Cemb', task_type='Readmission')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(cemb_readm_full_results)

### DescEmb - Readm - Full 

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'readm'
args.db_name = 'mimic'
args.is_dev = False  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'desc_emb'
args.predict_model_type = 'desc_emb'
args.override_batch_size = BATCH_SIZE_
args.collate_fn = bert_per_patient_collate_function_new_trainer
args.embed_index_size = 0 # unused mort_demb_metadata['extra_data']['embed_index_size']=768
args.pred_embed_size = 0  # unused 

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-4

In [None]:
# 10000 / 32 * 0.8 = 250
trainer = Trainer(args)
demb_readm_full_results = trainer.train()
del trainer
gc.collect()

In [None]:
# One entry per epoch in demb_mort_results.
val_set_results = demb_readm_full_results[0]
test_set_results = demb_readm_full_results[1]
PrintTrainTime(demb_readm_full_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Demb', task_type='Readmission')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(demb_readm_full_results)
labels = ['cemb', 'demb']
PlotDiffResults(cemb_readm_full_results, demb_readm_full_results, task_type='Readmission', labels=labels)

### DescEmbFt - Mort - Dev

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher
import torch
# from GPUtil import showUtilization as gpu_usage

# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'mort'
args.db_name = 'mimic'
args.is_dev = True  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'desc_emb_ft'
args.predict_model_type = 'desc_emb_ft'
args.override_batch_size = BATCH_SIZE_
args.collate_fn = bert_fine_tune_collate
args.no_use_cached_dataset = mortality_dembft_dataset
args.embed_index_size = 0 # unused mort_demb_metadata['extra_data']['embed_index_size']=768
args.pred_embed_size = 0  # unused 

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-4

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.memory_summary(device=None, abbreviated=False))
gc.collect()
del trainer
gc.collect()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:
trainer = Trainer(args)
dembft_mort_dev_results = trainer.train()

In [None]:
# One entry per epoch in demb_mort_results.
val_set_results = dembft_mort_dev_results[0]
test_set_results = dembft_mort_dev_results[1]
PrintTrainTime(dembft_mort_dev_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Dembft', task_type='Mortality')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(dembft_mort_dev_results)

### DescEmbFt - Mort - Full 

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher


# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'mort'
args.db_name = 'mimic'
args.is_dev = False  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.embed_model_type = 'desc_emb_ft'
args.predict_model_type = 'desc_emb_ft'
args.override_batch_size = BATCH_SIZE_
args.collate_fn = bert_fine_tune_collate
args.no_use_cached_dataset = mortality_dembft_dataset
args.embed_index_size = 0 # unused mort_demb_metadata['extra_data']['embed_index_size']=768
args.pred_embed_size = 0  # unused 


# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-4

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.memory_summary(device=None, abbreviated=False))
gc.collect()
del trainer
gc.collect()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:
trainer = Trainer(args)
dembft_mort_full_results = trainer.train()

In [None]:
# One entry per epoch in cemb_mort_results.
val_set_results = dembft_mort_full_results[0]
test_set_results = dembft_mort_full_results[1]
PrintTrainTime(dembft_mort_full_results[2])
PlotAucRecallResults(val_set_results[-1], test_set_results[-1],
                     emb_type='Dembft', task_type='Mortality')
PlotPrecisionRecallAcrossEpoch(val_set_results, test_set_results)
PlotAccuracyAcrossEpoch(dembft_mort_full_results)

# eICU Eval

In [None]:
def eval_model(model, data_loader):
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    # for x, masks, rev_x, rev_masks, y in val_loader:
    for sample_dict in tqdm.tqdm(data_loader):
        y = sample_dict['y']
        y_hat = model(**sample_dict)
        y_score = torch.cat((y_score,  y_hat.detach().to('cpu')), dim=0)
        y_hat = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    """
    TODO:
        Calculate precision, recall, f1, and roc auc scores.
        Use `average='binary'` for calculating precision, recall, and fscore.
    """
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)
    rcurve = roc_curve(y_true, y_score)
    precision_curve, recall_curve, thresholds = precision_recall_curve(y_true, y_score)
    accuracy = accuracy_score(y_true, y_pred)
    return [(p, r, f, roc_auc, rcurve, precision_curve, recall_curve, accuracy)]


def load_model_state_from_file(fname):
    path = os.path.join(os.getcwd(), 'modelcache', fname)
    d = torch.load(path)
    model_state_dict = d['model_state_dict']
    args = d['args']
    print(f'Loading model state with args{args} from \n {path}')
    return model_state_dict, args


def load_model_from_file(embed_model_type: str,
                         task: str,
                         is_dev: bool):
    fname = f'model_{embed_model_type}_task_{task}_isdev{is_dev}'
    fname = fname + "_best.pt"
    model_state_dict, args = load_model_state_from_file(fname)
    model = EHRModel(args)
    model.load_state_dict(model_state_dict)
    model.eval()
    return model, args

### DescEMB eICU Pred Task

In [None]:
# re_drug_prefix = re.compile('Event with eICU_DRUGNAME code{.*}from table medication')
# re_bar = re.compile('\|')
# ICD_9_LUT_ = {}
# ICD_10_LUT_ = {}
# fname = os.path.expanduser('~/sw/icd10cm-code descriptions- April 1 2023/icd10cm-codes- April 1 2023.txt')
# with open(fname, 'r') as f:
#     for l in f:
#         code, desc = l.split(sep=' ', maxsplit=1)
#         ICD_10_LUT_[code] = desc
        
# def eicu_mortality_pred_task_demb(CODE_COUNT, patient):
#     """
#     patient is a <pyhealth.data.Patient> object
#     """
#     samples = []
#     visits = []
#     kMaxListSize = 40
    
#     global_mortality_label = 0
#     # loop over all visits but the last one
#     for i in range(len(patient)):

#         # visit and next_visit are both <pyhealth.data.Visit> objects
#         # there are no vists.attr_dict keys
#         visit: Visit = patient[i]
#         mortality_label = 0 if visit.discharge_status == 'Alive' else 1
#         global_mortality_label |= mortality_label
        
#     # loop over all visits but the last one
#     for i, visit in enumerate(patient):
#         # visit: Visit.
        
#         # step 2: get code-based feature information
#         conditions = visit.get_code_list(table="diagnosis")
#         procedures = visit.get_code_list(table="treatment")
#         # drugs = [x.code for x in visit.get_event_list(table="medication")]
#         drugs_full = visit.get_event_list(table="medication")
#         drugs_full = [d.code for d in drugs_full]
#         # if i == 0: print([d.attr_dict for d in drugs_full])
#         # if i == 0: print(conditions)
#         # if i == 0: print(procedures)
#         # if i == 0: print(drugs)
#         # TODO(botelho3) - add this datasource back in once we have full MIMIC-III dataset.
#         # labevents = visit.get_code_list(table="LABEVENTS")

#         # step 3: exclusion criteria: visits without condition, procedure, or drug
#         if len(conditions) * len(procedures) == 0 * len(drugs_full) == 0:
#             # print(f'Excluded something 0 {len(conditions)}, {len(procedures)}, {len(drugs_full)}')
#             # print(f'conditions {conditions}')
#             # print(f'procedures {procedures}')
#             # print(f'drugs_full {drugs_full}')
#             continue
#         if len(conditions) + len(procedures) + len(drugs_full) < 5:
#             # Exclude stays with less than 5 procedures.
#             continue
        
#         # step 3.5: build text lists from the ICD codes
#         # diag_lut = mimic3base.get_text_lut("D_ICD_DIAGNOSES")
#         # proc_lut = mimic3base.get_text_lut("D_ICD_PROCEDURES")
        
#         # if i == 0: print(d_diag)
#         # if i == 0: print(d_proc)
#         # Index 0 is shortname, index 1 is longname.
#         # print([str(cond) + ' ' + str(d_diag.get(cond)) for cond in conditions])
#         # print(d_proc.get(procedures[0]))
#         # print(f'condition {conditions}')
#         # print(f'proc {procedures}')
#         # print(f'drugs {drugs_full}')
#         # conditions_text = [diag_lut.get(cond,("", ""))[1] for cond in conditions]
#         # procedures_text = [proc_lut.get(proc,("", ""))[1] for proc in procedures]
#         conditions = filter(lambda x: True if x[0].isalpha() else False, conditions)
#         conditions = [cond.replace('.', '') for cond in conditions]
#         conditions_text = [ICD_10_LUT_.get(cond, '') for cond in conditions] 
#         procedures_text = [re_bar.sub(' ', proc) for proc in procedures]
#         drugs_text = [re_drug_prefix.sub('\1', str(d)) for d in drugs_full]
#         # TODO(botelho3) - add the labevents data source back in once we have full MIMIC-III dataset.
#         # labevents_text =
        
#         # step 4: assemble the samples into a pyHealth Visit.
#         visits.append(
#             {
#                 "visit_id": visit.visit_id,
#                 "patient_id": patient.patient_id,
#                 # the following keys can be the "feature_keys" or "label_key" for initializing downstream ML model
#                 "conditions": conditions,
#                 "procedures": procedures,
#                 "conditions_text": conditions_text,
#                 "procedures_text": procedures_text,
#                 "drugs_text": drugs_text,
#                 # "labevents": labevents,
#                 # "labevents_text": labevents_text
#                 "label": global_mortality_label,
#             }
#         )
   
    
#     # Return empty list, didn't meet exclusion criteria.
#     num_visits = len(visits)
#     if num_visits < 1:
#         return [] 
    
   
#     # pyHealth requires that all list fields in sample are equal size.
#     def pad_field(field, visits, empty_val: Any):
#         l = [empty_val for x in range(kMaxListSize)]
#         data = [x[field] for x in visits]
#         data = list(itertools.chain.from_iterable(data))
#         slice_size = min(kMaxListSize, len(data))
#         l[:slice_size] = data[:slice_size]
#         return l, slice_size
    
#     conditions, conditions_pad = pad_field("conditions", visits, '0')
#     conditions_text, conditions_text_pad = pad_field("conditions_text", visits, '')
#     procedures, procedures_pad = pad_field("procedures", visits, '0')
#     procedures_text, procedures_text_pad = pad_field("procedures_text", visits, '')
#     drugs_text, drugs_text_pad = pad_field("drugs_text", visits, '')
#     sample = {
#         "patient_id": patient.patient_id,
#         # TODO(botelho3) Why does pyhealth require a visit id in the keys if we're combining vists?
#         "visit_id": visits[0]["visit_id"],
#         "num_visits": num_visits,
#         # the following keys can be the "feature_keys" or "label_key" for initializing downstream ML model
#         "conditions": conditions,
#         "conditions_text": conditions_text,
#         # "procedures": procedures,
#         "procedures_text": procedures_text,
#         "drugs_text": drugs_text,
        
#         "conditions_pad": conditions_pad,
#         "procedures_pad": procedures_pad,
#         "conditions_text_pad": conditions_text_pad,
#         "procedures_text_pad": procedures_text_pad,
#         "drugs_text_pad": drugs_text_pad,
#         # "labevents": labevents,
#         # "labevents_text": labevents_text
#         "label": global_mortality_label,
#     }
   
#     # For every condition in the sample (all visits). Record frequency.
#     # Will be used to build code->index LUT.
#     for code in sample['conditions']:
#         CODE_COUNT[code] = CODE_COUNT.get(code, 0) + 1
       
#     # if len(CODE_COUNT) in [10,11,12]:
#     #     print(sample)
#     samples.append(sample)
#     return samples

### Test Load Mortality Dataset

In [None]:
if False:
    EICU_MORTALITY_PER_PATIENT_ICD_9_CODE_COUNT_ = {}
    EICU_MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ = {}
    task_fn = functools.partial(eicu_mortality_pred_task_demb, EICU_MORTALITY_PER_PATIENT_ICD_9_CODE_COUNT_)
    eicu_mor_dataset = eicubase.set_task(task_fn, task_name=eicu_mortality_pred_task_demb.__name__)
    EICU_MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ = {
        code: idx for idx, code in enumerate(sorted(EICU_MORTALITY_PER_PATIENT_ICD_9_CODE_COUNT_.keys()))
    }
    eicu_mor_dataset.stat()
    eicu_mor_dataset.samples[1]
    # TODO(botelho3) could try a freq codes limit on this.
    # print(f"MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_ len: {len(EICU_MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_)}\n"
    #       f"{EICU_MORTALITY_PER_PATIENT_ICD_9_CODE2IDX_}")
    del eicu_mor_dataset

### Load DescEmb - Mort - Full

In [None]:
desc_emb_mort_full_model, _ = load_model_from_file('desc_emb', 'mort', is_dev=False) 
print(desc_emb_mort_full_model)

### Eval DescEmb - Mort - Full - eICU

In [None]:
demb_mort_eicu_results = eval_model(desc_emb_mort_full_model, eicu_mort_demb_loader)

In [None]:
PlotAucRecallResults(None, demb_mort_eicu_results[-1],
                     emb_type='Demb', task_type='Mortality eICU')
PrintFinalAccuracy(demb_mort_eicu_results)

### Load DescEmbFt - Mort - Full

In [None]:
demb_ft_mort_full_model, _ = load_model_from_file('desc_emb_ft', 'mort', is_dev=False) 
print(demb_ft_mort_full_model)

### Eval DescEmbFt - Mort - Full - eICU

In [None]:
# Create the transform that will take each sample (visit) in the dataset
# and convert the text description of the visit into a single embedding.
bert_ft_xform = BertFineTuneTransform()
BERT_FT_EMBEDDING_SIZE = bert_ft_xform.emb_size

# Mortality Task Datasets: set_task -> SampleEHRDataset(SampleBaseDataset)
# Add a transform to convert the visit into an embedding.
CODE2IDX = {}
CODE_COUNT = {}
task_fn = functools.partial(eicu_mortality_pred_task_demb, CODE_COUNT)
eicu_mortality_dembft_dataset = eicubase.set_task(task_fn, task_name=eicu_mortality_pred_task_demb.__name__)
CODE2IDX = {
    code: idx for idx, code in enumerate(sorted(CODE_COUNT.keys()))
}
# Wrap the default pyHealth dataset class in our own wrapper. The wrapper takes each
# sample and applies BERT to xform text->pytorch.tensor.
eicu_mortality_dembft_dataset = TextEmbedDataset(eicu_mortality_dembft_dataset,
                                                 transform=bert_ft_xform,
                                                 should_cache=False)

indices = range(0, min(50000, len(eicu_mortality_dembft_dataset)) )
eicu_mortality_dembft_dataset = torch.utils.data.Subset(eicu_mortality_dembft_dataset, indices)
print(len(eicu_mortality_dembft_dataset))

In [None]:
# eicu_mort_dembft_loader = DataLoader(
#     torch.utils.data.Subset(eicu_mortality_dembft_dataset, indices),
#     batch_size=BATCH_SIZE_,
#     shuffle=SHUFFLE_,
#     collate_fn=bert_fine_tune_collate,
# )

# loader_iter = iter(eicu_mort_dembft_loader)
# # for _ in loader_iter:
# #     pass
# try:
#     x, masks, rev_x, rev_masks, y = next(loader_iter)
# except StopIteration as e:
#     print(e)

In [None]:
dembft_mort_eicu_results = eval_model(demb_ft_mort_full_model, eicu_mort_dembft_loader)

In [None]:
PlotAucRecallResults(None, dembft_mort_eicu_results[-1],
                     emb_type='Dembft', task_type='Mortality eICU')
PrintFinalAccuracy(dembft_mort_eicu_results)

# Condensed Training using Trainer

### DescEmb - Mort - Full

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

# MORT_DEMB_CODE2IDX 1625
# MORT_CEMB_CODE2IDX 1456

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'mort'
args.db_name = 'eicu'
args.is_dev = False  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.eval_only = True
args.embed_model_type = 'desc_emb'
args.predict_model_type = 'desc_emb'
args.override_batch_size = BATCH_SIZE_
args.collate_fn = bert_per_patient_collate_function_new_trainer 
args.embed_index_size = 0 # unused mort_demb_metadata['extra_data']['embed_index_size']=768
args.pred_embed_size = 0  # unused 

# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-4

In [None]:
trainer = Trainer(args)
desc_emb_mort_full_model, _ = load_model_from_file('desc_emb', 'mort', is_dev=False) 
trainer.model = desc_emb_mort_full_model
eicu_demb_mort_full_results = trainer.train()
del trainer
gc.collect()

In [None]:
# One entry per epoch.
test_set_results = eicu_demb_mort_full_results[1]
PrintTrainTime(eicu_demb_mort_full_results[2])
PlotAucRecallResults(None, test_set_results[-1],
                     emb_type='Demb eICU', task_type='Mortality')
PrintFinalAccuracy(test_set_results)

### DescEmbFt - Mort - Full 

In [None]:
import importlib
import trainlib
import models
import datasets
importlib.reload(trainlib)
importlib.reload(models)
importlib.reload(datasets)
from trainlib import Trainer
from models import ModelA, EHRModel, CembRNN, DembRNN
from datasets import SimpleDataset, DatasetCacher

args = DotArgs()
args.save_dir = os.path.join(os.getcwd(), 'modelcache')
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
args.save_prefix = 'checkpoints'
args.random_seed = 90210
args.task = 'mort'
args.db_name = 'eicu'
args.is_dev = False  # if dev mode on we use a small subset of the full dataset.
assert(args.task in task_types)

# Model Args
# Load either a code_emb embed model or a desc_emb embed model.
args.eval_only = True
args.embed_model_type = 'desc_emb_ft'
args.predict_model_type = 'desc_emb_ft'
args.override_batch_size = BATCH_SIZE_
args.collate_fn = bert_fine_tune_collate
args.no_use_cached_dataset = eicu_mortality_dembft_dataset
args.embed_index_size = 0 # unused mort_demb_metadata['extra_data']['embed_index_size']=768
args.pred_embed_size = 0  # unused 


# Training Args
args.load_pretrained_weights = False
args.n_epochs = 10
args.learning_rate = 1.0e-4

In [None]:
torch.cuda.empty_cache()
print(torch.cuda.memory_summary(device=None, abbreviated=False))
gc.collect()
del trainer
gc.collect()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:
trainer = Trainer(args)
demb_ft_mort_full_model, _ = load_model_from_file('desc_emb_ft', 'mort', is_dev=False) 
trainer.model = demb_ft_mort_full_model
eicu_dembft_mort_full_results = trainer.train()

In [None]:
# One entry per epoch.
test_set_results = eicu_dembft_mort_full_results[1]
PrintTrainTime(eicu_dembft_mort_full_results[2])
PlotAucRecallResults(None, test_set_results[-1],
                     emb_type='Dembft eICU', task_type='Mortality')
PrintFinalAccuracy(test_set_results)