# **RETAIN heart failure detection**

An implementation of [RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism](https://arxiv.org/abs/1608.05745) by Choi et al. to predict heart failure based on clinical records.

### 1. Load libraries

In [289]:
import numpy as np
import pandas as pd
import pickle as pickle

import torch
import torch.nn as nn


# **Prepare dataset:**

### 2. Load dataset
Dataset consists of several files:
* **pids:** patient ID
* **vids:** visit ID
* **hfs:** heart failures
* **seqs:** sequences of ICD-9 clinical records
* **types:** mapping of ICD-9 code to ID
* **rtypes:** index mapping of ICD-9 ID, used to simplify code

In [290]:
pids = pd.read_pickle('/dataset/pids.pkl')
vids = pd.read_pickle('/dataset/vids.pkl')
hfs = pd.read_pickle('/dataset/hfs.pkl')
seqs = pd.read_pickle('/dataset/seqs.pkl')
types = pd.read_pickle('/dataset/types.pkl')
rtypes = pd.read_pickle('/dataset/rtypes.pkl')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


### 3. Data preview

In [291]:
def print_patient_data(index):
    patient_stats = pd.DataFrame({
        "ID": pids[index],
        "Visits": len(vids[index]),
        "Heart failure": bool(hfs[index])
    }, index=[index])

    return patient_stats


print("Total patients:", len(pids))
print("Total amount of patients with heart failure:", sum(hfs))
print("Rato of heart failure patients:", (sum(hfs) / len(hfs)))
print("\nExample patient:")
print_patient_data(3)

Total patients: 1000
Total amount of patients with heart failure: 548
Rato of heart failure patients: 0.548

Example patient:


Unnamed: 0,ID,Visits,Heart failure
3,47537,2,False


In [292]:
def get_patient_visit(index, visit):

    patient_visit = pd.DataFrame({
        'event_id': seqs[index][visit]
    })

    patient_visit['event'] = patient_visit.event_id.map(rtypes)

    return patient_visit


print("Details of patient visit:")
get_patient_visit(3, 0)

Details of patient visit:


Unnamed: 0,event_id,event
0,12,DIAG_041
1,103,DIAG_276
2,262,DIAG_518
3,285,DIAG_560
4,290,DIAG_567
5,292,DIAG_569
6,359,DIAG_707
7,416,DIAG_785
8,39,DIAG_155
9,225,DIAG_456


### 3. Prepare dataset - padding, masking, reversing:

- **Padding:** inserts 0 to each visit to fill the gap between this visit and visit with highest length. We do that to have same dimensions across all sequences.
- **Masking:** creates a boolean mask to indicate which element was added.
- **Reversing:** according to paper we have to reverse the data to achieve "reverse time attention mechanism". We do reversing only on actual data not the padding.

In [293]:
def collate_fn(data):

    sequences, labels = zip(*data)

    num_patients = len(sequences)
    max_num_visits = 0
    max_num_codes = 0

    for patient in sequences:
        max_num_visits = max(max_num_visits, len(patient))
        for visit in patient:
            max_num_codes = max(max_num_codes, len(visit))

    shape = (num_patients, max_num_visits, max_num_codes)

    y = torch.tensor(labels, dtype=torch.float)
    x = torch.zeros(shape, dtype=torch.long)
    rev_x = torch.zeros(shape, dtype=torch.long)
    masks = torch.zeros(shape, dtype=torch.bool)
    rev_masks = torch.zeros(shape, dtype=torch.bool)

    for i, patient in enumerate(sequences):
        for j, visit in enumerate(patient):

            x[i,j,:len(visit)] = torch.tensor(visit, dtype=torch.long)
            masks[i,j, :len(visit)] = 1

            rev_x[i, len(patient) - j - 1, :len(visit)] = torch.tensor(visit, dtype=torch.long)
            rev_masks[i, len(patient) - j - 1, :len(visit)] = 1
            
    
    return x, masks, rev_x, rev_masks, y


def load_data(train_dataset, val_dataset, collate_fn):

    batch_size = 32

    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_fn
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_fn
    )
    
    return train_loader, val_loader



class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, seqs, hfs):
        self.x = seqs
        self.y = hfs
    

    def __len__(self):
        return len(self.x)
    

    def __getitem__(self, index):
        return (self.x[index], self.y[index])


dataset = CustomDataset(seqs, hfs)
split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = torch.utils.data.dataset.random_split(dataset, lengths)


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)


# **RETAIN model**

### 4. Custom layers
- Alpha attention
- Beta attention
- Attention sum
- Sum embeddings with applied mask

In [294]:
class AlphaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        self.a_att = nn.Linear(hidden_dim, 1)


    def forward(self, g):
        alpha = self.a_att(g)
        alpha = torch.softmax(alpha, 1)
        return alpha


class BetaAttention(torch.nn.Module):

    def __init__(self, hidden_dim):
        super().__init__()
        self.b_att = nn.Linear(hidden_dim, hidden_dim)


    def forward(self, h):
        beta = self.b_att(h)
        beta = torch.tanh(beta)
        return beta


