In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import pickle
import os

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(123)
np.random.seed(123)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class DiscreteTimeNN(torch.nn.Module):
    def __init__(self, hidden_layer_sizes, num_bins):
        super(DiscreteTimeNN, self).__init__()
        
        self.encoder_layers = nn.ModuleList([
            torch.nn.LazyLinear(size)
            for size in hidden_layer_sizes
        ])
        
        self.activation = torch.nn.ReLU()
        self.prediction_head = torch.nn.LazyLinear(num_bins + 1)
        self.softmax = torch.nn.Softmax(-1)
    
    def forward(self, x):
        for layer in self.encoder_layers:
            x = layer(x)
            x = self.activation(x)
        
        x = self.prediction_head(x)
        x = self.softmax(x)
        
        return x

class DiscreteFailureTimeNLL(torch.nn.Module):
    
    def __init__(self, bin_boundaries, tolerance=1e-8):
        super(DiscreteFailureTimeNLL, self).__init__()
        
        # Register as buffers so they move with the module
        self.register_buffer('bin_starts', torch.tensor(bin_boundaries[:-1], dtype=torch.float32))
        self.register_buffer('bin_ends', torch.tensor(bin_boundaries[1:], dtype=torch.float32))
        self.register_buffer('bin_lengths', self.bin_ends - self.bin_starts)
        
        self.tolerance = tolerance
    
    def _discretize_times(self, times):
        return (
            (times[:, None] > self.bin_starts[None, :])
            & (times[:, None] <= self.bin_ends[None, :])
        )
    
    def _get_proportion_of_bins_completed(self, times):
        return torch.clamp(
            (times[:, None] - self.bin_starts[None, :]) / self.bin_lengths[None, :],
            min=0.0,
            max=1.0
        )
    
    def forward(self, predictions, event_indicators, event_times):
        # Ensure input tensors are on the same device as the model
        event_indicators = event_indicators.to(predictions.device)
        event_times = event_times.to(predictions.device)
        
        event_likelihood = torch.sum(
            self._discretize_times(event_times) * predictions[:, :-1],
            dim=-1
        ) + self.tolerance
        
        nonevent_likelihood = 1 - torch.sum(
            self._get_proportion_of_bins_completed(event_times) * predictions[:, :-1],
            dim=-1
        ) + self.tolerance
        
        log_likelihood = event_indicators * torch.log(event_likelihood)
        log_likelihood += (1 - event_indicators) * torch.log(nonevent_likelihood)
        
        return -1. * torch.mean(log_likelihood)

def preprocess_data(data, 
                   numerical_cols=['age', 'eGFR', 'sbp', 'bmi', 'tc', 'hdlc'],
                   categorical_cols=['diabetes', 'smoker', 'antihtn', 'statin'],
                   scaler=None):
    """
    Preprocess data: standardize numerical columns, ensure categorical are numeric
    """
    processed_data = data.copy()
    
    # Handle numerical columns
    if scaler is None:
        # Training phase - fit scaler
        scaler = StandardScaler()
        numerical_data = processed_data[numerical_cols]
        processed_data[numerical_cols] = scaler.fit_transform(numerical_data)
    else:
        # Test phase - apply existing scaler
        numerical_data = processed_data[numerical_cols]
        processed_data[numerical_cols] = scaler.transform(numerical_data)
    
    # Handle categorical columns - ensure they're numeric
    for col in categorical_cols:
        if col in processed_data.columns:
            processed_data[col] = processed_data[col].astype(float)
    
    return {'data': processed_data, 'scaler': scaler}

def get_batches(*arrs, batch_size=1):
    """Generate batches of data"""
    l = len(arrs[0])
    for ndx in range(0, l, batch_size):
        yield [torch.tensor(arr[ndx:min(ndx + batch_size, l)], dtype=torch.float32) for arr in arrs]

def create_time_bins(times, num_bins=10):
    """Create time bins for discretization"""
    max_time = np.max(times)
    bin_boundaries = np.linspace(0, max_time, num_bins + 1)
    return bin_boundaries

