In [None]:
"""
Evaluation and Stacking Ensemble for Tau Protein Misfolding Prediction

This notebook:
1. Loads all base model predictions
2. Trains meta-learner (stacking)
3. Evaluates ensemble performance
4. Compares all models
5. Generates visualizations and final results
"""

import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json

# Import our models and utilities
from models import (
    LogisticMetaLearner,
    XGBoostMetaLearner,
    MLPMetaLearner,
)

from utils import (
    compute_classification_metrics,
    plot_roc_curve,
    plot_multiple_roc_curves,
    plot_confusion_matrix,
    plot_accuracy_bar,
    create_metrics_comparison_table,
    export_predictions,
    save_metrics_json,
    EMBEDDINGS_DIR,
    PREDICTIONS_DIR,
    SAVED_MODELS_DIR,
    METRICS_DIR,
)

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

print("‚úÖ Imports successful!")
print(f"Working directory: {Path.cwd()}")


In [None]:
"""
Load predictions from all base models
"""

print("=" * 80)
print("STEP 1: LOADING BASE MODEL PREDICTIONS")
print("=" * 80)

# Load labels
print("\nüì¶ Loading labels...")
y_train = np.load(EMBEDDINGS_DIR / 'labels_train.npy')
y_val = np.load(EMBEDDINGS_DIR / 'labels_val.npy')
y_test = np.load(EMBEDDINGS_DIR / 'labels_test.npy')

print(f"‚úÖ Labels loaded:")
print(f"  Train: {len(y_train)}")
print(f"  Val:   {len(y_val)}")
print(f"  Test:  {len(y_test)}")

# Load Model A predictions
print("\nüì¶ Loading Model A (ProtBERT+SVM) predictions...")
train_prob_a = np.load(PREDICTIONS_DIR / 'model_a_train_probs.npy')
val_prob_a = np.load(PREDICTIONS_DIR / 'model_a_val_probs.npy')
test_prob_a = np.load(PREDICTIONS_DIR / 'model_a_test_probs.npy')
print(f"‚úÖ Model A: {train_prob_a.shape}")

# Load Model B predictions
print("\nüì¶ Loading Model B (Fine-tuned) predictions...")
train_prob_b = np.load(PREDICTIONS_DIR / 'model_b_train_probs.npy')
val_prob_b = np.load(PREDICTIONS_DIR / 'model_b_val_probs.npy')
test_prob_b = np.load(PREDICTIONS_DIR / 'model_b_test_probs.npy')
print(f"‚úÖ Model B: {train_prob_b.shape}")

# Load Model C predictions
print("\nüì¶ Loading Model C (CNN-BiLSTM) predictions...")
train_prob_c = np.load(PREDICTIONS_DIR / 'model_c_train_probs.npy')
val_prob_c = np.load(PREDICTIONS_DIR / 'model_c_val_probs.npy')
test_prob_c = np.load(PREDICTIONS_DIR / 'model_c_test_probs.npy')
print(f"‚úÖ Model C: {train_prob_c.shape}")

# Load Model D predictions
print("\nüì¶ Loading Model D (Transformer) predictions...")
train_prob_d = np.load(PREDICTIONS_DIR / 'model_d_train_probs.npy')
val_prob_d = np.load(PREDICTIONS_DIR / 'model_d_val_probs.npy')
test_prob_d = np.load(PREDICTIONS_DIR / 'model_d_test_probs.npy')
print(f"‚úÖ Model D: {train_prob_d.shape}")

print("\n‚úÖ All base model predictions loaded!")


"""
Build meta-features for stacking ensemble
"""

print("=" * 80)
print("STEP 2: BUILDING META-FEATURES")
print("=" * 80)

# Stack all probabilities to create meta-features
print("\nüî® Building meta-features from base model predictions...")

# Training meta-features
X_meta_train = np.hstack([
    train_prob_a,
    train_prob_b,
    train_prob_c,
    train_prob_d
])

# Validation meta-features
X_meta_val = np.hstack([
    val_prob_a,
    val_prob_b,
    val_prob_c,
    val_prob_d
])

# Test meta-features
X_meta_test = np.hstack([
    test_prob_a,
    test_prob_b,
    test_prob_c,
    test_prob_d
])

