## Random Survival Forest Training
 Dieses Notebook demonstriert das Training eines Random Survival Forest Models mit verschiedenen Optionen:
- Verschiedene Input-Typen (Kohorten, merged data)
- Mit/ohne PCA
- Grid/Random Search
- Verschiedene Cross-Validation Strategien


## Setup und Konfiguration

### Train Model


In [4]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import logging

# Setup paths
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.getcwd()))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Setup directories
MODEL_DIR = os.path.join(os.getcwd(), 'model')
RESULTS_DIR = os.path.join(os.getcwd(), 'results')
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Imports
from preprocessing.data_container import DataContainer
from models.rsf_model import RSFModel
from utils.evaluation import cindex_score

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Data configuration
DATA_CONFIG = {
    'use_pca': False,
    'gene_type': 'intersection',
    'use_imputed': True,
    'use_cohorts': False
}

# Model configuration
MODEL_CONFIG = {
    'params_cv': {
        'n_estimators': [100, 200],
        'min_samples_split': [5, 10],
        'min_samples_leaf': [3, 5]
    },
    'use_cohort_cv': False,
    'n_splits_inner': 5
}

In [5]:
try:
    # Create DataContainer and load data
    logger.info("Loading data...")
    data_container = DataContainer(DATA_CONFIG, project_root=PROJECT_ROOT)
    X, y = data_container.load_data()
    
    logger.info(f"Loaded data with shape: X={X.shape}")
    
    # Initialize RSF
    logger.info("Initializing RSF model...")
    rsf = RSFModel()
    
    # Optional: Set default parameters
    rsf.set_params(
        n_estimators=100,
        min_samples_split=10,
        min_samples_leaf=5,
        max_features="sqrt",
        random_state=42
    )
    
except Exception as e:
    logger.error(f"Error during initialization: {str(e)}")
    raise

2024-11-10 16:53:49,826 - INFO - Loading data...
2024-11-10 16:53:49,827 - INFO - Loading data...
2024-11-10 16:55:07,048 - INFO - Loaded data: 1091 samples, 13214 features
2024-11-10 16:55:10,584 - INFO - Loaded data with shape: X=(1091, 13214)
2024-11-10 16:55:10,585 - INFO - Initializing RSF model...


In [6]:
try:
    logger.info("Starting model training...")
    logger.info(f"Cross-validation config: {MODEL_CONFIG}")
    
    # Fit model
    rsf.fit(
        X=X,
        y=y,
        data_container=data_container,
        **MODEL_CONFIG
    )
    
    logger.info("Model training completed successfully!")
    
except Exception as e:
    logger.error(f"Error during training: {str(e)}")
    raise

2024-11-10 16:55:18,072 - INFO - Starting model training...
2024-11-10 16:55:18,073 - INFO - Cross-validation config: {'params_cv': {'n_estimators': [100, 200], 'min_samples_split': [5, 10], 'min_samples_leaf': [3, 5]}, 'use_cohort_cv': False, 'n_splits_inner': 5}
2024-11-10 16:55:18,076 - INFO - Starting RSF training...
2024-11-10 16:55:18,077 - INFO - Parameter grid: {'n_estimators': [100, 200], 'min_samples_split': [5, 10], 'min_samples_leaf': [3, 5]}
2024-11-10 16:55:18,082 - INFO - Starting model training...
2024-11-10 16:55:18,084 - INFO - Input data shape: X=(1091, 13214)
2024-11-10 16:55:18,087 - INFO - Starting nested cross-validation...
2024-11-10 16:55:18,089 - INFO - Data shape: X=(1091, 13214), groups=9 unique
2024-11-10 16:55:18,092 - INFO - 
Outer fold 1
2024-11-10 16:55:18,188 - INFO - Test cohort: Atlanta_2014_Long
2024-11-10 16:55:18,190 - INFO - Starting inner grid search with 8 parameter combinations


KeyboardInterrupt: 

In [None]:
try:
    # Get and save feature importance
    logger.info("Calculating feature importance...")
    importance_df = rsf.get_feature_importance(feature_names=X.columns)
    importance_file = os.path.join(RESULTS_DIR, 'feature_importance.csv')
    importance_df.to_csv(importance_file)
    logger.info(f"Saved feature importance to {importance_file}")
    
    # Display top features
    print("\nTop 10 most important features:")
    print(importance_df.head(10))
    
    # Save model
    logger.info("Saving model...")
    rsf.save(MODEL_DIR, "rsf_model")
    logger.info(f"Model saved to {MODEL_DIR}")
    
except Exception as e:
    logger.error(f"Error during final steps: {str(e)}")
    raise

In [None]:
logger.info("\nCross-validation results:")
cv_results = pd.DataFrame(rsf.cv_results_)

print("\nOverall CV Performance:")
print(f"Mean C-index: {cv_results['mean_score'].mean():.3f} "
      f"± {cv_results['std_score'].mean():.3f}")

print("\nBest parameters found:")
print(cv_results['best_params'])

# Save detailed CV results
cv_results_file = os.path.join(RESULTS_DIR, 'cv_results.csv')
cv_results.to_csv(cv_results_file)
logger.info(f"Saved detailed CV results to {cv_results_file}")
    
