# Test Classification Using LSTM

This Notebook tests the usage of `LSTM` for classification on test data genereated from [mk_test_data.py](./mk_test_data.py).

## Pre-amble

The following code cell imports the required libraries and sets up the notebook

In [None]:
import os
#TEST_REGISTRATION = os.environ.get('test_registration', False)
USE_GPU = os.environ.get('use_gpu', 1)

In [None]:
import sys
sys.path.append('../')

In [None]:
# Jupyter notebook Specific imports
%matplotlib inline

import warnings
#warnings.filterwarnings('ignore')

# Imports injecting into namespace
from tqdm.auto import tqdm
tqdm.pandas()

# General imports
import os
import json
import pickle
from pathlib import Path

import pandas as pd
import numpy as np
from getpass import getpass
import argparse

from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import NotFittedError

import torch as T
from torch import nn
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from lightsaber import constants as C
import lightsaber.data_utils.utils as du
from lightsaber.data_utils import pt_dataset as ptd
from lightsaber.trainers import pt_trainer as ptr

from lightsaber.model_lib.pt_sota_models import rnn

In [None]:
import logging
logging.basicConfig(level=logging.INFO)
log = logging.getLogger()
#log.setLevel(logging.DEBUG)
log.info('Ready to log')

In [None]:
import io

data_dir = Path('./data')
assert data_dir.is_dir()

expt_conf = f"""
tgt_col: treatment

idx_cols: 
    - id
    - time
time_order_col: 
    - time_history

feat_cols: 
    - prev_cov1
    - prev_treat

train:
    tgt_file: '{data_dir}/easiest_sim_shifted_TGT_train.csv'
    feat_file: '{data_dir}/easiest_sim_shifted_FEAT_train.csv'

val:
    tgt_file: '{data_dir}/easiest_sim_shifted_TGT_val.csv'
    feat_file: '{data_dir}/easiest_sim_shifted_FEAT_val.csv'
    
test:
    tgt_file: '{data_dir}/easiest_sim_shifted_TGT_test.csv'
    feat_file: '{data_dir}/easiest_sim_shifted_FEAT_test.csv'

category_map:
    prev_treat: [0, 1]
    
numerical: 
    - prev_cov1

normal_values:
    prev_cov1: 0.
    prev_treat: 0
"""
expt_conf = du.yaml.load(io.StringIO(expt_conf), Loader=du._Loader)

## Model Training

In general, user need to follow the following steps to train a `LSTM` for IHM model.

* _Data Ingestion_: The first step involves setting up the pre-processors to train a classification model. In this example, we will  use `StandardScaler` from `scikit-learn` using filters defined within lightsaber.

  - We would next read the train, test, and validation dataset. In some cases, users may also want to define a calibration dataset
    
* _Model Definition_: We would next need to define a base model for classification. In this example, we will use a pre-packaged `LSTM` model from  `lightsaber`

* _Model Training_: Once the models are defined, we can use `lightsaber` to train the model via the pre-packaged `PyModel` and the corresponding trainer code. This step will also generate the relevant `metrics` for this problem.

### Data ingestion

We first start by reading extracted cohort data and use a `StandardScaler` demonstrating the proper usage of a pre-processor

In [None]:
preprocessor = StandardScaler()
train_filter = [ptd.filter_preprocessor(cols=expt_conf['numerical'], 
                                        preprocessor=preprocessor,
                                        refit=True),
                ptd.filter_fillna(fill_value=expt_conf['normal_values'],
                                  time_order_col=expt_conf['time_order_col'])
                ]
transform = ptd.transform_drop_cols(cols_to_drop=expt_conf['time_order_col'])

In [None]:
train_dataset = ptd.BaseDataset(tgt_file=expt_conf['train']['tgt_file'],
                                feat_file=expt_conf['train']['feat_file'],
                                idx_col=expt_conf['idx_cols'],
                                tgt_col=expt_conf['tgt_col'],
                                feat_columns=expt_conf['feat_cols'],
                                time_order_col=expt_conf['time_order_col'],
                                category_map=expt_conf['category_map'],
                                transform=transform,
                                filter=train_filter,
                               )
# print(train_dataset.data.head())
print(train_dataset.shape, len(train_dataset))

In [None]:
# For other datasets use fitted preprocessors
fitted_filter = [ptd.filter_preprocessor(cols=expt_conf['numerical'], 
                                         preprocessor=preprocessor, refit=False),
                 ptd.filter_fillna(fill_value=expt_conf['normal_values'],
                                   time_order_col=expt_conf['time_order_col'])
                 ]

