# Model Training Notebook

This notebook trains multiple classification models (Logistic Regression, Random Forest, XGBoost) for hotel cancellation prediction.

## Objectives:
1. Load processed training and test data
2. Check for class imbalance and apply SMOTE if needed
3. Train all models with cross-validation
4. Evaluate models on test set
5. Compare models and identify best performer
6. Save all trained models
7. Visualize model comparison results

## 1. Setup and Imports

In [None]:
import sys
import os
import yaml
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

# Import custom modules
from src.data_processing.data_splitter import DataSplitter
from src.modeling.model_trainer import ModelTrainer
from src.modeling.imbalance_handler import ImbalanceHandler
from src.evaluation.model_evaluator import ModelEvaluator
from src.evaluation.model_comparator import ModelComparator
from src.modeling.model_registry import ModelRegistry
from src.utils.logger import get_logger

# Set up logging
logger = get_logger(__name__)

# Set random seeds for reproducibility
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

print("✓ All imports successful")
print(f"✓ Project root: {project_root}")

## 2. Load Configuration

In [None]:
# Load configuration file
config_path = project_root / 'config' / 'config.yaml'

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded successfully")
print(f"\nEnabled models:")
for model_name, model_config in config['models'].items():
    status = '✓' if model_config.get('enabled', False) else '✗'
    print(f"  {status} {model_name}")

print(f"\nEvaluation settings:")
print(f"  Primary metric: {config['evaluation']['primary_metric']}")
print(f"  Imbalance threshold: {config['evaluation']['imbalance_threshold']}")

## 3. Load Processed Data

In [None]:
# Initialize DataSplitter to load data
data_splitter = DataSplitter(
    test_size=config['data']['test_size'],
    random_state=config['data']['random_state']
)

# Load processed training and test data
processed_data_path = project_root / config['data']['processed_data_path']
X_train, X_test, y_train, y_test = data_splitter.load_splits(str(processed_data_path))

print("\n" + "="*60)
print("DATA LOADED SUCCESSFULLY")
print("="*60)
print(f"Training set: X_train={X_train.shape}, y_train={y_train.shape}")
print(f"Test set: X_test={X_test.shape}, y_test={y_test.shape}")
print(f"Number of features: {X_train.shape[1]}")
print(f"Total samples: {len(X_train) + len(X_test)}")

## 4. Check Class Imbalance

In [None]:
# Initialize ImbalanceHandler
imbalance_handler = ImbalanceHandler(
    imbalance_threshold=config['evaluation']['imbalance_threshold'],
    random_state=config['data']['random_state']
)

# Check class distribution
print("\n" + "="*60)
print("CLASS DISTRIBUTION ANALYSIS")
print("="*60)

# Calculate class distribution
class_counts = pd.Series(y_train).value_counts().sort_index()
class_percentages = pd.Series(y_train).value_counts(normalize=True).sort_index() * 100

print("\nOriginal training set distribution:")
for class_label in class_counts.index:
    print(f"  Class {class_label}: {class_counts[class_label]:,} samples ({class_percentages[class_label]:.2f}%)")

# Check imbalance ratio
majority_ratio = imbalance_handler.check_imbalance(y_train)
print(f"\nMajority class ratio: {majority_ratio:.3f}")
print(f"Imbalance threshold: {config['evaluation']['imbalance_threshold']}")

# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar plot
axes[0].bar(class_counts.index, class_counts.values, color=['#3498db', '#e74c3c'])
axes[0].set_xlabel('Class', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Class Distribution (Training Set)', fontsize=14, fontweight='bold')
axes[0].set_xticks([0, 1])
axes[0].set_xticklabels(['Not Cancelled (0)', 'Cancelled (1)'])
for i, (label, count) in enumerate(class_counts.items()):
    axes[0].text(i, count, f'{count:,}', ha='center', va='bottom', fontsize=11)

# Pie chart
colors = ['#3498db', '#e74c3c']
axes[1].pie(class_counts.values, labels=['Not Cancelled', 'Cancelled'], 
            autopct='%1.1f%%', colors=colors, startangle=90)
axes[1].set_title('Class Distribution Proportion', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\n{'='*60}")

## 5. Apply SMOTE if Needed

In [None]:
# Apply SMOTE if class imbalance exceeds threshold
print("\n" + "="*60)
print("HANDLING CLASS IMBALANCE")
print("="*60)

X_train_resampled, y_train_resampled = imbalance_handler.handle_imbalance(X_train, y_train)

# Check if resampling was applied
if X_train_resampled.shape[0] != X_train.shape[0]:
    print("\n✓ SMOTE applied successfully")
    print(f"  Original training samples: {X_train.shape[0]:,}")
    print(f"  Resampled training samples: {X_train_resampled.shape[0]:,}")
    print(f"  Samples added: {X_train_resampled.shape[0] - X_train.shape[0]:,}")
    
    # Show new distribution
    resampled_counts = pd.Series(y_train_resampled).value_counts().sort_index()
    resampled_percentages = pd.Series(y_train_resampled).value_counts(normalize=True).sort_index() * 100
    
    print("\nResampled training set distribution:")
    for class_label in resampled_counts.index:
        print(f"  Class {class_label}: {resampled_counts[class_label]:,} samples ({resampled_percentages[class_label]:.2f}%)")
else:
    print("\n✓ No resampling needed - classes are balanced")

print(f"\n{'='*60}")

## 6. Train All Models

In [None]:
# Initialize ModelTrainer
model_trainer = ModelTrainer(
    models_config=config['models'],
    cv_folds=5
)

print("\n" + "="*60)
print("TRAINING MODELS")
print("="*60)
print("\nThis may take several minutes...\n")

# Train all enabled models
training_results = model_trainer.train_all_models(X_train_resampled, y_train_resampled)

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

## 7. Display Cross-Validation Results

In [None]:
# Display cross-validation scores for all models
print("\n" + "="*60)
print("CROSS-VALIDATION RESULTS")
print("="*60)

cv_results_data = []

for model_name, result in training_results.items():
    if result['model'] is not None and result['cv_scores'] is not None:
        cv_scores = result['cv_scores']
        print(f"\n{model_name.upper()}:")
        print(f"  Accuracy:  {cv_scores['mean_accuracy']:.4f} (+/- {cv_scores['std_accuracy']:.4f})")
        print(f"  Precision: {cv_scores['mean_precision']:.4f} (+/- {cv_scores['std_precision']:.4f})")
        print(f"  Recall:    {cv_scores['mean_recall']:.4f} (+/- {cv_scores['std_recall']:.4f})")
        print(f"  F1-Score:  {cv_scores['mean_f1']:.4f} (+/- {cv_scores['std_f1']:.4f})")
        print(f"  ROC-AUC:   {cv_scores['mean_roc_auc']:.4f} (+/- {cv_scores['std_roc_auc']:.4f})")
        
        cv_results_data.append({
            'Model': model_name,
            'CV Accuracy': cv_scores['mean_accuracy'],
            'CV Precision': cv_scores['mean_precision'],
            'CV Recall': cv_scores['mean_recall'],
            'CV F1-Score': cv_scores['mean_f1'],
            'CV ROC-AUC': cv_scores['mean_roc_auc']
        })

# Create DataFrame for visualization
cv_results_df = pd.DataFrame(cv_results_data)

print(f"\n{'='*60}")

## 8. Visualize Cross-Validation Results

In [None]:
# Visualize CV results
if not cv_results_df.empty:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    metrics = ['CV Accuracy', 'CV Precision', 'CV Recall', 'CV F1-Score']
    colors = ['#3498db', '#e74c3c', '#2ecc71']
    
    for idx, metric in enumerate(metrics):
        ax = axes[idx // 2, idx % 2]
        
        bars = ax.bar(cv_results_df['Model'], cv_results_df[metric], color=colors[:len(cv_results_df)])
        ax.set_xlabel('Model', fontsize=12)
        ax.set_ylabel(metric, fontsize=12)
        ax.set_title(f'{metric} Comparison', fontsize=14, fontweight='bold')
        ax.set_ylim([0, 1.0])
        ax.grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.3f}', ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    plt.show()
else:
    print("No CV results to visualize")

## 9. Evaluate Models on Test Set

In [None]:
# Initialize ModelEvaluator and ModelComparator
model_evaluator = ModelEvaluator(output_dir=str(project_root / 'reports' / 'figures'))
model_comparator = ModelComparator(output_dir=str(project_root / 'reports'))

print("\n" + "="*60)
print("EVALUATING MODELS ON TEST SET")
print("="*60)
print("\nGenerating predictions and metrics...\n")

# Extract trained models
trained_models = {}
for model_name, result in training_results.items():
    if result['model'] is not None:
        trained_models[model_name] = result['model']

# Compare all models on test set
comparison_df = model_comparator.compare_models(trained_models, X_test, y_test)

print("\n✓ Evaluation complete")
print(f"\n{'='*60}")

## 10. Rank and Compare Models

In [None]:
# Rank models by F1-score (primary metric)
primary_metric = config['evaluation']['primary_metric']
ranked_df = model_comparator.rank_models(comparison_df, metric=primary_metric)

print("\n" + "="*60)
print("MODEL COMPARISON RESULTS")
print("="*60)
print(f"\nRanked by: {primary_metric}\n")
print(ranked_df.to_string(index=False, float_format=lambda x: f'{x:.4f}'))

# Generate and save comparison report
report = model_comparator.generate_comparison_report(ranked_df, save_csv=True)
print(f"\n✓ Comparison report saved to: {project_root / 'reports' / 'model_comparison.csv'}")

# Identify best model
best_model_name, best_model, best_metric_value = model_comparator.get_best_model(
    trained_models, comparison_df, metric=primary_metric
)

print(f"\n{'='*60}")
print("BEST MODEL IDENTIFIED")
print(f"{'='*60}")
print(f"Model: {best_model_name}")
print(f"{primary_metric}: {best_metric_value:.4f}")
print(f"{'='*60}")

## 11. Visualize Model Comparison

In [None]:
# Create comprehensive comparison visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

metrics_to_plot = ['accuracy', 'precision', 'recall', 'f1_score', 'roc_auc']
metric_labels = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC']
colors = ['#3498db', '#e74c3c', '#2ecc71']

for idx, (metric, label) in enumerate(zip(metrics_to_plot, metric_labels)):
    ax = axes[idx // 3, idx % 3]
    
    if metric in ranked_df.columns:
        bars = ax.bar(ranked_df['model_name'], ranked_df[metric], color=colors[:len(ranked_df)])
        ax.set_xlabel('Model', fontsize=11)
        ax.set_ylabel(label, fontsize=11)
        ax.set_title(f'{label} on Test Set', fontsize=13, fontweight='bold')
        ax.set_ylim([0, 1.0])
        ax.grid(axis='y', alpha=0.3)
        ax.tick_params(axis='x', rotation=15)
        
        # Add value labels
        for bar in bars:
            height = bar.get_height()
            if not np.isnan(height):
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{height:.3f}', ha='center', va='bottom', fontsize=9)

# Overall comparison in the last subplot
ax = axes[1, 2]
metrics_for_radar = ['accuracy', 'precision', 'recall', 'f1_score']
x_pos = np.arange(len(ranked_df))
width = 0.2

for i, metric in enumerate(metrics_for_radar):
    if metric in ranked_df.columns:
        ax.bar(x_pos + i*width, ranked_df[metric], width, 
               label=metric.replace('_', ' ').title(), alpha=0.8)

ax.set_xlabel('Model', fontsize=11)
ax.set_ylabel('Score', fontsize=11)
ax.set_title('All Metrics Comparison', fontsize=13, fontweight='bold')
ax.set_xticks(x_pos + width * 1.5)
ax.set_xticklabels(ranked_df['model_name'], rotation=15)
ax.legend(fontsize=9)
ax.grid(axis='y', alpha=0.3)
ax.set_ylim([0, 1.0])

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'figures' / 'model_comparison_all_metrics.png', 
            dpi=300, bbox_inches='tight')
plt.show()

print(f"\n✓ Comparison visualization saved")

## 12. Display Confusion Matrices

In [None]:
# Plot confusion matrices for all models
n_models = len(trained_models)
fig, axes = plt.subplots(1, n_models, figsize=(6*n_models, 5))

if n_models == 1:
    axes = [axes]

for idx, (model_name, model) in enumerate(trained_models.items()):
    y_pred = model.predict(X_test)
    
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(y_test, y_pred)
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx],
                xticklabels=['Not Cancelled', 'Cancelled'],
                yticklabels=['Not Cancelled', 'Cancelled'],
                cbar_kws={'label': 'Count'})
    
    axes[idx].set_xlabel('Predicted Label', fontsize=11)
    axes[idx].set_ylabel('True Label', fontsize=11)
    axes[idx].set_title(f'Confusion Matrix - {model_name}', fontsize=13, fontweight='bold')

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'figures' / 'confusion_matrices_all_models.png',
            dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Confusion matrices displayed")

