# Mutation Prediction from Gene Expression

This notebook demonstrates the complete workflow:
1. Load and preprocess TCGA expression and mutation data
2. Run k-fold cross-validation for mutation prediction
3. Combine predictions from all folds
4. Compute per-cancer-type metrics
5. Visualize results


In [4]:
# Setup and Imports
import sys
import os
sys.path.append(os.path.abspath('../src'))

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

from preprocessing import data_loader
from models.model_factory import ModelFactory
from main import load_config
from training.trainer import run_kfold_training
from evaluation.metrics import (
    combine_fold_predictions,
    compute_per_cancer_metrics_from_files,
    evaluate_multilabel,
)
from visualization.metrics import plot_per_cancer_metrics


In [5]:
# Initialize components
selected_cohorts = ['ACC','BLCA','BRCA','CESC','CHOL','COAD','ESCA','GBM','HNSC','KICH',
 'KIRC','KIRP','LGG','LIHC','LUAD','LUSC','MESO','OV','PAAD','PCPG',
 'PRAD','READ','SARC','SKCM','STAD','TGCT','THCA','THYM','UCEC','UVM']
data_load = data_loader.TCGADataLoader(use_cache=True)

# Load and preprocess data for the selected cohorts
print(f"Loading and preprocessing data for: {', '.join(selected_cohorts)}")
expression_data, mutation_data = data_load.preprocess_data(cancer_types=selected_cohorts)

Loading and preprocessing data for: ACC, BLCA, BRCA, CESC, CHOL, COAD, ESCA, GBM, HNSC, KICH, KIRC, KIRP, LGG, LIHC, LUAD, LUSC, MESO, OV, PAAD, PCPG, PRAD, READ, SARC, SKCM, STAD, TGCT, THCA, THYM, UCEC, UVM
Loading cached aligned expression and mutation data from:
C:\Users\KerenYlab\Asaf\CodingProjects\expression_to_mutation_prediction\cache\expression_aligned_ACC-BLCA-BRCA-CESC-CHOL-COAD-ESCA-GBM-HNSC-KICH-KIRC-KIRP-LGG-LIHC-LUAD-LUSC-MESO-OV-PAAD-PCPG-PRAD-READ-SARC-SKCM-STAD-TGCT-THCA-THYM-UCEC-UVM.pkl C:\Users\KerenYlab\Asaf\CodingProjects\expression_to_mutation_prediction\cache\mutation_aligned_ACC-BLCA-BRCA-CESC-CHOL-COAD-ESCA-GBM-HNSC-KICH-KIRC-KIRP-LGG-LIHC-LUAD-LUSC-MESO-OV-PAAD-PCPG-PRAD-READ-SARC-SKCM-STAD-TGCT-THCA-THYM-UCEC-UVM.pkl


In [6]:
# Load and Preprocess Data
print("="*60)
print("Loading and Preprocessing Data")
print("="*60)

data_loader_obj = data_loader.TCGADataLoader(use_cache=True)

# Load expression and mutation data
expression_data, mutation_data = data_loader_obj.preprocess_data(
    cancer_types=selected_cohorts if selected_cohorts else None
)

print(f"\nExpression data shape: {expression_data.shape}")
print(f"Mutation data shape: {mutation_data.shape}")
print(f"Number of genes (mutations): {len(mutation_data.columns)}")
print(f"Number of samples: {len(expression_data)}")


Loading and Preprocessing Data
Loading cached aligned expression and mutation data from:
C:\Users\KerenYlab\Asaf\CodingProjects\expression_to_mutation_prediction\cache\expression_aligned_ACC-BLCA-BRCA-CESC-CHOL-COAD-ESCA-GBM-HNSC-KICH-KIRC-KIRP-LGG-LIHC-LUAD-LUSC-MESO-OV-PAAD-PCPG-PRAD-READ-SARC-SKCM-STAD-TGCT-THCA-THYM-UCEC-UVM.pkl C:\Users\KerenYlab\Asaf\CodingProjects\expression_to_mutation_prediction\cache\mutation_aligned_ACC-BLCA-BRCA-CESC-CHOL-COAD-ESCA-GBM-HNSC-KICH-KIRC-KIRP-LGG-LIHC-LUAD-LUSC-MESO-OV-PAAD-PCPG-PRAD-READ-SARC-SKCM-STAD-TGCT-THCA-THYM-UCEC-UVM.pkl

