# Recurrent Neural Network (RNN)

## 1. Overview

In this homework, we will build a **bi-directional RNN** to predict **Heart Failure** from patients' diagnosis codes.  
The recurrent nature of RNNs allows modeling **temporal relationships** between multiple visits of a patient.

### About Raw Data

- Dataset: Synthetic data based on [MIMIC-III](https://mimic.physionet.org/gettingstarted/access/)  
- Input: Sequences of diagnosis codes for each patient  
- Task: Predict Heart Failure occurrence  
- Data is already preprocessed and ready to be loaded into the model.


In [1]:
import os
import sys
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# Set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

# Define data path
DATA_PATH = "data2"

# Suppress warnings about ragged sequences
import warnings
warnings.filterwarnings('ignore', message='.*creating an ndarray from ragged nested sequences.*')


In [2]:
pids = pickle.load(open(os.path.join(DATA_PATH,'train/pids.pkl'), 'rb')) 
vids = pickle.load(open(os.path.join(DATA_PATH,'train/vids.pkl'), 'rb'))
hfs = pickle.load(open(os.path.join(DATA_PATH,'train/hfs.pkl'), 'rb'))
seqs = pickle.load(open(os.path.join(DATA_PATH,'train/seqs.pkl'), 'rb'))
types = pickle.load(open(os.path.join(DATA_PATH,'train/types.pkl'), 'rb'))
rtypes = pickle.load(open(os.path.join(DATA_PATH,'train/rtypes.pkl'), 'rb'))

# Print lengths to verify
print("Number of patients (pids):", len(pids))
print("Number of visits (vids):", len(vids))
print("Number of heart failure labels (hfs):", len(hfs))
print("Number of sequences (seqs):", len(seqs))
print("Number of diagnosis types (types):", len(types))
print("Number of related types (rtypes):", len(rtypes))

Number of patients (pids): 1000
Number of visits (vids): 1000
Number of heart failure labels (hfs): 1000
Number of sequences (seqs): 1000
Number of diagnosis types (types): 619
Number of related types (rtypes): 619


where

- `pids`: contains the patient ids
- `vids`: contains a list of visit ids for each patient
- `hfs`: contains the heart failure label (0: normal, 1: heart failure) for each patient
- `seqs`: contains a list of visit (in ICD9 codes) for each patient
- `types`: contains the map from ICD9 codes to ICD-9 labels
- `rtypes`: contains the map from ICD9 labels to ICD9 codes


In [3]:
# take the 3rd patient as an example

print("Patient ID:", pids[3])
print("Heart Failure:", hfs[3])
print("# of visits:", len(vids[3]))
for visit in range(len(vids[3])):
    print(f"\t{visit}-th visit id:", vids[3][visit])
    print(f"\t{visit}-th visit diagnosis labels:", seqs[3][visit])
    print(f"\t{visit}-th visit diagnosis codes:", [rtypes[label] for label in seqs[3][visit]])

Patient ID: 47537
Heart Failure: 0
# of visits: 2
	0-th visit id: 0
	0-th visit diagnosis labels: [12, 103, 262, 285, 290, 292, 359, 416, 39, 225, 275, 294, 326, 267, 93]
	0-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_518', 'DIAG_560', 'DIAG_567', 'DIAG_569', 'DIAG_707', 'DIAG_785', 'DIAG_155', 'DIAG_456', 'DIAG_537', 'DIAG_571', 'DIAG_608', 'DIAG_529', 'DIAG_263']
	1-th visit id: 1
	1-th visit diagnosis labels: [12, 103, 240, 262, 290, 292, 319, 359, 510, 513, 577, 307, 8, 280, 18, 131]
	1-th visit diagnosis codes: ['DIAG_041', 'DIAG_276', 'DIAG_482', 'DIAG_518', 'DIAG_567', 'DIAG_569', 'DIAG_599', 'DIAG_707', 'DIAG_995', 'DIAG_998', 'DIAG_V09', 'DIAG_584', 'DIAG_031', 'DIAG_553', 'DIAG_070', 'DIAG_305']


### Data Overview

- `seqs` is a **3-level nested list**:  
  - `seqs[i][j][k]` gives the **k-th diagnosis code** for the **j-th visit** of the **i-th patient**.  
  - Example: `seqs[0][0]` → diagnosis codes of the **first visit of the first patient**.  
  - ICD9 codes like `DIAG_276` can be looked up online (e.g., *disorders of fluid electrolyte and acid-base balance*).

- `hfs` is a **list of heart failure labels**:  
  - `1` → patient has heart failure  
  - `0` → patient does not have heart failure  

- **Number of heart failure patients in the training set:**  
  - `sum(hfs)` gives the total number of HF patients  
  - Fraction of HF patients: `sum(hfs) / len(hfs)`  


In [4]:
print("number of heart failure patients:", sum(hfs))
print("ratio of heart failure patients: %.2f" % (sum(hfs) / len(hfs)))

number of heart failure patients: 548
ratio of heart failure patients: 0.55


## 2. Build the dataset

### - CustomDataset

In [5]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    
    def __init__(self, seqs, hfs):
        """
        Store seqs and hfs as lists. Do NOT convert to np.array since sequences are ragged.
        """
        self.x = seqs  # keep as list
        self.y = hfs
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]
        return x, y

