# Advanced Temporal Model Proof of Concept

This notebook demonstrates a proof-of-concept implementation of a temporal model (LSTM) for predicting hospital readmissions using time-series EHR data from the MIMIC dataset.

Unlike traditional ML models that treat features as static, this approach explicitly models the temporal dynamics of patient data, potentially capturing important patterns that develop over time.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve

# Add project root to path for imports
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

# Import project utilities
from src.utils import get_logger, load_config, get_data_path

# Configure matplotlib
plt.style.use('seaborn-whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

## 1. Load and Prepare Time-Series Data

For this POC, we'll restructure our data to preserve the temporal nature of patient measurements. Instead of aggregating features over the entire stay, we'll organize them into sequences.

In [None]:
# Load configuration
config = load_config()
logger = get_logger('temporal_model_poc')

# Load processed data
data_path = get_data_path("processed", "combined_features", config)
data = pd.read_csv(data_path)

# Display basic info
print(f"Loaded data with {data.shape[0]} rows and {data.shape[1]} columns")
data.head()

In [None]:
# For this POC, we'll simulate time-series data if the actual data is already aggregated
# In a real implementation, you would extract the actual temporal measurements from the MIMIC database

def create_temporal_dataset(data, vital_features, lab_features, seq_length=24):
    """
    Create a temporal dataset from the processed data.
    
    For this POC, we'll simulate temporal data by:
    1. Using the existing features as the final values
    2. Generating synthetic time series leading up to these values
    
    Args:
        data: DataFrame with processed features
        vital_features: List of vital sign feature names
        lab_features: List of lab value feature names
        seq_length: Number of time steps in each sequence
        
    Returns:
        X_temporal: Dictionary mapping hadm_id to sequence data
        y: Series with readmission labels
    """
    # Extract target and patient IDs
    y = data['readmission_30day'].copy()
    hadm_ids = data['hadm_id'].values
    
    # Combine vital and lab features
    temporal_features = [f for f in data.columns if any(vf in f for vf in vital_features) or 
                                                 any(lf in f for lf in lab_features)]
    
    # Create dictionary to store sequences for each admission
    X_temporal = {}
    
    # For each admission, create a synthetic time series
    for i, hadm_id in enumerate(hadm_ids):
        # Get final values for this admission
        final_values = data.loc[i, temporal_features].values.astype(float)
        
        # Create a sequence leading up to these values
        # For simplicity, we'll use a random walk with the final value as the endpoint
        sequence = np.zeros((seq_length, len(temporal_features)))
        
        for j, final_val in enumerate(final_values):
            # Start with a value in the healthy range
            start_val = final_val * 0.8 + np.random.normal(0, 0.1)
            
            # Generate a trajectory from start to final value
            trajectory = np.linspace(start_val, final_val, seq_length)
            
            # Add some noise to make it realistic
            noise = np.random.normal(0, abs(final_val) * 0.05, seq_length)
            trajectory += noise
            
            # Store in sequence
            sequence[:, j] = trajectory
        
        # Store sequence for this admission
        X_temporal[hadm_id] = sequence
    
    return X_temporal, y

In [None]:
# Define vital sign and lab value features
vital_features = ['heart_rate', 'sbp', 'dbp', 'mbp', 'resp_rate', 'temperature', 'spo2']
lab_features = ['wbc', 'hgb', 'platelet', 'sodium', 'potassium', 'bicarbonate', 'bun', 'creatinine', 'glucose']

# Create temporal dataset
X_temporal, y = create_temporal_dataset(data, vital_features, lab_features)

# Display an example sequence
example_id = list(X_temporal.keys())[0]
example_sequence = X_temporal[example_id]

plt.figure(figsize=(14, 8))
for i, feature in enumerate(vital_features[:5]):  # Plot first 5 vital signs
    plt.plot(example_sequence[:, i], label=feature)
plt.title(f"Example Temporal Sequence for Admission {example_id}")
plt.xlabel("Time Step")
plt.ylabel("Normalized Value")
plt.legend()
plt.show()

## 2. Implement LSTM Model for Temporal Data

Now we'll implement an LSTM model that can process these temporal sequences and predict readmission risk.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

# Define a PyTorch dataset for our temporal data
class TemporalEHRDataset(Dataset):
    def __init__(self, sequences, labels, hadm_ids):
        self.sequences = sequences
        self.labels = labels
        self.hadm_ids = hadm_ids
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        hadm_id = self.hadm_ids[idx]
        sequence = self.sequences[hadm_id]
        label = self.labels[idx]
        return torch.FloatTensor(sequence), torch.FloatTensor([label])

# Define the LSTM model
class TemporalPatientLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_dim, 
            hidden_dim, 
            num_layers=num_layers, 
            batch_first=True,
            dropout=dropout
        )
        self.attention = nn.Linear(hidden_dim, 1)  # Simple attention mechanism
        self.classifier = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        # Process sequence with LSTM
        lstm_out, _ = self.lstm(x)  # lstm_out shape: [batch_size, seq_len, hidden_dim]
        
        # Apply attention to focus on important time steps
        attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
        context = torch.sum(attention_weights * lstm_out, dim=1)  # Weighted sum
        
        # Classify
        return torch.sigmoid(self.classifier(context))

In [None]:
# Prepare data for PyTorch
hadm_ids = list(X_temporal.keys())
labels = y.values

# Split data
train_ids, test_ids, train_labels, test_labels = train_test_split(
    hadm_ids, labels, test_size=0.2, random_state=42, stratify=labels
)

