In [1]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import logging
import torch


# Setup paths
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.getcwd()))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Setup directories
MODEL_DIR = os.path.join(os.getcwd(), 'model')
RESULTS_DIR = os.path.join(os.getcwd(), 'results')
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Imports
from preprocessing.data_container import DataContainer
from models.deep_surv_model import DeepSurvModel
from utils.evaluation import cindex_score


In [2]:
# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

DATA_CONFIG = {
    'use_pca': False,
    'gene_type': 'intersection',
    'use_imputed': True,
    'validation_split': 0.2,
    'use_cohorts': True
}

# Parameter grid for nested CV
PARAM_GRID = {
    'hidden_layers': [[32, 16], [64, 32]],
    'learning_rate': [0.001, 0.0001],
    'dropout': [0.3, 0.4],
    'batch_size': [32, 64]
}

# Create DataContainer
data_container = DataContainer(DATA_CONFIG, project_root=PROJECT_ROOT)
X, y = data_container.load_data()







2024-11-10 12:09:43,301 - INFO - Loading data...
2024-11-10 12:10:35,978 - INFO - Loaded data: 1091 samples, 13214 features


In [3]:
# With nested CV (klappt noch nicht)
deep_surv = DeepSurvModel(use_nested_cv=True)
deep_surv.fit(
    X=X, 
    y=y,
    data_container=data_container,
    param_grid=PARAM_GRID  # Parameter zum Testen
)


2024-11-10 12:10:53,091 - INFO - Using device: cpu
2024-11-10 12:10:53,092 - INFO - Starting nested CV for DeepSurv...
2024-11-10 12:10:53,094 - INFO - Starting nested cross-validation...
2024-11-10 12:10:53,095 - INFO - Data shape: X=(1091, 13214), groups=9 unique
2024-11-10 12:10:53,097 - INFO - 
Outer fold 1
2024-11-10 12:10:53,143 - INFO - Test cohort: Atlanta_2014_Long
2024-11-10 12:10:53,143 - INFO - Starting inner grid search with 16 parameter combinations
2024-11-10 12:10:53,184 - INFO - Using device: cpu
2024-11-10 12:10:53,186 - INFO - Starting model training...
2024-11-10 12:10:53,186 - INFO - Input data shape: X=(743, 13214)
2024-11-10 12:10:53,188 - INFO - Using random 20.0% validation split
2024-11-10 12:10:53,190 - ERROR - Error creating train/val split: "None of [Index([283,  44, 703,  97, 697, 187, 141, 404, 173, 461,\n       ...\n       520,  74, 176, 279, 513, 342, 127, 671,  86, 725],\n      dtype='int64', length=594)] are in the [columns]"
2024-11-10 12:10:53,191 -

KeyError: "None of [Index([283,  44, 703,  97, 697, 187, 141, 404, 173, 461,\n       ...\n       520,  74, 176, 279, 513, 342, 127, 671,  86, 725],\n      dtype='int64', length=594)] are in the [columns]"

In [4]:
deep_surv = DeepSurvModel(use_nested_cv=False) 
deep_surv.fit(
    X=X,
    y=y, 
    data_container=data_container,
    hidden_layers=[64, 32],  # Direkte Parameter
    batch_size=64,
    learning_rate=0.001,
    n_epochs=100,
    early_stopping=True,
    patience=10,
    dropout=0.4
)

2024-11-10 12:11:09,282 - INFO - Using device: cpu
2024-11-10 12:11:09,283 - INFO - Starting model training...
2024-11-10 12:11:09,283 - INFO - Input data shape: X=(1091, 13214)
2024-11-10 12:11:09,286 - INFO - Using cohort CamCap_2016_Ross_Adams for validation
2024-11-10 12:11:09,353 - INFO - Training samples: 979, Validation samples: 112
2024-11-10 12:11:11,137 - INFO - Starting training for 100 epochs...
  uncensored_likelihood = risk_scores.T - log_risk
2024-11-10 12:11:11,389 - INFO - Epoch 1/100: Train Loss = 103.4071, Val Loss = 248.2084
2024-11-10 12:11:11,536 - INFO - Epoch 2/100: Train Loss = 99.8294, Val Loss = 246.9686
2024-11-10 12:11:11,695 - INFO - Epoch 3/100: Train Loss = 97.2198, Val Loss = 246.7682
2024-11-10 12:11:11,937 - INFO - Epoch 4/100: Train Loss = 96.4392, Val Loss = 246.3821
2024-11-10 12:11:12,170 - INFO - Epoch 5/100: Train Loss = 95.4574, Val Loss = 244.9789
2024-11-10 12:11:12,406 - INFO - Epoch 6/100: Train Loss = 95.6403, Val Loss = 245.0984
2024-11-1