# Create dataset instance
dataset = CustomDataset(seqs, hfs)

# Quick check
print("Number of patients in dataset:", len(dataset))
print("Example patient first visit:", dataset[0][0])
print("Example patient label:", dataset[0][1])


Number of patients in dataset: 1000
Example patient first visit: [[85, 112, 346, 380, 269, 511, 114, 103, 530, 597, 511], [85, 103, 112, 513, 511, 19, 149, 530, 186, 66]]
Example patient label: 1


In [6]:
dataset = CustomDataset(seqs, hfs)

print(len(dataset))

1000


### Collate Function

In this section, we define a **collate function (`collate_fn`)** for batch training RNNs on patient diagnosis sequences.

**Key points:**

1. **Variable-length sequences:**  
   - Each patient has multiple visits, each with a varying number of diagnosis codes.  
   - Example: `seqs[i][j][k]` → k-th code of j-th visit of i-th patient.

2. **Padding:**  
   - Pad visits and diagnosis codes with `0` so all patients in a batch have the same shape.  
   - Ensures tensor compatibility for batch processing.

3. **Masking:**  
   - Create a mask with `1` for original codes and `0` for padded values.  
   - Allows the model to **ignore padded values** during training.

4. **Reversed sequences:**  
   - Flip visits in time (only true visits) for bi-directional RNNs.  
   - Create a reversed mask correspondingly.

In [7]:
def collate_fn(data):
    """
    Collate a list of samples into a batch for RNN training.

    Args:
        data: list of samples from CustomDataset, each (sequence, label)
    
    Returns:
        x: tensor (#patients, max_visits, max_codes), type=torch.long
        masks: tensor (#patients, max_visits, max_codes), type=torch.bool
        rev_x: same as x but reversed in time
        rev_masks: same as masks but reversed in time
        y: tensor (#patients), type=torch.float
    """
    sequences, labels = zip(*data)
    y = torch.tensor(labels, dtype=torch.float)
    
    num_patients = len(sequences)
    max_num_visits = max(len(patient) for patient in sequences)
    max_num_codes = max(len(visit) for patient in sequences for visit in patient)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros_like(x)
    masks = torch.zeros_like(x, dtype=torch.bool)
    rev_masks = torch.zeros_like(x, dtype=torch.bool)
    
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            visit_tensor = torch.tensor(visit, dtype=torch.long)
            x[i_patient, j_visit, :len(visit)] = visit_tensor
            rev_x[i_patient, len(patient)-j_visit-1, :len(visit)] = visit_tensor
            masks[i_patient, j_visit, :len(visit)] = True
            rev_masks[i_patient, len(patient)-j_visit-1, :len(visit)] = True
    
    return x, masks, rev_x, rev_masks, y


In [8]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)

print(x.dtype)
print(y.dtype)
print(masks.dtype)
print(x.shape)
print(y.shape)



torch.int64
torch.float32
torch.bool
torch.Size([10, 3, 24])
torch.Size([10])


Now we have `CustomDataset` and `collate_fn()`. I split the dataset into training and validation sets.

