# Enhanced Temporal Model: Time-Aware LSTM for Readmission Prediction

**Objective:** This notebook builds upon the initial LSTM proof-of-concept (`advanced_temporal_model_poc.ipynb`) by implementing and evaluating a **Time-Aware LSTM**. The key enhancement is the explicit modeling of irregular time intervals between clinical measurements, a crucial aspect of real-world EHR data often ignored by standard sequence models.

**Narrative:** While traditional ML models (Logistic Regression, LightGBM) provide a baseline, they treat patient data statically. Our first LSTM PoC introduced sequence modeling but didn't fully leverage the *timing* information. This enhanced model incorporates learned time embeddings, aiming to capture the significance of *when* events occur relative to each other. We will train this model and compare its performance (ROC AUC, PR AUC) against a strong baseline (LightGBM with SMOTE, trained on the *same data split*) to assess the potential value added by temporal awareness, even on the limited MIMIC-III Demo dataset.

**Methodology:**
1. Load processed data (`combined_features.csv`).
2. Prepare temporal data: Generate synthetic sequences *with explicit, irregular time intervals*.
3. Split data into training and testing sets.
4. Define the Time-Aware LSTM architecture using PyTorch (including `TimeEncoder`).
5. Implement a PyTorch `Dataset` and `DataLoader` with appropriate padding for sequences and intervals.
6. Train the Time-Aware LSTM model.
7. Train a baseline LightGBM model on the *same* train/test split for fair comparison (using static/aggregated features).
8. Evaluate both models using ROC AUC and Precision-Recall AUC.
9. Visualize the results: Training curves, ROC/PR comparison.
10. Analyze attention weights in the context of time intervals.
11. Discuss findings, limitations, and implications for the MLOps pipeline.

## 1. Imports and Setup

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, roc_curve, classification_report
import lightgbm as lgb # Baseline model
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm # Use notebook version of tqdm

# Add project root to path for imports
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import project utilities and model definitions
from src.utils import get_logger, load_config, get_data_path
from src.models.temporal_modeling import TimeEncoder, TimeAwarePatientLSTM, TemporalEHRDataset, get_attention_weights # Assuming these are defined here or imported

# --- Configuration ---
try:
    config = load_config()
    logger = get_logger('temporal_model_enhanced_nb') # Use specific logger name
except FileNotFoundError:
    logger.error("Configuration file not found. Please ensure 'configs/config.yaml' exists.")
    # Provide default config or raise error if necessary
    config = {}

# Plotting setup
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6) # Adjusted default size
plt.rcParams['font.size'] = 12

# Define output directory for results
results_dir = os.path.join(os.getcwd(), 'results') # Save results within notebooks dir
os.makedirs(results_dir, exist_ok=True)

# Set random seed for reproducibility
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

## 2. Load Data

In [None]:
try:
    # Load processed data (static/aggregated features)
    data_path = get_data_path("processed", "combined_features", config)
    data = pd.read_csv(data_path)
    logger.info(f"Loaded data from {data_path}. Shape: {data.shape}")
    # Display basic info
    print(f"Loaded data with {data.shape[0]} rows and {data.shape[1]} columns")
    display(data.head())
except FileNotFoundError:
    logger.error(f"Processed data file not found at {data_path}. Cannot proceed.")
    # Handle error appropriately, e.g., raise Exception or exit
    data = pd.DataFrame() # Assign empty df to prevent further errors
except Exception as e:
    logger.error(f"Error loading data: {e}", exc_info=True)
    data = pd.DataFrame()

## 3. Data Preparation for Temporal Model

We need to restructure the data into sequences of measurements over time, including the time intervals *between* measurements. As the `combined_features.csv` contains aggregated data, we'll generate synthetic sequences for this demonstration, simulating irregular measurement intervals.

