# Model Training & Selection

This notebook focuses on:
- Training multiple machine learning models on the engineered dataset
- Performing basic performance sanity checks
- Selecting the best-performing model
- Saving the trained model for further evaluation and deployment 

## 1. Import Required Libraries

In [None]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier

from sklearn.metrics import accuracy_score, recall_score, roc_auc_score

import joblib
import warnings
warnings.filterwarnings("ignore")

## 2. Load Engineered Dataset

In [None]:
df = pd.read_csv("../data/processed/processed_churn_data.csv")

## 3. Dataset Verification

In [None]:
df.shape
df.dtypes.value_counts()

Ensures correct dimensionality and numeric feature types

## 4. Separate Features and Target

In [None]:
X = df.drop(columns=['CHURN'])
y = df['CHURN']

## 5. Train–Test Split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    stratify=y,
    random_state=42
)

Stratification ensures balanced churn distribution

## 6. Baseline Model: Logistic Regression
Logistic Regression is used as a baseline linear model to validate preprocessing and establish a reference point.

In [None]:
lr = LogisticRegression(max_iter=1000)
lr.fit(X_train, y_train)

lr_pred = lr.predict(X_test)
lr_prob = lr.predict_proba(X_test)[:, 1]

### Sanity Metrics (Logistic Regression)

In [None]:
print("Logistic Regression")
print("Accuracy:", accuracy_score(y_test, lr_pred))
print("Recall:", recall_score(y_test, lr_pred))
print("ROC-AUC:", roc_auc_score(y_test, lr_prob))

## 7. Decision Tree Model
Decision Tree captures non-linear patterns and interactions between features.

In [None]:
dt = DecisionTreeClassifier(
    max_depth=6,
    min_samples_split=50,
    random_state=42
)

dt.fit(X_train, y_train)

dt_pred = dt.predict(X_test)
dt_prob = dt.predict_proba(X_test)[:, 1]

### Sanity Metrics (Decision Tree)

In [None]:
print("Decision Tree")
print("Accuracy:", accuracy_score(y_test, dt_pred))
print("Recall:", recall_score(y_test, dt_pred))
print("ROC-AUC:", roc_auc_score(y_test, dt_prob))

## 8. Random Forest Model
Random Forest is an ensemble model that improves generalization and robustness by combining multiple decision trees.

In [None]:
rf = RandomForestClassifier(
    n_estimators=200,
    max_depth=10,
    min_samples_split=30,
    class_weight='balanced',
    random_state=42
)

rf.fit(X_train, y_train)

rf_pred = rf.predict(X_test)
rf_prob = rf.predict_proba(X_test)[:, 1]

### Sanity Metrics (Random Forest)

In [None]:
print("Random Forest")
print("Accuracy:", accuracy_score(y_test, rf_pred))
print("Recall:", recall_score(y_test, rf_pred))
print("ROC-AUC:", roc_auc_score(y_test, rf_prob))

## 9. Model Comparison Summary

Three classification models were trained and evaluated to predict customer churn.  
Performance was compared using Accuracy, Recall (Churn = 1), and ROC–AUC.

### Evaluation Metrics
- **Accuracy**: Overall correctness of predictions
- **Recall (Churn)**: Ability to correctly identify churn customers (business-critical)
- **ROC–AUC**: Model’s ability to distinguish churn vs non-churn across thresholds

---

### Performance Comparison

| Model               | Accuracy | Recall (Churn) | ROC–AUC |
|--------------------|----------|----------------|---------|
| Logistic Regression | 0.645    | 0.704          | 0.705   |
| Decision Tree       | 0.735    | 0.706          | 0.819   |
| Random Forest       | 0.883    | 0.871          | 0.943   |

---

### Key Observations

- **Logistic Regression** served as a strong baseline, achieving moderate recall and ROC–AUC, validating the effectiveness of preprocessing and feature engineering.
- **Decision Tree** improved discriminative power, capturing non-linear patterns and significantly increasing ROC–AUC.
- **Random Forest** delivered the best performance across all metrics, with:
  - High recall (87%) — effectively identifying most churn customers
  - Excellent ROC–AUC (0.94) — strong class separation
  - Balanced handling of class imbalance using `class_weight='balanced'`

---

### Model Selection Decision

Given the business objective of minimizing missed churn customers, **Random Forest** was selected as the final model due to its superior recall and overall predictive performance.

This model demonstrates strong generalization capability and is well-suited for deployment in churn prediction scenarios.


## 10. Save Best Model

In [None]:
joblib.dump(rf, "../models/churn_random_forest.pkl")

## 11. Save Train/Test Splits

In [None]:
X_test.to_csv("../data/processed/X_test.csv", index=False)
y_test.to_csv("../data/processed/y_test.csv", index=False)

## 12. Conclusion

- Multiple models were trained on engineered features
- Random Forest showed superior churn detection capability
- The trained model has been saved for in-depth evaluation in the next notebook