Expression data shape: (8585, 19992)
Mutation data shape: (8585, 111)
Number of genes (mutations): 111
Number of samples: 8585


In [7]:
# Preprocess Features
print("="*60)
print("Preprocessing Features")
print("="*60)

# Log transform and scale
X_log = np.log1p(expression_data)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_log)

# Convert to DataFrame preserving index and columns
X_scaled_df = pd.DataFrame(
    X_scaled,
    index=expression_data.index,
    columns=expression_data.columns
)

print(f"Preprocessed features shape: {X_scaled_df.shape}")
print(f"Features: log-transformed and standardized")


Preprocessing Features
Preprocessed features shape: (8585, 19992)
Features: log-transformed and standardized


In [9]:
# Setup Output Directory
# cohort_suffix = "-".join(selected_cohorts) if selected_cohorts else "all_cancers"
cohort_suffix = 'all_cancers'
results_root = Path("../results/kfold_prediction") / cohort_suffix
results_root.mkdir(parents=True, exist_ok=True)

print(f"Results will be saved to: {results_root}")


Results will be saved to: ..\results\kfold_prediction\all_cancers


In [11]:
config['model']['name'] = 'multitask_nn'

In [None]:
# Run K-Fold Cross-Validation
print("="*60)
print("Running K-Fold Cross-Validation")
print("="*60)

# Initialize model
model_factory = ModelFactory()
model = model_factory.get_model(
    model_name=config['model']['name'],
    input_size=X_scaled_df.shape[1],
    output_size=mutation_data.shape[1],
    config=config,
)

print(f"\nModel: {config['model']['name']}")
print(f"Folds: {config['evaluation']['cv_folds']}")
print(f"Features: {X_scaled_df.shape[1]}")
print(f"Genes: {mutation_data.shape[1]}")
print(f"Samples: {X_scaled_df.shape[0]}")

# Run k-fold training
shared_meta = {
    "config": config,
    "selected_features": X_scaled_df.columns.tolist(),
    "cancer_types": selected_cohorts if selected_cohorts else "all",
}

run_kfold_training(
    model=model,
    X=X_scaled_df,
    Y=mutation_data,
    k=config['evaluation']['cv_folds'],
    output_dir=results_root,
    config_meta=shared_meta,
    random_state=config.get("preprocessing", {}).get("random_state", 42),
    label=config['model']['name'],
)

print(f"\n✅ K-fold cross-validation complete!")


Running K-Fold Cross-Validation

Model: multitask_nn
Folds: 5
Features: 19992
Genes: 111
Samples: 8585

Fold 1/5...

Fold 2/5...

Fold 3/5...


In [None]:
# Combine Predictions from All Folds
print("="*60)
print("Combining Predictions from All Folds")
print("="*60)

# Combine predictions (already done by run_kfold_training, but we can verify)
combined_dir = results_root / "combined_predictions"

if not (combined_dir / "predictions.csv").exists():
    print("Combining predictions from folds...")
    combine_fold_predictions(
        kfold_dir=results_root,
        output_dir=results_root,
        include_fold_info=True,
    )
else:
    print(f"Combined predictions already exist at: {combined_dir}")

# Load combined predictions
combined_preds = pd.read_csv(combined_dir / "predictions.csv", index_col=0)
combined_probs = pd.read_csv(combined_dir / "probabilities.csv", index_col=0)

print(f"\nCombined predictions: {len(combined_preds)} samples, {len(combined_preds.columns)} genes")


In [None]:
# Compute Overall Metrics
print("="*60)
print("Computing Overall Metrics")
print("="*60)

# Align indices
common_samples = mutation_data.index.intersection(combined_preds.index)
y_true_overall = mutation_data.loc[common_samples]
y_pred_overall = combined_preds.loc[common_samples]
y_prob_overall = combined_probs.loc[common_samples]

print(f"Common samples: {len(common_samples)}")

# Compute overall metrics
overall_metrics = evaluate_multilabel(
    y_true=y_true_overall.values,
    y_pred=y_pred_overall.values,
    y_prob=y_prob_overall.values,
    mutation_names=mutation_data.columns.tolist()
)

# Save overall metrics
overall_metrics.to_csv(results_root / 'overall_metrics.csv')

# Print summary
key_metrics = ['roc_auc', 'auprc', 'f1', 'accuracy', 'precision', 'recall']
print("\nOverall Metrics (across all cancer types):")
print(overall_metrics[key_metrics].mean().round(3))

