In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score, f1_score
from sklearn.metrics import roc_curve, precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
class TripletEmbedding(nn.Module):
    """
    Processes input triplets (t, z, v) where:
    - t: time (scaled to [0,1])
    - z: categorical variable (one-hot encoded)
    - v: observed value (scaled)
    """
    def __init__(self, num_variables=41, embedding_dim=64):
        super(TripletEmbedding, self).__init__()
        self.num_variables = num_variables
        self.embedding_dim = embedding_dim
        
        # Linear projection for time values
        self.time_projection = nn.Linear(1, embedding_dim)
        
        # Embedding for variable categories
        self.variable_embedding = nn.Embedding(num_variables, embedding_dim)
        
        # Linear projection for observed values
        self.value_projection = nn.Linear(1, embedding_dim)
        
        # Final projection to combine all three embeddings
        self.combined_projection = nn.Linear(3 * embedding_dim, embedding_dim)
        
    def forward(self, x):
        """
        Args:
            x: Tensor of shape [batch_size, seq_length, 3]
               where each triplet is (t, z, v)
        """
        batch_size, seq_length, _ = x.shape
        
        # Split the input into components
        t = x[:, :, 0].unsqueeze(-1)  # [batch_size, seq_length, 1]
        z = x[:, :, 1].long()          # [batch_size, seq_length]
        v = x[:, :, 2].unsqueeze(-1)  # [batch_size, seq_length, 1]
        
        # Process each component
        t_emb = self.time_projection(t)  # [batch_size, seq_length, embedding_dim]
        z_emb = self.variable_embedding(z)  # [batch_size, seq_length, embedding_dim]
        v_emb = self.value_projection(v)  # [batch_size, seq_length, embedding_dim]
        
        # Combine all embeddings
        combined = torch.cat([t_emb, z_emb, v_emb], dim=2)  # [batch_size, seq_length, 3*embedding_dim]
        
        # Project to final embedding dimension
        output = self.combined_projection(combined)  # [batch_size, seq_length, embedding_dim]
        
        return output


In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        
        # Register as buffer (not a parameter but part of the module)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # Add positional encoding to input
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [None]:
class TripletTransformerClassifier(nn.Module):
    def __init__(self, num_variables=41, d_model=64, nhead=4, num_layers=2, 
                 dim_feedforward=128, dropout=0.1, output_dim=1, max_seq_length=5000):
        super(TripletTransformerClassifier, self).__init__()
        
        # Triplet embedding layer
        self.triplet_embedding = TripletEmbedding(num_variables, d_model)
        
        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_length)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
        
        # Output layer
        self.fc = nn.Linear(d_model, output_dim)
    
    def forward(self, x):
        # Process triplets into embeddings
        x = self.triplet_embedding(x)
        
        # Add positional encoding
        x = self.pos_encoder(x)
        
        # Apply transformer encoder
        x = self.transformer_encoder(x)
        
        # Global pooling and classification
        x = x.transpose(1, 2)
        x = self.global_avg_pool(x)
        x = x.squeeze(-1)
        out = self.fc(x)
        
        return out

In [16]:
def load_triplet_data(triplet_data_path, outcomes_path):
    """
    Load triplet data from CSV file and outcomes.
    
    Args:
        triplet_data_path: Path to CSV with triplet data
        outcomes_path: Path to outcomes file
    
    Returns:
        DataFrame with combined data
    """
    # Load triplet data
    triplet_data = pd.read_csv(triplet_data_path)
    
    # Load outcomes
    outcomes = pd.read_csv(outcomes_path)
    outcomes = outcomes.rename(columns={'RecordID': 'PatientID'})
    
    # Map outcomes to patients in triplet data
    patient_outcomes = dict(zip(outcomes['PatientID'], outcomes['In-hospital_death']))
    
    # Return combined data
    return triplet_data, patient_outcomes

