In [1]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import logging
import torch
from datetime import datetime

# Setup paths
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.getcwd()))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Create timestamped results directory
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
BASE_DIR = os.getcwd()
MODEL_DIR = os.path.join(BASE_DIR, 'models', timestamp)
RESULTS_DIR = os.path.join(BASE_DIR, 'results', timestamp)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Imports
from preprocessing.data_container import DataContainer
from models.deep_surv_model import DeepSurvModel
from utils.evaluation import cindex_score

# Setup logging
log_file = os.path.join(RESULTS_DIR, 'training.log')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Check for CUDA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Using device: {device}")


2024-11-10 17:37:05,135 - INFO - Using device: cpu


In [2]:
# Data configuration
DATA_CONFIG = {
    'use_pca': False,
    'gene_type': 'intersection',
    'use_imputed': True,
    'use_cohorts': True
}

# Model configuration with or without CV
USE_CV = True  # Set to False for direct training

if USE_CV:
    # Configuration for CV training
    MODEL_CONFIG = {
        'params_cv': {
            'hidden_layers': [[32, 16], [64, 32], [32, 32, 16]],
            'learning_rate': [0.001, 0.0001],
            'batch_size': [32, 64],
            'num_epochs': [100]
        },
        'use_cohort_cv': True,
        'n_splits_inner': 5
    }
else:
    # Configuration for direct training
    MODEL_CONFIG = {
        'hidden_layers': [32, 16],
        'learning_rate': 0.001,
        'batch_size': 64,
        'num_epochs': 100,
        'device': device,
        'random_state': 42
    }

# Save configurations
config_file = os.path.join(RESULTS_DIR, 'config.json')
import json
with open(config_file, 'w') as f:
    json.dump({
        'data_config': DATA_CONFIG,
        'model_config': MODEL_CONFIG,
        'use_cv': USE_CV
    }, f, indent=4)

try:
    # Create DataContainer and load data
    logger.info("Loading data...")
    data_container = DataContainer(DATA_CONFIG, project_root=PROJECT_ROOT)
    X, y = data_container.load_data()
    
    logger.info(f"Loaded data with shape: X={X.shape}")
    
    # Save feature names
    feature_names = pd.DataFrame({'feature': X.columns})
    feature_names.to_csv(os.path.join(RESULTS_DIR, 'feature_names.csv'), index=False)
    
    # Initialize DeepSurv
    logger.info("Initializing DeepSurv model...")
    if not USE_CV:
        deep_surv = DeepSurvModel(**MODEL_CONFIG)
    else:
        deep_surv = DeepSurvModel(device=device, random_state=42)
    
except Exception as e:
    logger.error(f"Error during initialization: {str(e)}")
    raise


2024-11-10 17:37:07,489 - INFO - Loading data...
2024-11-10 17:37:07,490 - INFO - Loading data...
2024-11-10 17:37:55,030 - INFO - Loaded data: 1091 samples, 13214 features
2024-11-10 17:37:57,463 - INFO - Loaded data with shape: X=(1091, 13214)
2024-11-10 17:37:57,499 - INFO - Initializing DeepSurv model...


In [3]:
try:
    logger.info("Starting model training...")
    
    if USE_CV:
        logger.info(f"Cross-validation config: {MODEL_CONFIG}")
        # Fit with CV
        deep_surv.fit(
            X=X,
            y=y,
            data_container=data_container,
            **MODEL_CONFIG
        )
        
        # Log CV results
        cv_results = pd.DataFrame(deep_surv.cv_results_['cv_results'])
        logger.info("\nCross-validation results:")
        logger.info(f"Mean c-index: {deep_surv.cv_results_['mean_score']:.3f} "
                   f"± {deep_surv.cv_results_['std_score']:.3f}")
        
        # Save detailed CV results
        cv_results.to_csv(os.path.join(RESULTS_DIR, 'cv_results.csv'))
        
    else:
        # Direct training without CV
        X_train, y_train, X_val, y_val = data_container.get_train_val_split(X, y)
        deep_surv.fit(
            X=X_train,
            y=y_train,
            validation_data=(X_val, y_val)
        )
        
        # Evaluate on validation set
        val_pred = deep_surv.predict(X_val)
        val_score = cindex_score(y_val, val_pred)
        logger.info(f"\nValidation c-index: {val_score:.3f}")
        
    logger.info("Model training completed successfully!")
    
except Exception as e:
    logger.error(f"Error during training: {str(e)}")
    raise

2024-11-10 17:38:01,375 - INFO - Starting model training...
2024-11-10 17:38:01,377 - INFO - Cross-validation config: {'params_cv': {'hidden_layers': [[32, 16], [64, 32], [32, 32, 16]], 'learning_rate': [0.001, 0.0001], 'batch_size': [32, 64], 'num_epochs': [100]}, 'use_cohort_cv': True, 'n_splits_inner': 5}
2024-11-10 17:38:01,378 - INFO - Starting DeepSurv training...
2024-11-10 17:38:01,379 - INFO - Starting nested cross-validation for DeepSurv...
2024-11-10 17:38:01,381 - INFO - Outer fold 1
2024-11-10 17:38:05,289 - ERROR - Error during training: Model must be fitted before predicting


ValueError: Model must be fitted before predicting

In [None]:
try:
    # Save model
    logger.info("Saving model...")
    model_name = f"deep_surv_model_{timestamp}"
    deep_surv.save(MODEL_DIR, model_name)
    
    # Save training history
    history_df = pd.DataFrame(deep_surv.training_history_)
    history_df.to_csv(os.path.join(RESULTS_DIR, 'training_history.csv'))
    
    # Plot training history if available
    if len(deep_surv.training_history_['train_loss']) > 0:
        import matplotlib.pyplot as plt
        plt.figure(figsize=(10, 6))
        plt.plot(history_df['train_loss'], label='Training Loss')
        if 'val_loss' in history_df.columns:
            plt.plot(history_df['val_loss'], label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training History')
        plt.legend()
        plt.savefig(os.path.join(RESULTS_DIR, 'training_history.png'))
        plt.close()
    
except Exception as e:
    logger.error(f"Error saving results: {str(e)}")
    raise

# Final summary logging
logger.info("\nDeep Survival Network training pipeline completed successfully!")
logger.info(f"Results saved in: {RESULTS_DIR}")
logger.info(f"Model saved in: {MODEL_DIR}")

# Calculate final c-index if possible
try:
    # Calculate per-cohort performance if cohort information is available
    if hasattr(data_container, 'get_groups') and data_container.get_groups() is not None:
        groups = data_container.get_groups()
        print("\nPerformance by cohort:")
        for cohort in np.unique(groups):
            mask = groups == cohort
            if sum(mask) > 0:
                cohort_pred = deep_surv.predict(X[mask])
                cohort_score = cindex_score(y[mask], cohort_pred)
                print(f"{cohort}: {cohort_score:.3f} (n={sum(mask)})")
    
    # Calculate overall performance
    full_pred = deep_surv.predict(X)
    full_cindex = cindex_score(y, full_pred)
    print("\nModel Performance:")
    print(f"C-index on full dataset: {full_cindex:.3f}")
    
except Exception as e:
    logger.error(f"Error calculating final C-index: {str(e)}")