# BraTS2020 Survival Prediction: Comprehensive Feature Analysis
Binary classification of survival time (above/below median) using various feature engineering techniques combined with classical ML models.

Dataset: BraTS2020 with pre-computed autoencoder embeddings (3×14×14×12 per patient)

## 1. Environment Setup

Import required libraries for data processing, machine learning algos and evaluation.

In [None]:
import pandas as pd
import numpy as np
from datetime import datetime

from sklearn.model_selection import GridSearchCV, StratifiedKFold, train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA

# Models
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier

# Metrics
from sklearn.metrics import (
    f1_score,
    classification_report,
    precision_score,
    recall_score,
    accuracy_score
)

# Configuration
np.random.seed(42)

## 2. Data Loading and Preparation

Load the pre-processed embeddings and create binary survival labels based on median survival time.

In [None]:
# Load dataset
data_path = "./embeddings_aekl.pkl"
df = pd.read_pickle(data_path)

print(f"Dataset loaded: {len(df)} patients")
print(f"Embedding shape: {df['embedding'].iloc[0].shape}")
print(f"Survival range: {df['survival_days'].min():.0f} - {df['survival_days'].max():.0f} days")
print(f"Median survival: {df['survival_days'].median():.0f} days")

Dataset loaded: 235 patients
Embedding shape: (3, 14, 14, 12)
Survival range: 5 - 1767 days
Median survival: 370 days


In [None]:
df.head()

