## Random Survival Forest Training
 Dieses Notebook demonstriert das Training eines Random Survival Forest Models mit verschiedenen Optionen:
- Verschiedene Input-Typen (Kohorten, merged data)
- Mit/ohne PCA
- Grid/Random Search
- Verschiedene Cross-Validation Strategien


In [6]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.preprocessing import StandardScaler
from sksurv.ensemble import RandomSurvivalForest
import logging

In [7]:
# Get absolute path to project root
def find_project_root(current_path, marker_file='requirements.txt'):
    current = Path(current_path).resolve()
    while current != current.parent:
        if (current / marker_file).exists():
            return str(current)
        current = current.parent
    raise FileNotFoundError(f"Could not find project root with marker file {marker_file}")

try:
    PROJECT_ROOT = find_project_root(os.getcwd())
except FileNotFoundError:
    PROJECT_ROOT = os.path.dirname(os.path.dirname(os.getcwd()))
    print("Warning: Could not find project root automatically.")

if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Setup directories
MODEL_DIR = os.path.join(os.getcwd(), 'model')
RESULTS_DIR = os.path.join(os.getcwd(), 'results')
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)





In [8]:
from preprocessing.data_loader import DataLoader
from preprocessing.dimension_reduction import PCADimensionReduction
from models.rsf_model import RSFModel
from utils.evaluation import cindex_score
from utils.visualization import plot_survival_curves, plot_cv_results

## Setup und Konfiguration

In [9]:


# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Configuration
CONFIG = {
    # Data options
    'USE_COHORTS': True,         # True für kohortenweise CV
    'USE_PCA': False,            # True für PCA Dimensionsreduktion
    'GENE_TYPE': 'intersection', # 'intersection', 'common_genes', oder 'all_genes'
    'USE_IMPUTED': True,        # True für imputierte Daten
    
    # CV options
    'USE_COHORT_CV_INNER': True, # True für kohortenweise innere CV
    'N_SPLITS_INNER': 5,        # Anzahl Inner CV Splits wenn nicht kohortenbasiert
    'USE_PARALLEL': True,       # Parallel Processing
    
    # RSF parameters
    'N_ESTIMATORS': [10, 20],
    'MAX_DEPTH': [3, 5],
    'MIN_SAMPLES_SPLIT': [5, 10],
    'MIN_SAMPLES_LEAF': [3, 5]
}

# Save configuration
pd.Series(CONFIG).to_csv(os.path.join(RESULTS_DIR, 'config.csv'))
logger.info(f"Saved configuration to {os.path.join(RESULTS_DIR, 'config.csv')}")

# Initialize data loader
logger.info("Initializing data loader...")
loader = DataLoader(PROJECT_ROOT)

# Load and prepare data
try:
    logger.info("Loading data...")
    # Load merged data
    X, pdata = loader.get_merged_data(
        gene_type=CONFIG['GENE_TYPE'],
        use_imputed=CONFIG['USE_IMPUTED']
    )
    
    logger.info("Preparing survival data...")
    # Prepare survival data
    y = loader.prepare_survival_data(pdata)
    
    # Setup groups if using cohorts
    if CONFIG['USE_COHORTS']:
        groups = np.array([idx.split('.')[0] for idx in X.index])
    else:
        groups = None
    
    # Print data info
    logger.info("\nData shapes:")
    logger.info(f"X: {X.shape}")
    logger.info(f"y: {y.shape}")
    logger.info(f"y dtype: {y.dtype}")
    
    # Validate survival data
    logger.info("\nSurvival data validation:")
    logger.info(f"Field names: {y.dtype.names}")
    logger.info(f"Event field: {'status' if 'status' in y.dtype.names else 'event'}")
    logger.info(f"Number of events: {y['status' if 'status' in y.dtype.names else 'event'].sum()}")
    logger.info(f"Time range: [{y['time'].min():.1f}, {y['time'].max():.1f}]")
    
    if CONFIG['USE_COHORTS']:
        logger.info("\nCohort distribution:")
        logger.info(pd.Series(groups).value_counts())

except Exception as e:
    logger.error(f"Error loading/preparing data: {str(e)}")
    raise

# Optional PCA
if CONFIG['USE_PCA']:
    logger.info("\nPerforming PCA...")
    pca = PCADimensionReduction(variance_threshold=0.95)
    X = pca.fit_transform(X)
    pca.save(os.path.join(MODEL_DIR, 'pca_transform.pkl'))
    logger.info(f"Reduced dimensions from {X.shape[1]} to {pca.n_components} components")
    logger.info(f"Explained variance ratio: {pca.explained_variance_ratio_.sum():.3f}")

