In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Development - Customer Churn Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.metrics import classification_report, confusion_matrix\n",
    "import xgboost as xgb\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import joblib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Load the processed data\n",
    "X_train = pd.read_csv('../data/processed/X_train.csv')\n",
    "X_test = pd.read_csv('../data/processed/X_test.csv')\n",
    "y_train = pd.read_csv('../data/processed/y_train.csv')\n",
    "y_test = pd.read_csv('../data/processed/y_test.csv')\n",
    "\n",
    "# Convert y values to series from dataframe\n",
    "y_train = y_train.iloc[:, 0]\n",
    "y_test = y_test.iloc[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Initialize models\n",
    "models = {\n",
    "    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),\n",
    "    'XGBoost': xgb.XGBClassifier(random_state=42),\n",
    "    'Logistic Regression': LogisticRegression(random_state=42)\n",
    "}\n",
    "\n",
    "# Train and evaluate models\n",
    "results = {}\n",
    "for name, model in models.items():\n",
    "    print(f\"\\nTraining {name}...\")\n",
    "    \n",
    "    # Train the model\n",
    "    model.fit(X_train, y_train)\n",
    "    \n",
    "    # Make predictions\n",
    "    y_pred = model.predict(X_test)\n",
    "    \n",
    "    # Store results\n",
    "    results[name] = {\n",
    "        'predictions': y_pred,\n",
    "        'report': classification_report(y_test, y_pred, output_dict=True)\n",
    "    }\n",
    "    \n",
    "    print(f\"\\nResults for {name}:\")\n",
    "    print(classification_report(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Visualize model comparison\n",
    "metrics = ['precision', 'recall', 'f1-score']\n",
    "model_comparison = pd.DataFrame({\n",
    "    name: {metric: results[name]['report']['weighted avg'][metric] \n",
    "           for metric in metrics}\n",
    "    for name in models.keys()\n",
    "}).T\n",
    "\n",
    "plt.figure(figsize=(10, 6))\n",
    "model_comparison.plot(kind='bar')\n",
    "plt.title('Model Comparison')\n",
    "plt.ylabel('Score')\n",
    "plt.xticks(rotation=45)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Plot feature importance for Random Forest\n",
    "rf_model = models['Random Forest']\n",
    "feature_importance = pd.DataFrame({\n",
    "    'feature': X_train.columns,\n",
    "    'importance': rf_model.feature_importances_\n",
    "}).sort_values('importance', ascending=False)\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "sns.barplot(data=feature_importance, x='importance', y='feature')\n",
    "plt.title('Feature Importance (Random Forest)')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Save the best model\n",
    "best_model_name = model_comparison['f1-score'].idxmax()\n",
    "best_model = models[best_model_name]\n",
    "\n",
    "# Create models directory if it doesn't exist\n",
    "import os\n",
    "os.makedirs('../models', exist_ok=True)\n",
    "\n",
    "# Save the model\n",
    "joblib.dump(best_model, '../models/best_model.pkl')\n",
    "print(f\"Best model ({best_model_name}) saved to '../models/best_model.pkl'\")"
   ]
  }
 ]
}