In [None]:
def create_temporal_dataset_with_intervals(data, vital_features, lab_features, seq_length=24):
    """
    Creates a temporal dataset with explicit time intervals.
    Generates synthetic sequences ending at the aggregated values from 'data'.
    Simulates irregular intervals between measurements.

    Args:
        data (pd.DataFrame): DataFrame with processed aggregated features.
        vital_features (List[str]): Base names of vital sign features.
        lab_features (List[str]): Base names of lab value features.
        seq_length (int): Number of time steps in each sequence.
        
    Returns:
        X_temporal (Dict[str, np.ndarray]): Dictionary mapping hadm_id to sequence data [seq_length, num_features].
        time_intervals (Dict[str, np.ndarray]): Dictionary mapping hadm_id to time interval data [seq_length, 1]. Intervals represent hours since previous measurement (first is 0).
        timestamps (Dict[str, np.ndarray]): Dictionary mapping hadm_id to cumulative timestamp data [seq_length, 1]. Timestamps represent hours since admission.
        y (pd.Series): Series with readmission labels, indexed by hadm_id.
        temporal_feature_names (List[str]): List of feature names included in the sequences.
    """
    logger.info(f"Creating synthetic temporal dataset with sequence length {seq_length}...")
    # Extract target and patient IDs
    target_col = 'readmission_30day'
    if target_col not in data.columns:
        raise ValueError(f"Target column '{target_col}' not found in data.")
    y = data.set_index('hadm_id')[target_col].copy()
    hadm_ids = data['hadm_id'].values

    # Identify relevant feature columns based on base names (more robust)
    # Assumes aggregated features have names like 'heart_rate_mean', 'glucose_max', etc.
    all_feature_cols = [f for f in data.columns if any(vf in f for vf in vital_features) or any(lf in f for lf in lab_features)]
    
    # Determine a single 'final value' for each base feature (e.g., use mean if available, else max, etc.)
    temporal_feature_map = {}
    final_value_cols = []
    for base_feat in vital_features + lab_features:
        found_col = None
        for suffix in ['_mean', '_max', '_min', '_last', '']:
             potential_col = f"{base_feat}{suffix}"
             if potential_col in all_feature_cols:
                  found_col = potential_col
                  break
        if found_col:
             temporal_feature_map[base_feat] = found_col # Map base name to the column used for final value
             final_value_cols.append(found_col)
        else:
             logger.warning(f"No suitable aggregated column found for base feature: {base_feat}")

    temporal_feature_names = list(temporal_feature_map.keys()) # Use base names for consistency
    num_features = len(temporal_feature_names)
    logger.info(f"Generating sequences for {num_features} features: {temporal_feature_names}")

    X_temporal = {}
    time_intervals = {}
    timestamps = {}

    # Use tqdm for progress tracking
    for i in tqdm(range(len(data)), desc="Generating Sequences"):
        hadm_id = data.loc[i, 'hadm_id']
        
        # Generate time intervals (hours since previous measurement)
        intervals = np.zeros(seq_length)
        # Simulate irregular intervals (e.g., 1-8 hours)
        intervals[1:] = np.random.uniform(1, 8, seq_length - 1)
        cumulative_timestamps = np.cumsum(intervals)

        sequence = np.zeros((seq_length, num_features))

        for j, base_feat_name in enumerate(temporal_feature_names):
            final_val_col = temporal_feature_map[base_feat_name]
            final_val = data.loc[i, final_val_col]

            # Handle potential NaN final values (e.g., replace with 0 or median)
            if pd.isna(final_val):
                final_val = 0 # Simple imputation for demo

            # Simulate a plausible starting value (e.g., 80% of final + noise)
            start_val = final_val * 0.8 + np.random.normal(0, abs(final_val * 0.1) + 1e-6)
            
            # Generate a non-linear trajectory from start to final value
            progress = np.linspace(0, 1, seq_length) ** 1.5 
            trajectory = start_val + (final_val - start_val) * progress
            
            # Add random noise
            noise = np.random.normal(0, abs(final_val * 0.05) + 1e-6, seq_length)
            sequence[:, j] = trajectory + noise
        
        X_temporal[hadm_id] = sequence
        time_intervals[hadm_id] = intervals.reshape(-1, 1)
        timestamps[hadm_id] = cumulative_timestamps.reshape(-1, 1)
    
    logger.info("Synthetic temporal dataset generation complete.")
    return X_temporal, time_intervals, timestamps, y, temporal_feature_names

In [None]:
# Define feature sets (base names)
# These should ideally come from config, but hardcoding for notebook clarity
vital_features = ['heart_rate', 'sbp', 'dbp', 'mbp', 'resp_rate', 'temperature', 'spo2', 'gcs']
lab_features = [
    'aniongap', 'albumin', 'bands', 'bicarbonate', 'bilirubin', 'bun', 
    'calcium', 'chloride', 'creatinine', 'glucose', 'hematocrit', 'hemoglobin', 
    'lactate', 'platelet', 'potassium', 'ptt', 'inr', 'pt', 'sodium', 'wbc'
]