print(f"‚úÖ Meta-features created:")
print(f"  Train: {X_meta_train.shape}")
print(f"  Val:   {X_meta_val.shape}")
print(f"  Test:  {X_meta_test.shape}")
print(f"\nFeature breakdown:")
print(f"  Model A: 2 probabilities")
print(f"  Model B: 2 probabilities")
print(f"  Model C: 2 probabilities")
print(f"  Model D: 2 probabilities")
print(f"  Total:   8 meta-features")

# Visualize meta-features
print("\nüìä Visualizing meta-feature distributions...")

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

feature_names = [
    'Model A - Class 0', 'Model A - Class 1',
    'Model B - Class 0', 'Model B - Class 1',
    'Model C - Class 0', 'Model C - Class 1',
    'Model D - Class 0', 'Model D - Class 1'
]

for idx, (ax, feature_name) in enumerate(zip(axes, feature_names)):
    # Plot by true label
    for label in [0, 1]:
        mask = y_train == label
        ax.hist(X_meta_train[mask, idx], bins=30, alpha=0.6,
                label=f'Label {label}', edgecolor='black')
    
    ax.set_xlabel('Probability')
    ax.set_ylabel('Frequency')
    ax.set_title(feature_name, fontsize=9)
    ax.legend(fontsize=8)
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
"""
Train Logistic Regression meta-learner
"""

print("=" * 80)
print("STEP 3A: TRAINING META-LEARNER (LOGISTIC REGRESSION)")
print("=" * 80)

# Initialize logistic meta-learner
print("\nüîß Initializing Logistic Regression meta-learner...")
meta_logistic = LogisticMetaLearner(
    C=1.0,
    penalty='l2',
    max_iter=1000
)

# Train
print("\nüöÄ Training meta-learner...")
metrics_meta_log = meta_logistic.fit(
    X_meta=X_meta_train,
    y=y_train,
    X_val=X_meta_val,
    y_val=y_val
)

# Generate predictions
print("\nüîÆ Generating ensemble predictions (Logistic)...")
ensemble_train_pred_log = meta_logistic.predict(X_meta_train)
ensemble_train_prob_log = meta_logistic.predict_proba(X_meta_train)

ensemble_val_pred_log = meta_logistic.predict(X_meta_val)
ensemble_val_prob_log = meta_logistic.predict_proba(X_meta_val)

ensemble_test_pred_log = meta_logistic.predict(X_meta_test)
ensemble_test_prob_log = meta_logistic.predict_proba(X_meta_test)

# Evaluate
print("\nüìä Evaluating ensemble (Logistic)...")
train_metrics_ens_log = compute_classification_metrics(y_train, ensemble_train_pred_log, ensemble_train_prob_log)
val_metrics_ens_log = compute_classification_metrics(y_val, ensemble_val_pred_log, ensemble_val_prob_log)
test_metrics_ens_log = compute_classification_metrics(y_test, ensemble_test_pred_log, ensemble_test_prob_log)

print(f"\n‚úÖ Ensemble (Logistic) Results:")
print(f"  Train accuracy: {train_metrics_ens_log['accuracy']:.4f}")
print(f"  Val accuracy:   {val_metrics_ens_log['accuracy']:.4f}")
print(f"  Test accuracy:  {test_metrics_ens_log['accuracy']:.4f}")
print(f"  Test ROC-AUC:   {test_metrics_ens_log['roc_auc']:.4f}")

# Feature importance
print("\nüîç Feature importance (meta-learner coefficients):")
feature_importance_log = meta_logistic.get_feature_importance()
for idx, (name, importance) in enumerate(zip(feature_names, feature_importance_log)):
    print(f"  {name}: {importance:.4f}")

# Save meta-learner
meta_logistic.save(SAVED_MODELS_DIR / 'meta_learner_logistic.pkl')
print(f"\nüíæ Meta-learner saved")


In [None]:
"""
Train XGBoost meta-learner
"""

print("=" * 80)
print("STEP 3B: TRAINING META-LEARNER (XGBOOST)")
print("=" * 80)

# Initialize XGBoost meta-learner
print("\nüîß Initializing XGBoost meta-learner...")
meta_xgb = XGBoostMetaLearner(
    n_estimators=100,
    max_depth=3,
    learning_rate=0.1,
    subsample=0.8,
    colsample_bytree=0.8,
    use_gpu=False
)

# Train
print("\nüöÄ Training XGBoost meta-learner...")
metrics_meta_xgb = meta_xgb.fit(
    X_meta=X_meta_train,
    y=y_train,
    X_val=X_meta_val,
    y_val=y_val,
    early_stopping_rounds=10,
    verbose=False
)

