In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

In [None]:
data = pd.read_csv('../../data/strict_cohort2/all_data.csv')

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class DiscreteTimeNN(torch.nn.Module):
    def __init__(self, hidden_layer_sizes, num_bins):
        super(DiscreteTimeNN, self).__init__()
        
        self.encoder_layers = nn.ModuleList([
            torch.nn.LazyLinear(size)
            for size in hidden_layer_sizes
        ])
        
        self.activation = torch.nn.ReLU()
        self.prediction_head = torch.nn.LazyLinear(num_bins + 1)
        self.softmax = torch.nn.Softmax(-1)
    
    def forward(self, x):
        for layer in self.encoder_layers:
            x = layer(x)
            x = self.activation(x)
        
        x = self.prediction_head(x)
        x = self.softmax(x)
        
        return x

class DiscreteFailureTimeNLL(torch.nn.Module):
    
    def __init__(self, bin_boundaries, tolerance=1e-8):
        super(DiscreteFailureTimeNLL, self).__init__()
        
        # Register as buffers so they move with the module
        self.register_buffer('bin_starts', torch.tensor(bin_boundaries[:-1], dtype=torch.float32))
        self.register_buffer('bin_ends', torch.tensor(bin_boundaries[1:], dtype=torch.float32))
        self.register_buffer('bin_lengths', self.bin_ends - self.bin_starts)
        
        self.tolerance = tolerance
    
    def _discretize_times(self, times):
        return (
            (times[:, None] > self.bin_starts[None, :])
            & (times[:, None] <= self.bin_ends[None, :])
        )
    
    def _get_proportion_of_bins_completed(self, times):
        return torch.clamp(
            (times[:, None] - self.bin_starts[None, :]) / self.bin_lengths[None, :],
            min=0.0,
            max=1.0
        )
    
    def forward(self, predictions, event_indicators, event_times):
        # Ensure input tensors are on the same device as the model
        event_indicators = event_indicators.to(predictions.device)
        event_times = event_times.to(predictions.device)
        
        event_likelihood = torch.sum(
            self._discretize_times(event_times) * predictions[:, :-1],
            dim=-1
        ) + self.tolerance
        
        nonevent_likelihood = 1 - torch.sum(
            self._get_proportion_of_bins_completed(event_times) * predictions[:, :-1],
            dim=-1
        ) + self.tolerance
        
        log_likelihood = event_indicators * torch.log(event_likelihood)
        log_likelihood += (1 - event_indicators) * torch.log(nonevent_likelihood)
        
        return -1. * torch.mean(log_likelihood)

def preprocess_data(data, 
                   numerical_cols=['age', 'eGFR', 'sbp', 'bmi', 'tc', 'hdlc'],
                   categorical_cols=['diabetes', 'smoker', 'antihtn', 'statin'],
                   scaler=None):
    """
    Preprocess data: standardize numerical columns, ensure categorical are numeric
    """
    processed_data = data.copy()
    
    # Handle numerical columns
    if scaler is None:
        # Training phase - fit scaler
        scaler = StandardScaler()
        numerical_data = processed_data[numerical_cols]
        processed_data[numerical_cols] = scaler.fit_transform(numerical_data)
    else:
        # Test phase - apply existing scaler
        numerical_data = processed_data[numerical_cols]
        processed_data[numerical_cols] = scaler.transform(numerical_data)
    
    # Handle categorical columns - ensure they're numeric
    for col in categorical_cols:
        if col in processed_data.columns:
            processed_data[col] = processed_data[col].astype(float)
    
    return {'data': processed_data, 'scaler': scaler}

def get_batches(*arrs, batch_size=1):
    """Generate batches of data"""
    l = len(arrs[0])
    for ndx in range(0, l, batch_size):
        yield [torch.tensor(arr[ndx:min(ndx + batch_size, l)], dtype=torch.float32) for arr in arrs]

def create_time_bins(times, num_bins=10):
    """Create time bins for discretization"""
    max_time = np.max(times)
    bin_boundaries = np.linspace(0, max_time, num_bins + 1)
    return bin_boundaries

