In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Model Training and Evaluation\n",
    "\n",
    "**Goal:** Train a classification model on the preprocessed Kepler data and evaluate its performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import joblib\n",
    "import os\n",
    "\n",
    "# Adjust path to import from src\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "\n",
    "from src import config\n",
    "from src import data_loader\n",
    "from src import preprocessor\n",
    "from src import model_trainer\n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier # Direct import for experimentation if needed\n",
    "from sklearn.linear_model import LogisticRegression # For a baseline model\n",
    "from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc, accuracy_score\n",
    "\n",
    "pd.set_option('display.max_columns', 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 Load and Preprocess Data\n",
    "\n",
    "We'll use the functions from our `src` modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_df = data_loader.load_data(config.RAW_DATA_FILE)\n",
    "X, y, scaler = None, None, None # Initialize\n",
    "\n",
    "if raw_df is not None:\n",
    "    X, y, scaler = preprocessor.preprocess_data(raw_df)\n",
    "else:\n",
    "    print(\"Failed to load data.\")\n",
    "\n",
    "if X is not None and y is not None:\n",
    "    print(\"\\nData preprocessing successful.\")\n",
    "    print(f\"Shape of X: {X.shape}\")\n",
    "    print(f\"Shape of y: {y.shape}\")\n",
    "    print(f\"Target distribution:\\n{y.value_counts(normalize=True)}\")\n",
    "else:\n",
    "    print(\"Data preprocessing failed. Cannot proceed with model training.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 Split Data into Training and Test Sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = None, None, None, None\n",
    "if X is not None and y is not None:\n",
    "    X_train, X_test, y_train, y_test = preprocessor.split_data(X, y)\n",
    "    if X_train is not None:\n",
    "        print(f\"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}\")\n",
    "        print(f\"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}\")\n",
    "else:\n",
    "    print(\"Cannot split data as X or y is None.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.3 Model Training\n",
    "\n",
    "We will train two models for comparison:\n",
    "1.  **Logistic Regression** (as a simple baseline)\n",
    "2.  **Random Forest Classifier** (as configured in `src.model_trainer`)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3.1 Logistic Regression (Baseline)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_reg_model = None\n",
    "if X_train is not None and y_train is not None:\n",
    "    print(\"Training Logistic Regression model...\")\n",
    "    log_reg_model = LogisticRegression(random_state=config.RANDOM_STATE, max_iter=1000, class_weight='balanced')\n",
    "    log_reg_model.fit(X_train, y_train)\n",
    "    print(\"Logistic Regression training complete.\")\n",
    "else:\n",
    "    print(\"Training data not available for Logistic Regression.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3.2 Random Forest Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rf_model = None\n",
    "if X_train is not None and y_train is not None:\n",
    "    print(\"\\nTraining Random Forest model using model_trainer.py...\")\n",
    "    # This function also saves the model as per config\n",
    "    rf_model = model_trainer.train_model(X_train, y_train)\n",
    "    if rf_model:\n",
    "        print(\"Random Forest training complete and model saved.\")\n",
    "else:\n",
    "    print(\"Training data not available for Random Forest.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.4 Model Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4.1 Evaluate Logistic Regression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if log_reg_model and X_test is not None and y_test is not None:\n",
    "    print(\"\\n--- Logistic Regression Evaluation ---\")\n",
    "    y_pred_lr = log_reg_model.predict(X_test)\n",
    "    \n",
    "    print(f\"Accuracy: {accuracy_score(y_test, y_pred_lr):.4f}\")\n",
    "    print(\"Classification Report:\")\n",
    "    print(classification_report(y_test, y_pred_lr, target_names=['Not Exoplanet (0)', 'Exoplanet (1)']))\n",
    "    \n",
    "    cm_lr = confusion_matrix(y_test, y_pred_lr)\n",
    "    plt.figure(figsize=(6,4))\n",
    "    sns.heatmap(cm_lr, annot=True, fmt='d', cmap='Blues', \n",
    "                xticklabels=['Not Exoplanet (0)', 'Exoplanet (1)'], \n",
    "                yticklabels=['Not Exoplanet (0)', 'Exoplanet (1)'])\n",
    "    plt.title('Logistic Regression Confusion Matrix')\n",
    "    plt.xlabel('Predicted')\n",
    "    plt.ylabel('Actual')\n",
    "    plt.show()\n",
    "else:\n",
    "    print(\"Logistic Regression model or test data not available for evaluation.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4.2 Evaluate Random Forest Classifier\n",
    "We'll use the `evaluate_model` function from `model_trainer.py` which prints metrics and saves a plot."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if rf_model and X_test is not None and y_test is not None:\n",
    "    print(\"\\n--- Random Forest Evaluation (using model_trainer.evaluate_model) ---\")\n",
    "    # This function prints the report and saves the confusion matrix plot\n",
    "    accuracy_rf, report_rf, cm_rf = model_trainer.evaluate_model(rf_model, X_test, y_test)\n",
    "    # The plot will be saved to reports/ directory, let's also display it here\n",
    "    plt.figure(figsize=(6,4))\n",
    "    sns.heatmap(cm_rf, annot=True, fmt='d', cmap='Blues', \n",
    "                xticklabels=['Not Exoplanet (0)', 'Exoplanet (1)'], \n",
    "                yticklabels=['Not Exoplanet (0)', 'Exoplanet (1)'])\n",
    "    plt.title('Random Forest Confusion Matrix (from notebook)')\n",
    "    plt.xlabel('Predicted')\n",
    "    plt.ylabel('Actual')\n",
    "    plt.show()\n",
    "else:\n",
    "    print(\"Random Forest model or test data not available for evaluation.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.5 ROC Curve and AUC\n",
    "\n",
    "Comparing both models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 6))\n",
    "\n",
    "# Logistic Regression ROC\n",
    "if log_reg_model and X_test is not None and y_test is not None:\n",
    "    y_pred_proba_lr = log_reg_model.predict_proba(X_test)[:, 1]\n",
    "    fpr_lr, tpr_lr, _ = roc_curve(y_test, y_pred_proba_lr)\n",
    "    auc_lr = auc(fpr_lr, tpr_lr)\n",
    "    plt.plot(fpr_lr, tpr_lr, label=f'Logistic Regression (AUC = {auc_lr:.2f})')\n",
    "\n",
    "# Random Forest ROC\n",
    "if rf_model and X_test is not None and y_test is not None:\n",
    "    y_pred_proba_rf = rf_model.predict_proba(X_test)[:, 1]\n",
    "    fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_proba_rf)\n",
    "    auc_rf = auc(fpr_rf, tpr_rf)\n",
    "    plt.plot(fpr_rf, tpr_rf, label=f'Random Forest (AUC = {auc_rf:.2f})')\n",
    "\n",
    "plt.plot([0, 1], [0, 1], 'k--') # Diagonal dashed line (random guess)\n",
    "plt.xlabel('False Positive Rate')\n",
    "plt.ylabel('True Positive Rate')\n",
    "plt.title('ROC Curve Comparison')\n",
    "plt.legend()\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Random Forest model generally performs significantly better than Logistic Regression on this dataset, as indicated by higher AUC and better precision/recall scores, especially for the minority class (Exoplanet)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.6 Feature Importances (for Random Forest)\n",
    "\n",
    "The `model_trainer.py` module has a function for this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if rf_model and X_train is not None:\n",
    "    print(\"\\n--- Random Forest Feature Importances (using model_trainer.get_feature_importances) ---\")\n",
    "    # This function prints top features and saves a plot\n",
    "    df_importances = model_trainer.get_feature_importances(rf_model, X_train.columns, top_n=20)\n",
    "    \n",
    "    # Display the plot here as well\n",
    "    if df_importances is not None:\n",
    "        plt.figure(figsize=(10, 8))\n",
    "        sns.barplot(x='importance', y='feature', data=df_importances, palette='viridis')\n",
    "        plt.title('Top 20 Feature Importances (Random Forest - from notebook)')\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "else:\n",
    "    print(\"Random Forest model or training features not available for feature importance analysis.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Features like `koi_fpflag_nt`, `koi_fpflag_ss`, `koi_model_snr`, `koi_prad` often show up as important. These relate to false positive flags and physical parameters of the transiting object or its signal."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.7 Conclusion & Next Steps\n",
    "\n",
    "The Random Forest model demonstrates good performance for exoplanet classification on this dataset.\n",
    "\n",
    "**Potential Next Steps:**\n",
    "1.  **Hyperparameter Tuning:** Use GridSearchCV or RandomizedSearchCV to find optimal parameters for the Random Forest model.\n",
    "2.  **Advanced Feature Engineering:** Explore creating new features from existing ones (e.g., ratios, interactions).\n",
    "3.  **Try Other Models:** Experiment with Gradient Boosting (XGBoost, LightGBM), Support Vector Machines, or Neural Networks.\n",
    "4.  **Address Class Imbalance More Explicitly:** Techniques like SMOTE (Synthetic Minority Over-sampling Technique) or ADASYN could be explored if `class_weight='balanced'` isn't sufficient.\n",
    "5.  **Error Analysis:** Deep dive into misclassified samples to understand model weaknesses.\n",
    "6.  **Cross-Validation:** Implement k-fold cross-validation for more robust performance estimation during development.\n",
    "\n",
    "The current pipeline (data loading, preprocessing, training, evaluation, model saving) is set up in the `src` directory and orchestrated by `main.py`."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}