if not data.empty:
    # Create temporal dataset
    X_temporal, time_intervals, timestamps, y, temporal_feature_names = create_temporal_dataset_with_intervals(
        data, vital_features, lab_features, seq_length=24 # Using 24 time steps for POC
    )

    # Get admission IDs
    hadm_ids = list(X_temporal.keys())
    labels = y.loc[hadm_ids].values # Ensure labels align with the keys

    # Split data (use same random state for consistency across runs/comparisons)
    train_ids, test_ids, train_labels, test_labels = train_test_split(
        hadm_ids, labels, test_size=0.2, random_state=RANDOM_STATE, stratify=labels
    )
    logger.info(f"Data split: {len(train_ids)} train samples, {len(test_ids)} test samples.")

    # Display an example sequence and its intervals
    example_id = train_ids[0]
    example_sequence = X_temporal[example_id]
    example_intervals = time_intervals[example_id]
    example_timestamps = timestamps[example_id]

    plt.figure(figsize=(15, 5))
    # Plot first few features against cumulative time
    num_features_to_plot = min(5, len(temporal_feature_names))
    for i in range(num_features_to_plot):
        plt.plot(example_timestamps, example_sequence[:, i], label=temporal_feature_names[i], marker='o', linestyle='--')
    plt.title(f"Example Temporal Sequence (Features vs. Time Since Admission) - Admission {example_id}")
    plt.xlabel("Time Since Admission (Hours)")
    plt.ylabel("Simulated Value (Arbitrary Units)")
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.grid(True)
    plt.tight_layout()
    plt.show()
else:
     logger.error("Data loading failed earlier, cannot prepare temporal data.")
     # Set dummy variables to avoid errors later, though execution should ideally stop
     X_temporal, time_intervals, timestamps, y, temporal_feature_names = {}, {}, {}, pd.Series(), []
     train_ids, test_ids, train_labels, test_labels = [], [], [], []

## 4. Model Definition (Time-Aware LSTM)

In [None]:
# Using TimeEncoder and TimeAwarePatientLSTM from src.models.temporal_modeling
# These classes should be defined in that file or copied here for self-containment.
# Assuming they are imported correctly.
logger.info("Using TimeEncoder and TimeAwarePatientLSTM classes.")

## 5. Dataset and DataLoader

In [None]:
class EnhancedTemporalEHRDataset(Dataset):
    """PyTorch Dataset for temporal sequences with time intervals."""
    def __init__(self, sequences, time_intervals, labels, hadm_ids):
        self.sequences = sequences
        self.time_intervals = time_intervals
        self.labels = labels
        self.hadm_ids = hadm_ids

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        hadm_id = self.hadm_ids[idx]
        sequence = self.sequences[hadm_id]
        interval = self.time_intervals[hadm_id]
        label = self.labels[idx]
        return {
            'sequence': torch.FloatTensor(sequence),
            'intervals': torch.FloatTensor(interval),
            'label': torch.FloatTensor([label]) # Label as float tensor
        }

def collate_fn(batch):
    """Pads sequences and intervals in a batch."""
    # Sort batch by sequence length (optional but common)
    # batch.sort(key=lambda x: len(x['sequence']), reverse=True)
    
    sequences = [item['sequence'] for item in batch]
    intervals = [item['intervals'] for item in batch]
    labels = [item['label'] for item in batch]

    # Pad sequences and intervals
    padded_sequences = nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=0.0)
    padded_intervals = nn.utils.rnn.pad_sequence(intervals, batch_first=True, padding_value=0.0)
    
    # Stack labels
    labels_tensor = torch.stack(labels)

    return {
        'sequences': padded_sequences,
        'intervals': padded_intervals,
        'labels': labels_tensor
    }

