# 03 - Model Training and Evaluation

## Overview
In this notebook, we'll train and compare multiple ML models for cancer classification:

1. **Baseline**: Logistic Regression
2. **Tree-based**: Random Forest, XGBoost
3. **SVM**: Support Vector Machine

We'll also:
- Use cross-validation for robust evaluation
- Analyze feature importance (which genes matter most)
- Create publication-quality visualizations
- Interpret results biologically

## 1. Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import warnings
warnings.filterwarnings('ignore')

# ML imports
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier

# Evaluation
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    roc_auc_score, roc_curve, auc, precision_recall_fscore_support
)
from sklearn.preprocessing import label_binarize

# For SHAP (model interpretation)
try:
    import shap
    SHAP_AVAILABLE = True
except ImportError:
    SHAP_AVAILABLE = False
    print("SHAP not installed. Install with: pip install shap")

plt.style.use('seaborn-v0_8-whitegrid')
RANDOM_STATE = 42

print("Libraries loaded!")

In [None]:
# Load preprocessed data
X_train = np.load('../data/processed/X_train.npy')
X_test = np.load('../data/processed/X_test.npy')
y_train = np.load('../data/processed/y_train.npy')
y_test = np.load('../data/processed/y_test.npy')

with open('../data/processed/selected_genes.pkl', 'rb') as f:
    selected_genes = pickle.load(f)

with open('../data/processed/label_encoder.pkl', 'rb') as f:
    label_encoder = pickle.load(f)

class_names = label_encoder.classes_
n_classes = len(class_names)

print(f"Training data: {X_train.shape}")
print(f"Test data: {X_test.shape}")
print(f"Classes: {class_names}")

## 2. Define Models

We'll compare several commonly used classifiers:

In [None]:
# Define models with reasonable hyperparameters
models = {
    'Logistic Regression': LogisticRegression(
        max_iter=1000, 
        random_state=RANDOM_STATE,
        multi_class='multinomial',
        solver='lbfgs',
        C=1.0
    ),
    'Random Forest': RandomForestClassifier(
        n_estimators=200,
        max_depth=20,
        min_samples_split=5,
        random_state=RANDOM_STATE,
        n_jobs=-1
    ),
    'SVM (RBF)': SVC(
        kernel='rbf',
        C=10,
        gamma='scale',
        probability=True,
        random_state=RANDOM_STATE
    ),
    'XGBoost': XGBClassifier(
        n_estimators=200,
        max_depth=6,
        learning_rate=0.1,
        random_state=RANDOM_STATE,
        use_label_encoder=False,
        eval_metric='mlogloss'
    )
}

print("Models defined:")
for name in models:
    print(f"  - {name}")

## 3. Cross-Validation Comparison

First, let's compare all models using 5-fold cross-validation on training data:

In [None]:
# Cross-validation setup
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)

# Store results
cv_results = {}

print("Running 5-fold cross-validation...\n")
print("-" * 50)

for name, model in models.items():
    print(f"Training {name}...", end=" ")
    scores = cross_val_score(model, X_train, y_train, cv=cv, scoring='accuracy', n_jobs=-1)
    cv_results[name] = scores
    print(f"Accuracy: {scores.mean():.4f} (+/- {scores.std()*2:.4f})")

print("-" * 50)

In [None]:
# Visualize CV results
fig, ax = plt.subplots(figsize=(10, 6))

cv_df = pd.DataFrame(cv_results)
cv_df_melted = cv_df.melt(var_name='Model', value_name='Accuracy')

sns.boxplot(data=cv_df_melted, x='Model', y='Accuracy', palette='husl', ax=ax)
sns.stripplot(data=cv_df_melted, x='Model', y='Accuracy', color='black', alpha=0.5, ax=ax)

ax.set_ylabel('Accuracy', fontsize=12)
ax.set_xlabel('')
ax.set_title('5-Fold Cross-Validation Results', fontsize=14)
ax.set_ylim([0.9, 1.01])

# Add mean values as text
for i, name in enumerate(models.keys()):
    mean_acc = cv_results[name].mean()
    ax.text(i, mean_acc + 0.008, f'{mean_acc:.3f}', ha='center', fontsize=10)

