# RETAIN - Reverse Time Attention Model

## 1. Overview


This notebook implements **RETAIN**, a recurrent neural network with reverse-time attention for interpretable healthcare predictions, allowing us to identify which medical events most influence the model's output.

In [1]:
import os
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# Set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "data2"

# Suppress warnings about ragged sequences
import warnings
warnings.filterwarnings('ignore', message='.*creating an ndarray from ragged nested sequences.*')

### About Raw Data

- Dataset: Synthetic data based on [MIMIC-III](https://mimic.physionet.org/gettingstarted/access/)  
- Input: Sequences of diagnosis codes for each patient  
- Task: Predict Heart Failure occurrence  
- Data is already preprocessed and ready to be loaded into the model.


In [2]:
pids = pickle.load(open(os.path.join(DATA_PATH,'train/pids.pkl'), 'rb')) 
vids = pickle.load(open(os.path.join(DATA_PATH,'train/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'train/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'train/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'train/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'train/rtypes.pkl'), 'rb'))

# Print lengths to verify
print("Number of patients (pids):", len(pids))
print("Number of visits (vids):", len(vids))
print("Number of heart failure labels (hfs):", len(hfs))
print("Number of sequences (seqs):", len(seqs))
print("Number of diagnosis types (types):", len(types))
print("Number of related types (rtypes):", len(rtypes))

Number of patients (pids): 1000
Number of visits (vids): 1000
Number of heart failure labels (hfs): 1000
Number of sequences (seqs): 1000
Number of diagnosis types (types): 619
Number of related types (rtypes): 619


where

- `pids`: contains the patient ids
- `vids`: contains a list of visit ids for each patient
- `hfs`: contains the heart failure label (0: normal, 1: heart failure) for each patient
- `seqs`: contains a list of visit (in ICD9 codes) for each patient
- `types`: contains the map from ICD9 codes to ICD-9 labels
- `rtypes`: contains the map from ICD9 labels to ICD9 codes


In [3]:
# take the 3rd patient as an example

print("Patient ID:", pids[3])
print("Heart Failure:", hfs[3])
print("# of visits:", len(vids[3]))
for visit in range(len(vids[3])):
    print(f"\t{visit}-th visit id:", vids[3][visit])
    print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
    print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 47537
Heart Failure: 0
# of visits: 2
	0-th visit id: 0
	0-th visit diagnosis labels: [12, 103, 262, 285, 290, 292, 359, 416, 39, 225, 275, 294, 326, 267, 93]
	0-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_518', 'DIAG_560', 'DIAG_567', 'DIAG_569', 'DIAG_707', 'DIAG_785', 'DIAG_155', 'DIAG_456', 'DIAG_537', 'DIAG_571', 'DIAG_608', 'DIAG_529', 'DIAG_263']
	1-th visit id: 1
	1-th visit diagnosis labels: [12, 103, 240, 262, 290, 292, 319, 359, 510, 513, 577, 307, 8, 280, 18, 131]
	1-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_482', 'DIAG_518', 'DIAG_567', 'DIAG_569', 'DIAG_599', 'DIAG_707', 'DIAG_995', 'DIAG_998', 'DIAG_V09', 'DIAG_584', 'DIAG_031', 'DIAG_553', 'DIAG_070', 'DIAG_305']


### Data Overview

- `seqs` is a **3-level nested list**:  
  - `seqs[i][j][k]` gives the **k-th diagnosis code** for the **j-th visit** of the **i-th patient**.  
  - Example: `seqs[0][0]` → diagnosis codes of the **first visit of the first patient**.  
  - ICD9 codes like `DIAG_276` can be looked up online (e.g., *disorders of fluid electrolyte and acid-base balance*).

- `hfs` is a **list of heart failure labels**:  
  - `1` → patient has heart failure  
  - `0` → patient does not have heart failure  

- **Number of heart failure patients in the training set:**  
  - `sum(hfs)` gives the total number of HF patients  
  - Fraction of HF patients: `sum(hfs) / len(hfs)`  


In [4]:
print("number of heart failure patients:", sum(hfs))
print("ratio of heart failure patients: %.2f" % (sum(hfs) / len(hfs)))

number of heart failure patients: 548
ratio of heart failure patients: 0.55


## 2. Build the dataset

### - CustomDataset

In [5]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, seqs, hfs):
        """
        Store seqs and hfs as lists. Do NOT convert to np.array since sequences are ragged.
        """
        self.x = seqs  # keep as list
        self.y = hfs
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]
        return x, y