if train_ids: # Only create datasets if data splitting was successful
    # Create datasets
    train_dataset = EnhancedTemporalEHRDataset(X_temporal, time_intervals, train_labels, train_ids)
    test_dataset = EnhancedTemporalEHRDataset(X_temporal, time_intervals, test_labels, test_ids)

    # Create dataloaders
    batch_size = config.get('models', {}).get('temporal_readmission', {}).get('batch_size', 32) # Get from config or default
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    logger.info(f"Created DataLoaders with batch size: {batch_size}")
else:
     logger.warning("Skipping DataLoader creation due to empty train_ids.")
     train_loader, test_loader = None, None

## 6. Train the Time-Aware LSTM Model

In [None]:
lstm_test_results = {
    'labels': None, 
    'preds_proba': None,
    'roc_auc': np.nan,
    'pr_auc': np.nan
}

if train_loader and test_loader: # Check if loaders were created
    # Initialize model
    input_dim = len(temporal_feature_names)
    hidden_dim = config.get('models', {}).get('temporal_readmission', {}).get('hidden_dim', 64)
    num_layers = config.get('models', {}).get('temporal_readmission', {}).get('num_layers', 1)
    time_embed_dim = config.get('models', {}).get('temporal_readmission', {}).get('time_embed_dim', 16)
    dropout = config.get('models', {}).get('temporal_readmission', {}).get('dropout', 0.2)
    
    # Note: TimeAwarePatientLSTM expects num_static_features, but we are not using them in this notebook's dataset.
    # We should adapt the model or dataset if static features are needed alongside temporal.
    # For now, we'll need a modified model or pass num_static_features=0.
    # Let's modify the call assuming the model can handle num_static_features=0 or modify model definition.
    # Assuming a simplified LSTM for this notebook focusing only on time-aware sequences:
    class SimplifiedTimeAwareLSTM(nn.Module):
        def __init__(self, input_dim, hidden_dim, time_embed_dim=16, num_layers=1, dropout=0.2):
            super().__init__()
            self.time_encoder = TimeEncoder(time_embed_dim)
            self.lstm = nn.LSTM(
                input_dim + time_embed_dim, hidden_dim, num_layers=num_layers, 
                batch_first=True, dropout=(dropout if num_layers > 1 else 0)
            )
            self.attention = nn.Sequential(nn.Linear(hidden_dim, 32), nn.Tanh(), nn.Linear(32, 1))
            self.classifier = nn.Linear(hidden_dim, 1)

        def forward(self, x, time_intervals):
            time_encoding = self.time_encoder(time_intervals)
            x_with_time = torch.cat([x, time_encoding], dim=2)
            lstm_out, _ = self.lstm(x_with_time)
            attention_weights = torch.softmax(self.attention(lstm_out), dim=1)
            context = torch.sum(attention_weights * lstm_out, dim=1)
            # Return logits (BCEWithLogitsLoss expects logits)
            return self.classifier(context)

    model = SimplifiedTimeAwareLSTM(input_dim, hidden_dim, time_embed_dim, num_layers, dropout)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    logger.info(f"Initialized SimplifiedTimeAwareLSTM model on {device}.")

    # Define loss function and optimizer
    criterion = nn.BCEWithLogitsLoss() # Use Logits loss
    learning_rate = config.get('models', {}).get('temporal_readmission', {}).get('learning_rate', 0.001)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    num_epochs = config.get('models', {}).get('temporal_readmission', {}).get('num_epochs', 15) # Increase epochs slightly
    train_losses, test_losses, test_roc_aucs, test_pr_aucs = [], [], [], []

    logger.info(f"Starting LSTM training for {num_epochs} epochs...")
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        epoch_train_loss = 0.0
        train_progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)
        for batch in train_progress_bar:
            sequences = batch['sequences'].to(device)
            intervals = batch['intervals'].to(device)
            labels = batch['labels'].to(device).float() # Ensure labels are float

            optimizer.zero_grad()
            outputs = model(sequences, intervals) # Get logits
            loss = criterion(outputs.squeeze(), labels.squeeze()) # Squeeze outputs and labels
            loss.backward()
            optimizer.step()
            epoch_train_loss += loss.item()
            train_progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        avg_train_loss = epoch_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Evaluation phase
        model.eval()
        epoch_test_loss = 0.0
        all_labels_list = []
        all_preds_proba_list = []
        test_progress_bar = tqdm(test_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Eval]", leave=False)
        with torch.no_grad():
            for batch in test_progress_bar:
                sequences = batch['sequences'].to(device)
                intervals = batch['intervals'].to(device)
                labels = batch['labels'].to(device).float()

                outputs = model(sequences, intervals)
                loss = criterion(outputs.squeeze(), labels.squeeze())
                epoch_test_loss += loss.item()
                
                # Apply sigmoid to logits to get probabilities
                probabilities = torch.sigmoid(outputs).squeeze()
                
                all_labels_list.append(labels.cpu().numpy())
                # Handle cases where batch size is 1 and probabilities might become scalar
                if probabilities.ndim == 0:
                    probabilities = probabilities.unsqueeze(0)
                all_preds_proba_list.append(probabilities.cpu().numpy())

        avg_test_loss = epoch_test_loss / len(test_loader)
        test_losses.append(avg_test_loss)
        
        # Concatenate results from all batches
        epoch_labels = np.concatenate([lbl.flatten() for lbl in all_labels_list])
        epoch_preds_proba = np.concatenate([prob.flatten() for prob in all_preds_proba_list])
        
        # Calculate metrics
        epoch_roc_auc = roc_auc_score(epoch_labels, epoch_preds_proba)
        epoch_pr_auc = average_precision_score(epoch_labels, epoch_preds_proba)
        test_roc_aucs.append(epoch_roc_auc)
        test_pr_aucs.append(epoch_pr_auc)

        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Test ROC AUC: {epoch_roc_auc:.4f}, Test PR AUC: {epoch_pr_auc:.4f}")
    
    logger.info("LSTM training complete.")
    
    # Store final test results
    lstm_test_results['labels'] = epoch_labels
    lstm_test_results['preds_proba'] = epoch_preds_proba
    lstm_test_results['roc_auc'] = epoch_roc_auc
    lstm_test_results['pr_auc'] = epoch_pr_auc
    
    # Plot training curves
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
    plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss (BCEWithLogits)')
    plt.title('Training and Test Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(range(1, num_epochs + 1), test_roc_aucs, label='Test ROC AUC')
    plt.plot(range(1, num_epochs + 1), test_pr_aucs, label='Test PR AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC Score')
    plt.title('Test AUC Scores')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'time_aware_lstm_training_curves.png'))
    plt.show()
