In [None]:
# %% [markdown]
# # ALS Diagnosis with Machine Learning and SHAP
# 
# This notebook provides a quick start guide for running the ALS diagnosis pipeline
# 
# ## Overview
# 1. Data Download and Processing
# 2. Feature Selection (MMPC + Ridge + SFFS)
# 3. Model Training and SHAP Analysis
# 4. Results Interpretation

# %% [markdown]
# ## Setup

# %%
import sys
import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append('src')

# Create necessary directories
Path('data/raw').mkdir(parents=True, exist_ok=True)
Path('data/processed').mkdir(parents=True, exist_ok=True)
Path('data/results/feature_selection').mkdir(parents=True, exist_ok=True)
Path('data/results/shap_analysis').mkdir(parents=True, exist_ok=True)
Path('logs').mkdir(parents=True, exist_ok=True)

print("Setup completed!")

# %% [markdown]
# ## Step 1: Data Download and Processing

# %%
from data_processing.geo_downloader import GEODownloader

# Initialize downloader
downloader = GEODownloader(base_dir="data")

# Download and process datasets (this may take a few minutes)
print("Downloading and processing GEO datasets...")
try:
    gse_ids = ["GSE112676", "GSE112680"]
    expression_data, metadata = downloader.process_datasets(gse_ids)
    downloader.save_processed_data(expression_data, metadata)
    
    print(f"✓ Data processing completed!")
    print(f"Expression data shape: {expression_data.shape}")
    print(f"Sample distribution: {metadata['group'].value_counts().to_dict()}")
    
except Exception as e:
    print(f"❌ Error in data processing: {str(e)}")
    print("Please check your internet connection and try again")

# %% [markdown]
# ## Step 2: Data Exploration

# %%
# Load and explore the processed data
expression_data = pd.read_csv("data/processed/combined_expression_data.csv", index_col=0)
metadata = pd.read_csv("data/processed/sample_metadata.csv")

print("Dataset Information:")
print(f"Genes: {expression_data.shape[0]}")
print(f"Samples: {expression_data.shape[1]}")
print(f"Sample groups: {metadata['group'].value_counts().to_dict()}")

# Visualize sample distribution
plt.figure(figsize=(10, 6))
metadata['group'].value_counts().plot(kind='bar')
plt.title('Sample Distribution')
plt.ylabel('Number of Samples')
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()

# %% [markdown]
# ## Step 3: Feature Selection Pipeline

# %%
from feature_selection_pipeline import FeatureSelectionPipeline
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Prepare data for feature selection
X = expression_data.T  # Transpose: samples as rows, genes as columns
y = metadata['group'].map({'ALS': 1, 'Control': 0})

# Remove samples with missing labels
valid_samples = ~y.isna()
X = X[valid_samples]
y = y[valid_samples]

print(f"Feature selection input: {X.shape[0]} samples, {X.shape[1]} genes")
print(f"Class distribution: {y.value_counts().to_dict()}")

# %%
# Run feature selection pipeline
print("Running feature selection pipeline...")
print("This may take 15-30 minutes depending on your hardware...")

pipeline = FeatureSelectionPipeline()
pipeline.fit(X, y)

# Get results
best_features = pipeline.get_selected_features()
best_algorithm = pipeline.get_best_algorithm()
best_score = pipeline.best_config_[3]

print(f"\n🎯 FEATURE SELECTION RESULTS:")
print(f"Best algorithm: {best_algorithm}")
print(f"Selected features: {len(best_features)}")
print(f"Best CV score: {best_score:.4f}")

print(f"\nSelected genes:")
for i, gene in enumerate(best_features, 1):
    print(f"  {i:2d}. {gene}")

# Save results
pipeline.save_results("data/results/feature_selection")

# %% [markdown]
# ## Step 4: SHAP Analysis

# %%
from shap_interpretability import SHAPAnalyzer

# Load feature selection results
import pickle
with open("data/results/feature_selection/pipeline_results.pkl", 'rb') as f:
    pipeline_results = pickle.load(f)

best_config = pipeline_results['best_config']
selected_features = best_config['features']
best_algorithm = best_config['algorithm']

print(f"Loading SHAP analysis for {best_algorithm} with {len(selected_features)} features")

# Prepare data
X_selected = expression_data.T[selected_features]
y = metadata['group'].map({'ALS': 1, 'Control': 0})

# Remove missing labels
valid_samples = ~y.isna()
X_selected = X_selected[valid_samples]
y = y[valid_samples]

# %%
# Initialize and run SHAP analysis
print("Running SHAP analysis...")