## 13. Save All Trained Models

In [None]:
# Initialize ModelRegistry
model_registry = ModelRegistry(models_dir=str(project_root / 'models'))

print("\n" + "="*60)
print("SAVING TRAINED MODELS")
print("="*60)

# Save all trained models with metadata
for model_name, model in trained_models.items():
    # Get metrics for this model from comparison_df
    model_metrics = comparison_df[comparison_df['model_name'] == model_name].iloc[0].to_dict()
    model_metrics.pop('model_name', None)
    
    # Get hyperparameters
    hyperparameters = config['models'][model_name]['params']
    
    # Get CV scores
    cv_scores = training_results[model_name]['cv_scores']
    
    # Save model
    model_path = model_registry.save_model(
        model=model,
        model_name=model_name,
        version="1.0.0",
        metrics=model_metrics,
        hyperparameters=hyperparameters,
        additional_metadata={
            'cv_scores': cv_scores,
            'training_samples': X_train_resampled.shape[0],
            'test_samples': X_test.shape[0],
            'n_features': X_train.shape[1],
            'smote_applied': X_train_resampled.shape[0] != X_train.shape[0]
        }
    )
    
    print(f"\n✓ Saved: {model_name}")
    print(f"  Path: {model_path}")
    print(f"  F1-Score: {model_metrics.get('f1_score', 'N/A'):.4f}")