In [9]:
from torch.utils.data.dataset import random_split

split = int(len(dataset)*0.8)

lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 800
Length of val dataset: 200


### DataLoader

In [10]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    batch_size = 32
    train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

## 3. Naive Bi-directional RNN

In this section, I implement a **naive bi-directional RNN** for heart failure prediction.

**Key steps:**

1. **Embedding layer:**  
   - Transform diagnosis codes of each visit into dense vectors.  
   - Use `nn.Embedding(num_embeddings, embedding_dim)`.  
   - `num_embeddings` = total number of unique diagnosis codes.  
   - `embedding_dim` = size of the embedding vector for each code.

2. **Bi-directional RNN:**  
   - Processes the embedded sequences in both forward and backward directions.  
   - Captures temporal dependencies across visits in both directions.

3. **Output:**  
   - Aggregated hidden states → Linear layer → Binary classification (heart failure: yes/no).


**Summary:**  
- Each patient sequence → Embedding → Bi-RNN → Linear → Prediction.  
- Using masks ensures that padded visits/codes do not affect the model.


- **3.1 Mask Selection**

In [11]:
def sum_embeddings_with_mask(x, masks):
 
    masks = masks.unsqueeze(-1)
    mask_embeddings = x * masks
    result = mask_embeddings.sum(dim=2)

    return result

In [12]:
import random
import ast
import inspect


def uses_loop(function):
    loop_statements = ast.For, ast.While, ast.AsyncFor

    nodes = ast.walk(ast.parse(inspect.getsource(function)))
    return any(isinstance(node, loop_statements) for node in nodes)

def generate_random_mask(batch_size, max_num_visits , max_num_codes):
    num_visits = [random.randint(1, max_num_visits) for _ in range(batch_size)]
    num_codes = []
    for n in num_visits:
        num_codes_visit = [0] * max_num_visits
        for i in range(n):
            num_codes_visit[i] = (random.randint(1, max_num_codes))
        num_codes.append(num_codes_visit)
    masks = [torch.ones((l,), dtype=torch.bool) for num_codes_visit in num_codes for l in num_codes_visit]
    masks = torch.stack([torch.cat([i, i.new_zeros(max_num_codes - i.size(0))], 0) for i in masks], 0)
    masks = masks.view((batch_size, max_num_visits, max_num_codes)).bool()
    return masks


batch_size = 16
max_num_visits = 10
max_num_codes = 20
embedding_dim = 100

torch.random.manual_seed(7)
x = torch.randn((batch_size, max_num_visits , max_num_codes, embedding_dim))
masks = generate_random_mask(batch_size, max_num_visits , max_num_codes)
out = sum_embeddings_with_mask(x, masks)

assert uses_loop(sum_embeddings_with_mask) is False
assert out.shape == (batch_size, max_num_visits, embedding_dim)




In [13]:
def get_last_visit(hidden_states, masks):
    """
    Obtain the hidden state for the last true visit (not padding visits).

    Arguments:
        hidden_states: tensor of shape (batch_size, # visits, embedding_dim)
        masks: tensor of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        last_hidden_state: tensor of shape (batch_size, embedding_dim)
    """ 
    visit_mask = masks.any(dim=2)  
    lengths = visit_mask.sum(dim=1) 
    batch_indices = torch.arange(hidden_states.size(0))
    last_hidden_state = hidden_states[batch_indices, lengths-1, :]   
    
    return last_hidden_state


In [14]:
print(uses_loop(get_last_visit))

max_num_visits = 10
batch_size = 16
max_num_codes = 20
embedding_dim = 100

torch.random.manual_seed(7)
hidden_states = torch.randn((batch_size, max_num_visits, embedding_dim))
masks = generate_random_mask(batch_size, max_num_visits , max_num_codes)
out = get_last_visit(hidden_states, masks)

print(out.shape) 


False
torch.Size([16, 100])


- **3.2 Build NaiveRNN**