else:
    logger.error("Cannot train LSTM model as DataLoaders were not created.")

## 7. Train Baseline Model (LightGBM)

In [None]:
lgbm_test_results = {
    'labels': None, 
    'preds_proba': None,
    'roc_auc': np.nan,
    'pr_auc': np.nan
}

if not data.empty:
    # Prepare data for LightGBM (using static/aggregated features)
    # Ensure we use the *exact same* train/test split based on hadm_id
    target_col = 'readmission_30day'
    feature_cols = [col for col in data.columns if col not in ['hadm_id', 'subject_id', 'admittime', 'dischtime', target_col]]
    
    # Filter data based on train/test IDs
    train_data_lgbm = data[data['hadm_id'].isin(train_ids)].copy()
    test_data_lgbm = data[data['hadm_id'].isin(test_ids)].copy()
    
    X_train_lgbm = train_data_lgbm[feature_cols]
    y_train_lgbm = train_data_lgbm[target_col]
    X_test_lgbm = test_data_lgbm[feature_cols]
    y_test_lgbm = test_data_lgbm[target_col]
    
    # Handle potential NaNs (simple mean imputation for baseline)
    X_train_lgbm = X_train_lgbm.fillna(X_train_lgbm.mean())
    X_test_lgbm = X_test_lgbm.fillna(X_train_lgbm.mean()) # Use train mean for test set
    
    # Scale features
    scaler = StandardScaler()
    X_train_lgbm_scaled = scaler.fit_transform(X_train_lgbm)
    X_test_lgbm_scaled = scaler.transform(X_test_lgbm)
    
    # Define LightGBM parameters (should ideally come from config/tuning)
    lgbm_params = {
        'objective': 'binary',
        'metric': 'auc',
        'boosting_type': 'gbdt',
        'n_estimators': 1000,
        'learning_rate': 0.05,
        'num_leaves': 31,
        'max_depth': -1,
        'seed': RANDOM_STATE,
        'n_jobs': -1,
        'verbose': -1, # Suppress verbose output
        'colsample_bytree': 0.8,
        'subsample': 0.8,
        'reg_alpha': 0.1,
        'reg_lambda': 0.1
        # Add scale_pos_weight if not using SMOTE, or handle imbalance separately
    }
    
    # Train LightGBM model
    logger.info("Training LightGBM baseline model...")
    lgbm_model = lgb.LGBMClassifier(**lgbm_params)
    lgbm_model.fit(X_train_lgbm_scaled, y_train_lgbm, 
                   eval_set=[(X_test_lgbm_scaled, y_test_lgbm)], 
                   eval_metric='auc', 
                   callbacks=[lgb.early_stopping(100, verbose=False)]) # Early stopping
    
    # Evaluate LightGBM model
    lgbm_preds_proba = lgbm_model.predict_proba(X_test_lgbm_scaled)[:, 1]
    lgbm_roc_auc = roc_auc_score(y_test_lgbm, lgbm_preds_proba)
    lgbm_pr_auc = average_precision_score(y_test_lgbm, lgbm_preds_proba)
    
    lgbm_test_results['labels'] = y_test_lgbm.values
    lgbm_test_results['preds_proba'] = lgbm_preds_proba
    lgbm_test_results['roc_auc'] = lgbm_roc_auc
    lgbm_test_results['pr_auc'] = lgbm_pr_auc
    
    logger.info(f"LightGBM Baseline - Test ROC AUC: {lgbm_roc_auc:.4f}, Test PR AUC: {lgbm_pr_auc:.4f}")