print(f"\n{'='*60}")

## 14. Save Best Model

In [None]:
# Save the best model as 'best_model.pkl' for easy access
print("\n" + "="*60)
print("SAVING BEST MODEL")
print("="*60)

# Get best model metrics
best_model_metrics = comparison_df[comparison_df['model_name'] == best_model_name].iloc[0].to_dict()
best_model_metrics.pop('model_name', None)

# Get hyperparameters
best_hyperparameters = config['models'][best_model_name]['params']

# Get CV scores
best_cv_scores = training_results[best_model_name]['cv_scores']

# Save as best_model
best_model_path = model_registry.save_model(
    model=best_model,
    model_name='best_model',
    version="1.0.0",
    metrics=best_model_metrics,
    hyperparameters=best_hyperparameters,
    additional_metadata={
        'original_model_name': best_model_name,
        'cv_scores': best_cv_scores,
        'training_samples': X_train_resampled.shape[0],
        'test_samples': X_test.shape[0],
        'n_features': X_train.shape[1],
        'smote_applied': X_train_resampled.shape[0] != X_train.shape[0]
    }
)

print(f"\n✓ Best model saved: {best_model_name}")
print(f"  Path: {best_model_path}")
print(f"  {primary_metric}: {best_metric_value:.4f}")
print(f"\n{'='*60}")

## 15. Summary and Next Steps

In [None]:
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)

print(f"\n✓ Models trained: {len(trained_models)}")
print(f"✓ Best model: {best_model_name}")
print(f"✓ Best {primary_metric}: {best_metric_value:.4f}")

# Check if minimum performance threshold is met
min_f1_threshold = 0.75
if best_metric_value >= min_f1_threshold:
    print(f"\n✓ Performance requirement MET: F1-score >= {min_f1_threshold}")
else:
    print(f"\n⚠ Performance requirement NOT MET: F1-score < {min_f1_threshold}")
    print("  Consider hyperparameter optimization in the next notebook")

print(f"\n✓ All models saved to: {project_root / 'models'}")
print(f"✓ Comparison report saved to: {project_root / 'reports' / 'model_comparison.csv'}")
print(f"✓ Visualizations saved to: {project_root / 'reports' / 'figures'}")

print("\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("\n1. Review model performance metrics and visualizations")
print("2. Proceed to hyperparameter optimization (notebook 04)")
print("3. Use the best model for predictions")
print("\n" + "="*60)