In [3]:
import sys
import os
import lightgbm
sys.path.append(os.path.abspath('../src'))

import copy
from preprocessing import data_loader
from pathlib import Path
import numpy as np
from sklearn.preprocessing import StandardScaler

from models.model_factory import ModelFactory
from main import load_config, run_kfold_training, run_train_test_split


In [4]:
# Initialize components
selected_cohorts = ["LUAD", "BRCA"]
data_load = data_loader.TCGADataLoader(use_cache=True)

# Load and preprocess data for the selected cohorts
print(f"Loading and preprocessing data for: {', '.join(selected_cohorts)}")
expression_data, mutation_data = data_load.preprocess_data(cancer_types=selected_cohorts)

Loading and preprocessing data for: LUAD, BRCA
Loading cached aligned expression and mutation data from:
C:\Users\KerenYlab.MEDICINE\OneDrive - Technion\Asaf\Expression_to_Mutation\mutation_prediction\cache\expression_aligned_BRCA-LUAD.pkl C:\Users\KerenYlab.MEDICINE\OneDrive - Technion\Asaf\Expression_to_Mutation\mutation_prediction\cache\mutation_aligned_BRCA-LUAD.pkl


In [None]:

config_path = Path("../config/config.yaml")
config = load_config(config_path)

X_log = np.log1p(expression_data)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_log)


cohort_suffix = "-".join(selected_cohorts)
results_root = Path("../results/notebook_multitask_test") / cohort_suffix
results_root.mkdir(parents=True, exist_ok=True)



In [None]:

cohort_suffix = "-".join(selected_cohorts)
results_root = Path("../results/notebook_multitask_test") / cohort_suffix
results_root.mkdir(parents=True, exist_ok=True)

models_to_compare = ["multitask_nn", "neural_net", "lightgbm"]
comparison_rows = []

for model_name in models_to_compare:
    config_variant = copy.deepcopy(config)
    config_variant['model']['name'] = model_name

    factory = ModelFactory()
    model = factory.get_model(
        model_name=model_name,
        input_size=X_scaled.shape[1],
        output_size=mutation_data.shape[1],
        config=config_variant,
    )

    model_dir = results_root / model_name
    metrics = run_train_test_split(
        model=model,
        X=X_scaled,
        Y=mutation_data,
        test_size=config_variant['preprocessing']['test_size'],
        output_dir=model_dir,
        config_meta={
            'config': config_variant,
            'cohorts': selected_cohorts,
            'model_name': model_name,
        },
        random_state=config_variant['preprocessing']['random_state'],
        label=model_name,
    )

    mean_metrics = metrics.mean()
    comparison_rows.append({'model': model_name, **mean_metrics.to_dict()})

comparison_df = pd.DataFrame(comparison_rows).set_index('model')
comparison_df.sort_values('f1', ascending=False)




âœ… Train/test evaluation complete.
                    f1   roc_auc  accuracy
multitask_nn  0.304076  0.794594  0.723675
âœ… Train/test evaluation complete.
                  f1  roc_auc  accuracy
neural_net  0.301886  0.78241  0.845082
[LightGBM] [Info] Number of positive: 52, number of negative: 976
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.964833 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3800537
[LightGBM] [Info] Number of data points in the train set: 1028, number of used features: 19210
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.050584 -> initscore=-2.932219
[LightGBM] [Info] Start training from score -2.932219
[LightGBM] [Info] Number of positive: 95, number of negative: 933
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.935503 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3800537
[Li

In [7]:
# Train multitask model on full dataset and extract head weights
from training.extract_weights import train_and_extract_head_weights
import pandas as pd

# Prepare data (using the same preprocessing as before)
X_log = np.log1p(expression_data)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_log)

# Configure model with specified architecture
config_weights = copy.deepcopy(config)
config_weights['model']['name'] = 'multitask_nn'
config_weights['model']['multitask_nn'].update({
    'hidden_layers': [2048, 1536, 1024],
    'head_layers': [],  # Simple linear output
    'dropout_rate': 0.2,
    'learning_rate': 0.0005,
    'epochs': 60,
    'batch_size': 128,
})

# Create output directory
weights_output_dir = results_root / "head_weights_extraction"
weights_output_dir.mkdir(parents=True, exist_ok=True)

print("Training multitask model on full dataset...")
print(f"Hidden layers: {config_weights['model']['multitask_nn']['hidden_layers']}")
print(f"Head layers: {config_weights['model']['multitask_nn']['head_layers']} (simple linear)")
print(f"Output directory: {weights_output_dir}")

# Train and extract head weights
factory = ModelFactory()
head_weights = train_and_extract_head_weights(
    model_factory=factory,
    X=X_scaled,
    Y=mutation_data,
    config=config_weights,
    hidden_layers=[2048, 1536, 1024],
    head_layers=None,  # Simple linear output
    output_dir=weights_output_dir,
)

