# LSTM model for Sequential Medical Data & RNNs comparison

### Our Goal:

1.  **Build a sequential patient history**: Instead of one averaged vector per patient, we will create a *sequence of vectors* for each patient, where each vector represents a single hospital visit.
2.  **Use a Recurrent Neural Network (RNN)**: We will feed this 3D data `(patients, timesteps, features)` into an **LSTM** (Long Short-Term Memory) network.
3.  **Compare Performance**: We will show that thios sequence-aware model outperform previous "bag-of-words" models for the 1-year mortality prediction task and compare it with GRU and Bidirectional LSTM

# ðŸ“‚ **Data Access & Setup**

This project uses the **MIMIC-IV v3.1** dataset. Due to the sensitive nature of clinical data and regulation, the raw datasets are not included in this repository.

**1. Requesting Access**
To run this pipeline, you must have a signed Data Use Agreement (DUA):

Training: Complete the [CITI Data or Specimens Researchers training](https://about.citiprogram.org/)

PhysioNet: Create an account and request access via the [MIMIC-IV PhysioNet Page](https://physionet.org/content/mimiciv/3.1/)

**2. Local Setup**
Once access is granted, download the following files and place them in a folder named data/ in the root of this project:

hosp/patients.csv.gz

hosp/admissions.csv.gz

hosp/diagnoses_icd.csv.gz

hosp/procedures_icd.csv.gz

icu/icustays.csv.gz

**3. Running on Google Colab**
If using Google Colab, upload these files to your Google Drive and update the data_path variable at the top of the notebook to point to your Drive folder.


- - -


We use the **Med2Vec model built in the last notebook** called 'Word2Vec for Clinical Codes Embedding & Logistic Regression 1-year mortality prediction'

## Setup: Import Libraries

We will need `tensorflow` and `keras` to build our LSTM. You will also need `gensim` to load the model from the previous notebook (Word2Vec for Clinical Codes Embedding & Logistic Regression 1-year mortality prediction).

In [None]:
%pip install scipy tensorflow gensim scikit-learn pandas numpy

In [None]:
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
tqdm.pandas()
import time
# To load our pre-trained embeddings
from gensim.models import Word2Vec

# For splitting and class weights
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from sklearn.metrics import confusion_matrix, classification_report

# For building the RNN
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, LSTM, GRU, Dense, Dropout, Masking, Bidirectional
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.metrics import AUC
import matplotlib.pyplot as plt
import seaborn as sns

# Data
import pickle

## Part 1: Load our Med2vec Model and Define Cohort

First, we load the `med2vec_model.wv` we trained in the previous notebook (Word2Vec for Clinical Codes Embedding & Logistic Regression 1-year mortality prediction). We will use this as our embedding lookup.

Second, we will re-build our 1-year mortality cohort, exactly as we did in the previous notebooks.

In [None]:
data_path = '' # write your data path

EMBEDDING_MODEL_PATH = "med2vec_mimic4.w2v" # From previous notebook

# Load Pre-trained Embeddings
try:
    med2vec_model = Word2Vec.load(EMBEDDING_MODEL_PATH)
    embedding_dim = med2vec_model.vector_size
    model_vocab = set(med2vec_model.wv.key_to_index.keys())
    print(f"Loaded Med2Vec model with {len(model_vocab)} codes and {embedding_dim}-dim vectors.")
except FileNotFoundError:
    print(f"Error: '{EMBEDDING_MODEL_PATH}' not found. Please run previsou lab.")
    # Create dummy vars to allow notebook to be read
    med2vec_model = None
    embedding_dim = 100
    model_vocab = set()

# Load and Define Cohort
print("Rebuilding cohort from previsous lab...")
patients = pd.read_csv(os.path.join(data_path,'patients.csv.gz'))
admissions = pd.read_csv(os.path.join(data_path,'admissions.csv.gz'))
icustays = pd.read_csv(os.path.join(data_path,'icustays.csv.gz'))

patients['dod'] = pd.to_datetime(patients['dod'])
admissions['admittime'] = pd.to_datetime(admissions['admittime'])
admissions['dischtime'] = pd.to_datetime(admissions['dischtime'])
icustays['intime'] = pd.to_datetime(icustays['intime'])
icustays['outtime'] = pd.to_datetime(icustays['outtime'])

admissions = admissions.merge(patients[['subject_id', 'anchor_age', 'anchor_year', 'dod']], on='subject_id')
admissions['age_at_admission'] = admissions['anchor_age'] + (admissions['admittime'].dt.year - admissions['anchor_year'])
adult_admissions = admissions[admissions['age_at_admission'] >= 18]
adult_icu = icustays.merge(adult_admissions, on=['subject_id', 'hadm_id'])
adult_icu = adult_icu.sort_values(by='intime')
first_icu_stays = adult_icu.groupby('subject_id').first().reset_index()
cohort = first_icu_stays[first_icu_stays['deathtime'].isnull()].copy()
cohort["index_date"] = cohort['outtime']

cohort["time_to_death"] = (cohort['dod'] - cohort['index_date']).dt.days
cohort['label_1yr_mortality'] = (cohort["time_to_death"] <= 365) & (cohort["time_to_death"] > 0)

# This is our final cohort: subject_id and their label
final_cohort = cohort[['subject_id', 'label_1yr_mortality']].set_index('subject_id')

print(f"Cohort of {len(final_cohort)} patients recreated.")

## Part 2: Building Sequential Data

This is the most important new step. We will create patient histories as a **sequence of visits**.

1.  Load all codes (diag, proc, presc) for *all* admissions.
2.  Create a "visit vector" for *every* admission (`hadm_id`) by averaging the embeddings of all codes in that visit (just like we did for the *patient* vector in the previous lab).
3.  Find all hospital admissions for the patients in our cohort, *sorted by time*.
4.  Group these ordered visits by patient to create our final sequences (a list of visit vectors for each patient).

In [None]:
# Load all features
print("Loading all features...")
diagnoses = pd.read_csv(os.path.join(data_path,'diagnoses_icd.csv.gz'))
procedures = pd.read_csv(os.path.join(data_path,'procedures_icd.csv.gz'))
prescriptions = pd.read_csv(os.path.join(data_path,'prescriptions.csv.gz'))

diag_features = diagnoses[['hadm_id', 'icd_code']].rename(columns={'icd_code':'feature'})
proc_features = procedures[['hadm_id', 'icd_code']].rename(columns={'icd_code':'feature'})
presc_features = prescriptions[['hadm_id', 'drug']].rename(columns={'drug':'feature'})

all_features = pd.concat([diag_features, proc_features, presc_features]).dropna()

# Create Visit Vectors for ALL hadm_ids
print("Building visit-level corpus...")
# Group all codes by hadm_id
visit_corpus = all_features.groupby('hadm_id')['feature'].apply(list)

# Function to average embeddings for a list of codes
def get_avg_vector(codes, model_wv, vocab, dim):
    vec = np.zeros(dim)
    count = 0
    if isinstance(codes, list):
        for code in codes:
            if code in vocab:
                vec += model_wv.wv[code]
                count += 1
    if count > 0:
        vec /= count
    return vec

print("Calculating average embedding for every visit...")
visit_vectors = visit_corpus.progress_apply(
    get_avg_vector,
    args=(med2vec_model, model_vocab, embedding_dim)
)

# We now have a mapping: {hadm_id: [0.1, -0.4, ...], ...}
visit_vector_map = dict(zip(visit_vectors.index, visit_vectors.values))
print("Visit vector map created.")

# Load the dictionary from the file
with open('visit_vector_map.pkl', 'rb') as f:
    visit_vector_map = pickle.load(f)


In [None]:
# Save the dictionary to a file
with open('visit_vector_map.pkl', 'wb') as f:
    pickle.dump(visit_vector_map, f)

print("âœ… Visit vectors saved successfully!")

In [None]:
# Get Ordered Visit Sequences for Cohort Patients
print("Building ordered patient sequences...")
# Get all admissions for the patients in our cohort
cohort_subjects = final_cohort.index.unique()
all_cohort_admissions = admissions[
    admissions['subject_id'].isin(cohort_subjects)
][['subject_id', 'hadm_id', 'admittime']]

In [None]:
# Sort by patient and then by admission time
all_cohort_admissions = all_cohort_admissions.sort_values(by=['subject_id', 'admittime'])

# Group by subject_id to get the list of ordered hadm_ids
ordered_hadm_sequences = all_cohort_admissions.groupby('subject_id')['hadm_id'].apply(list)

# Create Final (X, y) Data
print("Mapping visit vectors to patient sequences...")
X_sequences = []
y_labels = []
default_vector = np.zeros(embedding_dim)

# Iterate through our ordered patient sequences
for subject_id, hadm_ids in tqdm(ordered_hadm_sequences.items(), desc="Mapping sequences"):
    patient_sequence = []
    for hadm_id in hadm_ids:
        # Get the pre-calculated vector for that visit
        # Use a default (zero) vector if hadm_id had no valid codes
        visit_vec = visit_vector_map.get(hadm_id, default_vector)
        patient_sequence.append(visit_vec)

    # Only add if the patient is in our final (labeled) cohort
    if subject_id in final_cohort.index:
        X_sequences.append(patient_sequence)
        y_labels.append(final_cohort.loc[subject_id]['label_1yr_mortality'])

# 'X_sequences' is a list of lists of vectors. 'y_labels' is a list of 0s and 1s.
y = np.array(y_labels).astype(int)

In [None]:
# Calculate sequence lengths
sequence_lengths = [len(seq) for seq in X_sequences]

# Statistics
min_visits = min(sequence_lengths)
max_visits = max(sequence_lengths)
mean_visits = np.mean(sequence_lengths)
median_visits = np.median(sequence_lengths)

print(f"Minimum number of visits: {min_visits}")
print(f"Maximum number of visits: {max_visits}")
print(f"Mean: {mean_visits:.2f}")
print(f"Median: {median_visits:.1f}")

# Distribution
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.hist(sequence_lengths, bins=50, color='steelblue', edgecolor='black')
plt.axvline(mean_visits, color='red', linestyle='--', label=f'Mean: {mean_visits:.1f}')
plt.xlabel('Number of Visits')
plt.ylabel('Number of Patients')
plt.title('Distribution of Visit Sequences per Patient')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

## Part 3: Padding the Sequences

As we see above, our patient sequences have different lengths. To train an RNN in a batch, Keras requires all sequences in a batch to have the **same length**.

We use `pad_sequences` to fix this.
* It will find the longest sequence in the dataset (e.g., 50 visits).
* It will pad all shorter sequences with zero-vectors at the **beginning** (this is `padding='pre'`).

Our final data `X` will be a 3D NumPy array of shape:
**(num_patients, max_sequence_length, num_features)**
e.g., `(45000, 50, 100)`

In [None]:
X_padded = pad_sequences(
    X_sequences,
    padding='pre',  # Pad at the beginning of the sequence
    truncating='pre', # Truncate from the beginning if too long
    dtype='float32' # Keras needs a consistent float type
)

print(f"Shape of X after padding: {X_padded.shape}")

X = X_padded
# 'y' is already our aligned NumPy array of labels

## Part 4: Building the LSTM Model

Now we build our sequence model. We will use a `Sequential` Keras model.

In [None]:
# Get the shape of our input
# X.shape[1] = max_sequence_length (e.g., 50)
# X.shape[2] = n_features (e.g., 100)
input_shape = (X.shape[1], X.shape[2])

model = Sequential()

# 1. Input Layer: Tell the model the shape of the data
model.add(Input(shape=input_shape))

# 2. Masking Layer: Ignore padding (0.0 values)
# This is critical for padded sequences!
model.add(Masking(mask_value=0.0))

# 3. LSTM Layer: The 'memory' of our network
model.add(LSTM(64)) # 64-unit memory cell

# 4. Dropout Layer: For regularization
model.add(Dropout(0.3))

# 5. Output Layer: Final prediction
model.add(Dense(1, activation='sigmoid')) # Sigmoid for binary classification

# Compile the model
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[AUC(name='auroc'),AUC(curve='PR',  name='auprc')] # Area Under ROC Curve
)

model.summary()

## Part 5: Training & Evaluating the Model

Now we're ready to train. We will split our 3D data `X` and labels `y` into training and testing sets.

We also need to handle the **class imbalance** (from TP 3, we know far more patients survive). We'll use `class_weight` to tell the model to pay more attention to the rare 'Mortality' class.

In [None]:
# LIMITING SEQ LENGHT FOR COMPUTATION REASONS
# Most patients have < 20 visits, so truncate at 20
X_padded = pad_sequences(
    X_sequences,
    maxlen=20,  # ADD THIS
    padding='pre',
    truncating='pre',
    dtype='float32'
)

X = X_padded


In [None]:
# 1. Split the data
patient_ids = np.array(ordered_hadm_sequences.index)

X_train, X_test, y_train, y_test, id_train, id_test = train_test_split(
    X, y, patient_ids,
    test_size=0.2,
    random_state=42,
    stratify=y
)
print(f"Train data shape: {X_train.shape}")
print(f"Test data shape: {X_test.shape}")

# 2. Calculate class weights to handle imbalance
weights = class_weight.compute_class_weight(
    'balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights = {0: weights[0], 1: weights[1]}
print(f"Class Weights: {class_weights}")

# 3. Train the model
print("\nTraining LSTM model...")
history = model.fit(
    X_train,
    y_train,
    validation_split=0.2, # Use 20% of training data for validation
    epochs=10,            # Start with 10 passes over the data
    batch_size=64,
    class_weight=class_weights,
    verbose=1
)

print("Model training complete.")

In [None]:
# 4. Plot training history

print("Plotting training history...")
plt.figure(figsize=(12, 5))
plt.subplot(1, 3, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(history.history['auroc'], label='Train AUROC')
plt.plot(history.history['val_auroc'], label='Validation AUROC')
plt.title('Model AUROC')
plt.xlabel('Epoch')
plt.ylabel('AUROC')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(history.history['auprc'], label='Train AUPRC')
plt.plot(history.history['val_auprc'], label='Validation AUPRC')
plt.title('Model AUPRC')
plt.xlabel('Epoch')
plt.ylabel('AUPRC')
plt.legend()
plt.tight_layout()
plt.show()

# 5. Evaluate on the held-out test set
print("\nEvaluating on test set...")
results = model.evaluate(X_test, y_test)
print(f"Test Loss: {results[0]:.4f}")
print(f"Test AUROC: {results[1]:.4f}")

In [None]:
# COMPARE MODEL with GRU and Bidirectional LSTM

# Store results
results_comparison = []

# ============================================================
# MODEL 2: GRU
# ============================================================
print("=" * 60)
print("MODEL 2: GRU")
print("=" * 60)

model_gru = Sequential([
    Input(shape=(X.shape[1], X.shape[2])),
    Masking(mask_value=0.0),
    GRU(64),
    Dropout(0.3),
    Dense(1, activation='sigmoid')
])

model_gru.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[AUC(name='auroc'), AUC(curve='PR', name='auprc')]
)

print(f"Parameters: {model_gru.count_params():,}")

start_time = time.time()
history_gru = model_gru.fit(
    X_train, y_train,
    validation_split=0.2,
    epochs=10,
    batch_size=32,
    class_weight=class_weights,
    verbose=0
)
train_time_gru = time.time() - start_time

test_results_gru = model_gru.evaluate(X_test, y_test, verbose=0)

results_comparison.append({
    'Model': 'GRU',
    'Parameters': model_gru.count_params(),
    'Train Time (s)': train_time_gru,
    'Test Loss': test_results_gru[0],
    'Test AUROC': test_results_gru[1],
    'Test AUPRC': test_results_gru[2]
})

print(f"Training time: {train_time_gru:.1f}s")
print(f"Test AUROC: {test_results_gru[1]:.4f}")
print(f"Test AUPRC: {test_results_gru[2]:.4f}\n")

# ============================================================
# MODEL 3: Bidirectional LSTM
# ============================================================
print("=" * 60)
print("MODEL 3: Bidirectional LSTM")
print("=" * 60)

model_bilstm = Sequential([
    Input(shape=(X.shape[1], X.shape[2])),
    Masking(mask_value=0.0),
    Bidirectional(LSTM(64)),
    Dropout(0.3),
    Dense(1, activation='sigmoid')
])

model_bilstm.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[AUC(name='auroc'), AUC(curve='PR', name='auprc')]
)

print(f"Parameters: {model_bilstm.count_params():,}")

start_time = time.time()
history_bilstm = model_bilstm.fit(
    X_train, y_train,
    validation_split=0.2,
    epochs=10,
    batch_size=32,
    class_weight=class_weights,
    verbose=0
)
train_time_bilstm = time.time() - start_time

test_results_bilstm = model_bilstm.evaluate(X_test, y_test, verbose=0)

results_comparison.append({
    'Model': 'Bidirectional LSTM',
    'Parameters': model_bilstm.count_params(),
    'Train Time (s)': train_time_bilstm,
    'Test Loss': test_results_bilstm[0],
    'Test AUROC': test_results_bilstm[1],
    'Test AUPRC': test_results_bilstm[2]
})

print(f"Training time: {train_time_bilstm:.1f}s")
print(f"Test AUROC: {test_results_bilstm[1]:.4f}")
print(f"Test AUPRC: {test_results_bilstm[2]:.4f}\n")



In [None]:
# Get predictions on test set
y_pred_probs = model.predict(X_test).flatten()

# Convert probabilities to binary predictions (threshold = 0.5)
y_pred = (y_pred_probs >= 0.5).astype(int)

# Calculate confusion matrix
cm = confusion_matrix(y_test, y_pred)

# Extract values
TN, FP, FN, TP = cm.ravel()

print("=" * 60)
print("CONFUSION MATRIX ANALYSIS")
print("=" * 60)
print(f"\nTrue Negatives (TN):  {TN:>6} - Correctly predicted survivors")
print(f"False Positives (FP): {FP:>6} - Predicted death, actually survived")
print(f"False Negatives (FN): {FN:>6} - Predicted survival, actually died")
print(f"True Positives (TP):  {TP:>6} - Correctly predicted deaths")

print(f"\n{'â”€' * 60}")
print(f"Total False Positives:  {FP}")
print(f"Total False Negatives:  {FN}")
print(f"{'â”€' * 60}")

# Calculate metrics
accuracy = (TP + TN) / (TP + TN + FP + FN)
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
recall = TP / (TP + FN) if (TP + FN) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

print(f"\nPERFORMANCE METRICS:")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f} (of predicted deaths, {precision*100:.1f}% were correct)")
print(f"  Recall:    {recall:.4f} (detected {recall*100:.1f}% of actual deaths)")
print(f"  F1-Score:  {f1:.4f}")