def attention_sum(alpha, beta, rev_v, rev_masks):
    c = (torch.sum(rev_masks, -1) > 0).type(torch.float)
    c = c.unsqueeze(-1)
    c = torch.sum(alpha * beta * rev_v * c, dim=1)
    return c


def sum_embeddings_with_mask(x, masks):
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

### 5. Model architecture

In [295]:
class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=128):
        super().__init__()

        self.embedding = nn.Embedding(num_codes, embedding_dim)

        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)

        self.att_a = AlphaAttention(embedding_dim)
        self.att_b = BetaAttention(embedding_dim)

        self.fc = nn.Linear(embedding_dim, 1)

        self.sigmoid = nn.Sigmoid()


    def forward(self, x, masks, rev_x, rev_masks):

        rev_x = self.embedding(rev_x)
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)

        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)

        alpha = self.att_a(g)
        beta = self.att_b(h)

        c = attention_sum(alpha, beta, rev_x, rev_masks)

        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()
    


### 6. Evaluation function

In [296]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score


def eval(model, val_loader):

    
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()

    model.eval()

    for x, masks, rev_x, rev_masks, y in val_loader:
        y_logit = model(x, masks, rev_x, rev_masks)

        y_hat = y_logit >= 0.5

        y_score = torch.cat((y_score,  y_logit.detach().to('cpu')), dim=0)
        y_pred = torch.cat((y_pred,  y_hat.detach().to('cpu')), dim=0)
        y_true = torch.cat((y_true, y.detach().to('cpu')), dim=0)
    
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)
    
    return p, r, f, roc_auc



### 7. Train the model

In [297]:
retain = RETAIN(num_codes = len(types))

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(retain.parameters(), lr=0.0003)
n_epochs = 100


def train(model, train_loader, val_loader, n_epochs):

    model.train()
    
    for epoch in range(n_epochs):

        train_loss = 0

        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks, rev_x, rev_masks)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()


        if epoch % 10 == 0:
            train_loss = train_loss / len(train_loader)
            print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch, train_loss))
            p, r, f, roc_auc = eval(model, val_loader)
            print('Epoch: {} \t Validation p: {:.2f}, r:{:.2f}, f: {:.2f}, roc_auc: {:.2f}'.format(epoch, p, r, f, roc_auc))

    return round(roc_auc, 2)


train(retain, train_loader, val_loader, n_epochs)

Epoch: 0 	 Training Loss: 0.667214
Epoch: 0 	 Validation p: 0.68, r:0.86, f: 0.76, roc_auc: 0.67
Epoch: 10 	 Training Loss: 0.075948
Epoch: 10 	 Validation p: 0.81, r:0.72, f: 0.76, roc_auc: 0.80
Epoch: 20 	 Training Loss: 0.012533
Epoch: 20 	 Validation p: 0.80, r:0.73, f: 0.77, roc_auc: 0.80
Epoch: 30 	 Training Loss: 0.005566
Epoch: 30 	 Validation p: 0.79, r:0.74, f: 0.77, roc_auc: 0.80
Epoch: 40 	 Training Loss: 0.003681
Epoch: 40 	 Validation p: 0.79, r:0.74, f: 0.76, roc_auc: 0.80
Epoch: 50 	 Training Loss: 0.002958
Epoch: 50 	 Validation p: 0.79, r:0.74, f: 0.76, roc_auc: 0.80
Epoch: 60 	 Training Loss: 0.002613
Epoch: 60 	 Validation p: 0.79, r:0.74, f: 0.76, roc_auc: 0.80
Epoch: 70 	 Training Loss: 0.002289
Epoch: 70 	 Validation p: 0.79, r:0.74, f: 0.76, roc_auc: 0.80
Epoch: 80 	 Training Loss: 0.002281
Epoch: 80 	 Validation p: 0.79, r:0.74, f: 0.76, roc_auc: 0.80
Epoch: 90 	 Training Loss: 0.002101
Epoch: 90 	 Validation p: 0.79, r:0.74, f: 0.76, roc_auc: 0.80


0.8

### 8. Evaluate

In [298]:
print(eval(retain, val_loader))
print(retain)

for x, masks, rev_x, rev_masks, y in val_loader:
    y_hat = retain(x, masks, rev_x, rev_masks)
    y_hat = y_hat >= 0.5
    print(y_hat)

(0.7876106194690266, 0.7416666666666667, 0.7639484978540773, 0.80125)
RETAIN(
  (embedding): Embedding(619, 128)
  (rnn_a): GRU(128, 128, batch_first=True)
  (rnn_b): GRU(128, 128, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): BetaAttention(
    (b_att): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)
tensor([False,  True,  True,  True, False,  True,  True, False, False,  True,
         True, False, False,  True,  True,  True,  True, False, False, False,
        False, False, False, False,  True,  True, False,  True,  True, False,
         True,  True])
tensor([False,  True, False, False, False,  True,  True,  True,  True,  True,
        False,  True, False, False, False, False,  True, False, False,  True,
         True, False,  True,  True,  True,  True,  True,  True, False, False,
         True,  True])
tensor([Fals

### 9. Save the model

In [299]:
torch.save(retain, "model.pt")