# Display results
print("\n" + "="*60)
print("Head Weights Extraction Complete!")
print("="*60)
print(f"Weight matrix shape: {head_weights['weight_matrix_shape']}")
print(f"Number of genes: {len(head_weights['gene_names'])}")
print(f"\nFirst 10 genes:")
for gene in head_weights['gene_names'][:10]:
    print(f"  - {gene}")

# Create a DataFrame with gene names and their weight vector norms
weight_matrix = head_weights['all_weights'].cpu().numpy()
weight_norms = np.linalg.norm(weight_matrix, axis=1)
weights_df = pd.DataFrame({
    'gene': head_weights['gene_names'],
    'weight_norm': weight_norms,
})
weights_df = weights_df.sort_values('weight_norm', ascending=False)

print(f"\nTop 10 genes by weight norm:")
print(weights_df.head(10))

# Save weights DataFrame
weights_df.to_csv(weights_output_dir / "gene_weight_norms.csv", index=False)
print(f"\nGene weight norms saved to: {weights_output_dir / 'gene_weight_norms.csv'}")


Training multitask model on full dataset...
Hidden layers: [2048, 1536, 1024]
Head layers: [] (simple linear)
Output directory: ..\results\notebook_multitask_test\LUAD-BRCA\head_weights_extraction

ðŸ§¬ Training multitask model on full dataset...
   Genes: 129
   Hidden layers: [2048, 1536, 1024]
   Head layers: Simple linear
   Training on full dataset (no validation split)...
   Extracting head weights...
   Model and weights saved to: ..\results\notebook_multitask_test\LUAD-BRCA\head_weights_extraction
âœ… Training complete!
   Head weight matrix shape: torch.Size([129, 1024])

Head Weights Extraction Complete!
Weight matrix shape: torch.Size([129, 1024])
Number of genes: 129

First 10 genes:
  - ZNF831
  - OBSCN
  - ABCA13
  - ASTN1
  - SVEP1
  - ASXL3
  - SPTA1
  - FAT3
  - FAT1
  - FAT4

Top 10 genes by weight norm:
        gene  weight_norm
95     KEAP1     1.429592
52     STK11     1.385130
106   PIK3CA     1.366320
93      CDH1     1.357066
114    GATA3     1.353912
65     NRX

In [None]:
# Run gene ablation analysis
from interpretation.ablation import run_gene_ablation_analysis
import pandas as pd

# Prepare data (using the same preprocessing as before)
X_log = np.log1p(expression_data)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_log)

# Configure model for ablation (should be multitask_nn)
config_ablation = copy.deepcopy(config)
config_ablation['model']['name'] = 'multitask_nn'
# Use the same architecture as the head weights extraction for consistency
config_ablation['model']['multitask_nn'].update({
    'hidden_layers': [2048, 1536, 1024],
    'head_layers': [],  # Simple linear output
    'dropout_rate': 0.2,
    'learning_rate': 0.0005,
    'epochs': 60,
    'batch_size': 128,
})

# Create output directory
ablation_output_dir = results_root / "gene_ablation_analysis"
ablation_output_dir.mkdir(parents=True, exist_ok=True)

print("="*60)
print("Running Gene Ablation Analysis")
print("="*60)
print(f"Evaluation mode: train_test")
print(f"Number of genes: {mutation_data.shape[1]}")
print(f"This will train {mutation_data.shape[1]} models (one per removed gene)")
print(f"Output directory: {ablation_output_dir}")
print("="*60)

# Run ablation analysis
factory = ModelFactory()
metric_matrices = run_gene_ablation_analysis(
    model_factory=factory,
    X=X_scaled,
    Y=mutation_data,
    config=config_ablation,
    output_dir=results_root,
    eval_mode='train_test',  # Use 'kfold' for cross-validation
    test_size=config_ablation['preprocessing']['test_size'],
    random_state=config_ablation['preprocessing']['random_state'],
)

print("\n" + "="*60)
print("Ablation Analysis Complete!")
print("="*60)
print(f"Generated {len(metric_matrices)} metric matrices:")
for metric_name in metric_matrices.keys():
    print(f"  - {metric_name}")

# Display a sample matrix (F1 score)
if 'f1' in metric_matrices:
    print("\nF1 Score Ablation Matrix (removed_gene Ã— evaluated_gene):")
    print("Rows = gene removed, Columns = gene evaluated")
    print(metric_matrices['f1'].head(10))


In [None]:
# Visualize and explore ablation results
import matplotlib.pyplot as plt
import seaborn as sns

# Load ablation matrices from saved files
ablation_dir = results_root / "gene_ablation_analysis"
print(f"Loading ablation matrices from: {ablation_dir}")

# Load all metric matrices
metric_matrices = {}
metric_files = list(ablation_dir.glob("ablation_matrix_*.csv"))
for metric_file in metric_files:
    metric_name = metric_file.stem.replace("ablation_matrix_", "")
    metric_matrices[metric_name] = pd.read_csv(metric_file, index_col=0)

print(f"\nLoaded {len(metric_matrices)} metric matrices:")
for metric_name in metric_matrices.keys():
    print(f"  - {metric_name}")

