# Mutation Prediction from Gene Expression - Data Exploration

This notebook demonstrates how to:
1. Load and preprocess TCGA expression and mutation data
2. Run gene ablation analysis
3. Perform SHAP analysis for model interpretation

## Setup Instructions

1. Make sure you have installed all dependencies: `pip install -r requirements.txt`
2. Update the data paths in `config/config.yaml` to point to your data directories
3. Run the cells sequentially


In [8]:
import sys
import os
import lightgbm
sys.path.append(os.path.abspath('../src'))

import copy
from preprocessing import data_loader
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

from models.model_factory import ModelFactory
from main import load_config, run_kfold_training, run_train_test_split
from interpretation.ablation import run_gene_ablation_analysis


In [9]:
# Initialize components
selected_cohorts = ["LUAD", "BRCA"]
data_load = data_loader.TCGADataLoader(use_cache=True)

# Load and preprocess data for the selected cohorts
print(f"Loading and preprocessing data for: {', '.join(selected_cohorts)}")
expression_data, mutation_data = data_load.preprocess_data(cancer_types=selected_cohorts)

Loading and preprocessing data for: LUAD, BRCA
Loading cached aligned expression and mutation data from:
C:\Users\KerenYlab.MEDICINE\OneDrive - Technion\Asaf\Expression_to_Mutation\mutation_prediction\cache\expression_aligned_BRCA-LUAD.pkl C:\Users\KerenYlab.MEDICINE\OneDrive - Technion\Asaf\Expression_to_Mutation\mutation_prediction\cache\mutation_aligned_BRCA-LUAD.pkl


In [10]:
config_path = Path("../config/config.yaml")
config = load_config(config_path)

X_log = np.log1p(expression_data)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_log)

cohort_suffix = "-".join(selected_cohorts)
results_root = Path("../results/notebook_multitask_test") / cohort_suffix
results_root.mkdir(parents=True, exist_ok=True)

# Convert to DataFrame for ablation functions
X_scaled_df = pd.DataFrame(X_scaled, index=expression_data.index, columns=expression_data.columns)


In [11]:
# Import ablation functions
from interpretation.ablation import (
    run_gene_ablation_analysis,
    load_ablation_results
)

# Initialize model factory
model_factory = ModelFactory()

# For testing, let's use a subset of genes (first 10) to make it faster
# Remove this line to run on all genes
test_genes = mutation_data.columns[:4].tolist()
mutation_data_subset = mutation_data[test_genes].copy()

print(f"Expression data shape: {X_scaled_df.shape}")
print(f"Mutation data shape: {mutation_data_subset.shape}")
print(f"Number of genes: {len(mutation_data_subset.columns)}")
print(f"Genes: {', '.join(mutation_data_subset.columns.tolist())}")


Expression data shape: (1285, 19964)
Mutation data shape: (1285, 4)
Number of genes: 4
Genes: ZNF831, OBSCN, ABCA13, ASTN1


In [12]:
config['model']['name'] = 'multitask_nn'

In [13]:
# Create mapping from sample ID to cancer type (correct format)
# Since X_scaled_df doesn't have cancer_* columns (they're in expression_data),
# we need to create this mapping from expression_data and pass it to the ablation function
sample_to_cancer = {}
for cancer_col in expression_data.columns[expression_data.columns.str.startswith('cancer_')]:
    cancer_type_name = cancer_col.replace('cancer_', '')
    samples_with_cancer = expression_data.index[expression_data[cancer_col] == 1]
    for sample_id in samples_with_cancer:
        sample_to_cancer[sample_id] = cancer_type_name

print(f"Created sample_to_cancer mapping with {len(sample_to_cancer)} samples")
print(f"Cancer types: {set(sample_to_cancer.values())}")

Created sample_to_cancer mapping with 1285 samples
Cancer types: {'LUAD', 'BRCA'}


In [14]:
# Run ablation analysis
# This will train a model for each gene removal
ablation_results = run_gene_ablation_analysis(
    model_factory=model_factory,
    X=X_scaled_df,
    Y=mutation_data_subset,
    config=config,
    output_dir=results_root,
    sample_to_cancer=sample_to_cancer,
    eval_mode="kfold",  # Use "kfold" for cross-validation
    test_size=0.2,
    k=2,
    random_state=42,
)



ðŸ”¬ Running ablation with baseline comparison for 4 genes...
   Mode: kfold
   Folds: 2

  Fold 1/2...





  Fold 2/2...





ðŸ“Š Combining predictions from 2 folds...
   âœ“ Baseline predictions combined (1285 samples)
   âœ“ Ablation predictions combined for 4 genes
   Combined predictions saved to: ..\results\notebook_multitask_test\LUAD-BRCA\ablation\combined_predictions