# Setup pipeline and parameter grid
logger.info("\nSetting up model pipeline...")
base_rsf = RandomSurvivalForest(
    n_estimators=10,
    random_state=42
)

pipeline_steps = [
    ('scaler', StandardScaler()),
    ('rsf', base_rsf)
]

# Correct format for sklearn GridSearchCV
param_grid = {
    'rsf__n_estimators': CONFIG['N_ESTIMATORS'],
    'rsf__max_depth': CONFIG['MAX_DEPTH'],
    'rsf__min_samples_split': CONFIG['MIN_SAMPLES_SPLIT'],
    'rsf__min_samples_leaf': CONFIG['MIN_SAMPLES_LEAF']
}

# Validate param_grid format
if not isinstance(param_grid, dict):
    raise ValueError("param_grid must be a dictionary")
for param_name, param_values in param_grid.items():
    if not isinstance(param_values, (list, tuple, np.ndarray)):
        raise ValueError(f"Values for parameter {param_name} must be a list")

logger.info("\nParameter grid:")
for param, values in param_grid.items():
    logger.info(f"{param}: {values}")
logger.info(f"Total combinations: {np.prod([len(v) for v in param_grid.values()])}")

# Initialize and train model
logger.info("\nTraining model...")
rsf = RSFModel()


2024-11-10 06:38:02,872 - INFO - Saved configuration to /Users/jonasschernich/Library/Mobile Documents/com~apple~CloudDocs/Uni/Master/9. Semester/Consulting/Organization/PCaPrognostics/models/rsf/results/config.csv
2024-11-10 06:38:02,875 - INFO - Initializing data loader...
2024-11-10 06:39:46,600 - INFO - Loading data...
2024-11-10 06:39:48,992 - INFO - Preparing survival data...
2024-11-10 06:39:49,019 - INFO - 
Data shapes:
2024-11-10 06:39:49,020 - INFO - X: (1091, 13214)
2024-11-10 06:39:49,020 - INFO - y: (1091,)
2024-11-10 06:39:49,021 - INFO - y dtype: [('status', '?'), ('time', '<f8')]
2024-11-10 06:39:49,022 - INFO - 
Survival data validation:
2024-11-10 06:39:49,024 - INFO - Field names: ('status', 'time')
2024-11-10 06:39:49,025 - INFO - Event field: status
2024-11-10 06:39:49,026 - INFO - Number of events: 559
2024-11-10 06:39:49,027 - INFO - Time range: [0.0, 120.0]
2024-11-10 06:39:49,028 - INFO - 
Cohort distribution:
2024-11-10 06:39:49,036 - INFO - Belfast_2018_Jain 

In [10]:
try:
    rsf.fit_model(
        X=X,
        y=y,
        groups=groups,
        fname='rsf_results',
        path=RESULTS_DIR,
        pipeline_steps=pipeline_steps,
        params_cv=param_grid,
        use_cohort_cv=CONFIG['USE_COHORT_CV_INNER'],
        n_splits_inner=CONFIG['N_SPLITS_INNER'],
        parallel=CONFIG['USE_PARALLEL'],
        refit=True
    )
    logger.info("Model training completed successfully.")

except Exception as e:
    logger.error(f"\nError during model training: {str(e)}")
    raise

# Print training results
if hasattr(rsf, 'cv_results_'):
    logger.info("\nTraining Results:")
    logger.info(f"Mean CV Score: {rsf.cv_results_['mean_score']:.3f} ± {rsf.cv_results_['std_score']:.3f}")
    if 'best_params' in rsf.cv_results_:
        logger.info("\nBest Parameters:")
        for param, value in rsf.cv_results_['best_params'].items():
            logger.info(f"{param}: {value}")

logger.info("\nTraining completed!")


2024-11-10 06:40:18,345 - INFO - Starting nested cross-validation...
2024-11-10 06:40:18,349 - INFO - Data shape: X=(1091, 13214), groups=9 unique
2024-11-10 06:40:18,353 - INFO - 
Outer fold 1
2024-11-10 06:40:18,435 - INFO - Test cohort: Atlanta_2014_Long
2024-11-10 06:40:18,436 - INFO - Starting inner grid search with 16 parameter combinations
2024-11-10 06:41:30,466 - INFO - New best score: 0.504 with params: {'max_depth': 3, 'min_samples_leaf': 3, 'min_samples_split': 5, 'n_estimators': 10}


KeyboardInterrupt: 

### Train Model