else:
    logger.error("Cannot train LightGBM model as data loading failed.")

## 8. Evaluate and Compare Models

In [None]:
plt.figure(figsize=(12, 5))

# Plot ROC Curves
plt.subplot(1, 2, 1)
if lgbm_test_results['labels'] is not None:
    fpr_lgbm, tpr_lgbm, _ = roc_curve(lgbm_test_results['labels'], lgbm_test_results['preds_proba'])
    plt.plot(fpr_lgbm, tpr_lgbm, label=f"LightGBM (AUC = {lgbm_test_results['roc_auc']:.3f})")

if lstm_test_results['labels'] is not None:
    fpr_lstm, tpr_lstm, _ = roc_curve(lstm_test_results['labels'], lstm_test_results['preds_proba'])
    plt.plot(fpr_lstm, tpr_lstm, label=f"Time-Aware LSTM (AUC = {lstm_test_results['roc_auc']:.3f})")

plt.plot([0, 1], [0, 1], 'k--', label='Random Chance')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve Comparison')
plt.legend()
plt.grid(True)

# Plot Precision-Recall Curves
plt.subplot(1, 2, 2)
if lgbm_test_results['labels'] is not None:
    precision_lgbm, recall_lgbm, _ = precision_recall_curve(lgbm_test_results['labels'], lgbm_test_results['preds_proba'])
    plt.plot(recall_lgbm, precision_lgbm, label=f"LightGBM (AUC = {lgbm_test_results['pr_auc']:.3f})")

if lstm_test_results['labels'] is not None:
    precision_lstm, recall_lstm, _ = precision_recall_curve(lstm_test_results['labels'], lstm_test_results['preds_proba'])
    plt.plot(recall_lstm, precision_lstm, label=f"Time-Aware LSTM (AUC = {lstm_test_results['pr_auc']:.3f})")

# Calculate baseline precision (positive class ratio)
positive_ratio = np.mean(lgbm_test_results['labels']) if lgbm_test_results['labels'] is not None else 0
plt.axhline(positive_ratio, color='k', linestyle='--', label=f'Baseline ({positive_ratio:.3f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve Comparison')
plt.legend()
plt.grid(True)
plt.ylim([0.0, 1.05]) # Ensure y-axis starts at 0

plt.tight_layout()
plt.savefig(os.path.join(results_dir, 'time_aware_lstm_roc_pr_curves.png'))
plt.show()

# Print final comparison
print("--- Model Comparison ---")
print(f"LightGBM Baseline: ROC AUC = {lgbm_test_results['roc_auc']:.4f}, PR AUC = {lgbm_test_results['pr_auc']:.4f}")
print(f"Time-Aware LSTM:   ROC AUC = {lstm_test_results['roc_auc']:.4f}, PR AUC = {lstm_test_results['pr_auc']:.4f}")

## 9. Analyze Attention Weights (Optional)

