In [None]:
# 04-model-training.ipynb
# Model Training & SHAP Explainability

"""
## 04 - Model Training and Explainability

This notebook:
1. Loads `phishing_graph_features.csv` with all features & labels.
2. Splits into train/test.
3. Trains a RandomForestClassifier for phishing detection.
4. Evaluates performance (accuracy, precision, recall, F1, ROC-AUC).
5. Uses SHAP to generate global and local explanations.
6. Saves model and explanation visuals.
"""

#%%
# 1. Imports and Config
import os
import joblib
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, roc_curve
import shap

# Paths
FEATURE_CSV = os.path.join('..','data','processed','phishing_graph_features.csv')
MODEL_PATH = os.path.join('..','models','phishing_rf_model.pkl')
SHAP_SUMMARY_PLOT = os.path.join('..','models','phishing_shap_summary.png')

#%%
# 2. Load Data
df = pd.read_csv(FEATURE_CSV)
X = df.drop(columns=['label','url','sender_domain','domain'])  # drop non-numeric
y = df['label']

#%%
# 3. Split Data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y)

#%%
# 4. Train Random Forest
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
joblib.dump(clf, MODEL_PATH)
print("Model saved to", MODEL_PATH)

#%%
# 5. Evaluate Performance
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred))
roc_auc = roc_auc_score(y_test, clf.predict_proba(X_test)[:,1])
fpr, tpr, _ = roc_curve(y_test, clf.predict_proba(X_test)[:,1])
print(f"ROC AUC: {roc_auc:.3f}")

plt.figure()
plt.plot(fpr, tpr, label=f'RF (AUC = {roc_auc:.2f})')
plt.plot([0,1], [0,1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Phishing Detection')
plt.legend()
plt.show()

#%%
# 6. SHAP Explainability
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(X_test)
# Global summary plot
shap.summary_plot(shap_values[1], X_test, show=False)
plt.savefig(SHAP_SUMMARY_PLOT, bbox_inches='tight')
print("SHAP summary plot saved to", SHAP_SUMMARY_PLOT)

#%%
# 7. Local Explanation Example
i = 0  # index of test instance
data_row = X_test.iloc[[i]]
shap.force_plot(explainer.expected_value[1], shap_values[1][i], data_row, matplotlib=True)
plt.show()