# Create summary visualization
fig, ax = plt.subplots(figsize=(10, 6))
metric_means = overall_metrics[key_metrics].mean()
metric_stds = overall_metrics[key_metrics].std()

ax.bar(range(len(metric_means)), metric_means.values, 
       yerr=metric_stds.values, capsize=5, alpha=0.7)
ax.set_xticks(range(len(metric_means)))
ax.set_xticklabels(metric_means.index, rotation=45, ha='right')
ax.set_ylabel('Metric Value')
ax.set_title('Overall Performance Metrics (All Cancer Types Combined)')
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(results_root / 'overall_metrics_summary.png', dpi=150, bbox_inches='tight')
plt.close()

print(f"\nOverall metrics saved to: {results_root / 'overall_metrics.csv'}")


In [None]:
# Create Sample-to-Cancer Mapping
print("="*60)
print("Computing Per-Cancer-Type Metrics")
print("="*60)

# Extract cancer type information from expression data
sample_to_cancer = {}
for cancer_col in expression_data.columns[expression_data.columns.str.startswith('cancer_')]:
    cancer_type_name = cancer_col.replace('cancer_', '')
    samples_with_cancer = expression_data.index[expression_data[cancer_col] == 1]
    for sample_id in samples_with_cancer:
        sample_to_cancer[sample_id] = cancer_type_name

cancer_types_found = sorted(set(sample_to_cancer.values()))
print(f"\nFound {len(cancer_types_found)} cancer types: {', '.join(cancer_types_found)}")
print(f"Total samples mapped: {len(sample_to_cancer)}")


In [None]:
# Compute Per-Cancer-Type Metrics
per_cancer_output_dir = results_root / "per_cancer_metrics"

# Use combined predictions for per-cancer analysis
per_cancer_metrics = compute_per_cancer_metrics_from_files(
    predictions_path=combined_dir / "predictions.csv",
    probabilities_path=combined_dir / "probabilities.csv",
    y_true=mutation_data,
    X=expression_data,  # Contains cancer_* columns
    sample_to_cancer=sample_to_cancer,
    output_dir=per_cancer_output_dir,
)

print(f"\n✅ Per-cancer-type metrics computed!")
print(f"   Results saved to: {per_cancer_output_dir}")


In [None]:
# Visualize Per-Cancer Results
print("="*60)
print("Creating Per-Cancer Visualizations")
print("="*60)

viz_output_dir = results_root / "visualizations"

# Create comprehensive visualizations
plot_per_cancer_metrics(
    per_cancer_metrics=per_cancer_metrics,
    output_dir=viz_output_dir,
    key_metrics=key_metrics,
    figsize=(18, 12),
)

print(f"\n✅ All visualizations created!")
print(f"   Visualizations saved to: {viz_output_dir}")


In [None]:
# Display Summary
print("="*60)
print("Summary: Per-Cancer-Type Performance")
print("="*60)

# Create summary DataFrame
summary_data = []
for cancer_type, metrics_df in per_cancer_metrics.items():
    for metric in key_metrics:
        if metric in metrics_df.columns:
            summary_data.append({
                'cancer_type': cancer_type,
                'metric': metric,
                'mean': metrics_df[metric].mean(),
                'std': metrics_df[metric].std(),
            })

summary_df = pd.DataFrame(summary_data)

# Display average metrics per cancer type
display_summary = summary_df.pivot(index='cancer_type', columns='metric', values='mean')
print("\nAverage Metrics by Cancer Type:")
print(display_summary.round(3))

# Display best performing cancer types
print("\n" + "="*60)
print("Best Performing Cancer Types by Metric:")
print("="*60)
for metric in key_metrics:
    if metric in summary_df['metric'].values:
        best = summary_df[summary_df['metric'] == metric].nlargest(1, 'mean')
        if not best.empty:
            print(f"{metric.upper():12s}: {best.iloc[0]['cancer_type']:8s} (mean={best.iloc[0]['mean']:.3f})")

print("\n" + "="*60)
print(f"All results saved to: {results_root}")
print("="*60)


In [10]:
# Initialize components
selected_cohorts = ['ACC','BLCA','BRCA','CESC','CHOL','COAD','ESCA','GBM','HNSC','KICH',
 'KIRC','KIRP','LGG','LIHC','LUAD','LUSC','MESO','OV','PAAD','PCPG',
 'PRAD','READ','SARC','SKCM','STAD','TGCT','THCA','THYM','UCEC','UVM']