In [None]:
# Function to get attention weights (assuming it's defined in temporal_modeling.py or here)
# def get_attention_weights(model, dataloader, device):
#     model.eval()
#     all_weights = []
#     with torch.no_grad():
#         for batch in dataloader:
#             sequences = batch['sequences'].to(device)
#             intervals = batch['intervals'].to(device)
#             # Need to modify model forward pass to return weights
#             _, weights = model(sequences, intervals) # Assuming model returns logits, weights
#             all_weights.append(weights.cpu().numpy())
#     return np.concatenate(all_weights, axis=0)

# # Get attention weights for the test set
# if test_loader:
#     try:
#         # Modify the model's forward pass temporarily or permanently to return weights
#         # This requires adjusting the SimplifiedTimeAwareLSTM definition
#         # For now, we'll skip this visualization as it requires model changes
#         logger.warning("Attention weight analysis requires model modification to return weights. Skipping.")
#         # attention_weights = get_attention_weights(model, test_loader, device)
#         # logger.info(f"Retrieved attention weights with shape: {attention_weights.shape}")
#         
#         # # Visualize attention for an example patient
#         # example_idx = 0
#         # example_attn = attention_weights[example_idx].flatten()
#         # example_ts = timestamps[test_ids[example_idx]].flatten()
#         
#         # plt.figure(figsize=(12, 4))
#         # plt.bar(range(len(example_attn)), example_attn)
#         # plt.xticks(range(len(example_attn)), [f"{t:.1f}h" for t in example_ts], rotation=45)
#         # plt.xlabel("Time Since Admission (Hours)")
#         # plt.ylabel("Attention Weight")
#         # plt.title(f"Attention Weights Over Time for Example Patient {test_ids[example_idx]}")
#         # plt.tight_layout()
#         # plt.show()
#     except Exception as e:
#         logger.error(f"Error during attention analysis: {e}", exc_info=True)
# else:
#     logger.warning("Cannot analyze attention weights as test_loader is not available.")
logger.info("Attention weight analysis skipped (requires model modification).")

## 10. Discussion and Conclusion

**Findings:**
*   Compare the ROC AUC and PR AUC scores between the Time-Aware LSTM and the LightGBM baseline. Did explicitly modeling time intervals improve performance on this synthetic, temporally-aware dataset?
*   Analyze the training curves. Did the LSTM converge well? Was there overfitting?
*   (If attention analysis was performed) Did the attention mechanism focus on specific time points? Did these align with significant changes in the synthetic feature trajectories?

**Limitations:**
*   **Synthetic Data:** The primary limitation is the use of synthetically generated temporal sequences based on aggregated final values. This does *not* reflect true patient trajectories and likely oversimplifies the temporal dynamics. The goal here was purely to demonstrate the *mechanics* of the Time-Aware LSTM and compare it to a static baseline on data *designed* to have temporal structure.
*   **MIMIC-III Demo Size:** The small size of the demo dataset limits the statistical significance of the results and the model's ability to learn complex patterns.
*   **Hyperparameter Tuning:** Limited hyperparameter tuning was performed for both models.
*   **Feature Engineering:** The synthetic generation process is basic. Real-world application requires sophisticated feature engineering from raw time-series data.

**Implications for MLOps:**
*   **Data Pipeline Complexity:** Handling true temporal data (extracting sequences, aligning timestamps, handling missing values within sequences) significantly increases the complexity of the data processing pipeline compared to using static/aggregated features.
*   **Feature Store Integration:** A feature store becomes even more critical for managing time-dependent features and ensuring consistency between training and inference (point-in-time correctness).
*   **Model Training Infrastructure:** Training deep learning models like LSTMs requires more computational resources (potentially GPUs) and different infrastructure compared to tree-based models.
*   **Monitoring:** Monitoring temporal models involves tracking sequence input drift and potentially concept drift related to temporal patterns, adding complexity to the monitoring strategy.

**Conclusion:** This notebook demonstrated the implementation of a Time-Aware LSTM, highlighting its potential to leverage temporal information in EHR data. While the results on synthetic data show [mention observed difference, e.g., 'a slight improvement' or 'comparable performance'] compared to the LightGBM baseline, applying this to real, complex EHR sequences is necessary to truly evaluate its benefits. The exercise underscores the significant increase in MLOps complexity when moving from static to temporal modeling.