ðŸ“Š Computing difference matrices by cancer type from combined predictions...
   âœ“ Clustermap saved for f1
   âœ“ Clustermap saved for roc_auc
   âœ“ Clustermap saved for auprc
   âœ“ Clustermap saved for accuracy
   âœ“ Clustermap saved for precision
   âœ“ Clustermap saved for recall
   âœ“ Clustermap saved for specificity
   âœ“ Clustermap saved for mcc
     âœ“ BRCA (8 metrics, 778 samples)
   âœ“ Clustermap saved for f1
   âœ“ Clustermap saved for roc_auc
   âœ“ Clustermap saved for auprc
   âœ“ Clustermap saved for accuracy
   âœ“ Clustermap saved for precision
   âœ“ Clustermap saved for recall
   âœ“ Clustermap saved for specificity
   âœ“ Clustermap saved for mcc
     âœ“ LUAD (8 metrics, 507 samples)

ðŸ“Š Building differe

In [None]:
# Train tree-based model for SHAP analysis
# Note: SHAP doesn't require train-test split - can use full dataset
print("Training tree-based model for SHAP analysis...")

model_shap = model_factory.get_model(
    model_name=config['model']['name'],
    input_size=X_scaled_df.shape[1],
    output_size=len(mutation_data_subset.columns),
    config=config,
)

print(f"   Training on {len(X_scaled_df)} samples...")
model_shap.fit(X_scaled_df.values, mutation_data_subset.values)

print("   Model trained successfully!")

ðŸ”¬ Training tree-based model for SHAP analysis...
   Training on 1285 samples...
[LightGBM] [Info] Number of positive: 66, number of negative: 1219
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 1.025880 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3880826
[LightGBM] [Info] Number of data points in the train set: 1285, number of used features: 19292
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.051362 -> initscore=-2.916131
[LightGBM] [Info] Start training from score -2.916131
[LightGBM] [Info] Number of positive: 120, number of negative: 1165
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 2.085411 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3880826
[LightGBM] [Info] Number of data points in the train set: 1285, number of used features: 19292
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.093385 -> 

In [None]:
# Run SHAP Analysis for All Genes
try:
    from interpretation.shap_analysis import compute_shap_for_all_genes
    from visualization.shap import visualize_all_shap_results
    
    print("\nRunning SHAP Analysis...")
    print("="*60)
    
    # Setup output directory
    shap_output_dir = results_root / "shap_analysis" / "combined"
    shap_output_dir.mkdir(parents=True, exist_ok=True)
    
    gene_names = mutation_data_subset.columns.tolist()
    
    # Optional: Select specific genes to explain (set to None to analyze all genes)
    selected_genes = None  # Example: selected_genes = ['ZNF831', 'OBSCN']
    
    print(f"   Model: {config['model']['name']}")
    if selected_genes:
        print(f"   Selected genes: {len(selected_genes)} ({', '.join(selected_genes)})")
    else:
        print(f"   Genes: {len(gene_names)} ({', '.join(gene_names)})")
    print(f"   Dataset: {len(X_scaled_df)} samples (full dataset)")
    
    # Compute SHAP values for all genes (or selected genes if specified)
    shap_results = compute_shap_for_all_genes(
        model=model_shap,
        X_explain=X_scaled_df,  # Full dataset (no train-test split needed)
        gene_names=gene_names,
        save_dir=shap_output_dir,
        selected_genes=selected_genes,  # Optional: analyze only selected genes
    )
    
    print(f"\n   SHAP computation complete!")
    print(f"   Results for {len(shap_results)} genes")
    
    # Visualize per gene (summary plots and bar plots)
    print(f"\n   Creating visualizations...")
    visualize_all_shap_results(
        all_shap_results=shap_results,
        output_dir=shap_output_dir,
        max_features=50,  # Top 50 features
        include_heatmap=False,  # Skip per-sample heatmaps
    )
    
    print(f"\n   Visualizations saved to: {shap_output_dir}")
    
except ImportError as e:
    print(f"Warning: SHAP not available. Install with: pip install shap")
    print(f"   Error: {e}")
except Exception as e:
    print(f"Error: {e}")
    raise



ðŸ”¬ Running SHAP Analysis...
   Model: lightgbm
   Genes: 4 (ZNF831, OBSCN, ABCA13, ASTN1)
   Dataset: 1285 samples (full dataset)

ðŸ”¬ Computing SHAP values for 4 genes...
   Dataset: 1285 samples

  [1/4] Processing ZNF831...
   Computing SHAP for ZNF831 using TreeExplainer...





  [2/4] Processing OBSCN...
   Computing SHAP for OBSCN using TreeExplainer...

  [3/4] Processing ABCA13...
   Computing SHAP for ABCA13 using TreeExplainer...

  [4/4] Processing ASTN1...
   Computing SHAP for ASTN1 using TreeExplainer...

âœ… SHAP computation complete for 4 genes

   âœ… SHAP computation complete!
   Results for 4 genes

   ðŸ“Š Creating visualizations...

