This Notebook shows an example of using LSTM to model `In-hospital mortality` from MIMIC-III dataset. 

Data is presumed to have been already extracted from cohort and defined via a `yaml` configuration as below:

```yaml

# USER DEFINED
tgt_col: y_true
idx_cols: stay
time_order_col: 
    - Hours
    - seqnum

feat_cols: null

train:
    tgt_file: '{DATA_DIR}/IHM_V0_COHORT_OUT_EXP-SPLIT0-train.csv'
    feat_file: '{DATA_DIR}/IHM_V0_FEAT_EXP-SPLIT0-train.csv'

val:
    tgt_file: '{DATA_DIR}/IHM_V0_COHORT_OUT_EXP-SPLIT0-val.csv'
    feat_file: '{DATA_DIR}/IHM_V0_FEAT_EXP-SPLIT0-val.csv'

test:
    tgt_file: '{DATA_DIR}/IHM_V0_COHORT_OUT_EXP-SPLIT0-test.csv'
    feat_file: '{DATA_DIR}/IHM_V0_FEAT_EXP-SPLIT0-test.csv'

# DATA DEFINITIONS

## Definitions of categorical data in the dataset
category_map:
  Capillary refill rate: ['0.0', '1.0']
  Glascow coma scale eye opening: ['To Pain', '3 To speech', '1 No Response', '4 Spontaneously',
                                   'To Speech', 'Spontaneously', '2 To pain', 'None'] 
  Glascow coma scale motor response: ['1 No Response' , '3 Abnorm flexion' , 'Abnormal extension' , 'No response',
                                      '4 Flex-withdraws' , 'Localizes Pain' , 'Flex-withdraws' , 'Obeys Commands',
                                      'Abnormal Flexion' , '6 Obeys Commands' , '5 Localizes Pain' , '2 Abnorm extensn']
  Glascow coma scale total: ['11', '10', '13', '12', '15', '14', '3', '5', '4', '7', '6', '9', '8']
  Glascow coma scale verbal response: ['1 No Response', 'No Response', 'Confused', 'Inappropriate Words', 'Oriented', 
                                       'No Response-ETT', '5 Oriented', 'Incomprehensible sounds', '1.0 ET/Trach', 
                                       '4 Confused', '2 Incomp sounds', '3 Inapprop words']

numerical: ['Heart Rate', 'Fraction inspired oxygen', 'Weight', 'Respiratory rate', 
            'pH', 'Diastolic blood pressure', 'Glucose', 'Systolic blood pressure',
            'Height', 'Oxygen saturation', 'Temperature', 'Mean blood pressure']

## Definitions of normal values in the dataset
normal_values:
  Capillary refill rate: 0.0
  Diastolic blood pressure: 59.0
  Fraction inspired oxygen: 0.21
  Glucose: 128.0
  Heart Rate: 86
  Height: 170.0
  Mean blood pressure: 77.0
  Oxygen saturation: 98.0
  Respiratory rate: 19
  Systolic blood pressure: 118.0
  Temperature: 36.6
  Weight: 81.0
  pH: 7.4
  Glascow coma scale eye opening: '4 Spontaneously'
  Glascow coma scale motor response: '6 Obeys Commands'
  Glascow coma scale total:  '15'
  Glascow coma scale verbal response: '5 Oriented'
```

## Pre-amble

In [None]:
# Jupyter notebook Specific imports
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

# Imports injecting into namespace
from tqdm.auto import tqdm
tqdm.pandas()

import sys
sys.path.append('../../')

# General imports
import os
import json
import pickle
from pathlib import Path

import pandas as pd
import numpy as np
from getpass import getpass
import argparse

from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import NotFittedError

import torch as T
from torch import nn
from pytorch_lightning import Trainer

from lightsaber import constants as C
import lightsaber.data_utils.utils as du
from lightsaber.data_utils import pt_dataset as ptd
from lightsaber.trainers import pt_trainer as ptr

from lightsaber.model_lib.pt_sota_models import rnn

In [None]:
import logging
log = logging.getLogger()