data_load = data_loader.TCGADataLoader(use_cache=True)

# Load and preprocess data for the selected cohorts
print(f"Loading and preprocessing data for: {', '.join(selected_cohorts)}")
expression_data, mutation_data = data_load.preprocess_data(cancer_types=selected_cohorts)

Loading and preprocessing data for: ACC, BLCA, BRCA, CESC, CHOL, COAD, ESCA, GBM, HNSC, KICH, KIRC, KIRP, LGG, LIHC, LUAD, LUSC, MESO, OV, PAAD, PCPG, PRAD, READ, SARC, SKCM, STAD, TGCT, THCA, THYM, UCEC, UVM
Loading cached aligned expression and mutation data from:
C:\Users\KerenYlab\Asaf\CodingProjects\expression_to_mutation_prediction\cache\expression_aligned_ACC-BLCA-BRCA-CESC-CHOL-COAD-ESCA-GBM-HNSC-KICH-KIRC-KIRP-LGG-LIHC-LUAD-LUSC-MESO-OV-PAAD-PCPG-PRAD-READ-SARC-SKCM-STAD-TGCT-THCA-THYM-UCEC-UVM.pkl C:\Users\KerenYlab\Asaf\CodingProjects\expression_to_mutation_prediction\cache\mutation_aligned_ACC-BLCA-BRCA-CESC-CHOL-COAD-ESCA-GBM-HNSC-KICH-KIRC-KIRP-LGG-LIHC-LUAD-LUSC-MESO-OV-PAAD-PCPG-PRAD-READ-SARC-SKCM-STAD-TGCT-THCA-THYM-UCEC-UVM.pkl


In [11]:
config_path = Path("../config/config.yaml")
config = load_config(config_path)

X_log = np.log1p(expression_data)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_log)

# cohort_suffix = "-".join(selected_cohorts)
cohort_suffix = 'all_cancers'
results_root = Path("../results/notebook_multitask_test") / cohort_suffix
results_root.mkdir(parents=True, exist_ok=True)

# Convert to DataFrame for ablation functions
X_scaled_df = pd.DataFrame(X_scaled, index=expression_data.index, columns=expression_data.columns)


In [13]:
# Import ablation functions
from interpretation.ablation import (
    run_gene_ablation_analysis,
    load_ablation_results
)

# Initialize model factory
model_factory = ModelFactory()

# For testing, let's use a subset of genes (first 10) to make it faster
# Remove this line to run on all genes
test_genes = mutation_data.columns[:4].tolist()
mutation_data_subset = mutation_data[test_genes].copy()

print(f"Expression data shape: {X_scaled_df.shape}")
print(f"Mutation data shape: {mutation_data_subset.shape}")
print(f"Number of genes: {len(mutation_data_subset.columns)}")
print(f"Genes: {', '.join(mutation_data_subset.columns.tolist())}")


Expression data shape: (8585, 19992)
Mutation data shape: (8585, 4)
Number of genes: 4
Genes: VPS13B, OBSCN, ABCA13, SVEP1


In [14]:
config['model']['name'] = 'multitask_nn'

In [None]:
config['model']['name'] = 'multitask_nn'

In [16]:
# Run K-Fold Cross-Validation with Per-Cancer Analysis
print("="*60)
print("Running K-Fold Cross-Validation")
print("="*60)

from evaluation.metrics import compute_per_cancer_metrics_kfold
import matplotlib.pyplot as plt
import seaborn as sns

# Setup model
model_factory = ModelFactory()
# config['model']['name'] = 'multi'  # Using LightGBM for faster training

model = model_factory.get_model(
    model_name=config['model']['name'],
    input_size=X_scaled_df.shape[1],
    output_size=mutation_data.shape[1],
    config=config,
)

# Setup output directory for k-fold results
kfold_output_dir = results_root / "kfold_prediction"
kfold_output_dir.mkdir(parents=True, exist_ok=True)

# Run k-fold cross-validation
print(f"\nRunning {config['evaluation']['cv_folds']}-fold cross-validation...")
print(f"   Model: {config['model']['name']}")
print(f"   Genes: {len(mutation_data.columns)}")
print(f"   Samples: {len(X_scaled_df)}")