def train_dnn_model(data, 
                   event_var, 
                   time_var,
                   prediction_times=[5, 10],
                   gender_val=None,
                   hidden_layers=[64, 32],
                   num_bins=20,
                   epochs=200,
                   learning_rate=0.001,
                   batch_size=100,
                   validation_split=0.1,
                   patience=15,
                   min_delta=1e-4):
    """
    Train DNN model with validation split and early stopping
    """
    # Filter by gender if specified
    if gender_val is not None:
        data_filtered = data[data['sex'] == gender_val].copy()
    else:
        data_filtered = data.copy()
    
    # Define feature columns
    feature_cols = ['age', 'eGFR', 'sbp', 'bmi', 'tc', 'hdlc', 
                   'diabetes', 'smoker', 'antihtn', 'statin']
    
    # Split into train/validation
    train_data, val_data = train_test_split(
        data_filtered, 
        test_size=validation_split, 
        stratify=data_filtered[event_var], 
        random_state=123
    )
    
    print(f"Train: {len(train_data)}, Validation: {len(val_data)}")
    
    # Preprocess data
    train_processed = preprocess_data(train_data, scaler=None)
    train_features = train_processed['data'][feature_cols]
    scaler = train_processed['scaler']
    
    val_processed = preprocess_data(val_data, scaler=scaler)
    val_features = val_processed['data'][feature_cols]
    
    # Create time bins based on training data
    bin_boundaries = create_time_bins(train_data[time_var].values, num_bins)
    
    # Convert to numpy arrays
    train_features_np = train_features.values
    val_features_np = val_features.values
    train_events_np = train_data[event_var].values
    train_times_np = train_data[time_var].values
    val_events_np = val_data[event_var].values
    val_times_np = val_data[time_var].values
    
    # Initialize model and move to device
    model = DiscreteTimeNN(hidden_layers, num_bins).to(device)
    loss_fn = DiscreteFailureTimeNLL(bin_boundaries).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Early stopping variables
    best_val_loss = float('inf')
    best_model_state = None
    epochs_without_improvement = 0
    
    # Training loop with validation
    for epoch in range(epochs):
        # Training phase
        model.train()
        epoch_train_losses = []
        
        for batch_data in get_batches(train_features_np, train_events_np, train_times_np, batch_size=batch_size):
            batch_X, batch_s, batch_t = batch_data
            batch_X = batch_X.to(device)
            batch_s = batch_s.to(device)
            batch_t = batch_t.to(device)
            
            optimizer.zero_grad()
            predictions = model(batch_X)
            loss = loss_fn(predictions, batch_s, batch_t)
            loss.backward()
            optimizer.step()
            
            epoch_train_losses.append(loss.item())
        
        avg_train_loss = np.mean(epoch_train_losses)
        
        # Validation phase
        model.eval()
        val_epoch_losses = []
        
        with torch.no_grad():
            for batch_data in get_batches(val_features_np, val_events_np, val_times_np, batch_size=batch_size):
                batch_X, batch_s, batch_t = batch_data
                batch_X = batch_X.to(device)
                batch_s = batch_s.to(device)
                batch_t = batch_t.to(device)
                
                predictions = model(batch_X)
                loss = loss_fn(predictions, batch_s, batch_t)
                val_epoch_losses.append(loss.item())
        
        avg_val_loss = np.mean(val_epoch_losses)
        
        # Early stopping logic
        if avg_val_loss < best_val_loss - min_delta:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
        
        # Print progress
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch + 1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            print(f"Best Val Loss: {best_val_loss:.4f}, Patience: {epochs_without_improvement}/{patience}")
        
        # Early stopping
        if epochs_without_improvement >= patience:
            print(f"Early stopping at epoch {epoch + 1}")
            break
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with validation loss: {best_val_loss:.4f}")
    
    return {
        'model': model,
        'scaler': scaler,
        'bin_boundaries': bin_boundaries,
        'feature_cols': feature_cols
    }

def predict_with_dnn(model_dict, test_data, prediction_times=[5, 10]):
    """
    Make predictions using trained DNN model
    """
    model = model_dict['model']
    scaler = model_dict['scaler']
    bin_boundaries = model_dict['bin_boundaries']
    feature_cols = model_dict['feature_cols']
    
    # Preprocess test data
    test_processed = preprocess_data(test_data, scaler=scaler)
    test_features = test_processed['data'][feature_cols]
    test_features_np = test_features.values
    
    # Make predictions
    model.eval()
    predictions = {}
    
    with torch.no_grad():
        test_x = torch.tensor(test_features_np, dtype=torch.float32).to(device)
        test_predictions = model(test_x)
        
        # Calculate survival probabilities for each prediction time
        for pred_time in prediction_times:
            bin_idx = np.where(bin_boundaries[:-1] <= pred_time)[0]
            if len(bin_idx) > 0:
                cum_event_prob = torch.sum(test_predictions[:, bin_idx], dim=-1)
                event_prob = torch.clamp(cum_event_prob, 0, 1)
                predictions[f"{pred_time}y"] = event_prob.cpu().numpy()
    
    return predictions