# Generate predictions
print("\nüîÆ Generating ensemble predictions (XGBoost)...")
ensemble_train_pred_xgb = meta_xgb.predict(X_meta_train)
ensemble_train_prob_xgb = meta_xgb.predict_proba(X_meta_train)

ensemble_val_pred_xgb = meta_xgb.predict(X_meta_val)
ensemble_val_prob_xgb = meta_xgb.predict_proba(X_meta_val)

ensemble_test_pred_xgb = meta_xgb.predict(X_meta_test)
ensemble_test_prob_xgb = meta_xgb.predict_proba(X_meta_test)

# Evaluate
print("\nüìä Evaluating ensemble (XGBoost)...")
train_metrics_ens_xgb = compute_classification_metrics(y_train, ensemble_train_pred_xgb, ensemble_train_prob_xgb)
val_metrics_ens_xgb = compute_classification_metrics(y_val, ensemble_val_pred_xgb, ensemble_val_prob_xgb)
test_metrics_ens_xgb = compute_classification_metrics(y_test, ensemble_test_pred_xgb, ensemble_test_prob_xgb)

print(f"\n‚úÖ Ensemble (XGBoost) Results:")
print(f"  Train accuracy: {train_metrics_ens_xgb['accuracy']:.4f}")
print(f"  Val accuracy:   {val_metrics_ens_xgb['accuracy']:.4f}")
print(f"  Test accuracy:  {test_metrics_ens_xgb['accuracy']:.4f}")
print(f"  Test ROC-AUC:   {test_metrics_ens_xgb['roc_auc']:.4f}")

# Feature importance
print("\nüîç Feature importance (XGBoost):")
feature_importance_xgb = meta_xgb.get_feature_importance()
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': feature_importance_xgb
}).sort_values('Importance', ascending=False)

print(importance_df.to_string(index=False))

# Visualize feature importance
plt.figure(figsize=(10, 6))
plt.barh(importance_df['Feature'], importance_df['Importance'], color='steelblue')
plt.xlabel('Importance')
plt.title('XGBoost Meta-Learner Feature Importance')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

# Save meta-learner
meta_xgb.save(SAVED_MODELS_DIR / 'meta_learner_xgboost')
print(f"\nüíæ XGBoost meta-learner saved")


In [None]:
"""
Comprehensive comparison of all models
"""

print("=" * 80)
print("STEP 4: COMPREHENSIVE MODEL COMPARISON")
print("=" * 80)

# Calculate metrics for base models on test set
print("\nüìä Computing metrics for base models...")

# Model A
pred_a = (test_prob_a[:, 1] > 0.5).astype(int)
metrics_a = compute_classification_metrics(y_test, pred_a, test_prob_a)

# Model B
pred_b = (test_prob_b[:, 1] > 0.5).astype(int)
metrics_b = compute_classification_metrics(y_test, pred_b, test_prob_b)

# Model C
pred_c = (test_prob_c[:, 1] > 0.5).astype(int)
metrics_c = compute_classification_metrics(y_test, pred_c, test_prob_c)

# Model D
pred_d = (test_prob_d[:, 1] > 0.5).astype(int)
metrics_d = compute_classification_metrics(y_test, pred_d, test_prob_d)

# Create comparison dataframe
comparison_data = {
    'Model': [
        'Model A\n(ProtBERT+SVM)',
        'Model B\n(Fine-tuned)',
        'Model C\n(CNN-BiLSTM)',
        'Model D\n(Transformer)',
        'Ensemble\n(Logistic)',
        'Ensemble\n(XGBoost)'
    ],
    'Accuracy': [
        metrics_a['accuracy'],
        metrics_b['accuracy'],
        metrics_c['accuracy'],
        metrics_d['accuracy'],
        test_metrics_ens_log['accuracy'],
        test_metrics_ens_xgb['accuracy']
    ],
    'Precision': [
        metrics_a['precision'],
        metrics_b['precision'],
        metrics_c['precision'],
        metrics_d['precision'],
        test_metrics_ens_log['precision'],
        test_metrics_ens_xgb['precision']
    ],
    'Recall': [
        metrics_a['recall'],
        metrics_b['recall'],
        metrics_c['recall'],
        metrics_d['recall'],
        test_metrics_ens_log['recall'],
        test_metrics_ens_xgb['recall']
    ],
    'F1-Score': [
        metrics_a['f1_score'],
        metrics_b['f1_score'],
        metrics_c['f1_score'],
        metrics_d['f1_score'],
        test_metrics_ens_log['f1_score'],
        test_metrics_ens_xgb['f1_score']
    ],
    'ROC-AUC': [
        metrics_a['roc_auc'],
        metrics_b['roc_auc'],
        metrics_c['roc_auc'],
        metrics_d['roc_auc'],
        test_metrics_ens_log['roc_auc'],
        test_metrics_ens_xgb['roc_auc']
    ]
}