In [15]:
class NaiveRNN(nn.Module):
    def __init__(self, num_codes):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=num_codes, embedding_dim=128)
        self.rnn = nn.GRU(input_size=128, hidden_size=128, batch_first=True)
        self.rev_rnn = nn.GRU(input_size=128, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(256, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, masks, rev_x, rev_masks):
        batch_size = x.shape[0]
        
        # Forward direction
        x = self.embedding(x)
        x = sum_embeddings_with_mask(x, masks)
        output, _ = self.rnn(x)
        true_h_n = get_last_visit(output, masks)
        
        # Reverse direction
        rev_x = self.embedding(rev_x)
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        output_rev, _ = self.rev_rnn(rev_x)
        true_h_n_rev = get_last_visit(output_rev, rev_masks)
        
        # Concatenate hidden states and pass through linear + sigmoid
        logits = self.fc(torch.cat([true_h_n, true_h_n_rev], dim=1))
        probs = self.sigmoid(logits)
        return probs.view(batch_size)
    

naive_rnn = NaiveRNN(num_codes=len(types))
naive_rnn


NaiveRNN(
  (embedding): Embedding(619, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

## 4.  Model Training

**Loss and Optimizer**

In [16]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr=1e-3)

**Evaluate**

In [17]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
import torch

def eval_model(model, val_loader):
    """
    Evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
    """
    model.eval()
    y_pred = torch.tensor([], dtype=torch.long)
    y_score = torch.tensor([], dtype=torch.float)
    y_true = torch.tensor([], dtype=torch.long)

    for x, masks, rev_x, rev_masks, y in val_loader:
        y_hat = model(x, masks, rev_x, rev_masks)
        y_score = torch.cat((y_score, y_hat.detach().cpu()), dim=0)
        y_hat_bin = (y_hat > 0.5).int()
        y_pred = torch.cat((y_pred, y_hat_bin.detach().cpu()), dim=0)
        y_true = torch.cat((y_true, y.detach().cpu()), dim=0)

    # Calculate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
    roc_auc = roc_auc_score(y_true, y_score)

    return precision, recall, f1, roc_auc


**Training and evlauation**

In [18]:
def train(model, train_loader, val_loader, n_epochs, optimizer, criterion):
    """
    Train the RNN model.

    Arguments:
        model: the RNN model
        train_loader: training dataloader
        val_loader: validation dataloader
        n_epochs: total number of epochs
        optimizer: optimizer
        criterion: loss function
    """
    for epoch in range(n_epochs):
        model.train()
        train_loss = 0.0

        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            
            y_hat = model(x, masks, rev_x, rev_masks)
            y_hat = y_hat.view(y_hat.shape[0])
            
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        print(f'Epoch {epoch+1} \t Training Loss: {train_loss:.6f}')
        
        # Evaluate on validation set
        p, r, f, roc_auc = eval_model(model, val_loader)
        print(f'Epoch {epoch+1} \t Validation Precision: {p:.2f}, Recall: {r:.2f}, F1: {f:.2f}, ROC AUC: {roc_auc:.2f}')


In [19]:
# number of epochs to train the model
n_epochs = 5
train(naive_rnn, train_loader, val_loader, n_epochs=5, optimizer=optimizer, criterion=criterion)


Epoch 1 	 Training Loss: 0.616105
Epoch 1 	 Validation Precision: 0.71, Recall: 0.89, F1: 0.79, ROC AUC: 0.84
Epoch 2 	 Training Loss: 0.436958
Epoch 2 	 Validation Precision: 0.71, Recall: 0.84, F1: 0.77, ROC AUC: 0.84
Epoch 3 	 Training Loss: 0.331513
Epoch 3 	 Validation Precision: 0.73, Recall: 0.80, F1: 0.76, ROC AUC: 0.85
Epoch 4 	 Training Loss: 0.227049
Epoch 4 	 Validation Precision: 0.73, Recall: 0.85, F1: 0.79, ROC AUC: 0.85
Epoch 5 	 Training Loss: 0.141484
Epoch 5 	 Validation Precision: 0.74, Recall: 0.86, F1: 0.79, ROC AUC: 0.85


In [20]:
p, r, f, roc_auc = eval_model(naive_rnn, val_loader)

print("Precision:", p)
print("Recall:", r)
print("F1 Score:", f)
print("ROC AUC:", roc_auc)


Precision: 0.7355371900826446
Recall: 0.8640776699029126
F1 Score: 0.7946428571428571
ROC AUC: 0.8537683915523971