In [None]:
print("Loading data...")
strict = pd.read_csv('../../data/strict_cohort2/all_data.csv')
data = pd.read_csv('../../data/relax_cohort2/all_data.csv')  # Assuming this is your main data
    
    # Initialize new columns
new_cols = ["10y_dnn_cvd", "10y_dnn_ascvd", "10y_dnn_hf", 
            "5y_dnn_cvd", "5y_dnn_ascvd", "5y_dnn_hf"]
for col in new_cols:
    data[col] = np.nan
    
    print(f"Initialized columns: {new_cols}")
    print(f"Data columns containing 'dnn': {[col for col in data.columns if 'dnn' in col]}")
    
    # Stage 1: Copy paste existing predictions for overlapping patients
    print("Stage 1: Copying existing predictions...")
    print(f"Strict data columns containing 'dnn': {[col for col in strict.columns if 'dnn' in col]}")
    
    # Set PATID as index for both dataframes for easier matching
    data_indexed = data.set_index('PATID')
    strict_indexed = strict.set_index('PATID')
    
    # Copy existing predictions directly
    for col in new_cols:
        if col in strict_indexed.columns:
            # Find overlapping patients
            overlapping_patients = data_indexed.index.intersection(strict_indexed.index)
            # Copy non-null values from strict to relax data
            strict_values = strict_indexed.loc[overlapping_patients, col]
            non_null_mask = strict_values.notna()
            if non_null_mask.sum() > 0:
                data_indexed.loc[overlapping_patients[non_null_mask], col] = strict_values[non_null_mask]
                print(f"Copied {non_null_mask.sum()} values for {col}")
    
    # Reset index back to regular dataframe
    data = data_indexed.reset_index()
    
    # Stage 2: Train DNN models on all strict data and predict for missing values
    print("Stage 2: Training DNN models...")
    
    # Define model specifications
    model_specs = [
        {'event': 'cvd', 'prediction_times': [5, 10]},
        {'event': 'ascvd', 'prediction_times': [5, 10]},
        {'event': 'hf', 'prediction_times': [5, 10]}
    ]
    
    # Create model directory if it doesn't exist
    os.makedirs('../model', exist_ok=True)
    
    for spec in model_specs:
        event = spec['event']
        prediction_times = spec['prediction_times']
        
        for pred_time in prediction_times:
            event_var = f"{event}_{pred_time}y"
            time_var = f"time2{event}_{pred_time}y"
            
            if event_var not in strict.columns or time_var not in strict.columns:
                print(f"Warning: Required columns {event_var} or {time_var} not found in strict data")
                continue
            
            print(f"Training models for {event} {pred_time}y...")
            
            # Train separate models for each gender
            for gender_val in [0, 1]:
                gender_str = "female" if gender_val == 1 else "male"
                print(f"  Training {gender_str} model...")
                
                # Train model
                model_dict = train_dnn_model(
                    data=strict,
                    event_var=event_var,
                    time_var=time_var,
                    prediction_times=[pred_time],
                    gender_val=gender_val,
                    hidden_layers=[64, 32],
                    num_bins=20,
                    epochs=200,
                    learning_rate=0.001,
                    batch_size=100,
                    validation_split=0.1,
                    patience=15
                )
                
                # Save model
                model_filename = f"../model/dnn_{event}_{pred_time}y_{gender_str}.pkl"
                with open(model_filename, 'wb') as f:
                    pickle.dump(model_dict, f)
                
                print(f"  Saved model to {model_filename}")
                
                # Predict for missing values in main data
                target_col = f"{pred_time}y_dnn_{event}"
                
                # Check if column exists, if not skip
                if target_col not in data.columns:
                    print(f"  Warning: Column {target_col} not found in data, skipping...")
                    continue
                
                # Find rows where prediction is missing for this gender
                gender_mask = data['sex'] == gender_val
                missing_mask = data[target_col].isna()
                target_mask = gender_mask & missing_mask
                
                if target_mask.sum() > 0:
                    print(f"  Predicting for {target_mask.sum()} missing {gender_str} patients...")
                    
                    # Make predictions
                    test_data = data[target_mask].copy()
                    predictions = predict_with_dnn(model_dict, test_data, [pred_time])
                    
                    # Update main data
                    data.loc[target_mask, target_col] = predictions[f"{pred_time}y"]

In [None]:
data.head()

In [None]:
data.to_csv('../../data/relax_cohort2/all_data2.csv', index=False)