# Create dataset instance
dataset = CustomDataset(seqs, hfs)

# Quick check
print("Number of patients in dataset:", len(dataset))
print("Example patient first visit:", dataset[0][0])
print("Example patient label:", dataset[0][1])


Number of patients in dataset: 1000
Example patient first visit: [[85, 112, 346, 380, 269, 511, 114, 103, 530, 597, 511], [85, 103, 112, 513, 511, 19, 149, 530, 186, 66]]
Example patient label: 1


In [6]:
dataset = CustomDataset(seqs, hfs)

print(len(dataset))

1000


### Collate Function

In this section, we define a **collate function (`collate_fn`)** for batch training RNNs on patient diagnosis sequences.

**Key points:**

1. **Variable-length sequences:**  
   - Each patient has multiple visits, each with a varying number of diagnosis codes.  
   - Example: `seqs[i][j][k]` → k-th code of j-th visit of i-th patient.

2. **Padding:**  
   - Pad visits and diagnosis codes with `0` so all patients in a batch have the same shape.  
   - Ensures tensor compatibility for batch processing.

3. **Masking:**  
   - Create a mask with `1` for original codes and `0` for padded values.  
   - Allows the model to **ignore padded values** during training.

4. **Reversed sequences:**  
   - Flip visits in time (only true visits) for bi-directional RNNs.  
   - Create a reversed mask correspondingly.

In [7]:
def collate_fn(data):
    """
    Collate a list of samples into a batch for RNN training.

    Args:
        data: list of samples from CustomDataset, each (sequence, label)
    
    Returns:
        x: tensor (#patients, max_visits, max_codes), type=torch.long
        masks: tensor (#patients, max_visits, max_codes), type=torch.bool
        rev_x: same as x but reversed in time
        rev_masks: same as masks but reversed in time
        y: tensor (#patients), type=torch.float
    """
    sequences, labels = zip(*data)
    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    max_num_visits = max(len(patient) for patient in sequences)
    max_num_codes = max(len(visit) for patient in sequences for visit in patient)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros_like(x)
    masks = torch.zeros_like(x, dtype=torch.bool)
    rev_masks = torch.zeros_like(x, dtype=torch.bool)
    
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            visit_tensor = torch.tensor(visit, dtype=torch.long)
            x[i_patient, j_visit, :len(visit)] = visit_tensor
            rev_x[i_patient, len(patient)-j_visit-1, :len(visit)] = visit_tensor
            masks[i_patient, j_visit, :len(visit)] = True
            rev_masks[i_patient, len(patient)-j_visit-1, :len(visit)] = True
    
    return x, masks, rev_x, rev_masks, y


In [8]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)

print(x.dtype)
print(y.dtype)
print(masks.dtype)
print(x.shape)
print(y.shape)



torch.int64
torch.float32
torch.bool
torch.Size([10, 3, 24])
torch.Size([10])


Now we have `CustomDataset` and `collate_fn()`. I split the dataset into training and validation sets.

In [9]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 800
Length of val dataset: 200


### DataLoader

In [10]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

## 3. RETAIN

RETAIN is essentially a RNN model with attention mechanism.
 
The idea of attention is quite simple: it boils down to weighted averaging. Let us consider machine translation in class as an example. When generating a translation of a source text, we first pass the source text through an encoder (an LSTM or an equivalent model) to obtain a sequence of encoder hidden states $\boldsymbol{h}_1, \dots, \boldsymbol{h}_T$. Then, at each step of generating a translation (decoding), we selectively attend to these encoder hidden states, that is, we construct a context vector $\boldsymbol{c}_i$ that is a weighted average of encoder hidden states.

