# Model Training and Evaluation

This notebook demonstrates the complete model training pipeline:
1. Load preprocessed SMS data
2. Split into training and testing sets
3. Train three classifiers:
   - Logistic Regression
   - Multinomial Naïve Bayes
   - Linear SVM
4. Evaluate and compare model performance
5. Display metrics and confusion matrices
6. Save trained models

## 1. Import Required Libraries

In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, auc
import warnings
warnings.filterwarnings('ignore')

# Import preprocessing and model functions
from src.preprocessing import preprocess_pipeline
from src.models import (
    train_models,
    save_all_models,
    compare_models
)

print("Libraries imported successfully!")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")

## 2. Load and Preprocess Data

In [None]:
# Load and preprocess data
csv_path = '../data/sample_sms_spam.csv'

print("Running preprocessing pipeline...")
result = preprocess_pipeline(csv_path, method='tfidf', min_df=1)

X = result['X']
y = result['y']
feature_names = result['feature_names']
metadata = result['metadata']

print("\nPreprocessing Complete!")
print(f"Feature matrix shape: {X.shape}")
print(f"Number of features: {len(feature_names)}")
print(f"Data split: {sum(y == 0)} ham, {sum(y == 1)} spam")
print(f"Matrix sparsity: {(1 - X.nnz / (X.shape[0] * X.shape[1])):.2%}")

## 3. Train All Models

In [None]:
# Train all models
print("Starting model training...\n")
training_result = train_models(X, y, test_size=0.2, random_state=42)

models = training_result['models']
evaluations = training_result['evaluations']
data_info = training_result['data_info']

print("\nModel Training Complete!")
print(f"Training samples: {data_info['train_size']}")
print(f"Test samples: {data_info['test_size']}")
print(f"Training split:")
print(f"  - Ham: {data_info['train_ham']}")
print(f"  - Spam: {data_info['train_spam']}")
print(f"Test split:")
print(f"  - Ham: {data_info['test_ham']}")
print(f"  - Spam: {data_info['test_spam']}")

## 4. Model Performance Comparison

In [None]:
# Compare models
compare_models(evaluations)

## 5. Detailed Evaluation Metrics

In [None]:
# Display detailed metrics for each model
for model_name, eval_dict in evaluations.items():
    print(f"\n{'='*80}")
    print(f"{model_name.upper().replace('_', ' ')}")
    print(f"{'='*80}")
    
    print(f"\nAccuracy:  {eval_dict['accuracy']:.4f}")
    print(f"Precision: {eval_dict['precision']:.4f}")
    print(f"Recall:    {eval_dict['recall']:.4f}")
    print(f"F1-Score:  {eval_dict['f1_score']:.4f}")
    if eval_dict['roc_auc']:
        print(f"ROC-AUC:   {eval_dict['roc_auc']:.4f}")
    
    print(f"\nClassification Report:")
    print(eval_dict['classification_report'])

## 6. Confusion Matrices

In [None]:
# Visualize confusion matrices
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

model_names = ['Logistic Regression', 'Multinomial Naïve Bayes', 'Linear SVM']
model_keys = ['logistic_regression', 'naive_bayes', 'svm']

for idx, (ax, model_name, model_key) in enumerate(zip(axes, model_names, model_keys)):
    conf_matrix = evaluations[model_key]['confusion_matrix']
    
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Ham', 'Spam'], yticklabels=['Ham', 'Spam'],
                cbar=False, ax=ax)
    ax.set_title(model_name, fontsize=12, fontweight='bold')
    ax.set_ylabel('True Label')
    ax.set_xlabel('Predicted Label')

plt.tight_layout()
plt.show()

print("Confusion Matrices:")
print("-" * 80)
for model_name, model_key in zip(model_names, model_keys):
    print(f"\n{model_name}:")
    print(evaluations[model_key]['confusion_matrix'])

## 7. Metrics Comparison Bar Charts

In [None]:
# Create comparison dataframe
metrics_data = []
for model_key, model_name in zip(model_keys, model_names):
    eval_dict = evaluations[model_key]
    metrics_data.append({
        'Model': model_name,
        'Accuracy': eval_dict['accuracy'],
        'Precision': eval_dict['precision'],
        'Recall': eval_dict['recall'],
        'F1-Score': eval_dict['f1_score']
    })

metrics_df = pd.DataFrame(metrics_data)

# Create bar chart
fig, ax = plt.subplots(figsize=(12, 5))

x = np.arange(len(metrics_df))
width = 0.2

ax.bar(x - 1.5*width, metrics_df['Accuracy'], width, label='Accuracy')
ax.bar(x - 0.5*width, metrics_df['Precision'], width, label='Precision')
ax.bar(x + 0.5*width, metrics_df['Recall'], width, label='Recall')
ax.bar(x + 1.5*width, metrics_df['F1-Score'], width, label='F1-Score')

ax.set_ylabel('Score', fontsize=12)
ax.set_title('Model Performance Comparison', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(metrics_df['Model'])
ax.legend()
ax.set_ylim([0, 1.05])
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\nMetrics Dataframe:")
print(metrics_df.to_string(index=False))

## 8. Save Models

In [None]:
# Save all trained models
print("Saving trained models...")
saved_paths = save_all_models(models, models_dir='../models')

print("\nModels saved successfully!")
print("\nSaved model paths:")
for model_name, path in saved_paths.items():
    print(f"  {model_name}: {path}")

## 9. Best Model Selection

In [None]:
# Find best performing model
print("\nBest Models by Metric:")
print("=" * 80)

best_accuracy = max([(k, v['accuracy']) for k, v in evaluations.items()], key=lambda x: x[1])
best_f1 = max([(k, v['f1_score']) for k, v in evaluations.items()], key=lambda x: x[1])

print(f"\nBest Accuracy: {best_accuracy[0].replace('_', ' ').title()}")
print(f"  Score: {best_accuracy[1]:.4f}")

print(f"\nBest F1-Score: {best_f1[0].replace('_', ' ').title()}")
print(f"  Score: {best_f1[1]:.4f}")

print("\n" + "=" * 80)
print("Model training and evaluation complete!")
print("\nNext steps:")
print("1. Use best model for predictions")
print("2. Fine-tune hyperparameters")
print("3. Deploy to Streamlit application")
print("4. Monitor model performance in production")