## Random Survival Forest Training
 Dieses Notebook demonstriert das Training eines Random Survival Forest Models mit verschiedenen Optionen:
- Verschiedene Input-Typen (Kohorten, merged data)
- Mit/ohne PCA
- Grid/Random Search
- Verschiedene Cross-Validation Strategien


## Setup und Konfiguration

### Train Model


In [4]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import logging

# Setup paths
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.getcwd()))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Setup directories
MODEL_DIR = os.path.join(os.getcwd(), 'model')
RESULTS_DIR = os.path.join(os.getcwd(), 'results')
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Imports
from preprocessing.data_container import DataContainer
from models.rsf_model import RSFModel
from utils.evaluation import cindex_score

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Data configuration
DATA_CONFIG = {
    'use_pca': False,
    'gene_type': 'intersection',
    'use_imputed': True,
    'use_cohorts': True
}

# Model/CV configuration
MODEL_CONFIG = {
    'params_cv': {
        'rsf__n_estimators': [5, 10],
        'rsf__min_samples_split': [3, 5],
        'rsf__min_samples_leaf': [2, 4]
    },
    'use_cohort_cv': False,
    'n_splits_inner': 3
}



In [None]:
try:
    # Create DataContainer and load data
    data_container = DataContainer(DATA_CONFIG, project_root=PROJECT_ROOT)
    X, y = data_container.load_data()

    # Initialize and train model
    rsf = RSFModel()
    rsf.fit(
        X=X,
        y=y,
        data_container=data_container,
        **MODEL_CONFIG
    )
    
    # Get and save feature importance
    importance_df = rsf.get_feature_importance(feature_names=X.columns)
    importance_df.to_csv(os.path.join(RESULTS_DIR, 'feature_importance.csv'))
    
    # Save model
    rsf.save(MODEL_DIR, "rsf_model")
    
    logger.info("Training completed successfully!")
    
except Exception as e:
    logger.error(f"Error during training: {str(e)}")
    raise

2024-11-10 12:58:42,542 - INFO - Loading data...
2024-11-10 12:59:24,408 - INFO - Loaded data: 1091 samples, 13214 features
2024-11-10 12:59:26,270 - INFO - Starting model training...
2024-11-10 12:59:26,271 - INFO - Input data shape: X=(1091, 13214)
2024-11-10 12:59:26,271 - INFO - Starting nested cross-validation...
2024-11-10 12:59:26,272 - INFO - Data shape: X=(1091, 13214), groups=9 unique
2024-11-10 12:59:26,274 - INFO - 
Outer fold 1
2024-11-10 12:59:26,311 - INFO - Test cohort: Atlanta_2014_Long
2024-11-10 12:59:26,311 - INFO - Starting inner grid search with 8 parameter combinations
2024-11-10 13:01:08,109 - INFO - New best score: 0.516 with params: {'min_samples_leaf': 2, 'min_samples_split': 3, 'n_estimators': 5}
2024-11-10 13:14:59,270 - INFO - Best parameters: {'rsf__min_samples_leaf': 2, 'rsf__min_samples_split': 3, 'rsf__n_estimators': 5}
2024-11-10 13:14:59,271 - INFO - Test score: 0.442
2024-11-10 13:14:59,272 - INFO - 
Outer fold 2
2024-11-10 13:14:59,346 - INFO - Tes