analyzer = SHAPAnalyzer(model_type=best_algorithm)

# Train model
training_results = analyzer.fit(X_selected, y, test_size=0.1)
print(f"Model Performance:")
print(f"  Test Accuracy: {training_results['test_accuracy']:.4f}")
print(f"  Test AUC: {training_results['test_auc']:.4f}")

# Create SHAP explainer and calculate values
print("Creating SHAP explainer...")
analyzer.create_shap_explainer(background_samples=50)

print("Calculating SHAP values...")
analyzer.calculate_shap_values(max_samples=100)

print("✓ SHAP analysis completed!")

# %%
# Generate SHAP visualizations
output_dir = "data/results/shap_analysis"

print("Generating SHAP visualizations...")

# Feature importance plot
importance_df = analyzer.plot_feature_importance(output_dir, show_plot=True)

# Summary plot
analyzer.plot_summary(output_dir, show_plot=True)

print("✓ Visualizations generated!")

# %% [markdown]
# ## Step 5: Individual Sample Analysis

# %%
# Analyze individual samples
print("Analyzing individual samples...")

# Show waterfall plots for a few samples
analyzer.plot_waterfall(sample_idx=0, output_dir=output_dir, show_plot=True)
analyzer.plot_waterfall(sample_idx=1, output_dir=output_dir, show_plot=True)

# %% [markdown]
# ## Step 6: Gene Interaction Analysis

# %%
# Analyze gene interactions
print("Analyzing gene interactions...")
correlation_matrix = analyzer.analyze_gene_interactions(output_dir)

print("Top gene pairs with strongest interactions:")
# Find top correlations (excluding diagonal)
mask = np.triu(np.ones_like(correlation_matrix), k=1).astype(bool)
correlations = correlation_matrix.where(mask).stack().sort_values(key=abs, ascending=False)
print(correlations.head(10))

# %% [markdown]
# ## Step 7: Comprehensive Report

# %%
# Generate comprehensive interpretation report
print("Generating comprehensive interpretation report...")
report = analyzer.generate_interpretation_report(output_dir)

print("\n🎯 FINAL RESULTS SUMMARY:")
print(f"Model: {best_algorithm}")
print(f"Test Accuracy: {training_results['test_accuracy']:.4f}")
print(f"Test AUC: {training_results['test_auc']:.4f}")
print(f"Number of features: {len(selected_features)}")

print(f"\nTop 10 Most Important Genes:")
for i, gene_info in enumerate(report['top_genes'][:10], 1):
    print(f"  {i:2d}. {gene_info['gene']:15s} - Importance: {gene_info['mean_abs_shap']:.4f}")

print(f"\nMost Important Gene: {report['most_important_gene']['name']}")
print(f"Importance Score: {report['most_important_gene']['importance']:.4f}")

# %% [markdown]
# ## Step 8: Results Summary

# %%
# Display final results
print("="*60)
print("ALS DIAGNOSIS PIPELINE - COMPLETED")
print("="*60)
print("✓ Data downloaded and processed")
print("✓ Feature selection completed")
print("✓ Model training finished")
print("✓ SHAP analysis completed")
print("✓ Comprehensive report generated")

print(f"\n📊 KEY METRICS:")
print(f"Selected Genes: {len(selected_features)}")
print(f"Best Algorithm: {best_algorithm}")
print(f"Cross-validation Score: {best_score:.4f}")
print(f"Test Accuracy: {training_results['test_accuracy']:.4f}")
print(f"Test AUC: {training_results['test_auc']:.4f}")

print(f"\n📁 Results Location:")
print(f"  - Feature Selection: data/results/feature_selection/")
print(f"  - SHAP Analysis: data/results/shap_analysis/")
print(f"  - Visualizations: PNG files in results directories")

print(f"\n🧬 Top 5 Genes for ALS Diagnosis:")
for i, gene_info in enumerate(report['top_genes'][:5], 1):
    print(f"  {i}. {gene_info['gene']} (importance: {gene_info['mean_abs_shap']:.4f})")

# %% [markdown]
# ## Next Steps
# 
# 1. **Biological Validation**: Research the biological roles of top-ranked genes
# 2. **External Validation**: Test on independent datasets
# 3. **Clinical Integration**: Develop clinical decision support tools
# 4. **Publication**: Document findings for scientific publication
# 
# ## Files Generated
# 
# - `data/results/feature_selection/`: Complete feature selection results
# - `data/results/shap_analysis/`: SHAP interpretability analysis
# - Various CSV files with detailed gene rankings and statistics
# - PNG visualization files for presentations and publications

# %%