In [17]:
def prepare_patient_triplet_sequences(data, patient_outcomes, max_seq_length=5000):
    """
    Prepare patient sequences from triplet data.
    
    Args:
        data: DataFrame with triplet data (PatientID, t, z, v)
        patient_outcomes: Dictionary mapping PatientID to outcome
        max_seq_length: Maximum sequence length to use
    
    Returns:
        Tuple of (sequences, targets, patient_ids)
    """
    # Get all unique patient IDs
    patient_ids = data['PatientID'].unique()
    
    sequences = []
    targets = []
    seq_patient_ids = []
    
    # Create one sequence per patient using all triplets
    for patient_id in patient_ids:
        # Skip patients not in outcomes
        if patient_id not in patient_outcomes:
            continue
            
        # Get outcome for this patient
        outcome = patient_outcomes[patient_id]
        
        # Get data for this patient
        patient_data = data[data['PatientID'] == patient_id]
        
        # Extract triplets (t, z, v)
        triplets = patient_data[['t', 'z', 'v']].values
        
        # Skip patients with no data
        if len(triplets) == 0:
            continue
            
        # Limit sequence length if needed
        if len(triplets) > max_seq_length:
            triplets = triplets[:max_seq_length]
        
        # If less than max_seq_length triplets, pad with zeros
        if len(triplets) < max_seq_length:
            padding = np.zeros((max_seq_length - len(triplets), 3))
            triplets = np.concatenate([triplets, padding])
        
        sequences.append(triplets)
        targets.append(outcome)
        seq_patient_ids.append(patient_id)
    
    return np.array(sequences), np.array(targets), np.array(seq_patient_ids)

In [18]:

def train_transformer_model(model, train_loader, val_loader, learning_rate=0.0005, num_epochs=50, patience=5):
    """
    Train the Transformer model for binary classification.
    
    Parameters:
    -----------
    model : nn.Module
        The transformer model
    train_loader : DataLoader
        DataLoader for training data
    val_loader : DataLoader
        DataLoader for validation data
    learning_rate : float, default=0.0005
        Learning rate for optimizer
    num_epochs : int, default=50
        Maximum number of training epochs
    patience : int, default=5
        Number of epochs with no improvement after which training will be stopped
    
    Returns:
    --------
    model : nn.Module
        The trained transformer model
    train_losses : list
        List of training losses for each epoch
    val_losses : list
        List of validation losses for each epoch
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Calculate positive weight based on class imbalance
    pos_weight = torch.tensor([5.0])  # Adjust based on your class distribution
    
    # Define loss function and optimizer
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-3)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=10,     # Reduce LR every 10 epochs
        gamma=0.5         # Multiply LR by 0.5 at each step
    )
    
    # For tracking training progress
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    early_stopping_counter = 0
    
    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            # Clip gradients to prevent exploding gradients (common in transformers)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
        
        # Calculate average training loss
        train_loss = train_loss / len(train_loader.dataset)
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        all_val_outputs = []
        all_val_targets = []
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                
                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item() * inputs.size(0)
                
                # Store outputs and targets for AUC calculation
                all_val_outputs.append(torch.sigmoid(outputs).cpu())
                all_val_targets.append(targets.cpu())
        
        # Calculate average validation loss
        val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(val_loss)
        
        # Calculate validation AUC for monitoring
        val_outputs = torch.cat(all_val_outputs).numpy().flatten()
        val_targets = torch.cat(all_val_targets).numpy().flatten()
        try:
            val_auc = roc_auc_score(val_targets, val_outputs)
            print(f"Validation AUC: {val_auc:.4f}")
        except:
            val_auc = 0
            print("Could not calculate validation AUC")
        
        # Print training progress
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
        
        # Learning rate scheduling
        scheduler.step()
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stopping_counter = 0
            # Save the best model
            torch.save(model.state_dict(), 'best_triplet_transformer.pth')
            print(f"Saved new best model with val_loss: {val_loss:.6f}")
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= patience:
                print(f"Early stopping after {epoch+1} epochs")
                break
    
    # Load the best model
    model.load_state_dict(torch.load('best_triplet_transformer.pth'))
    
    return model, train_losses, val_losses

In [19]:

def evaluate_model_simple(model, test_loader, threshold=0.5, save_figures=True, figure_prefix='triplet_model'):
    """
    Simplified evaluation function that calculates AUROC, AUPRC, and generates
    a confusion matrix heatmap.
    
    Parameters:
    -----------
    model : nn.Module
        The trained model
    test_loader : DataLoader
        DataLoader for test data
    threshold : float, default=0.5
        Threshold for binary classification
    save_figures : bool, default=True
        Whether to save the figures to disk
    figure_prefix : str, default='model'
        Prefix for saved figure filenames
        
    Returns:
    --------
    tuple
        (metrics_dict, raw_predictions, binary_predictions, true_values)
    """
    # Use GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Evaluation mode
    model.eval()
    
    # For storing predictions and true values
    all_preds = []
    all_targets = []
    
    # No gradient computation for evaluation
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc="Evaluating"):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            probs = torch.sigmoid(outputs)
            
            # Store predictions and targets
            all_preds.append(probs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())
    
    # Concatenate batches
    predictions = np.concatenate(all_preds).flatten()
    true_values = np.concatenate(all_targets).flatten()
    
    # Convert to binary predictions
    binary_predictions = (predictions >= threshold).astype(int)
    
    # Calculate AUROC
    try:
        auroc = roc_auc_score(true_values, predictions)
    except Exception as e:
        print(f"AUROC calculation failed: {str(e)}")
        auroc = np.nan
    
    # Calculate AUPRC
    try:
        auprc = average_precision_score(true_values, predictions)
    except Exception as e:
        print(f"AUPRC calculation failed: {str(e)}")
        auprc = np.nan
    
    # Calculate accuracy and F1 score
    accuracy = accuracy_score(true_values, binary_predictions)
    f1 = f1_score(true_values, binary_predictions, zero_division=0)
    
    # Calculate confusion matrix
    cm = confusion_matrix(true_values, binary_predictions)
    
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    if save_figures:
        plt.savefig(f'{figure_prefix}_confusion_matrix.png', dpi=300, bbox_inches='tight')
    
    # Create ROC and PR curves
    plt.figure(figsize=(16, 6))
    
    # ROC Curve
    plt.subplot(1, 2, 1)
    fpr, tpr, _ = roc_curve(true_values, predictions)
    plt.plot(fpr, tpr, 'b-', label=f'AUROC = {auroc:.3f}')
    plt.plot([0, 1], [0, 1], 'r--', label='Random')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    plt.grid(True)
    
    # Precision-Recall Curve
    plt.subplot(1, 2, 2)
    precision_curve, recall_curve, _ = precision_recall_curve(true_values, predictions)
    plt.plot(recall_curve, precision_curve, 'g-', label=f'AUPRC = {auprc:.3f}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    if save_figures:
        plt.savefig(f'{figure_prefix}_roc_pr_curves.png', dpi=300, bbox_inches='tight')
    
    plt.show()  # This will display figures in interactive environments like Jupyter
    
    # Create metrics dictionary
    metrics = {
        'AUC': auroc,
        'AUPRC': auprc,
        'Accuracy': accuracy,
        'F1 Score': f1,
        'Confusion Matrix': cm
    }
    
    return metrics, predictions, binary_predictions, true_values





In [20]:
def run_triplet_transformer_experiment(train_data_path, val_data_path, test_data_path, 
                                      train_outcomes_path, val_outcomes_path, test_outcomes_path,
                                      batch_size=32, learning_rate=0.001, num_epochs=50, 
                                      max_seq_length=5000, num_variables=41):
    """
    Run a complete transformer experiment from data loading to evaluation.
    
    Parameters:
    -----------
    train_data_path : str
        Path to training triplet data
    val_data_path : str
        Path to validation triplet data
    test_data_path : str
        Path to test triplet data
    train_outcomes_path : str
        Path to training outcomes data
    val_outcomes_path : str
        Path to validation outcomes data
    test_outcomes_path : str
        Path to test outcomes data
    batch_size : int, default=32
        Batch size for training
    learning_rate : float, default=0.001
        Learning rate for optimizer
    num_epochs : int, default=50
        Maximum number of training epochs
    max_seq_length : int, default=5000
        Maximum sequence length to use
    num_variables : int, default=41
        Number of different variables in the dataset
    
    Returns:
    --------
    model : nn.Module
        The trained transformer model
    metrics : dict
        Dictionary of evaluation metrics
    train_losses : list
        List of training losses for each epoch
    val_losses : list
        List of validation losses for each epoch
    predictions : numpy.ndarray
        Model predictions (probabilities)
    true_values : numpy.ndarray
        True target values
    """
    # 1. Load data
    print("Loading data...")
    train_data, train_outcomes = load_triplet_data(train_data_path, train_outcomes_path)
    val_data, val_outcomes = load_triplet_data(val_data_path, val_outcomes_path)
    test_data, test_outcomes = load_triplet_data(test_data_path, test_outcomes_path)
    
    # 2. Prepare sequences
    print("Preparing sequences...")
    X_train, y_train, train_patient_ids = prepare_patient_triplet_sequences(
        train_data, train_outcomes, max_seq_length)
    X_val, y_val, val_patient_ids = prepare_patient_triplet_sequences(
        val_data, val_outcomes, max_seq_length)
    X_test, y_test, test_patient_ids = prepare_patient_triplet_sequences(
        test_data, test_outcomes, max_seq_length)

    # 3. Convert to PyTorch tensors
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.FloatTensor(y_train.reshape(-1, 1))
    
    X_val_tensor = torch.FloatTensor(X_val)
    y_val_tensor = torch.FloatTensor(y_val.reshape(-1, 1))
    
    X_test_tensor = torch.FloatTensor(X_test)
    y_test_tensor = torch.FloatTensor(y_test.reshape(-1, 1))
    
    # 4. Create DataLoaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        shuffle=True
    )
    
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    # 5. Create the transformer model
    print("Creating Triplet Transformer model...")
    model = TripletTransformerClassifier(
        num_variables=num_variables,
        d_model=64,
        nhead=4,
        num_layers=2,
        dim_feedforward=128,
        dropout=0.3,
        output_dim=1,
        max_seq_length=max_seq_length
    )
    
    # 6. Train the model
    print("Training model...")
    model, train_losses, val_losses = train_transformer_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        learning_rate=learning_rate,
        num_epochs=num_epochs
    )
    
    # 7. Evaluate on test set
    print("Evaluating model on test set...")
    metrics, predictions, binary_predictions, true_values = evaluate_model_simple(
        model, test_loader, figure_prefix='triplet_transformer')
    
    # 8. Print metrics
    print("\nTest Set Performance:")
    for metric, value in metrics.items():
        if metric != 'Confusion Matrix':
            print(f"{metric}: {value:.6f}")
    
    print("\nConfusion Matrix:")
    print(metrics['Confusion Matrix'])
    
    # 9. Plot training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Triplet Transformer Model Training History')
    plt.legend()
    plt.savefig('triplet_transformer_training_history.png')
    
    return model, metrics, train_losses, val_losses, predictions, true_values

In [21]:
# Example usage
if __name__ == "__main__":
    # Define data paths
    train_data_path = "/home/taekim/enhanced_preprocessed_data/patient_triplets_set-a.csv"
    val_data_path = "/home/taekim/enhanced_preprocessed_data/patient_triplets_set-b.csv"
    test_data_path = "/home/taekim/enhanced_preprocessed_data/patient_triplets_set-c.csv"
    
    train_outcomes_path = "/home/taekim/ml4h_data/p1/Outcomes-a.txt"
    val_outcomes_path = "/home/taekim/ml4h_data/p1/Outcomes-b.txt"
    test_outcomes_path = "/home/taekim/ml4h_data/p1/Outcomes-c.txt"
    
    # Run experiment
    model, metrics, train_losses, val_losses, predictions, true_values = run_triplet_transformer_experiment(
        train_data_path=train_data_path,
        val_data_path=val_data_path,
        test_data_path=test_data_path,
        train_outcomes_path=train_outcomes_path,
        val_outcomes_path=val_outcomes_path,
        test_outcomes_path=test_outcomes_path,
        batch_size=32,
        learning_rate=0.001,
        num_epochs=50,
        max_seq_length=5000,
        num_variables=41
    )

Loading data...


KeyboardInterrupt: 