In [None]:
data_dir = Path(getpass())  # enter or REPLACE with your data path containing the mimic files
assert data_dir.is_dir()

expt_conf = du.yaml.load(open('./ihm_expt_config.yml').read().format(DATA_DIR=data_dir),
                         Loader=du._Loader)

### Data Transformation functions

Transform/Filter functions allow runtime processing of data. 
User can either use pre-packaged filter/transforms or write their own and pass
at run time

In [None]:
@ptd.functoolz.curry
def filter_fillna(data, target, fill_value=0., time_order_col=None):
    data = data.copy()
    
    idx_cols = data.index.names
    if time_order_col is not None:
        try:
            sort_cols = idx_cols + time_order_col
        except:
            sort_cols = idx_cols + [time_order_col]
    else:
        sort_cols = idx_cols
    
    data.update(data.reset_index()
               .sort_values(sort_cols)
               .groupby(idx_cols[0])
               .ffill())
    
    data.fillna(fill_value, inplace=True)
        
    return data, target


@ptd.functoolz.curry
def filter_preprocessor(data, target, cols=None, preprocessor=None, refit=False):
    if preprocessor is not None:
        all_columns = data.columns
        index = data.index

        # Extracting the columns to fit
        if cols is None:
            cols = all_columns
        _oCols = all_columns.difference(cols)
        xData = data[cols]
    
        # If fit required fitting it
        if refit:
            preprocessor.fit(xData)
            log.info(f'Fitting pre-proc: {preprocessor}')
  
        # Transforming data to be transformed
        try:
            xData = preprocessor.transform(xData)
        except NotFittedError:
            raise Exception(f"{preprocessor} not fitted. pass fitted preprocessor or set refit=True")
        xData = pd.DataFrame(columns=cols, data=xData, index=index)
        
        # Merging other columns if required
        if not _oCols.empty:
            tmp = pd.DataFrame(data=data[_oCols].values, 
                               columns=_oCols,
                               index=index)
            xData = pd.concat((tmp, xData), axis=1)
        
        # Re-ordering the columns to original order
        data = xData[all_columns]
    return data, target

## IHM Example

In general, user need to follow the following steps to train a `LSTM` for IHM model.

1. Define the `filters` and `transforms` to be used. In this example, we will use a `StandardScaler` from `scikit-learn` using `filters` defined within `lightsaber`. 
2. Read the `train`, `test`, and `validation` dataset. In some cases, users may also want to define a `calibration dataset`
3. Define the model. In this example, we will use a pre-packaged `LSTM` model.
4. Use `lightsaber` to chain the model via `pytorch-trainer` and generate metrics. 

### Reading data along with usage of pre-processor

In [None]:
preprocessor = StandardScaler()
train_filter = [filter_preprocessor(cols=expt_conf['numerical'], 
                                    preprocessor=preprocessor,
                                    refit=True),
                filter_fillna(fill_value=expt_conf['normal_values'],
                              time_order_col=expt_conf['time_order_col'])
                ]
transform = ptd.transform_drop_cols(cols_to_drop=expt_conf['time_order_col'])

In [None]:
train_dataset = ptd.BaseDataset(tgt_file=expt_conf['train']['tgt_file'],
                                feat_file=expt_conf['train']['feat_file'],
                                idx_col=expt_conf['idx_cols'],
                                tgt_col=expt_conf['tgt_col'],
                                feat_columns=expt_conf['feat_cols'],
                                time_order_col=expt_conf['time_order_col'],
                                category_map=expt_conf['category_map'],
                                transform=transform,
                                filter=train_filter,
                               )
# print(train_dataset.data.head())
print(train_dataset.shape, len(train_dataset))

In [None]:
# For other datasets use fitted preprocessors
fitted_filter = [filter_preprocessor(cols=expt_conf['numerical'], 
                                     preprocessor=preprocessor, refit=False),
                 filter_fillna(fill_value=expt_conf['normal_values'],
                               time_order_col=expt_conf['time_order_col'])
                 ]