# Create datasets
train_dataset = TemporalEHRDataset(X_temporal, train_labels, train_ids)
test_dataset = TemporalEHRDataset(X_temporal, test_labels, test_ids)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Initialize model
input_dim = X_temporal[hadm_ids[0]].shape[1]  # Number of features
hidden_dim = 64
model = TemporalPatientLSTM(input_dim, hidden_dim)

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Lists to store metrics
train_losses = []
test_losses = []
test_aucs = []

for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    for sequences, labels in train_loader:
        sequences, labels = sequences.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(sequences)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * sequences.size(0)
    
    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss)
    
    # Evaluation
    model.eval()
    test_loss = 0.0
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for sequences, labels in test_loader:
            sequences, labels = sequences.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(sequences)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item() * sequences.size(0)
            
            # Store predictions and labels
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(outputs.cpu().numpy())
    
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    
    # Calculate AUC
    all_labels = np.array(all_labels).flatten()
    all_preds = np.array(all_preds).flatten()
    test_auc = roc_auc_score(all_labels, all_preds)
    test_aucs.append(test_auc)
    
    print(f"Epoch {epoch+1}/{num_epochs}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Test Loss: {test_loss:.4f}, "
          f"Test AUC: {test_auc:.4f}")

## 3. Visualize Results and Compare with Traditional Models

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs+1), train_losses, 'b-', label='Train Loss')
plt.plot(range(1, num_epochs+1), test_losses, 'r-', label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs+1), test_aucs, 'g-')
plt.xlabel('Epoch')
plt.ylabel('AUC')
plt.title('Test AUC')

plt.tight_layout()
plt.show()

In [None]:
# Compare with traditional models
from src.models.model import ReadmissionModel

# Initialize and train traditional models
traditional_models = {
    "Logistic Regression": ReadmissionModel(),
    "Random Forest": ReadmissionModel(),
    "XGBoost": ReadmissionModel(),
    "LightGBM": ReadmissionModel()
}

# Train each model with its corresponding algorithm
for i, (name, model) in enumerate(traditional_models.items()):
    algorithm = name.lower().replace(" ", "_")
    if algorithm == "logistic_regression":
        algorithm = "logistic_regression"
    metrics = model.fit(data, algorithm=algorithm)
    print(f"{name} AUC: {metrics['roc_auc']:.4f}")

## 4. Analyze Attention Weights for Interpretability

One advantage of our temporal model is the ability to interpret which time points were most important for the prediction through the attention mechanism.

In [None]:
# Function to extract attention weights
def get_attention_weights(model, sequence):
    model.eval()
    with torch.no_grad():
        sequence_tensor = torch.FloatTensor(sequence).unsqueeze(0).to(device)
        lstm_out, _ = model.lstm(sequence_tensor)
        attention_weights = torch.softmax(model.attention(lstm_out), dim=1)
    return attention_weights.cpu().numpy().squeeze()

# Get a sample sequence from a readmitted patient
readmitted_ids = [id for id, label in zip(test_ids, test_labels) if label == 1]
if readmitted_ids:
    sample_id = readmitted_ids[0]
    sample_sequence = X_temporal[sample_id]
    
    # Get attention weights
    attention_weights = get_attention_weights(model, sample_sequence)
    
    # Plot sequence with attention weights
    plt.figure(figsize=(14, 10))
    
    # Plot vital signs
    plt.subplot(2, 1, 1)
    for i, feature in enumerate(vital_features[:5]):  # Plot first 5 vital signs
        plt.plot(sample_sequence[:, i], label=feature)
    plt.title(f"Vital Signs for Readmitted Patient {sample_id}")
    plt.xlabel("Time Step")
    plt.ylabel("Normalized Value")
    plt.legend()
    
    # Plot attention weights
    plt.subplot(2, 1, 2)
    plt.bar(range(len(attention_weights)), attention_weights)
    plt.title("Attention Weights")
    plt.xlabel("Time Step")
    plt.ylabel("Weight")
    
    plt.tight_layout()
    plt.show()
    
    # Identify the most important time steps
    top_indices = np.argsort(-attention_weights)[:5]
    print("Most important time steps:")
    for idx in top_indices:
        print(f"Time step {idx}: Weight {attention_weights[idx]:.4f}")
else:
    print("No readmitted patients in the test set")

## 5. Discussion and Conclusions

### Advantages of Temporal Modeling

1. **Capturing Temporal Patterns**: The LSTM model can identify patterns in how vital signs and lab values change over time, potentially detecting deterioration or improvement trends that static models would miss.

2. **Attention Mechanism**: The attention weights provide interpretability by highlighting which time points were most important for the prediction, which could help clinicians understand when critical changes occurred.

3. **Handling Variable-Length Sequences**: Although not fully implemented in this POC, LSTM models can naturally handle variable-length sequences, accommodating different hospital stay durations.

### Limitations and Future Work

1. **Data Requirements**: Temporal models require more detailed data with timestamps, which may not always be available or may require additional preprocessing.

2. **Computational Complexity**: Training LSTM models is more computationally intensive than traditional ML models.

3. **Hyperparameter Tuning**: These models have more hyperparameters to tune, including LSTM layers, hidden dimensions, and attention mechanisms.

### Next Steps

1. **Real Temporal Data**: Replace the synthetic temporal data with actual time-series measurements from the MIMIC database.

2. **More Sophisticated Architectures**: Explore bidirectional LSTMs, transformer models, or temporal convolutional networks.

3. **Multi-modal Integration**: Combine temporal data with static features (demographics, comorbidities) for a more comprehensive model.

4. **Clinical Validation**: Work with clinicians to validate the patterns identified by the attention mechanism.