In [None]:
# Jupyter notebook Specific imports
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

In [None]:
# Imports injecting into namespace
from tqdm.auto import tqdm
tqdm.pandas()

import sys
sys.path.append('../../')

In [None]:
# General imports
import os
import json
import pickle
from pathlib import Path

import pandas as pd
import numpy as np
from getpass import getpass
import argparse

from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import NotFittedError

import torch as T
from torch import nn
from pytorch_lightning import Trainer

from lightsaber import constants as C
import lightsaber.data_utils.utils as du
from lightsaber.data_utils import pt_dataset as ptd
from lightsaber.trainers import pt_trainer as ptr

from lightsaber.model_lib.pt_sota_models import rnn

In [None]:
import logging
log = logging.getLogger()

In [None]:
data_dir = Path(getpass())  # enter or REPLACE with your data path containing the mimic files
assert data_dir.is_dir()

expt_conf = du.yaml.load(open('./ihm_expt_config.yml').read().format(DATA_DIR=data_dir),
                         Loader=du._Loader)

### Data Transformation functions

Transform/Filter functions allow runtime processing of data. 
User can either use pre-packaged filter/transforms or write their own and pass
at run time

In [None]:
@ptd.functoolz.curry
def filter_fillna(data, target, fill_value=0., time_order_col=None):
    data = data.copy()
    
    idx_cols = data.index.names
    if time_order_col is not None:
        try:
            sort_cols = idx_cols + time_order_col
        except:
            sort_cols = idx_cols + [time_order_col]
    else:
        sort_cols = idx_cols
    
    data.update(data.reset_index()
               .sort_values(sort_cols)
               .groupby(idx_cols[0])
               .ffill())
    
    data.fillna(fill_value, inplace=True)
        
    return data, target


@ptd.functoolz.curry
def filter_preprocessor(data, target, cols=None, preprocessor=None, refit=False):
    if preprocessor is not None:
        all_columns = data.columns
        index = data.index

        # Extracting the columns to fit
        if cols is None:
            cols = all_columns
        _oCols = all_columns.difference(cols)
        xData = data[cols]
    
        # If fit required fitting it
        if refit:
            preprocessor.fit(xData)
            log.info(f'Fitting pre-proc: {preprocessor}')
  
        # Transforming data to be transformed
        try:
            xData = preprocessor.transform(xData)
        except NotFittedError:
            raise Exception(f"{preprocessor} not fitted. pass fitted preprocessor or set refit=True")
        xData = pd.DataFrame(columns=cols, data=xData, index=index)
        
        # Merging other columns if required
        if not _oCols.empty:
            tmp = pd.DataFrame(data=data[_oCols].values, 
                               columns=_oCols,
                               index=index)
            xData = pd.concat((tmp, xData), axis=1)
        
        # Re-ordering the columns to original order
        data = xData[all_columns]
    return data, target

### Reading data along with usage of pre-processor

In [None]:
preprocessor = StandardScaler()
train_filter = [filter_preprocessor(cols=expt_conf['numerical'], 
                                    preprocessor=preprocessor,
                                    refit=True),
                filter_fillna(fill_value=expt_conf['normal_values'],
                              time_order_col=expt_conf['time_order_col'])
                ]
transform = ptd.transform_drop_cols(cols_to_drop=expt_conf['time_order_col'])

In [None]:
train_dataset = ptd.BaseDataset(tgt_file=expt_conf['train']['tgt_file'],
                                feat_file=expt_conf['train']['feat_file'],
                                idx_col=expt_conf['idx_cols'],
                                tgt_col=expt_conf['tgt_col'],
                                feat_columns=expt_conf['feat_cols'],
                                time_order_col=expt_conf['time_order_col'],
                                category_map=expt_conf['category_map'],
                                transform=transform,
                                filter=train_filter,
                               )
# print(train_dataset.data.head())
print(train_dataset.shape, len(train_dataset))

In [None]:
# For other datasets use fitted preprocessors
fitted_filter = [filter_preprocessor(cols=expt_conf['numerical'], 
                                     preprocessor=preprocessor, refit=False),
                 filter_fillna(fill_value=expt_conf['normal_values'],
                               time_order_col=expt_conf['time_order_col'])
                 ]