shared_meta = {
    "config": config,
    "selected_features": X_scaled_df.columns.tolist(),
    "cancer_types": selected_cohorts,
}

run_kfold_training(
    model=model,
    X=X_scaled_df,
    Y=mutation_data,
    k=config['evaluation']['cv_folds'],
    output_dir=kfold_output_dir,
    config_meta=shared_meta,
    random_state=config.get("preprocessing", {}).get("random_state", 42),
    label=config['model']['name'],
)

print(f"\n✅ K-fold cross-validation complete!")
print(f"   Results saved to: {kfold_output_dir}")


Running K-Fold Cross-Validation

Running 5-fold cross-validation...
   Model: multitask_nn
   Genes: 111
   Samples: 8585

Fold 1/5...

Fold 2/5...

Fold 3/5...

Fold 4/5...

Fold 5/5...

Combining predictions from all folds...
   Combined 8585 samples from 5 folds
   Saved to: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\combined_predictions
K-fold evaluation complete.

✅ K-fold cross-validation complete!
   Results saved to: ..\results\notebook_multitask_test\all_cancers\kfold_prediction


In [17]:
# Compute Per-Cancer-Type Metrics from K-Fold Results
print("="*60)
print("Computing Per-Cancer-Type Metrics")
print("="*60)

# Create sample_to_cancer mapping (if not already created)
if 'sample_to_cancer' not in locals():
    sample_to_cancer = {}
    for cancer_col in expression_data.columns[expression_data.columns.str.startswith('cancer_')]:
        cancer_type_name = cancer_col.replace('cancer_', '')
        samples_with_cancer = expression_data.index[expression_data[cancer_col] == 1]
        for sample_id in samples_with_cancer:
            sample_to_cancer[sample_id] = cancer_type_name

# Compute per-cancer metrics from k-fold results
per_cancer_output_dir = kfold_output_dir / "per_cancer_metrics"

per_cancer_metrics = compute_per_cancer_metrics_kfold(
    kfold_dir=kfold_output_dir,
    y_true=mutation_data_subset,
    X=expression_data,  # Use original expression_data which has cancer_* columns
    sample_to_cancer=sample_to_cancer,
    output_dir=per_cancer_output_dir,
)

print(f"\n✅ Per-cancer-type metrics computed!")
print(f"   Results saved to: {per_cancer_output_dir}")


Computing Per-Cancer-Type Metrics
Processing 5 folds from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction

   Processing fold_1...
Loading predictions from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_1\predictions.csv
Loading probabilities from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_1\probabilities.csv

Computing per-cancer-type metrics...
   Cancer types: ['ACC', 'BLCA', 'BRCA', 'CESC', 'CHOL', 'COAD', 'ESCA', 'GBM', 'HNSC', 'KICH', 'KIRC', 'KIRP', 'LGG', 'LIHC', 'LUAD', 'LUSC', 'MESO', 'OV', 'PAAD', 'PCPG', 'PRAD', 'READ', 'SARC', 'SKCM', 'STAD', 'TGCT', 'THCA', 'THYM', 'UCEC', 'UVM']
   Total samples: 1717
   Genes: 4
   ACC: 19 samples
   BLCA: 89 samples
   BRCA: 157 samples
   CESC: 66 samples
   CHOL: 10 samples
   COAD: 62 samples
   ESCA: 31 samples
   GBM: 22 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


   HNSC: 89 samples
   KICH: 12 samples
   KIRC: 75 samples
   KIRP: 61 samples
   LGG: 82 samples
   LIHC: 69 samples
   LUAD: 102 samples
   LUSC: 92 samples
   MESO: 18 samples
   OV: 2 samples
   PAAD: 34 samples
   PCPG: 30 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


   PRAD: 106 samples
   READ: 22 samples
   SARC: 49 samples
   SKCM: 96 samples
   STAD: 73 samples
   TGCT: 33 samples
   THCA: 93 samples
   THYM: 16 samples
   UCEC: 87 samples
   UVM: 20 samples

   Processing fold_2...
Loading predictions from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_2\predictions.csv
Loading probabilities from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_2\probabilities.csv


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])