Unnamed: 0,sample_id,embedding,age,survival_days
0,BraTS20_Training_001,[[[[0.2374728 0.11662745 0.18001464 0.2038001...,60.463,289
1,BraTS20_Training_002,[[[[0.38166767 0.34106487 0.38072276 0.3726496...,52.263,616
2,BraTS20_Training_003,[[[[0.22839475 0.10543397 0.1688725 0.1925266...,54.301,464
3,BraTS20_Training_004,[[[[0.3078203 0.22849745 0.27818102 0.2623945...,39.068,788
4,BraTS20_Training_005,[[[[0.18848613 0.03515071 0.11100617 0.1380027...,68.493,465


In [None]:
# Create binary labels: 0 = short survival, 1 = long survival
median_survival = df["survival_days"].median()
bins = [df["survival_days"].min() - 1, median_survival, df["survival_days"].max()]
labels = [0, 1]

df['survival_class'] = pd.cut(
    df['survival_days'],
    bins=bins,
    labels=labels,
    include_lowest=True
)

# Display class distribution
class_counts = df['survival_class'].value_counts().sort_index()
print("\nClass distribution:")
print(f"Class 0 (<= {median_survival:.0f} days): {class_counts[0]} patients ({class_counts[0]/len(df)*100:.1f}%)")
print(f"Class 1 (> {median_survival:.0f} days): {class_counts[1]} patients ({class_counts[1]/len(df)*100:.1f}%)")


Class distribution:
Class 0 (<= 370 days): 118 patients (50.2%)
Class 1 (> 370 days): 117 patients (49.8%)


## 3. Feature Engineering

Extract multiple feature representations from the (3x14x14x12) embeddings:

1. Channel-wise statistics (9-dim): Mean, std, max for each channel
2. Mean pooling (3-dim): Average across spatial dimensions
3. Spatial projections (variable-dim): Mean pooling along different axes
4. Flattened embeddings (7056-dim): Complete embedding as vector

*Check Report.pdf for a more detailed explanation.*

### 3.1 Helper functions

In [None]:
def compute_channel_statistics(embedding):
    """
    Compute statistical features for each channel.

    Returns: 9D feature vector [mean_ch0, std_ch0, max_ch0, ..., mean_ch2, std_ch2, max_ch2]
    """
    statistics = []
    for channel_idx in range(embedding.shape[0]):
        channel_data = embedding[channel_idx]
        statistics.extend([
            np.mean(channel_data),
            np.std(channel_data),
            np.max(channel_data)
        ])
    return np.array(statistics)


def compute_spatial_projection(embedding, axes):
    """
    Compute mean pooling along specified spatial axes.

    Args:
        embedding: 4D array (C, H, W, D)
        axes: tuple of axes to average over

    Examples:
        axes=(1, 2) -> Axial projection (average over height & width)
        axes=(1, 3) -> Sagittal projection (average over height & depth)
        axes=(2, 3) -> Coronal projection (average over width & depth)
    """
    return embedding.mean(axis=axes).flatten()

### 3.2 Extract raw features (no Scaling/PCA yet)

In [None]:
# Individual channel means
df['channel_0_mean'] = df['embedding'].apply(lambda x: np.mean(x[0]).item())
df['channel_1_mean'] = df['embedding'].apply(lambda x: np.mean(x[1]).item())
df['channel_2_mean'] = df['embedding'].apply(lambda x: np.mean(x[2]).item())

# Channel statistics (9D)
df['channel_statistics'] = df['embedding'].apply(compute_channel_statistics)

# Mean pooled embedding (3D) (Equivalent to stacking channel_0_mean, channel_1_mean, channel_2_mean)
df['mean_pooled'] = df['embedding'].apply(lambda x: np.mean(x, axis=(1, 2, 3)))

print("Basic features extracted")
print("-" * 60)
print(f"- Channel means: 1 feature each")
print(f"- Channel statistics: {df['channel_statistics'].iloc[0].shape[0]} features")
print(f"- Mean pooled: {df['mean_pooled'].iloc[0].shape[0]} features")

Basic features extracted
------------------------------------------------------------
- Channel means: 1 feature each
- Channel statistics: 9 features
- Mean pooled: 3 features


### 3.3 Extract spatial projections (Raw, no PCA)

In [None]:
# Define projection configurations
projection_config = {
    'axial': (1, 2),      # Average over height & width -> depth profile
    'sagittal': (1, 3),   # Average over height & depth -> width profile
    'coronal': (2, 3)     # Average over width & depth -> height profile
}

print("Extracting spatial projections (raw features):")
print("-" * 60)

# Store raw projections (will be PCA'd after split)
raw_projections = {}
for projection_name, axes in projection_config.items():
    raw_proj = np.stack(
        df['embedding'].apply(lambda x: compute_spatial_projection(x, axes)).values
    )
    raw_projections[projection_name] = raw_proj
    print(f"{projection_name:10s} (axes {axes}): shape {raw_proj.shape}")

Extracting spatial projections (raw features):
------------------------------------------------------------
axial      (axes (1, 2)): shape (235, 36)
sagittal   (axes (1, 3)): shape (235, 42)
coronal    (axes (2, 3)): shape (235, 42)


### 3.4 Extract flattened embeddings (Raw)

In [None]:
# Flatten entire embedding (3 × 14 × 14 × 12 = 7,056 dimensions)
flattened_embeddings = np.stack(
    df['embedding'].apply(lambda x: x.flatten()).values
)

print(f"Flattened embedding shape: {flattened_embeddings.shape}")

Flattened embedding shape: (235, 7056)


### 3.5 Prepare Age feature (Raw)

In [None]:
age_raw = df['age'].values.reshape(-1, 1)

print(f"Age statistics:")
print(f"Mean: {df['age'].mean():.1f} years")
print(f"Range: {df['age'].min():.0f} - {df['age'].max():.0f} years")

Age statistics:
Mean: 61.2 years
Range: 19 - 87 years


## 4. Feature set definitions

Define all raw feature combinations.

In [None]:
# Define raw feature sets (no preprocessing applied yet)
raw_feature_sets = {
    # Basic channel features (already low-dimensional, no PCA needed)
    'channel_0_mean': {
        'data': df['channel_0_mean'].values.reshape(-1, 1),
        'needs_pca': False,
        'pca_components': None
    },
    'channel_1_mean': {
        'data': df['channel_1_mean'].values.reshape(-1, 1),
        'needs_pca': False,
        'pca_components': None
    },
    'channel_2_mean': {
        'data': df['channel_2_mean'].values.reshape(-1, 1),
        'needs_pca': False,
        'pca_components': None
    },
    'channel_statistics': {
        'data': np.stack(df['channel_statistics'].values),
        'needs_pca': False,
        'pca_components': None
    },
    'mean_pooled': {
        'data': np.stack(df['mean_pooled'].values).reshape(-1, 3),
        'needs_pca': False,
        'pca_components': None
    },

    # Spatial projections (will need PCA)
    'pca_axial': {
        'data': raw_projections['axial'],
        'needs_pca': True,
        'pca_components': 0.90 # Value selected heuristically.
    },
    'pca_sagittal': {
        'data': raw_projections['sagittal'],
        'needs_pca': True,
        'pca_components': 0.90
    },
    'pca_coronal': {
        'data': raw_projections['coronal'],
        'needs_pca': True,
        'pca_components': 0.90
    },

    # Full embedding
    'pca_full_embedding': {
        'data': flattened_embeddings,
        'needs_pca': True,
        'pca_components': 10 # Heuristics
    },
}

print(f"{len(raw_feature_sets)} raw feature sets defined")
for name, config in raw_feature_sets.items():
    if config["needs_pca"]:
        if config["pca_components"] > 1:
            pca_info = f" (PCA → {config['pca_components']}D)"
        else:
            pca_info = f" (PCA → {config['pca_components']})"
    else:
        pca_info = ""

    print(f"  - {name}: {config['data'].shape}{pca_info}")

9 raw feature sets defined
  - channel_0_mean: (235, 1)
  - channel_1_mean: (235, 1)
  - channel_2_mean: (235, 1)
  - channel_statistics: (235, 9)
  - mean_pooled: (235, 3)
  - pca_axial: (235, 36) (PCA → 0.9)
  - pca_sagittal: (235, 42) (PCA → 0.9)
  - pca_coronal: (235, 42) (PCA → 0.9)
  - pca_full_embedding: (235, 7056) (PCA → 10D)


## 5. Model(s) configurations

Define machine learning models and their hyperparameter grids for systematic evaluation.

In [None]:
# Cross-validation strategy
cv_strategy = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Model configurations with hyperparameter grids
model_configs = {
    'Logistic Regression': (
        Pipeline([('classifier', LogisticRegression(max_iter=5000, random_state=42))]),
        {
            'classifier__C': [0.01, 0.1, 1, 10],
            'classifier__penalty': ['l2'],
        }
    ),

    'Linear SVM': (
        Pipeline([('classifier', SVC(kernel='linear', random_state=42))]),
        {'classifier__C': [0.01, 0.1, 1, 10]}
    ),

    'RBF SVM': (
        Pipeline([('classifier', SVC(kernel='rbf', random_state=42))]),
        {
            'classifier__C': [0.01, 0.1, 1, 10],
            'classifier__gamma': ['scale', 0.001, 0.01]
        }
    ),

    'Random Forest': (
        Pipeline([('classifier', RandomForestClassifier(random_state=42))]),
        {
            'classifier__n_estimators': [100, 200, 500],
            'classifier__max_depth': [None, 5, 10],
            'classifier__min_samples_leaf': [1, 3, 5]
        }
    ),

    'Gradient Boosting': (
        Pipeline([('classifier', GradientBoostingClassifier(random_state=42))]),
        {
            'classifier__n_estimators': [100, 200],
            'classifier__learning_rate': [0.05, 0.1],
            'classifier__max_depth': [1, 2]
        }
    )
}

print(f"{len(model_configs)} models configured:")
for model_name in model_configs.keys():
    print(f"  - {model_name}")

5 models configured:
  - Logistic Regression
  - Linear SVM
  - RBF SVM
  - Random Forest
  - Gradient Boosting


## 6. Helper Function for Preprocessing

This function applies scaling and PCA ONLY on training data, then transforms test data using the fitted transformers.

In [None]:
def preprocess_features(X_train, X_test, needs_pca=False, n_components=None):
    """
    Apply scaling and optionally PCA, fitted on training data.

    Args:
        X_train: Training features (raw)
        X_test: Test features (raw)
        needs_pca: Whether to apply PCA
        n_components: Number of PCA components (if needs_pca=True)

    Returns:
        X_train_processed, X_test_processed, variance_explained (or None if no PCA)
    """
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Apply PCA if needed
    if needs_pca and n_components is not None:
        pca = PCA(n_components=n_components, random_state=42)
        X_train_processed = pca.fit_transform(X_train_scaled)
        X_test_processed = pca.transform(X_test_scaled)
        variance_explained = pca.explained_variance_ratio_.sum()
        return X_train_processed, X_test_processed, variance_explained
    else:
        return X_train_scaled, X_test_scaled, None
# Random Forest and Gradient Boosting do not require feature scaling,
# but scaling features doesn't hurt and makes the code cleaner.

## 7. Evaluation

Evaluate all combinations of:
- Feature sets (10 types)
- With/without age (2 variants)
- Machine learning models (5 types)

In [None]:
# Storage for all results
all_results = []
target_labels = df['survival_class'].values

experiment_count = 0
total_experiments = len(raw_feature_sets) * 2 * len(model_configs) + len(model_configs)

print(f"Starting evaluation: {total_experiments} total experiments")
print("=" * 80)

Starting evaluation: 95 total experiments


### 7.1 Baseline: Age-only Model
*Note: Only the metrics for the best hyperparameter setting per model are printed.*

In [None]:
print("\n[BASELINE] Age-only Models")
print("-" * 80)

experiment_name = "age_only"

X_train_raw, X_test_raw, y_train, y_test = train_test_split(
    age_raw, target_labels,
    train_size=0.80,
    random_state=42,
    stratify=target_labels
)

X_train, X_test, _ = preprocess_features(
    X_train_raw, X_test_raw,
    needs_pca=False
)

# Evaluate each model
for model_name, (pipeline, param_grid) in model_configs.items():
    experiment_count += 1

    # Grid search with cross-validation
    grid_search = GridSearchCV(
        pipeline,
        param_grid,
        cv=cv_strategy,
        scoring='f1_macro',
        n_jobs=-1
    )
    grid_search.fit(X_train, y_train)

    # Test set predictions
    y_pred = grid_search.best_estimator_.predict(X_test)

    # Compute all metrics
    precision_per_class = precision_score(y_test, y_pred, average=None)
    recall_per_class = recall_score(y_test, y_pred, average=None)
    f1_per_class = f1_score(y_test, y_pred, average=None)

    result = {
        'experiment': experiment_name,
        'features': 'age',
        'includes_age': True,
        'model': model_name,
        'feature_dim': 1,
        'variance_explained': None,

        # Cross-validation performance
        'cv_f1_mean': grid_search.best_score_,
        'cv_f1_std': grid_search.cv_results_['std_test_score'][grid_search.best_index_],

        # Test set metrics (macro-averaged)
        'test_f1_macro': f1_score(y_test, y_pred, average='macro'),
        'test_f1_weighted': f1_score(y_test, y_pred, average='weighted'),
        'test_precision_macro': precision_score(y_test, y_pred, average='macro'),
        'test_recall_macro': recall_score(y_test, y_pred, average='macro'),
        'test_accuracy': accuracy_score(y_test, y_pred),

        # Per-class metrics
        'test_precision_class0': precision_per_class[0],
        'test_recall_class0': recall_per_class[0],
        'test_f1_class0': f1_per_class[0],
        'test_precision_class1': precision_per_class[1],
        'test_recall_class1': recall_per_class[1],
        'test_f1_class1': f1_per_class[1],
        'best_params': grid_search.best_params_,

        # Best hyperparameters
        'best_params': grid_search.best_params_
    }

    all_results.append(result)

    print(f"  [{experiment_count:2d}/{total_experiments}] {model_name:20s} -> "
          f"CV: {result['cv_f1_mean']:.4f}±{result['cv_f1_std']:.4f} | "
          f"Test F1: {result['test_f1_macro']:.4f} | "
          f"Acc: {result['test_accuracy']:.4f}")

print("\nBaseline evaluation complete")


[BASELINE] Age-only Models
--------------------------------------------------------------------------------
  [ 1/95] Logistic Regression  -> CV: 0.6163±0.1597 | Test F1: 0.5743 | Acc: 0.5745
  [ 2/95] Linear SVM           -> CV: 0.6216±0.1432 | Test F1: 0.5743 | Acc: 0.5745
  [ 3/95] RBF SVM              -> CV: 0.6259±0.1082 | Test F1: 0.5648 | Acc: 0.5745
  [ 4/95] Random Forest        -> CV: 0.5666±0.0932 | Test F1: 0.5957 | Acc: 0.5957
  [ 5/95] Gradient Boosting    -> CV: 0.6193±0.0926 | Test F1: 0.5648 | Acc: 0.5745

Baseline evaluation complete


### 7.2 Main Evaluation Loop

*Note: Only the metrics for the best hyperparameter setting per model are printed.*

In [None]:
print("\n" + "=" * 80)
print("MAIN EVALUATION: Feature sets with/without Age")
print("=" * 80)

for feature_name, feature_config in raw_feature_sets.items():
    # Extract raw feature data
    X_raw = feature_config['data']
    needs_pca = feature_config['needs_pca']
    n_components = feature_config['pca_components']

    # Test both with and without age
    for include_age in [False, True]:
        experiment_name = f"{feature_name}_with_age" if include_age else feature_name

        print(f"\n{experiment_name} - Raw shape: {X_raw.shape}")
        print("  " + "-" * 76)

        X_train_raw, X_test_raw, y_train, y_test, train_idx, test_idx = train_test_split(
            X_raw,
            target_labels,
            np.arange(len(X_raw)),  # Will be used to retrieve ages
            train_size=0.80,
            random_state=42,
            stratify=target_labels
        )

        X_train_processed, X_test_processed, var_explained = preprocess_features(
            X_train_raw, X_test_raw,
            needs_pca=needs_pca,
            n_components=n_components
        )

        if include_age:
            # Use the indices we already got from the split above
            age_train_raw = age_raw[train_idx]
            age_test_raw = age_raw[test_idx]

            age_train_scaled, age_test_scaled, _ = preprocess_features(
                age_train_raw, age_test_raw,
                needs_pca=False
            )

            X_train = np.hstack([X_train_processed, age_train_scaled])
            X_test = np.hstack([X_test_processed, age_test_scaled])
        else:
            X_train = X_train_processed
            X_test = X_test_processed

        final_dim = X_train.shape[1]
        var_info = f" | Variance: {var_explained:.2%}" if var_explained else ""
        print(f"  Processed shape: {X_train.shape}{var_info}")

        # Evaluate each model
        for model_name, (pipeline, param_grid) in model_configs.items():
            experiment_count += 1

            # Grid search with cross-validation
            grid_search = GridSearchCV(
                pipeline,
                param_grid,
                cv=cv_strategy,
                scoring='f1_macro',
                n_jobs=-1
            )
            grid_search.fit(X_train, y_train)

            # Test set evaluation
            y_pred = grid_search.best_estimator_.predict(X_test)

            # Compute all metrics
            precision_per_class = precision_score(y_test, y_pred, average=None)
            recall_per_class = recall_score(y_test, y_pred, average=None)
            f1_per_class = f1_score(y_test, y_pred, average=None)

            result = {
                'experiment': experiment_name,
                'features': feature_name,
                'includes_age': include_age,
                'model': model_name,
                'feature_dim': final_dim,
                'variance_explained': var_explained,

                # Cross-validation performance
                'cv_f1_mean': grid_search.best_score_,
                'cv_f1_std': grid_search.cv_results_['std_test_score'][grid_search.best_index_],

                # Test set metrics (macro-averaged)
                'test_f1_macro': f1_score(y_test, y_pred, average='macro'),
                'test_f1_weighted': f1_score(y_test, y_pred, average='weighted'),
                'test_precision_macro': precision_score(y_test, y_pred, average='macro'),
                'test_recall_macro': recall_score(y_test, y_pred, average='macro'),
                'test_accuracy': accuracy_score(y_test, y_pred),

                # Per-class metrics
                'test_precision_class0': precision_per_class[0],
                'test_recall_class0': recall_per_class[0],
                'test_f1_class0': f1_per_class[0],
                'test_precision_class1': precision_per_class[1],
                'test_recall_class1': recall_per_class[1],
                'test_f1_class1': f1_per_class[1],

                # Best hyperparameters
                'best_params': grid_search.best_params_
            }

            all_results.append(result)

            # Progress update
            print(f"  [{experiment_count:2d}/{total_experiments}] {model_name:20s} -> "
                  f"CV: {result['cv_f1_mean']:.4f}±{result['cv_f1_std']:.4f} | "
                  f"Test F1: {result['test_f1_macro']:.4f} | "
                  f"Acc: {result['test_accuracy']:.4f}")


MAIN EVALUATION: Feature sets with/without Age

channel_0_mean - Raw shape: (235, 1)
  ----------------------------------------------------------------------------
  Processed shape: (188, 1)
  [ 6/95] Logistic Regression  -> CV: 0.4399±0.0812 | Test F1: 0.5317 | Acc: 0.5319
  [ 7/95] Linear SVM           -> CV: 0.3653±0.0503 | Test F1: 0.3188 | Acc: 0.4681
  [ 8/95] RBF SVM              -> CV: 0.4891±0.0694 | Test F1: 0.4671 | Acc: 0.4681
  [ 9/95] Random Forest        -> CV: 0.5292±0.0517 | Test F1: 0.4671 | Acc: 0.4681
  [10/95] Gradient Boosting    -> CV: 0.5242±0.0320 | Test F1: 0.5106 | Acc: 0.5106

channel_0_mean_with_age - Raw shape: (235, 1)
  ----------------------------------------------------------------------------
  Processed shape: (188, 2)
  [11/95] Logistic Regression  -> CV: 0.6008±0.1616 | Test F1: 0.5743 | Acc: 0.5745
  [12/95] Linear SVM           -> CV: 0.6009±0.1541 | Test F1: 0.5743 | Acc: 0.5745
  [13/95] RBF SVM              -> CV: 0.6102±0.1341 | Test F1: 0.

## 8. Results Analysis

In [None]:
# Convert to DataFrame and sort by test F1 score
results_df = pd.DataFrame(all_results)
results_df = results_df.sort_values(by='test_f1_macro', ascending=False)

print(f"Total experiments conducted: {len(results_df)}")
print(f"\nBest performing configuration:")
best = results_df.iloc[0]
print(f"  Experiment: {best['experiment']}")
print(f"  Model: {best['model']}")
print(f"  Test F1: {best['test_f1_macro']:.4f}")
print(f"  Test Accuracy: {best['test_accuracy']:.4f}")

Total experiments conducted: 95

Best performing configuration:
  Experiment: pca_coronal
  Model: Linear SVM
  Test F1: 0.6557
  Test Accuracy: 0.6596


### 8.1 Top 10 Configurations

In [None]:
# Display top 10 results
top_10 = results_df.head(10)[[
    'experiment',
    'model',
    'test_f1_macro',
    'test_precision_macro',
    'test_recall_macro',
    'test_accuracy'
]]

print("\n" + "=" * 100)
print("TOP 10 CONFIGURATIONS BY TEST F1 SCORE")
print("=" * 100)
print(top_10.to_string(index=False))
print("=" * 100)


TOP 10 CONFIGURATIONS BY TEST F1 SCORE
             experiment               model  test_f1_macro  test_precision_macro  test_recall_macro  test_accuracy
            pca_coronal          Linear SVM       0.655678              0.671456           0.662138       0.659574
     pca_full_embedding          Linear SVM       0.632306              0.652941           0.641304       0.638298
     pca_axial_with_age   Gradient Boosting       0.612637              0.626437           0.619565       0.617021
            pca_coronal Logistic Regression       0.608333              0.634073           0.620471       0.617021
               age_only       Random Forest       0.595745              0.596014           0.596014       0.595745
  pca_sagittal_with_age          Linear SVM       0.595011              0.597985           0.596920       0.595745
channel_2_mean_with_age          Linear SVM       0.595011              0.597985           0.596920       0.595745
     pca_full_embedding Logistic Regress

## 9. Export Results


In [None]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"brats2020_results_{timestamp}.csv"
results_df.to_csv(output_filename, index=False)

print(f"\nResults exported to: {output_filename}")