$$\boldsymbol{c}_i = \underset{j}{\Sigma} a_{ij}\boldsymbol{h}_j$$

We choose the weights $a_{ij}$ based both on encoder hidden states $\boldsymbol{h}_1, \dots, \boldsymbol{h}_T$ and decoder hidden states $\boldsymbol{s}_1, \dots, \boldsymbol{s}_T$ and normalize them so that they encode a categorical probability distribution $p(\boldsymbol{h}_j | \boldsymbol{s}_i)$.

$$\boldsymbol{a}_{i} = \text{Softmax}\left( a(\boldsymbol{s}_i, \boldsymbol{h}_j) \right)$$

RETAIN has two different attention mechanisms. 
- One is to help figure out what are the important visits. This attention $\alpha_i$, which is scalar for the i-th visit, tells you the importance of the i-th visit.
- Then we have another similar attention mechanism. But in this case, this attention ways $\mathbf{\beta}_i$ is a vector. That gives us a more detailed view of underlying cause of the input. That is, which are the important features within a visit.

<img src=./img3/retain-1.png>

Unfolded view of RETAIN’s architecture: Given input sequence $\mathbf{x}_1 , . . . , \mathbf{x}_i$, we predict the label $\mathbf{y}_i$. 
- Step 1: Embedding, 
- Step 2: generating $\alpha$ values using RNN-$\alpha$, 
- Step 3: generating $\mathbf{\beta}$ values using RNN-$\beta$, 
- Step 4: Generating the context vector using attention and representation vectors, 
- Step 5: Making prediction. 

Note that in Steps 2 and 3 we use RNN in the reversed time.

<img src=./img3/retain-2.png>

<img src=./img3/retain-3.png>

- **AlphaAttention**

In [11]:
class AlphaAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.a_att = nn.Linear(hidden_dim, 1)

    def forward(self, g):
        """
        Arguments:
            g: Tensor of shape (batch_size, seq_length, hidden_dim) 
               Output of RNN-alpha

        Returns:
            alpha: Tensor of shape (batch_size, seq_length, 1)
                   Normalized attention weights over visits
        """
        a = self.a_att(g)
        alpha = torch.softmax(a, dim=1)

        return alpha


- **BetaAttention**

In [12]:
class BetaAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.b_att = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h):
        """
        Arguments:
            h: Tensor of shape (batch_size, seq_length, hidden_dim)
               Output of RNN-beta

        Returns:
            beta: Tensor of shape (batch_size, seq_length, hidden_dim)
                  Feature-wise attention weights
        """
        beta = torch.tanh(self.b_att(h))
        return beta


- **Attention Sum** 

In [13]:
def attention_sum(alpha, beta, rev_v, rev_masks):
    """
    Compute RETAIN context vector c = sum_t alpha_t * (beta_t ⊙ v_t),
    while ignoring padded visits.

    Args:
        alpha: tensor (B, T, 1)      -- visit-level attention weights
        beta:  tensor (B, T, H)      -- feature-level attention vectors
        rev_v: tensor (B, T, H)      -- visit embeddings in reversed time
        rev_masks: tensor (B, T, C)  -- padded masks (per-code) in reversed time

    Returns:
        c: tensor (B, H) -- context vector per patient
    """
    v_alpha = beta * alpha     

    visit_mask = rev_masks.sum(dim=2) > 0 
    
    visit_mask_f = visit_mask.unsqueeze(-1).to(dtype=rev_v.dtype)

    masked_v = rev_v * visit_mask_f

    v_alpha_masked = v_alpha * masked_v

    c = v_alpha_masked.sum(dim=1)

    return c


- **Build RETAIN**

Now, we can build the RETAIN model.

In [14]:
def sum_embeddings_with_mask(x, masks):
    """
    Mask select the embeddings for true visits (not padding visits) and then sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
    """
    
    x = x * masks.unsqueeze(-1)
    x = torch.sum(x, dim = -2)
    return x