# Visualize confusion matrix
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap with counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Survived', 'Died'],
            yticklabels=['Survived', 'Died'],
            ax=axes[0], cbar=False)
axes[0].set_xlabel('Predicted Label', fontsize=12)
axes[0].set_ylabel('True Label', fontsize=12)
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')

# Normalized heatmap
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=['Survived', 'Died'],
            yticklabels=['Survived', 'Died'],
            ax=axes[1], cbar=False)
axes[1].set_xlabel('Predicted Label', fontsize=12)
axes[1].set_ylabel('True Label', fontsize=12)
axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Detailed classification report
print("\n" + "=" * 60)
print("DETAILED CLASSIFICATION REPORT")
print("=" * 60)
print(classification_report(y_test, y_pred,
                          target_names=['Survived (0)', 'Died (1)'],
                          digits=4))

# Clinical interpretation
print("=" * 60)
print("CLINICAL INTERPRETATION")
print("=" * 60)
print(f"\nFalse Positives ({FP} patients):")
print("  â†’ Predicted to die but survived")
print("  â†’ Clinical impact: Unnecessary interventions, patient anxiety")
print(f"\nFalse Negatives ({FN} patients):")
print("  â†’ Predicted to survive but died")
print("  â†’ Clinical impact: CRITICAL - Missed high-risk patients!")
print("\n" + "=" * 60)