df_comparison = pd.DataFrame(comparison_data)

print("\nüìã Test Set Performance Comparison:")
print(df_comparison.to_string(index=False))

# Find best model
best_idx = df_comparison['Accuracy'].idxmax()
best_model = df_comparison.loc[best_idx, 'Model']
best_acc = df_comparison.loc[best_idx, 'Accuracy']

print(f"\nüèÜ Best Model: {best_model}")
print(f"   Accuracy: {best_acc:.4f}")

# Save comparison
df_comparison.to_csv(METRICS_DIR / 'all_models_comparison.csv', index=False)
print(f"\nüíæ Comparison saved to: {METRICS_DIR / 'all_models_comparison.csv'}")


In [None]:
"""
Create comprehensive visualizations
"""

print("=" * 80)
print("STEP 5: VISUALIZATIONS")
print("=" * 80)

# 1. Accuracy comparison bar chart
print("\nüìä Creating accuracy comparison chart...")

fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Accuracy
axes[0, 0].barh(df_comparison['Model'], df_comparison['Accuracy'],
                color=['steelblue']*4 + ['orange', 'red'])
axes[0, 0].set_xlabel('Accuracy')
axes[0, 0].set_title('Test Accuracy Comparison', fontweight='bold')
axes[0, 0].set_xlim([0, 1])
axes[0, 0].grid(alpha=0.3, axis='x')

# Precision & Recall
x = np.arange(len(df_comparison))
width = 0.35
axes[0, 1].bar(x - width/2, df_comparison['Precision'], width, label='Precision', color='skyblue')
axes[0, 1].bar(x + width/2, df_comparison['Recall'], width, label='Recall', color='lightcoral')
axes[0, 1].set_ylabel('Score')
axes[0, 1].set_title('Precision vs Recall', fontweight='bold')
axes[0, 1].set_xticks(x)
axes[0, 1].set_xticklabels(df_comparison['Model'], rotation=45, ha='right')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3, axis='y')
axes[0, 1].set_ylim([0, 1])

# F1-Score
axes[1, 0].barh(df_comparison['Model'], df_comparison['F1-Score'],
                color=['steelblue']*4 + ['orange', 'red'])
axes[1, 0].set_xlabel('F1-Score')
axes[1, 0].set_title('F1-Score Comparison', fontweight='bold')
axes[1, 0].set_xlim([0, 1])
axes[1, 0].grid(alpha=0.3, axis='x')

# ROC-AUC
axes[1, 1].barh(df_comparison['Model'], df_comparison['ROC-AUC'],
                color=['steelblue']*4 + ['orange', 'red'])
axes[1, 1].set_xlabel('ROC-AUC')
axes[1, 1].set_title('ROC-AUC Comparison', fontweight='bold')
axes[1, 1].set_xlim([0, 1])
axes[1, 1].grid(alpha=0.3, axis='x')

plt.tight_layout()
plt.savefig(METRICS_DIR / 'model_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Saved: {METRICS_DIR / 'model_comparison.png'}")


In [None]:
"""
Plot ROC curves for all models
"""

print("\nüìä Creating ROC curves...")

# Prepare probabilities dict
probs_dict = {
    'Model A': test_prob_a,
    'Model B': test_prob_b,
    'Model C': test_prob_c,
    'Model D': test_prob_d,
    'Ensemble (Log)': ensemble_test_prob_log,
    'Ensemble (XGB)': ensemble_test_prob_xgb
}

# Plot multiple ROC curves
auc_scores = plot_multiple_roc_curves(
    y_true=y_test,
    y_probs_dict=probs_dict,
    title='ROC Curves: All Models',
    save_path=METRICS_DIR / 'roc_curves_all_models.png',
    show=True
)

print("\nüìä ROC-AUC Scores:")
for model, score in auc_scores.items():
    print(f"  {model}: {score:.4f}")

print(f"\n‚úÖ Saved: {METRICS_DIR / 'roc_curves_all_models.png'}")


In [None]:
"""
Plot confusion matrices for best models
"""

print("\nüìä Creating confusion matrices...")

fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.flatten()

models_to_plot = [
    ('Model A', pred_a),
    ('Model B', pred_b),
    ('Model C', pred_c),
    ('Model D', pred_d),
    ('Ensemble (Logistic)', ensemble_test_pred_log),
    ('Ensemble (XGBoost)', ensemble_test_pred_xgb)
]

for idx, (model_name, predictions) in enumerate(models_to_plot):
    ax = axes[idx]
    
    # Compute confusion matrix
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(y_test, predictions)
    
    # Plot
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                xticklabels=['Normal', 'Misfolding'],
                yticklabels=['Normal', 'Misfolding'])
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')
    ax.set_title(f'{model_name}', fontweight='bold')

