In [None]:
# Model Training for Protein Function Classifier

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os

sys.path.append('..')

# Force reload of our modules (important after updates!)
import importlib
import src.model
import src.features
importlib.reload(src.model)
importlib.reload(src.features)

from src.model import (
    load_data, prepare_data, get_models, 
    evaluate_model, train_all_models, compare_models,
    print_classification_report, get_feature_importance,
    save_model, EC_NAMES
)
from src.features import get_feature_names

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

plt.style.use('seaborn-v0_8-whitegrid')
print("Imports complete!")

In [None]:
# Load preprocessed features
X, y = load_data(
    features_path='../data/processed/X_features.npy',
    labels_path='../data/processed/y_labels.npy'
)

# Prepare train/test split (includes label encoding for XGBoost)
data = prepare_data(X, y, test_size=0.2, random_state=42, scale=True)

X_train = data['X_train']
X_test = data['X_test']
y_train = data['y_train']
y_test = data['y_test']
scaler = data['scaler']
label_encoder = data['label_encoder']

print(f"\nClass distribution in training set:")
unique, counts = np.unique(y_train, return_counts=True)
for encoded_label, count in zip(unique, counts):
    original_label = label_encoder.inverse_transform([encoded_label])[0]
    print(f"  EC {original_label} ({EC_NAMES[original_label]}): {count}")

In [None]:
# Train and evaluate all models
results = train_all_models(X_train, X_test, y_train, y_test)

# Display comparison table
print("\n" + "=" * 60)
print("MODEL COMPARISON")
print("=" * 60)
comparison_df = compare_models(results)
print(comparison_df.to_string(index=False))

In [None]:
# Visualize model comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Extract metrics for plotting
model_names = [r['model_name'] for r in results]
test_acc = [r['test_accuracy'] for r in results]
f1_scores = [r['f1_macro'] for r in results]
train_acc = [r['train_accuracy'] for r in results]
# Accuracy comparison
ax1 = axes[0]
x = np.arange(len(model_names))
width = 0.35
bars1 = ax1.bar(x - width/2, train_acc, width, label='Train', color='skyblue')
bars2 = ax1.bar(x + width/2, test_acc, width, label='Test', color='coral')
ax1.set_ylabel('Accuracy')
ax1.set_title('Model Accuracy Comparison')
ax1.set_xticks(x)
ax1.set_xticklabels(model_names, rotation=15, ha='right')
ax1.legend()
ax1.set_ylim(0, 1.0)
# Add value labels
for bar in bars1:
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=9)
for bar in bars2:
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=9)
# F1 Score comparison
ax2 = axes[1]
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']
bars = ax2.bar(model_names, f1_scores, color=colors)
ax2.set_ylabel('F1 Score (Macro)')
ax2.set_title('Model F1 Score Comparison')
ax2.set_xticklabels(model_names, rotation=15, ha='right')
ax2.set_ylim(0, 1.0)
for bar in bars:
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{bar.get_height():.3f}', ha='center', va='bottom', fontsize=10)
plt.tight_layout()
plt.savefig('../figures/model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Find best model based on test accuracy
best_result = max(results, key=lambda x: x['test_accuracy'])
best_model = best_result['model']
best_name = best_result['model_name']

print(f"BEST MODEL: {best_name}")
print(f"Test Accuracy:  {best_result['test_accuracy']:.4f}")
print(f"F1 Macro:       {best_result['f1_macro']:.4f}")
print(f"Precision:      {best_result['precision_macro']:.4f}")
print(f"Recall:         {best_result['recall_macro']:.4f}")

# Detailed classification report
print_classification_report(y_test, best_result['y_pred'], label_encoder)

In [None]:
# Plot confusion matrix for best model
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Get original EC class labels for display
original_labels = label_encoder.inverse_transform(sorted(np.unique(y_test)))
display_labels = [f'EC {i}' for i in original_labels]

# Raw counts
cm = confusion_matrix(y_test, best_result['y_pred'])
disp1 = ConfusionMatrixDisplay(cm, display_labels=display_labels)
disp1.plot(ax=axes[0], cmap='Blues', values_format='d')
axes[0].set_title(f'Confusion Matrix - {best_name}\n(Raw Counts)', fontsize=12)

# Normalized
cm_norm = confusion_matrix(y_test, best_result['y_pred'], normalize='true')
disp2 = ConfusionMatrixDisplay(cm_norm, display_labels=display_labels)
disp2.plot(ax=axes[1], cmap='Blues', values_format='.2f')
axes[1].set_title(f'Confusion Matrix - {best_name}\n(Normalized)', fontsize=12)

plt.tight_layout()
plt.savefig('../figures/confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

# Per-class accuracy
print("\nPer-Class Accuracy:")
print("-" * 40)
for i, (label, acc) in enumerate(zip(original_labels, cm_norm.diagonal())):
    print(f"  EC {label} ({EC_NAMES[label]}): {acc:.2%}")
    

In [None]:
# Get feature importance (works for tree-based models)
feature_names = get_feature_names()

if hasattr(best_model, 'feature_importances_'):
    importance_df = get_feature_importance(best_model, feature_names, top_n=25)
    
    # Plot
    plt.figure(figsize=(12, 8))
    plt.barh(range(len(importance_df)), importance_df['importance'].values, color='steelblue')
    plt.yticks(range(len(importance_df)), importance_df['feature'].values)
    plt.xlabel('Feature Importance')
    plt.ylabel('Feature')
    plt.title(f'Top 25 Most Important Features - {best_name}')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig('../figures/feature_importance.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print top features
    print("\nTop 25 Most Important Features:")
    print("-" * 50)
    for idx, row in importance_df.iterrows():
        print(f"  {row['feature']:30s} {row['importance']:.4f}")
else:
    print(f"{best_name} does not have feature_importances_ attribute")

In [None]:
# Save the best model, scaler, and label encoder
save_model(
    best_model, 
    scaler,
    label_encoder,
    model_path='../models/best_model.pkl',
    scaler_path='../models/scaler.pkl',
    encoder_path='../models/label_encoder.pkl'
)

# Also save the comparison results
comparison_df.to_csv('../models/model_comparison.csv', index=False)
print(" Saved model comparison to models/model_comparison.csv")


print(f"""
Summary:
- Best Model: {best_name}
- Test Accuracy: {best_result['test_accuracy']:.2%}
- F1 Score (Macro): {best_result['f1_macro']:.2%}

Saved files:
- models/best_model.pkl
- models/scaler.pkl
- models/label_encoder.pkl
- models/model_comparison.csv
- figures/model_comparison.png
- figures/confusion_matrix.png
- figures/feature_importance.png

Next step: Build the Streamlit app!
""")