In [15]:
class RETAIN(nn.Module):
    
    def __init__(self, num_codes, embedding_dim=128):
        super().__init__()
        # Define the embedding layer using `nn.Embedding`. Set `embDimSize` to 128.
        self.embedding = nn.Embedding(num_codes, embedding_dim)
        # Define the RNN-alpha using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_a = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the RNN-beta using `nn.GRU()`; Set `hidden_size` to 128. Set `batch_first` to True.
        self.rnn_b = nn.GRU(embedding_dim, embedding_dim, batch_first=True)
        # Define the alpha-attention using `AlphaAttention()`;
        self.att_a = AlphaAttention(embedding_dim)
        # Define the beta-attention using `BetaAttention()`;
        self.att_b = BetaAttention(embedding_dim)
        # Define the linear layers using `nn.Linear()`;
        self.fc = nn.Linear(embedding_dim, 1)
        # Define the final activation layer using `nn.Sigmoid().
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            rev_x: the diagnosis sequence in reversed time of shape (# visits, batch_size, # diagnosis codes)
            rev_masks: the padding masks in reversed time of shape (# visits, batch_size, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size)
        """
        # 1. Pass the reversed sequence through the embedding layer;
        rev_x = self.embedding(rev_x)
        # 2. Sum the reversed embeddings for each diagnosis code up for a visit of a patient.
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        # 3. Pass the reversed embegginds through the RNN-alpha and RNN-beta layer separately;
        g, _ = self.rnn_a(rev_x)
        h, _ = self.rnn_b(rev_x)
        # 4. Obtain the alpha and beta attentions using `AlphaAttention()` and `BetaAttention()`;
        alpha = self.att_a(g)
        beta = self.att_b(h)
        # 5. Sum the attention up using `attention_sum()`;
        c = attention_sum(alpha, beta, rev_x, rev_masks)
        # 6. Pass the context vector through the linear and activation layers.
        logits = self.fc(c)
        probs = self.sigmoid(logits)
        return probs.squeeze()
    

# load the model here
retain = RETAIN(num_codes = len(types))
retain

RETAIN(
  (embedding): Embedding(619, 128)
  (rnn_a): GRU(128, 128, batch_first=True)
  (rnn_b): GRU(128, 128, batch_first=True)
  (att_a): AlphaAttention(
    (a_att): Linear(in_features=128, out_features=1, bias=True)
  )
  (att_b): BetaAttention(
    (b_att): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

## 4. Training and Inferencing


In [16]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score

def eval_model(model, val_loader):
    """
    Evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
    """
    model.eval()
    y_pred = torch.tensor([], dtype=torch.long)
    y_score = torch.tensor([], dtype=torch.float)
    y_true = torch.tensor([], dtype=torch.long)

    for x, masks, rev_x, rev_masks, y in val_loader:
        y_hat = model(x, masks, rev_x, rev_masks)
        y_score = torch.cat((y_score, y_hat.detach().cpu()), dim=0)
        y_hat_bin = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred, y_hat_bin.detach().cpu()), dim=0)
        y_true = torch.cat((y_true, y.detach().cpu()), dim=0)

    # Calculate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)

    return precision, recall, f1, roc_auc


In [17]:
def train(model, train_loader, val_loader, n_epochs, optimizer, criterion):
    """
    Train the RNN model.

    Arguments:
        model: the RNN model
        train_loader: training dataloader
        val_loader: validation dataloader
        n_epochs: total number of epochs
        optimizer: optimizer
        criterion: loss function
    """
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0.0

        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            
            y_hat = model(x, masks, rev_x, rev_masks)
            y_hat = y_hat.view(y_hat.shape[0])
            
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        print(f'Epoch {epoch+1} \t Training Loss: {train_loss:.6f}')
        
        # Evaluate on validation set
        p, r, f, roc_auc = eval_model(model, val_loader)
        print(f'Epoch {epoch+1} \t Validation Precision: {p:.2f}, Recall: {r:.2f}, F1: {f:.2f}, ROC AUC: {roc_auc:.2f}')


In [18]:
# load the model
retain = RETAIN(num_codes = len(types))

# load the loss function
criterion = nn.BCELoss()
# load the optimizer
optimizer = torch.optim.Adam(retain.parameters(), lr=1e-3)
# number of epochs to train the model
n_epochs = 5
train(retain, train_loader, val_loader, n_epochs=5, optimizer=optimizer, criterion=criterion)
 

Epoch 1 	 Training Loss: 0.646113
Epoch 1 	 Validation Precision: 0.75, Recall: 0.83, F1: 0.78, ROC AUC: 0.83
Epoch 2 	 Training Loss: 0.469833
Epoch 2 	 Validation Precision: 0.77, Recall: 0.74, F1: 0.75, ROC AUC: 0.84
Epoch 3 	 Training Loss: 0.293150
Epoch 3 	 Validation Precision: 0.78, Recall: 0.81, F1: 0.79, ROC AUC: 0.83
Epoch 4 	 Training Loss: 0.148171
Epoch 4 	 Validation Precision: 0.77, Recall: 0.82, F1: 0.79, ROC AUC: 0.84
Epoch 5 	 Training Loss: 0.068694
Epoch 5 	 Validation Precision: 0.82, Recall: 0.79, F1: 0.80, ROC AUC: 0.85


## 5. Sensitivity analysis

I will train the same model but with different hyperparameters. I will be using 0.1 and 0.001 for learning rate, and 16, 128 for embedding dimensions. It shows how model performance varies with different values of learning rate and embedding dimensions.

In [19]:
lr_hyperparameter = [1e-1, 1e-3]
embedding_dim_hyperparameter = [8, 128]
n_epochs = 5
results = {}

for lr in lr_hyperparameter:
    for embedding_dim in embedding_dim_hyperparameter:
        print ('='*50)
        print ({'learning rate': lr, "embedding_dim": embedding_dim})
        print ('-'*50)
        """ 
        TODO: 
            1. Load the model by specifying `embedding_dim` as input to RETAIN. It will create different model with different embedding dimension.
            2. Load the loss function `nn.BCELoss`
            3. Load the optimizer `torch.optim.Adam` with learning rate using `lr` variable
        """
        # load the model
        retain = RETAIN(len(types), embedding_dim)

        # load the loss function
        criterion = nn.BCELoss()
        # load the optimizer
        optimizer = torch.optim.Adam(retain.parameters(), lr)
        n_epochs = 5
        train(retain, train_loader, val_loader, n_epochs=5, optimizer=optimizer, criterion=criterion)
        roc_auc = train(retain, train_loader, val_loader, n_epochs=5, optimizer=optimizer, criterion=criterion)
        results['lr:{},emb:{}'.format(str(lr), str(embedding_dim))] =  roc_auc

{'learning rate': 0.1, 'embedding_dim': 8}
--------------------------------------------------
Epoch 1 	 Training Loss: 0.671263
Epoch 1 	 Validation Precision: 0.72, Recall: 0.66, F1: 0.69, ROC AUC: 0.78
Epoch 2 	 Training Loss: 0.562615
Epoch 2 	 Validation Precision: 0.69, Recall: 0.84, F1: 0.76, ROC AUC: 0.81
Epoch 3 	 Training Loss: 0.531844
Epoch 3 	 Validation Precision: 0.66, Recall: 0.82, F1: 0.73, ROC AUC: 0.78
Epoch 4 	 Training Loss: 0.491145
Epoch 4 	 Validation Precision: 0.81, Recall: 0.56, F1: 0.66, ROC AUC: 0.78
Epoch 5 	 Training Loss: 0.464876
Epoch 5 	 Validation Precision: 0.68, Recall: 0.78, F1: 0.72, ROC AUC: 0.77
Epoch 1 	 Training Loss: 0.503799
Epoch 1 	 Validation Precision: 0.69, Recall: 0.81, F1: 0.74, ROC AUC: 0.76
Epoch 2 	 Training Loss: 0.456363
Epoch 2 	 Validation Precision: 0.69, Recall: 0.85, F1: 0.76, ROC AUC: 0.79
Epoch 3 	 Training Loss: 0.390751
Epoch 3 	 Validation Precision: 0.70, Recall: 0.65, F1: 0.67, ROC AUC: 0.76
Epoch 4 	 Training Loss: 0