In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import torch
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, accuracy_score, f1_score
from sklearn.metrics import roc_curve, precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
# Define data paths (adjust these based on your data location)
train_data_path = "/home/taekim/enhanced_preprocessed_data/enhanced_set-a.csv"
val_data_path = "/home/taekim/enhanced_preprocessed_data/enhanced_set-b.csv"
test_data_path = "/home/taekim/enhanced_preprocessed_data/enhanced_set-c.csv"

train_outcomes_path = "/home/taekim/ml4h_data/p1/Outcomes-a.txt"
val_outcomes_path = "/home/taekim/ml4h_data/p1/Outcomes-b.txt"
test_outcomes_path = "/home/taekim/ml4h_data/p1/Outcomes-c.txt"

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        
        # Register as buffer (not a parameter but part of the module)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # Add positional encoding to input
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [5]:
class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, d_model=64, nhead=4, num_layers=2, dim_feedforward=128, 
                 dropout=0.1, output_dim=1, max_seq_length=1000):
        super(TransformerClassifier, self).__init__()
        # Input projection to match transformer dimension
        self.input_projection = nn.Linear(input_dim, d_model)
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_length)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        
        # Output layer
        self.fc = nn.Linear(d_model, output_dim)
    
    def forward(self, x):
        # Project input features
        x = self.input_projection(x)
        
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Apply transformer encoder
        x = self.transformer_encoder(x)
        
        # Global pooling and classification
        x = x.transpose(1, 2)
        x = self.global_avg_pool(x)
        x = x.squeeze(-1)
        out = self.fc(x)
        
        return out

In [6]:
# Modified calculate_class_weights function
def calculate_class_weights(labels, print_info=False):
    """
    Calculate class weights for imbalanced data
    
    Parameters:
    -----------
    labels : torch.Tensor or numpy.ndarray
        The target labels (0 or 1)
    print_info : bool, default=False
        Whether to print class distribution info
        
    Returns:
    --------
    dict
        Dictionary mapping class indices to weights
    """
    # Convert to torch tensor if it's not already
    if not isinstance(labels, torch.Tensor):
        labels = torch.tensor(labels)
    
    # Count instances of each class
    class_counts = torch.bincount(labels.long())
    
    if print_info:
        print("Class 0 (negative) count:", class_counts[0].item())
        print("Class 1 (positive) count:", class_counts[1].item())
    
    # Calculate weights (inversely proportional to class frequency)
    n_samples = len(labels)
    n_classes = len(class_counts)
    
    # Formula: weight = n_samples / (n_classes * class_count)
    weights = n_samples / (n_classes * class_counts.float())
    
    if print_info:
        print("Class 0 weight:", weights[0].item())
        print("Class 1 weight:", weights[1].item())
    
    # Return weights as a dictionary mapping class indices to weights
    return {i: weights[i].item() for i in range(len(weights))}


In [7]:


def train_transformer_model(model, train_loader, val_loader, learning_rate=0.0005, class_weights=None, num_epochs=50, patience=3):
    """
    Train the Transformer model for binary classification.
    
    Parameters:
    -----------
    model : nn.Module
        The transformer model
    train_loader : DataLoader
        DataLoader for training data
    val_loader : DataLoader
        DataLoader for validation data
    learning_rate : float, default=0.0005
        Learning rate for optimizer
    num_epochs : int, default=50
        Maximum number of training epochs
    patience : int, default=5
        Number of epochs with no improvement after which training will be stopped
    
    Returns:
    --------
    model : nn.Module
        The trained transformer model
    train_losses : list
        List of training losses for each epoch
    val_losses : list
        List of validation losses for each epoch
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Create weight tensor for BCEWithLogitsLoss if class_weights provided
    if class_weights is not None:
        # For binary classification with BCEWithLogitsLoss, we need a single weight
        # for the positive class (class 1)
        pos_weight = torch.tensor([class_weights[1] / class_weights[0] ], device=device)
        print(f"Using positive class weight: {pos_weight.item():.4f}")
        
        # Use weighted BCE loss for imbalanced classes
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    else:
        # Use standard BCE loss if no weights provided
        criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-3)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=10,     # Reduce LR every 10 epochs
        gamma=0.5         # Multiply LR by 0.5 at each step
    )
    
    # For tracking training progress
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    early_stopping_counter = 0
    
    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            # Clip gradients to prevent exploding gradients (common in transformers)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
        
        # Calculate average training loss
        train_loss = train_loss / len(train_loader.dataset)
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        all_val_outputs = []
        all_val_targets = []
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                
                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item() * inputs.size(0)
                
                # Store outputs and targets for AUC calculation
                all_val_outputs.append(torch.sigmoid(outputs).cpu())
                all_val_targets.append(targets.cpu())
        
        # Calculate average validation loss
        val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(val_loss)
        
        # Calculate validation AUC for monitoring
        val_outputs = torch.cat(all_val_outputs).numpy().flatten()
        val_targets = torch.cat(all_val_targets).numpy().flatten()
        try:
            from sklearn.metrics import roc_auc_score
            val_auc = roc_auc_score(val_targets, val_outputs)
            print(f"Validation AUC: {val_auc:.4f}")
        except:
            val_auc = 0
            print("Could not calculate validation AUC")
        
        # Print training progress
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
        
        # Learning rate scheduling
        scheduler.step()
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stopping_counter = 0
            # Save the best model
            torch.save(model.state_dict(), 'best_transformer_classifier.pth')
            print(f"Saved new best model with val_loss: {val_loss:.6f}")
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print(f"Early stopping after {epoch+1} epochs")
                break
    
    # Load the best model
    model.load_state_dict(torch.load('best_transformer_classifier.pth'))
    
    return model, train_losses, val_losses

In [8]:


# Import necessary functions from your original code
def load_data(time_series_path, outcomes_path):
    """Reuse the load_data function from your original code"""

    time_series = pd.read_csv(time_series_path)
    outcomes = pd.read_csv(outcomes_path)
    outcomes = outcomes.rename(columns={'RecordID': 'PatientID'})
    patient_outcomes = dict(zip(outcomes['PatientID'], outcomes['In-hospital_death']))
    time_series['In-hospital_death'] = time_series['PatientID'].map(patient_outcomes)
    time_series = time_series.dropna(subset=['In-hospital_death'])
    return time_series

In [9]:
# 1. Load and merge the data with corresponding outcomes (reuse from LSTM code)
print("Loading and merging data...")
# Reuse the load_data function from your original code
train_data = load_data(train_data_path, train_outcomes_path)
val_data = load_data(val_data_path, val_outcomes_path)
test_data = load_data(test_data_path, test_outcomes_path)

Loading and merging data...


In [10]:



def prepare_patient_sequences(data, seq_length=49):
    """Reuse the prepare_patient_sequences function from your original code"""

    # Get all unique patient IDs
    patient_ids = data['PatientID'].unique()
    
    # Define features (all columns except categorical and outcome columns)
    exclude_cols = ['PatientID', 'Hours', 'Gender', 'ICUType', 'In-hospital_death']
    feature_cols = [col for col in data.columns if col not in exclude_cols]
    
    sequences = []
    targets = []
    seq_patient_ids = []
    
    # Create one sequence per patient using the first seq_length hours
    for patient_id in patient_ids:
        # Get data for this patient
        patient_data = data[data['PatientID'] == patient_id].sort_values('Hours')
        
        # Get the outcome for this patient
        outcome = patient_data['In-hospital_death'].iloc[0]
        
        # Skip patients with insufficient data
        if len(patient_data) < seq_length:
            continue
        
        # Extract features for the first seq_length hours
        X_patient = patient_data[feature_cols].values[:seq_length]
        
        # If less than seq_length hours, pad with zeros
        if len(X_patient) < seq_length:
            print("padding..")
            padding = np.zeros((seq_length - len(X_patient), len(feature_cols)))
            X_patient = np.concatenate([X_patient, padding])
        
        sequences.append(X_patient)
        targets.append(outcome)
        seq_patient_ids.append(patient_id)
    
    return np.array(sequences), np.array(targets), np.array(seq_patient_ids)


In [11]:
# 2. Prepare sequences (reuse the prepare_patient_sequences function)
print("Preparing sequences...")
X_train, y_train, train_patient_ids = prepare_patient_sequences(train_data)
X_val, y_val, val_patient_ids = prepare_patient_sequences(val_data)
X_test, y_test, test_patient_ids = prepare_patient_sequences(test_data)

Preparing sequences...


In [16]:
def evaluate_model_simple(model, test_loader, threshold=0.5, save_figures=True, figure_prefix='model'):
    """
    Simplified evaluation function that calculates AUROC, AUPRC, and generates
    a confusion matrix heatmap.
    
    Parameters:
    -----------
    model : nn.Module
        The trained model
    test_loader : DataLoader
        DataLoader for test data
    threshold : float, default=0.5
        Threshold for binary classification
    save_figures : bool, default=True
        Whether to save the figures to disk
    figure_prefix : str, default='model'
        Prefix for saved figure filenames
        
    Returns:
    --------
    tuple
        (metrics_dict, raw_predictions, binary_predictions, true_values)
    """
    # Use GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Evaluation mode
    model.eval()
    
    # For storing predictions and true values
    all_preds = []
    all_targets = []
    
    # No gradient computation for evaluation
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="Evaluating"):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            probs = torch.sigmoid(outputs)
            
            # Store predictions and targets
            all_preds.append(probs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    # Concatenate batches
    predictions = np.concatenate(all_preds).flatten()
    true_values = np.concatenate(all_targets).flatten()
    
    # Convert to binary predictions
    binary_predictions = (predictions >= threshold).astype(int)
    
    # Calculate AUROC
    try:
        auroc = roc_auc_score(true_values, predictions)
    except Exception as e:
        print(f"AUROC calculation failed: {str(e)}")
        auroc = np.nan
    
    # Calculate AUPRC
    try:
        auprc = average_precision_score(true_values, predictions)
    except Exception as e:
        print(f"AUPRC calculation failed: {str(e)}")
        auprc = np.nan
    
    # Calculate accuracy and F1 score
    accuracy = accuracy_score(true_values, binary_predictions)
    f1 = f1_score(true_values, binary_predictions, zero_division=0)
    
    # Calculate confusion matrix
    cm = confusion_matrix(true_values, binary_predictions)
    
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    if save_figures:
        plt.savefig(f'{figure_prefix}_confusion_matrix.png', dpi=300, bbox_inches='tight')
    
    # Create ROC and PR curves
    plt.figure(figsize=(16, 6))
    
    # ROC Curve
    plt.subplot(1, 2, 1)
    fpr, tpr, _ = roc_curve(true_values, predictions)
    plt.plot(fpr, tpr, 'b-', label=f'AUROC = {auroc:.3f}')
    plt.plot([0, 1], [0, 1], 'r--', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    plt.grid(True)
    
    # Precision-Recall Curve
    plt.subplot(1, 2, 2)
    precision_curve, recall_curve, _ = precision_recall_curve(true_values, predictions)
    plt.plot(recall_curve, precision_curve, 'g-', label=f'AUPRC = {auprc:.3f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    if save_figures:
        plt.savefig(f'{figure_prefix}_roc_pr_curves.png', dpi=300, bbox_inches='tight')
    
    plt.show()  # This will display figures in interactive environments like Jupyter
    
    # Create metrics dictionary
    metrics = {
        'AUC': auroc,
        'AUPRC': auprc,
        'Accuracy': accuracy,
        'F1 Score': f1,
        'Confusion Matrix': cm
    }
    
    return metrics, predictions, binary_predictions, true_values


In [None]:



def run_transformer_experiment(train_data_path, val_data_path, test_data_path, 
                              train_outcomes_path, val_outcomes_path, test_outcomes_path,
                              batch_size=1024, learning_rate=0.001, num_epochs=50):
    """
    Run a complete transformer experiment from data loading to evaluation.
    
    Parameters:
    -----------
    train_data_path : str
        Path to training time series data
    val_data_path : str
        Path to validation time series data
    test_data_path : str
        Path to test time series data
    train_outcomes_path : str
        Path to training outcomes data
    val_outcomes_path : str
        Path to validation outcomes data
    test_outcomes_path : str
        Path to test outcomes data
    batch_size : int, default=32
        Batch size for training (smaller than LSTM due to transformer's higher memory usage)
    learning_rate : float, default=0.001
        Learning rate for optimizer
    num_epochs : int, default=50
        Maximum number of training epochs
    
    Returns:
    --------
    model : nn.Module
        The trained transformer model
    metrics : dict
        Dictionary of evaluation metrics
    train_losses : list
        List of training losses for each epoch
    val_losses : list
        List of validation losses for each epoch
    predictions : numpy.ndarray
        Model predictions (probabilities)
    true_values : numpy.ndarray
        True target values
    """
    
    


    # 3. Convert to PyTorch tensors
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.FloatTensor(y_train.reshape(-1, 1))

     # 4. Calculate class weights using training data
    class_weights = calculate_class_weights(y_train, print_info=True)
    print("Class weights:", class_weights)
    
    X_val_tensor = torch.FloatTensor(X_val)
    y_val_tensor = torch.FloatTensor(y_val.reshape(-1, 1))
    
    X_test_tensor = torch.FloatTensor(X_test)
    y_test_tensor = torch.FloatTensor(y_test.reshape(-1, 1))
    
    # 4. Create DataLoaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        shuffle=True
    )

    for sample in train_loader:
        print(sample)
        break
    
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # 5. Create the transformer model
    input_dim = X_train.shape[2]  # Number of features
    output_dim = 1  # Binary classification
    
    print("Creating Transformer model...")
    model = TransformerClassifier(
        input_dim=input_dim,
        d_model=128,
        nhead=4,
        num_layers=3,
        dim_feedforward=512,
        dropout=0.3,
        output_dim=output_dim
    )
    
    # 6. Train the model
    print("Training model...")
    model, train_losses, val_losses = train_transformer_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        learning_rate=learning_rate,
        class_weights=class_weights,
        num_epochs=num_epochs
    )
    
    # 7. Evaluate on test set (reuse the evaluate_model function)
    print("Evaluating model on test set...")
    metrics, predictions, binary_predictions, true_values = evaluate_model_simple(model, test_loader)
    
    # 8. Print metrics
    print("\nTest Set Performance:")
    for metric, value in metrics.items():
        if metric != 'Confusion Matrix':
            print(f"{metric}: {value:.6f}")
    
    print("\nConfusion Matrix:")
    print(metrics['Confusion Matrix'])
    
    # 9. Plot training and validation loss
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Transformer Model Training History')
    plt.legend()
    plt.savefig('transformer_training_history.png')
    
    # 10. Plot ROC curve if possible
    try:
        from sklearn.metrics import roc_curve
        fpr, tpr, _ = roc_curve(true_values, predictions)
        
        plt.figure(figsize=(10, 5))
        plt.plot(fpr, tpr, label=f'AUC = {metrics["AUC"]:.3f}')
        plt.plot([0, 1], [0, 1], 'k--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Transformer ROC Curve')
        plt.legend()
        plt.savefig('transformer_roc_curve.png')
    except:
        print("Could not generate ROC curve, possibly due to having only one class in the test set.")
    
    return model, metrics, train_losses, val_losses, predictions, true_values

In [18]:
torch._dynamo.list_backends()

['cudagraphs', 'inductor', 'onnxrt', 'openxla', 'tvm']

In [None]:


# Example usage
if __name__ == "__main__":
    
    
    # Run transformer experiment
    print("\n=== Running Transformer Experiment ===")
    transformer_model, transformer_metrics, transformer_train_losses, transformer_val_losses, transformer_predictions, transformer_true_values = run_transformer_experiment(
        train_data_path=train_data_path,
        val_data_path=val_data_path,
        test_data_path=test_data_path,
        train_outcomes_path=train_outcomes_path,
        val_outcomes_path=val_outcomes_path,
        test_outcomes_path=test_outcomes_path,
        batch_size=1024,  # Smaller batch size due to transformer's higher memory usage
        learning_rate=1e-5,
        num_epochs=50
    )
    



=== Running Transformer Experiment ===
Class 0 (negative) count: 3446
Class 1 (positive) count: 554
Class 0 weight: 0.5803830623626709
Class 1 weight: 3.6101083755493164
Class weights: {0: 0.5803830623626709, 1: 3.6101083755493164}
[tensor([[[ 0.2267,  1.7918,  0.5306,  ...,  0.0000,  8.3303,  0.0000],
         [ 0.2267,  1.7918,  0.5306,  ...,  0.0000,  8.3303,  0.0000],
         [ 0.2267,  1.7918,  0.5306,  ...,  0.0000,  8.3303,  0.0000],
         ...,
         [ 0.2267,  1.7918,  0.5878,  ...,  0.0000, -2.6697,  0.0000],
         [ 0.2267,  1.7918,  0.5878,  ...,  0.0000,  6.3303,  0.0000],
         [ 0.2267,  1.7918,  0.5878,  ...,  0.0000,  4.3303,  0.0000]],

        [[ 0.6667,  3.2581,  0.5878,  ..., -0.9950,  0.0000,  0.0000],
         [ 0.6667,  3.2581,  0.5878,  ..., -0.9950,  0.0000,  0.0000],
         [ 0.6667,  3.2581,  0.5878,  ..., -0.9950,  0.0000,  0.0000],
         ...,
         [ 0.6667,  3.0445,  0.4055,  ..., -1.0350,  0.0000,  0.0000],
         [ 0.6667,  3.0445