val_dataset = ptd.BaseDataset(tgt_file=expt_conf['val']['tgt_file'],
                              feat_file=expt_conf['val']['feat_file'],
                              idx_col=expt_conf['idx_cols'],
                              tgt_col=expt_conf['tgt_col'],
                              feat_columns=expt_conf['feat_cols'],
                              time_order_col=expt_conf['time_order_col'],
                              category_map=expt_conf['category_map'],
                              transform=transform,
                              filter=fitted_filter,
                              )

test_dataset = ptd.BaseDataset(tgt_file=expt_conf['test']['tgt_file'],
                               feat_file=expt_conf['test']['feat_file'],
                               idx_col=expt_conf['idx_cols'],
                               tgt_col=expt_conf['tgt_col'],
                               feat_columns=expt_conf['feat_cols'],
                               time_order_col=expt_conf['time_order_col'],
                               category_map=expt_conf['category_map'],
                               transform=transform,
                               filter=fitted_filter,
                               )

print(val_dataset.shape, len(val_dataset))
print(test_dataset.shape, len(test_dataset))

In [None]:
# Handling imbala
input_dim, target_dim = train_dataset.shape
output_dim = 2

weight_labels = train_dataset.target.iloc[:, 0].value_counts()
weight_labels = (weight_labels.max() / ((weight_labels + 0.0000001) ** (1)))
weight_labels.sort_index(inplace=True)
weights = T.FloatTensor(weight_labels.values).to(train_dataset.device)
print(weights)

## Single Run

In [None]:
# For most models you need to change only this part
hparams = argparse.Namespace(lr=0.01,
                             batch_size=64,
                             hidden_dim=8,
                             rnn_class='LSTM',
                             n_layers=2,
                             dropout=0.1,
                             recurrent_dropout=0.1,
                             bidirectional=False,
                             )

hparams.rnn_class = C.PYTORCH_CLASS_DICT[hparams.rnn_class]

base_model = rnn.RNNClassifier(input_dim, output_dim, 
                               hidden_dim=hparams.hidden_dim,
                               rnn_class=hparams.rnn_class,
                               n_layers=hparams.n_layers,
                               dropout=hparams.dropout,
                               recurrent_dropout=hparams.recurrent_dropout,
                               bidirectional=hparams.bidirectional
                              )

criterion = nn.CrossEntropyLoss(weight=weights)
# optimizer = T.optim.Adam(base_model.parameters(),
#                          lr=hparams.lr,
#                          weight_decay=1e-5  # standard value)
#                          )

# scheduler = T.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# Creating the wrapped model
wrapped_model = ptr.PyModel(hparams, base_model,
                            train_dataset=train_dataset,
                            val_dataset=val_dataset, # None
                            test_dataset=test_dataset, # test_dataset
                            cal_dataset=val_dataset, # cal_dataset
                            #optimizer=optimizer,
                            loss_func=criterion,
                            #scheduler=scheduler,
                            collate_fn=ptd.collate_fn
                            )

In [None]:
# Training 
overfit_batches, fast_dev_run, terminate_on_nan, auto_lr_find, limit_batch = 0, False, False, False, 5
default_root_dir = os.path.join('./out/', 'classifier_test')
checkpoint_callback = ModelCheckpoint(dirpath=default_root_dir)
callbacks = [checkpoint_callback]

train_args = argparse.Namespace(gpus=USE_GPU,
                                max_epochs=2,
                                callbacks=callbacks,
                                default_root_dir=default_root_dir,
                                terminate_on_nan=terminate_on_nan,
                                auto_lr_find=auto_lr_find,
                                overfit_batches=overfit_batches,
                                fast_dev_run=fast_dev_run, #True if devugging
                                limit_train_batches=limit_batch,
                                limit_val_batches=limit_batch,
                                limit_predict_batches=limit_batch,
                                log_ever_n_steps=1,
                               )

In [None]:
mlflow_conf = dict(experiment_name=f'classifier_test')
artifacts = dict(preprocessor=preprocessor, 
                 weight_labels=weight_labels,
                )
experiment_tags = dict(model='RNNClassifier',
                       input_dim=input_dim,
                       output_dim=output_dim
                      )

(run_id, metrics, 
 y_val, y_val_hat, y_val_proba, 
 y_test, y_test_hat, y_test_proba) = ptr.run_training_with_mlflow(mlflow_conf, 
                                                                  train_args, 
                                                                  wrapped_model,
                                                                  artifacts=artifacts,
                                                                  **experiment_tags)

print(f"MLFlow Experiment: {mlflow_conf['experiment_name']} \t | Run ID: {run_id}")
print(metrics)

In [None]:
print(y_val.shape, y_val_hat.shape, y_val_proba.shape, 
      y_test.shape, y_test_hat.shape, y_test_proba.shape
     )
print(type(y_val), type(y_val_hat), type(y_val_proba), type(y_test), type(y_test_proba), type(y_test_proba))