ðŸ“Š Creating model-level SHAP visualizations for 4 genes...
   Creating model-level visualizations for ZNF831...
     âœ“ Visualizations saved to ..\results\notebook_multitask_test\LUAD-BRCA\shap_analysis\combined\ZNF831
   Creating model-level visualizations for OBSCN...
     âœ“ Visualizations saved to ..\results\notebook_multitask_test\LUAD-BRCA\shap_analysis\combined\OBSCN
   Creating model-level visualizations for ABCA13...
     âœ“ Visualizations saved to ..\results\notebook_multitask_test\LUAD-BRCA\shap_analysis\combined\ABCA13
   Creating model-level visualizations for ASTN1...
     âœ“ Visualizations sa

In [None]:
# Run SHAP Analysis Per Cancer Type (Optional)
try:
    from interpretation.shap_analysis import compute_shap_per_cancer_type
    from visualization.shap import visualize_all_shap_results
    
    print("\nRunning SHAP Analysis Per Cancer Type...")
    print("="*60)
    
    # Setup output directory
    shap_cancer_output_dir = results_root / "shap_analysis" / "per_cancer_type"
    
    gene_names = mutation_data_subset.columns.tolist()
    cancer_type_names = sorted(set(sample_to_cancer.values()))
    
    # Optional: Select specific genes to explain (set to None to analyze all genes)
    selected_genes = None  # Example: selected_genes = ['ZNF831', 'OBSCN']
    
    print(f"   Cancer types: {cancer_type_names}")
    if selected_genes:
        print(f"   Selected genes: {len(selected_genes)} ({', '.join(selected_genes)})")
    else:
        print(f"   Genes: {len(gene_names)} ({', '.join(gene_names)})")
    print(f"   Dataset: {len(X_scaled_df)} samples")
    
    # Compute SHAP per cancer type
    per_cancer_shap_results = compute_shap_per_cancer_type(
        model=model_shap,
        X_explain=X_scaled_df,  # Full dataset
        gene_names=gene_names,
        sample_to_cancer=sample_to_cancer,  # Dict mapping from Cell 5
        output_dir=shap_cancer_output_dir,
        selected_genes=selected_genes,  # Optional: analyze only selected genes
    )
    
    print(f"\n   Per-cancer-type SHAP analysis complete!")
    
    # Visualize per cancer type
    print(f"\n   Creating per-cancer-type visualizations...")
    for cancer_type, gene_results in per_cancer_shap_results.items():
        if len(gene_results) > 0:
            print(f"     Processing {cancer_type}...")
            cancer_dir = shap_cancer_output_dir / "cancer_type" / cancer_type
            visualize_all_shap_results(
                all_shap_results=gene_results,
                output_dir=cancer_dir,
                max_features=50,
                include_heatmap=False,
            )
            print(f"     {cancer_type} complete")
    
    print(f"\n   All per-cancer-type visualizations saved to: {shap_cancer_output_dir}")
    
except ImportError as e:
    print(f"Warning: SHAP not available. Install with: pip install shap")
    print(f"   Error: {e}")
except Exception as e:
    print(f"Error: {e}")
    raise



ðŸ”¬ Running SHAP Analysis Per Cancer Type...
   Cancer types: ['BRCA', 'LUAD']
   Genes: 4 (ZNF831, OBSCN, ABCA13, ASTN1)
   Dataset: 1285 samples

ðŸ”¬ Running SHAP analysis per cancer type...
   Cancer types: ['BRCA', 'LUAD']
   Total samples: 1285

Processing: BRCA
   Samples: 778

ðŸ”¬ Computing SHAP values for 4 genes...
   Dataset: 778 samples

  [1/4] Processing ZNF831...
   Computing SHAP for ZNF831 using TreeExplainer...





  [2/4] Processing OBSCN...
   Computing SHAP for OBSCN using TreeExplainer...

  [3/4] Processing ABCA13...
   Computing SHAP for ABCA13 using TreeExplainer...

  [4/4] Processing ASTN1...
   Computing SHAP for ASTN1 using TreeExplainer...

âœ… SHAP computation complete for 4 genes
   âœ… BRCA complete! (4 genes)

Processing: LUAD
   Samples: 507

ðŸ”¬ Computing SHAP values for 4 genes...
   Dataset: 507 samples

  [1/4] Processing ZNF831...
   Computing SHAP for ZNF831 using TreeExplainer...

  [2/4] Processing OBSCN...
   Computing SHAP for OBSCN using TreeExplainer...

  [3/4] Processing ABCA13...
   Computing SHAP for ABCA13 using TreeExplainer...

  [4/4] Processing ASTN1...
   Computing SHAP for ASTN1 using TreeExplainer...

âœ… SHAP computation complete for 4 genes
   âœ… LUAD complete! (4 genes)

   âœ… Per-cancer-type SHAP analysis complete!

   ðŸ“Š Creating per-cancer-type visualizations...
     Processing BRCA...

ðŸ“Š Creating model-level SHAP visualizations for 4 genes.