plt.tight_layout()
plt.savefig(METRICS_DIR / 'confusion_matrices.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Saved: {METRICS_DIR / 'confusion_matrices.png'}")


In [None]:
"""
Export final predictions with protein IDs
"""

print("=" * 80)
print("STEP 6: EXPORTING FINAL PREDICTIONS")
print("=" * 80)

# Load protein IDs
print("\nüì¶ Loading protein IDs...")
protein_ids_test = pd.read_csv(EMBEDDINGS_DIR / 'protein_ids_test.csv')['protein_id'].values

# Export ensemble predictions (XGBoost - best performer)
print("\nüíæ Exporting ensemble predictions...")
df_predictions = export_predictions(
    protein_ids=protein_ids_test,
    y_true=y_test,
    y_pred=ensemble_test_pred_xgb,
    y_prob=ensemble_test_prob_xgb,
    output_path=PREDICTIONS_DIR / 'final_ensemble_predictions.csv'
)

print("\nüìã Sample predictions:")
print(df_predictions.head(10))

print(f"\n‚úÖ Predictions exported: {PREDICTIONS_DIR / 'final_ensemble_predictions.csv'}")

# Save all metrics
print("\nüíæ Saving metrics to JSON...")
all_metrics = {
    'model_a': metrics_a,
    'model_b': metrics_b,
    'model_c': metrics_c,
    'model_d': metrics_d,
    'ensemble_logistic': test_metrics_ens_log,
    'ensemble_xgboost': test_metrics_ens_xgb
}

save_metrics_json(all_metrics, METRICS_DIR / 'final_metrics.json')
print(f"‚úÖ Metrics saved: {METRICS_DIR / 'final_metrics.json'}")


In [None]:
"""
Final summary and conclusions
"""

print("=" * 80)
print("‚úÖ EVALUATION COMPLETE!")
print("=" * 80)

print("\nüéØ Project Summary:")
print(f"  Total samples: {len(y_train) + len(y_val) + len(y_test)}")
print(f"  Train/Val/Test: {len(y_train)}/{len(y_val)}/{len(y_test)}")

print("\nüèÜ Best Model Performance:")
print(f"  Model: {best_model}")
print(f"  Test Accuracy: {best_acc:.4f}")
print(f"  Test ROC-AUC: {df_comparison.loc[best_idx, 'ROC-AUC']:.4f}")
print(f"  Test F1-Score: {df_comparison.loc[best_idx, 'F1-Score']:.4f}")

print("\nüìä Improvement from Base Models:")
base_models_max_acc = df_comparison.iloc[:4]['Accuracy'].max()
ensemble_acc = df_comparison.iloc[5]['Accuracy']  # XGBoost
improvement = (ensemble_acc - base_models_max_acc) / base_models_max_acc * 100

print(f"  Best base model accuracy: {base_models_max_acc:.4f}")
print(f"  Ensemble accuracy: {ensemble_acc:.4f}")
print(f"  Improvement: {improvement:+.2f}%")

print("\nüíæ Generated Files:")
print(f"  Models: {SAVED_MODELS_DIR}")
print(f"  Predictions: {PREDICTIONS_DIR}")
print(f"  Metrics: {METRICS_DIR}")
print(f"  Visualizations: {METRICS_DIR}")

print("\nüìÅ Key Output Files:")
output_files = [
    'all_models_comparison.csv',
    'final_metrics.json',
    'final_ensemble_predictions.csv',
    'model_comparison.png',
    'roc_curves_all_models.png',
    'confusion_matrices.png'
]

for f in output_files:
    print(f"  ‚úÖ {f}")

print("\nüéâ Tau Protein Misfolding Prediction Project Complete!")
print("=" * 80)