Computing per-cancer-type metrics...
   Cancer types: ['ACC', 'BLCA', 'BRCA', 'CESC', 'CHOL', 'COAD', 'ESCA', 'GBM', 'HNSC', 'KICH', 'KIRC', 'KIRP', 'LGG', 'LIHC', 'LUAD', 'LUSC', 'MESO', 'OV', 'PAAD', 'PCPG', 'PRAD', 'READ', 'SARC', 'SKCM', 'STAD', 'TGCT', 'THCA', 'THYM', 'UCEC', 'UVM']
   Total samples: 1717
   Genes: 4
   ACC: 14 samples
   BLCA: 88 samples
   BRCA: 157 samples
   CESC: 64 samples
   CHOL: 8 samples
   COAD: 51 samples
   ESCA: 36 samples
   GBM: 28 samples
   HNSC: 102 samples
   KICH: 11 samples
   KIRC: 80 samples
   KIRP: 51 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


   LGG: 90 samples
   LIHC: 69 samples
   LUAD: 115 samples
   LUSC: 106 samples
   MESO: 11 samples
   OV: 2 samples
   PAAD: 26 samples
   PCPG: 44 samples
   PRAD: 88 samples
   READ: 18 samples
   SARC: 52 samples
   SKCM: 95 samples
   STAD: 75 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])

   TGCT: 25 samples
   THCA: 86 samples
   THYM: 21 samples
   UCEC: 88 samples
   UVM: 16 samples

   Processing fold_3...
Loading predictions from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_3\predictions.csv
Loading probabilities from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_3\probabilities.csv

Computing per-cancer-type metrics...
   Cancer types: ['ACC', 'BLCA', 'BRCA', 'CESC', 'CHOL', 'COAD', 'ESCA', 'GBM', 'HNSC', 'KICH', 'KIRC', 'KIRP', 'LGG', 'LIHC', 'LUAD', 'LUSC', 'MESO', 'OV', 'PAAD', 'PCPG', 'PRAD', 'READ', 'SARC', 'SKCM', 'STAD', 'TGCT', 'THCA', 'THYM', 'UCEC', 'UVM']
   Total samples: 1717
   Genes: 4
   ACC: 13 samples
   BLCA: 77 samples
   BRCA: 143 samples
   CESC: 44 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])

   CHOL: 7 samples
   COAD: 55 samples
   ESCA: 33 samples
   GBM: 20 samples
   HNSC: 106 samples
   KICH: 18 samples
   KIRC: 62 samples
   KIRP: 61 samples
   LGG: 120 samples
   LIHC: 74 samples
   LUAD: 104 samples
   LUSC: 88 samples
   MESO: 22 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])

   OV: 3 samples
   PAAD: 43 samples
   PCPG: 33 samples
   PRAD: 89 samples
   READ: 16 samples
   SARC: 49 samples
   SKCM: 103 samples
   STAD: 86 samples
   TGCT: 26 samples
   THCA: 108 samples
   THYM: 21 samples
   UCEC: 81 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


   UVM: 12 samples

   Processing fold_4...
Loading predictions from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_4\predictions.csv
Loading probabilities from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_4\probabilities.csv

Computing per-cancer-type metrics...
   Cancer types: ['ACC', 'BLCA', 'BRCA', 'CESC', 'CHOL', 'COAD', 'ESCA', 'GBM', 'HNSC', 'KICH', 'KIRC', 'KIRP', 'LGG', 'LIHC', 'LUAD', 'LUSC', 'MESO', 'OV', 'PAAD', 'PCPG', 'PRAD', 'READ', 'SARC', 'SKCM', 'STAD', 'TGCT', 'THCA', 'THYM', 'UCEC', 'UVM']
   Total samples: 1717
   Genes: 4
   ACC: 18 samples
   BLCA: 74 samples
   BRCA: 169 samples
   CESC: 62 samples
   CHOL: 5 samples
   COAD: 58 samples
   ESCA: 39 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])

   GBM: 31 samples
   HNSC: 113 samples
   KICH: 9 samples
   KIRC: 59 samples
   KIRP: 46 samples
   LGG: 107 samples
   LIHC: 70 samples
   LUAD: 87 samples
   LUSC: 84 samples
   MESO: 10 samples
   OV: 4 samples
   PAAD: 35 samples
   PCPG: 29 samples
   PRAD: 105 samples
   READ: 15 samples
   SARC: 44 samples
   SKCM: 87 samples
   STAD: 81 samples
   TGCT: 29 samples
   THCA: 99 samples
   THYM: 31 samples
   UCEC: 100 samples
   UVM: 17 samples

   Processing fold_5...
