In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Entrenamiento del Modelo de Clasificación\n",
    "\n",
    "Este notebook entrena el modelo CNN usando los datos procesados."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from src.models.classifier import DocumentClassifier\n",
    "from src.features.extractor import FeatureExtractor\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "import json\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Cargar datos procesados y etiquetas\n",
    "with open('../data/labels.json', 'r') as f:\n",
    "    labels = json.load(f)\n",
    "\n",
    "# Preparar datos para entrenamiento\n",
    "X = []\n",
    "y = []\n",
    "\n",
    "for pdf_name, pages in labels.items():\n",
    "    for page in range(pages['total_pages']):\n",
    "        img_path = f\"../data/processed/{pdf_name}/page_{page}.png\"\n",
    "        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)\n",
    "        img = cv2.resize(img, (224, 224))\n",
    "        img = img / 255.0\n",
    "        X.append(img[..., np.newaxis])\n",
    "        y.append(1 if page in pages['target_pages'] else 0)\n",
    "\n",
    "X = np.array(X)\n",
    "y = np.array(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Dividir datos en entrenamiento y validación\n",
    "X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "\n",
    "# Crear y entrenar modelo\n",
    "classifier = DocumentClassifier()\n",
    "classifier.compile_model()\n",
    "\n",
    "history = classifier.train(\n",
    "    X_train, y_train,\n",
    "    validation_data=(X_val, y_val),\n",
    "    epochs=20\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Visualizar resultados\n",
    "plt.figure(figsize=(12, 4))\n",
    "\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(history.history['loss'], label='Training Loss')\n",
    "plt.plot(history.history['val_loss'], label='Validation Loss')\n",
    "plt.title('Model Loss')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss')\n",
    "plt.legend()\n",
    "\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(history.history['accuracy'], label='Training Accuracy')\n",
    "plt.plot(history.history['val_accuracy'], label='Validation Accuracy')\n",
    "plt.title('Model Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ]
}