In [None]:
val_dataset = ptd.BaseDataset(tgt_file=expt_conf['val']['tgt_file'],
                              feat_file=expt_conf['val']['feat_file'],
                              idx_col=expt_conf['idx_cols'],
                              tgt_col=expt_conf['tgt_col'],
                              feat_columns=expt_conf['feat_cols'],
                              time_order_col=expt_conf['time_order_col'],
                              category_map=expt_conf['category_map'],
                              transform=transform,
                              filter=fitted_filter,
                              )
print(val_dataset.shape, len(val_dataset))

In [None]:
test_dataset = ptd.BaseDataset(tgt_file=expt_conf['test']['tgt_file'],
                               feat_file=expt_conf['test']['feat_file'],
                               idx_col=expt_conf['idx_cols'],
                               tgt_col=expt_conf['tgt_col'],
                               feat_columns=expt_conf['feat_cols'],
                               time_order_col=expt_conf['time_order_col'],
                               category_map=expt_conf['category_map'],
                               transform=transform,
                               filter=fitted_filter,
                               )
print(test_dataset.shape, len(test_dataset))

In [None]:
# For most models you need to change only this part
input_dim, target_dim = train_dataset.shape
output_dim = 2

weight_labels = train_dataset.target.iloc[:, 0].value_counts()
weight_labels = (weight_labels.max() / ((weight_labels + 0.0000001) ** (1)))
weight_labels.sort_index(inplace=True)
weights = T.FloatTensor(weight_labels.values).to(train_dataset.device)
print(weights)

## Single Run

In [None]:
# For most models you need to change only this part
hparams = argparse.Namespace(gpus=[0],
                             lr=0.01,
                             max_epochs=100,
                             batch_size=32,
                             hidden_dim=32,
                             rnn_class='LSTM',
                             n_layers=1,
                             dropout=0.1,
                             recurrent_dropout=0.1,
                             bidirectional=False,
                             )


base_model = rnn.RNNClassifier(input_dim, output_dim, 
                               hidden_dim=hparams.hidden_dim,
                               rnn_class=hparams.rnn_class,
                               n_layers=hparams.n_layers,
                               dropout=hparams.dropout,
                               recurrent_dropout=hparams.recurrent_dropout,
                               bidirectional=hparams.bidirectional
                              )

criterion = nn.CrossEntropyLoss(weight=weights)
# optimizer = T.optim.Adam(base_model.parameters(),
#                          lr=hparams.lr,
#                          weight_decay=1e-5  # standard value)
#                          )

# scheduler = T.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# Creating the wrapped model
wrapped_model = ptr.PyModel(hparams, base_model,
                            train_dataset=train_dataset,
                            val_dataset=val_dataset, # None
                            test_dataset=test_dataset, # test_dataset
                            #optimizer=optimizer,
                            loss_func=criterion,
                            #scheduler=scheduler,
                            collate_fn=ptd.collate_fn
                            )

In [None]:
# Training 
overfit_pct, fast_dev_run, terminate_on_nan, auto_lr_find = 0, True, True, False

trainer = Trainer(max_epochs=hparams.max_epochs, 
                  gpus=hparams.gpus,
                  default_root_dir=os.path.join('./out/', 'classifier_ihm'),
                  terminate_on_nan=terminate_on_nan,
                  auto_lr_find=auto_lr_find,
                  overfit_pct=overfit_pct,
                  fast_dev_run=fast_dev_run #True if devugging
                 )

In [None]:
mlflow_conf = dict(experiment_name=f'classifier_ihm')
artifacts = dict(preprocessor=preprocessor)
experiment_tags = dict(model='RNNClassifier')

(metrics, test_y, 
 test_yhat, test_pred_proba) = ptr.run_training_with_mlflow(mlflow_conf, 
                                                            trainer, 
                                                            wrapped_model, 
                                                            overfit_pct=overfit_pct,
                                                            artifacts=artifacts,
                                                            **experiment_tags)


print(metrics)