plt.tight_layout()
plt.savefig('../results/figures/cv_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Train Final Models and Evaluate on Test Set

Now let's train on all training data and evaluate on the held-out test set:

In [None]:
# Train all models on full training set
trained_models = {}
test_results = {}

print("Training on full training set and evaluating on test set...\n")
print("-" * 60)

for name, model in models.items():
    print(f"Training {name}...", end=" ")
    
    # Fit on training data
    model.fit(X_train, y_train)
    trained_models[name] = model
    
    # Predict on test set
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)
    
    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    
    # Multi-class ROC-AUC
    y_test_bin = label_binarize(y_test, classes=range(n_classes))
    roc_auc = roc_auc_score(y_test_bin, y_prob, average='weighted', multi_class='ovr')
    
    test_results[name] = {
        'accuracy': accuracy,
        'roc_auc': roc_auc,
        'y_pred': y_pred,
        'y_prob': y_prob
    }
    
    print(f"Test Accuracy: {accuracy:.4f}, ROC-AUC: {roc_auc:.4f}")

print("-" * 60)

In [None]:
# Summary table
results_df = pd.DataFrame({
    'Model': list(test_results.keys()),
    'CV Accuracy (mean)': [cv_results[m].mean() for m in test_results.keys()],
    'CV Accuracy (std)': [cv_results[m].std() for m in test_results.keys()],
    'Test Accuracy': [test_results[m]['accuracy'] for m in test_results.keys()],
    'Test ROC-AUC': [test_results[m]['roc_auc'] for m in test_results.keys()]
})

results_df = results_df.round(4)
print("\nMODEL COMPARISON SUMMARY")
print("=" * 70)
print(results_df.to_string(index=False))

## 5. Detailed Analysis of Best Model

Let's analyze the best performing model in detail:

In [None]:
# Select best model based on test accuracy
best_model_name = max(test_results, key=lambda x: test_results[x]['accuracy'])
best_model = trained_models[best_model_name]
best_results = test_results[best_model_name]

print(f"Best Model: {best_model_name}")
print(f"Test Accuracy: {best_results['accuracy']:.4f}")
print(f"Test ROC-AUC: {best_results['roc_auc']:.4f}")

In [None]:
# Classification report
print("\nCLASSIFICATION REPORT")
print("=" * 60)
print(classification_report(y_test, best_results['y_pred'], target_names=class_names))

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_test, best_results['y_pred'])

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Absolute numbers
ax1 = axes[0]
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names, ax=ax1)
ax1.set_ylabel('True Label', fontsize=12)
ax1.set_xlabel('Predicted Label', fontsize=12)
ax1.set_title(f'Confusion Matrix ({best_model_name})', fontsize=14)

# Normalized (percentages)
ax2 = axes[1]
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
sns.heatmap(cm_norm, annot=True, fmt='.1f', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names, ax=ax2)
ax2.set_ylabel('True Label', fontsize=12)
ax2.set_xlabel('Predicted Label', fontsize=12)
ax2.set_title('Confusion Matrix (Normalized %)', fontsize=14)