val_dataset = ptd.BaseDataset(tgt_file=expt_conf['val']['tgt_file'],
                              feat_file=expt_conf['val']['feat_file'],
                              idx_col=expt_conf['idx_cols'],
                              tgt_col=expt_conf['tgt_col'],
                              feat_columns=expt_conf['feat_cols'],
                              time_order_col=expt_conf['time_order_col'],
                              category_map=expt_conf['category_map'],
                              transform=transform,
                              filter=fitted_filter,
                              )
print(val_dataset.shape, len(val_dataset))

test_dataset = ptd.BaseDataset(tgt_file=expt_conf['test']['tgt_file'],
                               feat_file=expt_conf['test']['feat_file'],
                               idx_col=expt_conf['idx_cols'],
                               tgt_col=expt_conf['tgt_col'],
                               feat_columns=expt_conf['feat_cols'],
                               time_order_col=expt_conf['time_order_col'],
                               category_map=expt_conf['category_map'],
                               transform=transform,
                               filter=fitted_filter,
                               )
print(test_dataset.shape, len(test_dataset))

In [None]:
# For most models you need to change only this part
input_dim, target_dim = train_dataset.shape
output_dim = 2

weight_labels = train_dataset.target.iloc[:, 0].value_counts()
weight_labels = (weight_labels.max() / ((weight_labels + 0.0000001) ** (1)))
weight_labels.sort_index(inplace=True)
weights = T.FloatTensor(weight_labels.values).to(train_dataset.device)
print(weights)

### Defining the user model

In [None]:
# For most models you need to change only this part
hparams = argparse.Namespace(gpus=[0],
                             lr=0.01,
                             max_epochs=100,
                             batch_size=32,
                             hidden_dim=32,
                             rnn_class='LSTM',
                             n_layers=1,
                             dropout=0.1,
                             recurrent_dropout=0.1,
                             bidirectional=False,
                             )


base_model = rnn.RNNClassifier(input_dim, output_dim, 
                               hidden_dim=hparams.hidden_dim,
                               rnn_class=hparams.rnn_class,
                               n_layers=hparams.n_layers,
                               dropout=hparams.dropout,
                               recurrent_dropout=hparams.recurrent_dropout,
                               bidirectional=hparams.bidirectional
                              )

criterion = nn.CrossEntropyLoss(weight=weights)
# optimizer = T.optim.Adam(base_model.parameters(),
#                          lr=hparams.lr,
#                          weight_decay=1e-5  # standard value)
#                          )

# scheduler = T.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')



### Lightsaber to run training and evaluate

This part entails a few steps:

* create a wrapped model that takes in a base pytorch model and adds the training routines to the model
* associate a trainer for the wrapped model
* run training on the model with model tracking enabled

In [None]:
# Creating the wrapped model
wrapped_model = ptr.PyModel(hparams, base_model,
                            train_dataset=train_dataset,
                            val_dataset=val_dataset, # None
                            test_dataset=test_dataset, # test_dataset
                            #optimizer=optimizer,
                            loss_func=criterion,
                            #scheduler=scheduler,
                            collate_fn=ptd.collate_fn
                            )

In [None]:
# Training 
overfit_pct, fast_dev_run, terminate_on_nan, auto_lr_find = 0, True, True, False

trainer = Trainer(max_epochs=hparams.max_epochs, 
                  gpus=hparams.gpus,
                  default_root_dir=os.path.join('./out/', 'classifier_ihm'),
                  terminate_on_nan=terminate_on_nan,
                  auto_lr_find=auto_lr_find,
                  overfit_pct=overfit_pct,
                  fast_dev_run=fast_dev_run #True if devugging
                 )

In [None]:
# Run Training with model tracking
mlflow_conf = dict(experiment_name=f'classifier_ihm')
artifacts = dict(preprocessor=preprocessor)
experiment_tags = dict(model='RNNClassifier')

(metrics, test_y, 
 test_yhat, test_pred_proba) = ptr.run_training_with_mlflow(mlflow_conf, 
                                                            trainer, 
                                                            wrapped_model, 
                                                            overfit_pct=overfit_pct,
                                                            artifacts=artifacts,
                                                            **experiment_tags)


print(metrics)