## Main notebook to run Attention-LSTM models: Single Fold

Author: Lin Lee Cheong <br>
Last Updated: 11/23/2020 <br>

In [1]:
import os
import argparse
import time
import pickle
import pandas as pd
import numpy as np
from collections import Counter
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.datasets import text_classification
from torchtext.vocab import Vocab
from attn_lstm_model import AttentionRNN
from model_utils import (
    log,
    build_lstm_dataset,
    epoch_train_lstm,
    epoch_val_lstm,
    generate_batch,
    count_parameters,
    epoch_time
)

**OPTIONS**

In [2]:
nrows = 1e9
min_freq = 500
device_id = 1

train_data_path = "../../../data/readmission/fold_4/train/raw_train_data_1000_30days.csv"
valid_data_path = "../../../data/readmission/fold_4/test/raw_test_data_1000_30days.csv"
model_save_path = './lstm_model_30days/gen_attn_lstm_fold4'
results_path = './lstm_results_30days/gen_attn_lstm_results_fold4'

batch_size = 2046 
N_EPOCHS = 20 

EMBEDDING_DIM = 30
HIDDEN_DIM = 30
BIDIRECTIONAL = False
DROPOUT = 0.3 # TODO: remove dropout

In [3]:
torch.cuda.is_available()
if device_id is None:
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
    DEVICE = torch.device(f'cuda:{device_id}' if torch.cuda.is_available() else 'cpu')

In [4]:
for fp in [model_save_path, results_path]:
    if not os.path.isdir(os.path.split(fp)[0]):
        print(f'New directory created: {fp}')
        os.makedirs(os.path.split(fp)[0])

**READ IN TO GENERATE DATASET**

In [5]:
train_dataset = build_lstm_dataset(
    datapath=train_data_path, min_freq=500, nrows=nrows, rev=False
)

valid_dataset = build_lstm_dataset(
    datapath=valid_data_path,
    min_freq=500,
    nrows=nrows,
    vocab=train_dataset._vocab,
    rev=False,
)

log('vocab length:', len(train_dataset._vocab))

    0.00: Build token list
    6.38: Build counter
    6.64: Build vocab
    6.64: Build data
    7.37: Build pytorch dataset
    7.37: Skipped 0 invalid patients
    7.37: Skipped 0 dead patients
    7.37: Done
    7.40: Build token list
    9.13: Build data
    9.29: Build pytorch dataset
    9.29: Skipped 0 invalid patients
    9.29: Skipped 0 dead patients
    9.29: Done
    9.29: vocab length: 5135


In [6]:
# TODO: build LSTM dataset to use a provided vocabulary to process

In [7]:
# TODO: SAVE dataset, vocab
# torch.save(train_dataset, './tmp_train_dataset.pt')
# torch.save(valid_dataset,'./tmp_valid_datset.pt')

In [8]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=generate_batch,
    num_workers=8,
)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=generate_batch,
    num_workers=8
)

**MODEL GENERATION**

In [9]:
import torch
log(torch.cuda.is_available())
log(DEVICE)

    9.29: True
    9.29: cuda:1


In [10]:
INPUT_DIM = len(train_dataset._vocab) 
OUTPUT_DIM = len(train_dataset._labels)

In [11]:
model = AttentionRNN(       
    INPUT_DIM, 
    EMBEDDING_DIM, 
    HIDDEN_DIM, 
    OUTPUT_DIM, 
    BIDIRECTIONAL, 
    DROPOUT,
    padding_idx=0,
    device=DEVICE
)

model =  model.to(DEVICE)

log(model)
log(f'Nb of params: {count_parameters(model)}')

  "num_layers={}".format(dropout, num_layers))


    9.35: AttentionRNN(
  (embedding): Embedding(5135, 30, padding_idx=0)
  (rnn): LSTM(30, 30, dropout=0.3)
  (fc): Linear(in_features=30, out_features=1, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)
    9.35: Nb of params: 161521


**MODEL TRAINING**

In [12]:
optimizer = optim.Adam(model.parameters(), lr=0.02)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 4, gamma=0.9)

#    optimizer = optim.SGD(model.parameters(), lr=args.lr)
#    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, gamma=0.9) #LLC-2/12: less aggresive drops
    
criterion = nn.BCEWithLogitsLoss() 
criterion = criterion.to(DEVICE)

In [13]:
log('Train')
best_valid_loss = float("inf")
valid_worse_loss = 0  # enable early stopping
stop_num = 6

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss, train_auc = epoch_train_lstm(
        model, train_dataloader, optimizer, criterion
    )

    valid_loss, valid_auc = epoch_val_lstm(
        model, valid_dataloader, criterion, return_preds=False
    )

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    print(f"Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s")

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), model_save_path)
        print("Saved Model, epoch {}".format(epoch))
        valid_worse_loss = 0

    else:
        valid_worse_loss += 1
        if valid_worse_loss == stop_num:
            print("EARLY STOP ------")
            break

    scheduler.step()
    log(
        f"Train Loss: {train_loss:.3f} | Train AUC: {train_auc:.2f} \t Val. Loss: {valid_loss:.3f} |  Val. AUC: {valid_auc:.4f}"
    )

    9.35: Train
Epoch: 01 | Epoch Time: 2m 17s
Saved Model, epoch 0
   11.64: Train Loss: 0.409 | Train AUC: 0.63 	 Val. Loss: 0.800 |  Val. AUC: 0.6648
Epoch: 02 | Epoch Time: 2m 19s
Saved Model, epoch 1
   13.96: Train Loss: 0.400 | Train AUC: 0.67 	 Val. Loss: 0.800 |  Val. AUC: 0.6673
Epoch: 03 | Epoch Time: 2m 16s
Saved Model, epoch 2
   16.24: Train Loss: 0.398 | Train AUC: 0.67 	 Val. Loss: 0.797 |  Val. AUC: 0.6689
Epoch: 04 | Epoch Time: 2m 18s
   18.54: Train Loss: 0.397 | Train AUC: 0.67 	 Val. Loss: 0.797 |  Val. AUC: 0.6690
Epoch: 05 | Epoch Time: 2m 17s
Saved Model, epoch 4
   20.84: Train Loss: 0.396 | Train AUC: 0.68 	 Val. Loss: 0.797 |  Val. AUC: 0.6693
Epoch: 06 | Epoch Time: 2m 19s
   23.15: Train Loss: 0.396 | Train AUC: 0.68 	 Val. Loss: 0.798 |  Val. AUC: 0.6687
Epoch: 07 | Epoch Time: 2m 16s
   25.43: Train Loss: 0.396 | Train AUC: 0.68 	 Val. Loss: 0.797 |  Val. AUC: 0.6691
Epoch: 08 | Epoch Time: 2m 18s
Saved Model, epoch 7
   27.74: Train Loss: 0.395 | Train 

## Get best model on val set: predictions, feature importance etc

In [14]:
model.load_state_dict(torch.load(model_save_path))

<All keys matched successfully>

In [15]:
# results = ( ids, predictions, labels, attn, events)
valid_loss, valid_auc, valid_results = epoch_val_lstm(
        model,
        valid_dataloader,
        criterion,
        return_preds=True
    )

In [16]:
torch.save(valid_results, results_path)

In [17]:
print(valid_auc)

0.6690754520801806


In [18]:
print(valid_loss)

0.7959693054747737