# Visualize F1 score matrix as heatmap
if 'f1' in metric_matrices:
    f1_matrix = metric_matrices['f1']
    
    # Create heatmap
    plt.figure(figsize=(16, 14))
    sns.heatmap(
        f1_matrix,
        cmap='viridis',
        center=f1_matrix.values[~np.isnan(f1_matrix.values)].mean() if not f1_matrix.isna().all().all() else 0,
        annot=False,  # Set to True to show values (may be cluttered with many genes)
        fmt='.3f',
        cbar_kws={'label': 'F1 Score'},
        square=False,
        linewidths=0.5,
        linecolor='gray'
    )
    plt.title('Gene Ablation Analysis - F1 Score Matrix\n(Removed Gene Ã— Evaluated Gene)', fontsize=14, pad=20)
    plt.xlabel('Evaluated Gene', fontsize=12)
    plt.ylabel('Removed Gene', fontsize=12)
    plt.tight_layout()
    plt.savefig(ablation_dir / "f1_ablation_heatmap.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    # Analyze which gene removals have the biggest impact
    # Calculate mean F1 for each row (when that gene was removed)
    row_means = f1_matrix.mean(axis=1, skipna=True).sort_values(ascending=False)
    
    print("\nTop 10 genes whose removal has the LEAST impact (highest mean F1 when removed):")
    print(row_means.head(10))
    
    print("\nTop 10 genes whose removal has the MOST impact (lowest mean F1 when removed):")
    print(row_means.tail(10))
    
    # Calculate mean F1 for each column (how well that gene is predicted when others are removed)
    col_means = f1_matrix.mean(axis=0, skipna=True).sort_values(ascending=False)
    
    print("\nTop 10 genes that are predicted BEST when others are removed:")
    print(col_means.head(10))
    
    print("\nTop 10 genes that are predicted WORST when others are removed:")
    print(col_means.tail(10))

# Visualize ROC-AUC matrix
if 'roc_auc' in metric_matrices:
    roc_matrix = metric_matrices['roc_auc']
    
    plt.figure(figsize=(16, 14))
    sns.heatmap(
        roc_matrix,
        cmap='plasma',
        center=0.5,
        annot=False,
        fmt='.3f',
        cbar_kws={'label': 'ROC-AUC'},
        square=False,
        linewidths=0.5,
        linecolor='gray'
    )
    plt.title('Gene Ablation Analysis - ROC-AUC Matrix\n(Removed Gene Ã— Evaluated Gene)', fontsize=14, pad=20)
    plt.xlabel('Evaluated Gene', fontsize=12)
    plt.ylabel('Removed Gene', fontsize=12)
    plt.tight_layout()
    plt.savefig(ablation_dir / "roc_auc_ablation_heatmap.png", dpi=150, bbox_inches='tight')
    plt.show()

print(f"\nAll visualizations saved to: {ablation_dir}")


In [None]:
import itertools
import json

# Define search grid for multitask NN
hidden_layer_sets = [
    [2048, 1536, 1024, 512],
    [2048, 1024, 512],
    [1536, 1024, 512],
    [1024, 768, 512],
    [1024, 512, 256],
]
head_layer_sets = [
    [512, 256],
    [256, 128],
    [512],
    [256],
    [128]
    []
]
dropout_rates = [0.2, 0.3]
learning_rates = [5e-4, 1e-4]

search_params = []
for hidden_layers, head_layers, dropout_rate, lr in itertools.product(
    hidden_layer_sets, head_layer_sets, dropout_rates, learning_rates
):
    # Skip redundant configuration where head equals main width and dropout high
    search_params.append(
        {
            'hidden_layers': hidden_layers,
            'head_layers': head_layers,
            'dropout_rate': dropout_rate,
            'learning_rate': lr,
        }
    )

grid_root = results_root / "multitask_grid"
grid_root.mkdir(parents=True, exist_ok=True)

grid_rows = []

for run_idx, params in enumerate(search_params, start=1):
    cfg = copy.deepcopy(config)
    cfg['model']['name'] = 'multitask_nn'
    cfg['model']['multitask_nn'].update(params)

    factory = ModelFactory()
    model = factory.get_model(
        model_name='multitask_nn',
        input_size=X_small.shape[1],
        output_size=Y_small.shape[1],
        config=cfg,
    )

    run_dir = grid_root / f"run_{run_idx:02d}"
    metrics = run_train_test_split(
        model=model,
        X=X_small,
        Y=Y_small,
        test_size=cfg['preprocessing']['test_size'],
        output_dir=run_dir,
        config_meta={
            'selected_features': selected_features,
            'config': cfg,
            'cohorts': selected_cohorts,
            'search_params': params,
        },
        random_state=cfg['preprocessing']['random_state'],
        label=f"multitask_nn_run_{run_idx:02d}",
    )

    mean_metrics = metrics.mean()
    grid_rows.append({
        'run': run_idx,
        **params,
        **mean_metrics.to_dict(),
    })

grid_df = pd.DataFrame(grid_rows).set_index('run')
grid_df.sort_values('f1', ascending=False)