plt.tight_layout()
plt.savefig('../results/figures/confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. ROC Curves (Multi-class)

In [None]:
# Compute ROC curve for each class
y_test_bin = label_binarize(y_test, classes=range(n_classes))
y_prob = best_results['y_prob']

fig, ax = plt.subplots(figsize=(10, 8))

colors = sns.color_palette('husl', n_classes)

for i, (color, class_name) in enumerate(zip(colors, class_names)):
    fpr, tpr, _ = roc_curve(y_test_bin[:, i], y_prob[:, i])
    roc_auc_i = auc(fpr, tpr)
    ax.plot(fpr, tpr, color=color, lw=2, 
            label=f'{class_name} (AUC = {roc_auc_i:.3f})')

ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title(f'ROC Curves - {best_model_name}', fontsize=14)
ax.legend(loc='lower right')

plt.tight_layout()
plt.savefig('../results/figures/roc_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Feature Importance Analysis

Which genes are most important for classification? This is key for biological interpretation!

In [None]:
# Get feature importance from Random Forest (has built-in importance)
rf_model = trained_models['Random Forest']
importances = rf_model.feature_importances_

# Create importance dataframe
importance_df = pd.DataFrame({
    'gene': selected_genes,
    'importance': importances
}).sort_values('importance', ascending=False)

print("TOP 30 MOST IMPORTANT GENES (Random Forest)")
print("=" * 50)
print(importance_df.head(30).to_string(index=False))

In [None]:
# Visualize top features
top_n = 25
top_features = importance_df.head(top_n)

fig, ax = plt.subplots(figsize=(10, 10))

ax.barh(range(top_n), top_features['importance'].values[::-1], color='steelblue')
ax.set_yticks(range(top_n))
ax.set_yticklabels(top_features['gene'].values[::-1])
ax.set_xlabel('Feature Importance', fontsize=12)
ax.set_title(f'Top {top_n} Most Important Genes for Cancer Classification', fontsize=14)

plt.tight_layout()
plt.savefig('../results/figures/feature_importance.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Expression patterns of top genes across cancer types
top_genes = importance_df.head(10)['gene'].values
top_gene_idx = [selected_genes.index(g) for g in top_genes]

# Get expression data for top genes
X_top = X_train[:, top_gene_idx]

# Calculate mean expression per cancer type
mean_expr = pd.DataFrame(X_top, columns=top_genes)
mean_expr['cancer'] = [class_names[i] for i in y_train]
mean_expr_grouped = mean_expr.groupby('cancer').mean()

# Heatmap
plt.figure(figsize=(12, 6))
sns.heatmap(mean_expr_grouped.T, cmap='RdBu_r', center=0, annot=True, fmt='.2f')
plt.title('Mean Expression of Top Predictive Genes by Cancer Type', fontsize=14)
plt.xlabel('Cancer Type', fontsize=12)
plt.ylabel('Gene', fontsize=12)
plt.tight_layout()
plt.savefig('../results/figures/top_genes_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. SHAP Analysis (Model Interpretability)

SHAP values explain why the model made each prediction:

In [None]:
if SHAP_AVAILABLE:
    print("Computing SHAP values (this may take a minute)...")
    
    # Use a subset of test data for faster computation
    X_explain = X_test[:50]
    
    # Create explainer for Random Forest
    explainer = shap.TreeExplainer(rf_model)
    shap_values = explainer.shap_values(X_explain)
    
    print("SHAP values computed!")
else:
    print("Skipping SHAP analysis (shap not installed)")

In [None]:
if SHAP_AVAILABLE:
    # Summary plot for each class
    fig, axes = plt.subplots(1, n_classes, figsize=(20, 6))
    
    for i, (ax, class_name) in enumerate(zip(axes, class_names)):
        plt.sca(ax)
        shap.summary_plot(shap_values[i], X_explain, 
                         feature_names=selected_genes, 
                         max_display=10, show=False)
        ax.set_title(f'{class_name}', fontsize=12)
    
    plt.tight_layout()
    plt.savefig('../results/figures/shap_summary.png', dpi=150, bbox_inches='tight')
    plt.show()

## 9. Save Best Model

In [None]:
import joblib

# Save the best model
model_path = f'../models/{best_model_name.lower().replace(" ", "_")}_best.joblib'
joblib.dump(best_model, model_path)
print(f"Best model saved to: {model_path}")

# Save all trained models
for name, model in trained_models.items():
    path = f'../models/{name.lower().replace(" ", "_").replace("(", "").replace(")", "")}.joblib'
    joblib.dump(model, path)
    
print(f"\nAll models saved to models/ directory")

## 10. Final Summary

In [None]:
print("="*70)
print("PROJECT SUMMARY: TCGA Pan-Cancer Classification")
print("="*70)
print()
print("DATASET:")
print(f"  - 801 tumor samples from 5 cancer types")
print(f"  - Original features: 20,531 genes")
print(f"  - Selected features: 1,000 genes (variance + F-test selection)")
print()
print("BEST MODEL:", best_model_name)
print(f"  - Test Accuracy: {best_results['accuracy']:.2%}")
print(f"  - Test ROC-AUC: {best_results['roc_auc']:.4f}")
print()
print("PER-CLASS PERFORMANCE:")
precision, recall, f1, _ = precision_recall_fscore_support(y_test, best_results['y_pred'])
for i, class_name in enumerate(class_names):
    print(f"  {class_name}: Precision={precision[i]:.2f}, Recall={recall[i]:.2f}, F1={f1[i]:.2f}")
print()
print("TOP 5 PREDICTIVE GENES:")
for i, row in importance_df.head(5).iterrows():
    print(f"  {i+1}. {row['gene']} (importance: {row['importance']:.4f})")
print()
print("="*70)
print("KEY INSIGHTS:")
print("  - Gene expression patterns can accurately classify cancer types")
print("  - Ensemble methods (RF, XGBoost) perform well on this data")
print("  - Top predictive genes are biologically meaningful markers")
print("  - This demonstrates the power of ML in cancer genomics")
print("="*70)

In [None]:
# Save results to CSV for the README
results_df.to_csv('../results/model_comparison.csv', index=False)
importance_df.to_csv('../results/feature_importance.csv', index=False)

print("Results saved to results/ directory")