def perform_cv_discrete_nn_with_validation(data, 
                                         event_var, 
                                         time_var,
                                         prediction_times=[5, 10],
                                         folds=5,
                                         hidden_layers=[64, 32],
                                         num_bins=20,
                                         epochs=200,  # Increased since we'll use early stopping
                                         learning_rate=0.001,
                                         batch_size=100,
                                         stratify_var="sex",
                                         validation_split=0.1,  # 10% for validation
                                         patience=15,           # Early stopping patience
                                         min_delta=1e-4):       # Minimum improvement threshold
    """
    Enhanced CV with validation split and early stopping
    
    Args:
        validation_split: Fraction of training data to use for validation
        patience: Number of epochs to wait for improvement before stopping
        min_delta: Minimum change to qualify as an improvement
    """
    # Set seed for reproducibility
    torch.manual_seed(123)
    np.random.seed(123)
    
    # Initialize result arrays
    results = {}
    for pred_time in prediction_times:
        results[f"year_{pred_time}"] = np.zeros(len(data))
    
    # Define feature columns
    feature_cols = ['age', 'eGFR', 'sbp', 'bmi', 'tc', 'hdlc', 
                   'diabetes', 'smoker', 'antihtn', 'statin']
    
    # Perform stratified CV by sex
    if stratify_var and stratify_var in data.columns:
        sex_groups = data[stratify_var].unique()
        
        for sex_val in sex_groups:
            print(f"Processing sex group: {sex_val} ({'Female' if sex_val == 1 else 'Male'})")
            
            # Subset data for current sex group
            sex_mask = data[stratify_var] == sex_val
            sex_data = data[sex_mask].copy()
            sex_indices = np.where(sex_mask)[0]
            
            # Create stratified folds within this sex group
            skf = StratifiedKFold(n_splits=folds, shuffle=True, random_state=123)
            
            for fold_idx, (train_idx, test_idx) in enumerate(skf.split(sex_data, sex_data[event_var])):
                print(f"  Processing sex {sex_val} fold {fold_idx + 1} of {folds}")
                
                # Get train/test data
                train_data = sex_data.iloc[train_idx].copy()
                test_data = sex_data.iloc[test_idx].copy()
                
                # Split training data into train/validation
                if validation_split > 0:
                    train_train_idx, train_val_idx = train_test_split(
                        range(len(train_data)),
                        test_size=validation_split,
                        stratify=train_data[event_var],
                        random_state=123 + fold_idx
                    )
                    
                    val_data = train_data.iloc[train_val_idx].copy()
                    train_data = train_data.iloc[train_train_idx].copy()
                    
                    print(f"    Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")
                
                # Convert local test indices to global indices
                test_index_global = sex_indices[test_idx]
                
                # Preprocess data
                train_processed = preprocess_data(train_data, scaler=None)
                train_features = train_processed['data'][feature_cols]
                scaler = train_processed['scaler']
                
                # Process validation data
                if validation_split > 0:
                    val_processed = preprocess_data(val_data, scaler=scaler)
                    val_features = val_processed['data'][feature_cols]
                    val_features_np = val_features.values
                    val_events_np = val_data[event_var].values
                    val_times_np = val_data[time_var].values
                
                # Process test data
                test_processed = preprocess_data(test_data, scaler=scaler)
                test_features = test_processed['data'][feature_cols]
                
                # Create time bins based on training data
                bin_boundaries = create_time_bins(train_data[time_var].values, num_bins)
                
                # Convert to numpy arrays
                train_features_np = train_features.values
                test_features_np = test_features.values
                train_events_np = train_data[event_var].values
                train_times_np = train_data[time_var].values
                
                # Initialize model and move to device
                device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
                model = DiscreteTimeNN(hidden_layers, num_bins).to(device)
                loss_fn = DiscreteFailureTimeNLL(bin_boundaries).to(device)
                optimizer = optim.Adam(model.parameters(), lr=learning_rate)
                
                # Early stopping variables
                best_val_loss = float('inf')
                best_model_state = None
                epochs_without_improvement = 0
                
                # Training loop with validation
                train_losses = []
                val_losses = []
                
                for epoch in range(epochs):
                    # Training phase
                    model.train()
                    epoch_train_losses = []
                    
                    for batch_data in get_batches(train_features_np, train_events_np, train_times_np, batch_size=batch_size):
                        batch_X, batch_s, batch_t = batch_data
                        batch_X = batch_X.to(device)
                        batch_s = batch_s.to(device)
                        batch_t = batch_t.to(device)
                        
                        optimizer.zero_grad()
                        predictions = model(batch_X)
                        loss = loss_fn(predictions, batch_s, batch_t)
                        loss.backward()
                        optimizer.step()
                        
                        epoch_train_losses.append(loss.item())
                    
                    avg_train_loss = np.mean(epoch_train_losses)
                    train_losses.append(avg_train_loss)
                    
                    # Validation phase
                    if validation_split > 0:
                        model.eval()
                        val_epoch_losses = []
                        
                        with torch.no_grad():
                            for batch_data in get_batches(val_features_np, val_events_np, val_times_np, batch_size=batch_size):
                                batch_X, batch_s, batch_t = batch_data
                                batch_X = batch_X.to(device)
                                batch_s = batch_s.to(device)
                                batch_t = batch_t.to(device)
                                
                                predictions = model(batch_X)
                                loss = loss_fn(predictions, batch_s, batch_t)
                                val_epoch_losses.append(loss.item())
                        
                        avg_val_loss = np.mean(val_epoch_losses)
                        val_losses.append(avg_val_loss)
                        
                        # Early stopping logic
                        if avg_val_loss < best_val_loss - min_delta:
                            best_val_loss = avg_val_loss
                            best_model_state = model.state_dict().copy()
                            epochs_without_improvement = 0
                        else:
                            epochs_without_improvement += 1
                        
                        # Print progress
                        if (epoch + 1) % 20 == 0:
                            print(f"    Sex {sex_val} Fold {fold_idx + 1} Epoch {epoch + 1}")
                            print(f"      Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
                            print(f"      Best Val Loss: {best_val_loss:.4f}, Patience: {epochs_without_improvement}/{patience}")
                        
                        # Early stopping
                        if epochs_without_improvement >= patience:
                            print(f"    Early stopping at epoch {epoch + 1}")
                            break
                    else:
                        # No validation - just print training loss
                        if (epoch + 1) % 20 == 0:
                            print(f"    Sex {sex_val} Fold {fold_idx + 1} Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}")
                
                # Load best model if using validation
                if validation_split > 0 and best_model_state is not None:
                    model.load_state_dict(best_model_state)
                    print(f"    Loaded best model with validation loss: {best_val_loss:.4f}")
                
                # Make predictions on test set
                model.eval()
                with torch.no_grad():
                    test_x = torch.tensor(test_features_np, dtype=torch.float32).to(device)
                    test_predictions = model(test_x)
                    
                    # Calculate survival probabilities for each prediction time
                    for pred_time in prediction_times:
                        bin_idx = np.where(bin_boundaries[:-1] <= pred_time)[0]
                        if len(bin_idx) > 0:
                            cum_event_prob = torch.sum(test_predictions[:, bin_idx], dim=-1)
                            event_prob = torch.clamp(cum_event_prob, 0, 1)
                            results[f"year_{pred_time}"][test_index_global] = event_prob.cpu().numpy()
    
    return results

def predict_multiple_events_enhanced(data,
                                   events=['cvd', 'ascvd', 'hf'],
                                   prediction_times=[5, 10],
                                   folds=5,
                                   hidden_layers=[64, 32],
                                   num_bins=20,
                                   epochs=200,
                                   learning_rate=0.001,
                                   batch_size=100,
                                   validation_split=0.1,
                                   patience=15):
    """
    Enhanced version with validation and early stopping
    """
    result_data = data.copy()
    
    for event in events:
        for h in prediction_times:
            event_var = f"{event}_{h}y"
            time_var = f"time2{event}_{h}y"
            
            print(f"CV for {event_var} using {time_var}")
            
            if event_var not in data.columns or time_var not in data.columns:
                print(f"Warning: Required columns not found")
                continue
            
            # Run enhanced CV
            cv_results = perform_cv_discrete_nn_with_validation(
                data=data,
                event_var=event_var,
                time_var=time_var,
                prediction_times=[h],
                folds=folds,
                hidden_layers=hidden_layers,
                num_bins=num_bins,
                epochs=epochs,
                learning_rate=learning_rate,
                batch_size=batch_size,
                validation_split=validation_split,
                patience=patience
            )
            
            out_col = f"{h}y_dnn_{event}"
            result_data[out_col] = cv_results[f"year_{h}"]
    
    return result_data

In [None]:
results = predict_multiple_events_enhanced(
    data=data,
    events=['cvd', 'ascvd', 'hf'],
    prediction_times=[5, 10],
    folds=5,
    hidden_layers=[64, 32],
    num_bins=20,
    epochs=200,
    learning_rate=0.001,
    validation_split=0.1,
    patience=10
)

In [None]:
results.head(20)

In [None]:
results.to_csv('../../data/strict_cohort2/all_data.csv')