Loading predictions from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_5\predictions.csv
Loading probabilities from: ..\results\notebook_multitask_test\all_cancers\kfold_prediction\fold_5\probabilities.csv


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Computing per-cancer-type metrics...
   Cancer types: ['ACC', 'BLCA', 'BRCA', 'CESC', 'CHOL', 'COAD', 'ESCA', 'GBM', 'HNSC', 'KICH', 'KIRC', 'KIRP', 'LGG', 'LIHC', 'LUAD', 'LUSC', 'MESO', 'OV', 'PAAD', 'PCPG', 'PRAD', 'READ', 'SARC', 'SKCM', 'STAD', 'TGCT', 'THCA', 'THYM', 'UCEC', 'UVM']
   Total samples: 1717
   Genes: 4
   ACC: 15 samples
   BLCA: 73 samples
   BRCA: 152 samples
   CESC: 44 samples
   CHOL: 5 samples
   COAD: 61 samples
   ESCA: 33 samples
   GBM: 41 samples
   HNSC: 84 samples
   KICH: 16 samples
   KIRC: 87 samples
   KIRP: 60 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


   LGG: 101 samples
   LIHC: 75 samples
   LUAD: 99 samples
   LUSC: 102 samples
   MESO: 15 samples
   OV: 7 samples
   PAAD: 32 samples
   PCPG: 40 samples
   PRAD: 92 samples
   READ: 16 samples
   SARC: 39 samples
   SKCM: 82 samples


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])

   STAD: 92 samples
   TGCT: 31 samples
   THCA: 99 samples
   THYM: 29 samples
   UCEC: 83 samples
   UVM: 12 samples

Aggregating results across 5 folds...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])

KeyboardInterrupt



In [None]:
# Compute Overall Metrics from Combined Predictions
print("="*60)
print("Computing Overall Metrics from Combined Predictions")
print("="*60)

from evaluation.metrics import compute_per_cancer_metrics

# Load combined predictions
combined_dir = kfold_output_dir / "combined_predictions"
combined_preds_path = combined_dir / "predictions.csv"
combined_probs_path = combined_dir / "probabilities.csv"

if combined_preds_path.exists() and combined_probs_path.exists():
    print(f"\nLoading combined predictions from: {combined_dir}")
    combined_preds = pd.read_csv(combined_preds_path, index_col=0)
    combined_probs = pd.read_csv(combined_probs_path, index_col=0)
    
    print(f"   Combined predictions: {len(combined_preds)} samples, {len(combined_preds.columns)} genes")
    
    # Compute overall metrics (across all cancer types)
    from evaluation.metrics import evaluate_multilabel
    
    # Align indices
    common_samples = mutation_data_subset.index.intersection(combined_preds.index)
    y_true_overall = mutation_data_subset.loc[common_samples]
    y_pred_overall = combined_preds.loc[common_samples]
    y_prob_overall = combined_probs.loc[common_samples]
    
    print(f"   Common samples: {len(common_samples)}")
    
    overall_metrics = evaluate_multilabel(
        y_true=y_true_overall.values,
        y_pred=y_pred_overall.values,
        y_prob=y_prob_overall.values,
        mutation_names=mutation_data_subset.columns.tolist()
    )
    
    # Save overall metrics
    overall_metrics.to_csv(viz_output_dir / 'overall_metrics_all_genes.csv')
    
    # Print summary
    print("\nOverall Metrics (across all cancer types):")
    print(overall_metrics[key_metrics].mean().round(3))
    
    # Create summary visualization
    fig, ax = plt.subplots(figsize=(10, 6))
    metric_means = overall_metrics[key_metrics].mean()
    metric_stds = overall_metrics[key_metrics].std()
    
    ax.bar(range(len(metric_means)), metric_means.values, 
           yerr=metric_stds.values, capsize=5, alpha=0.7)
    ax.set_xticks(range(len(metric_means)))
    ax.set_xticklabels(metric_means.index, rotation=45, ha='right')
    ax.set_ylabel('Metric Value')
    ax.set_title('Overall Performance Metrics (All Cancer Types Combined)')
    ax.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(viz_output_dir / 'overall_metrics_summary.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"\n   Overall metrics saved to: {viz_output_dir / 'overall_metrics_all_genes.csv'}")
else:
    print(f"   Warning: Combined predictions not found at {combined_dir}")
    print("   Run k-fold training first to generate combined predictions.")
