In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RL Agent Training\n",
    "Train Q-learning agent with HMM oracle guidance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import pickle\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "from pathlib import Path\n",
    "\n",
    "sys.path.append('..')\n",
    "\n",
    "from src.env.hangman_env import HangmanEnv\n",
    "from src.hmm.oracle import HMMOracle\n",
    "from src.hmm.emissions import EmissionBuilder\n",
    "from src.rl.q_learning import QLearningAgent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Data and Create Oracle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load preprocessed data\n",
    "with open('../data/processed/words_by_length.pkl', 'rb') as f:\n",
    "    words_by_length = pickle.load(f)\n",
    "\n",
    "# Load HMM emissions\n",
    "emissions = EmissionBuilder.load('../models/hmm/emissions.pkl')\n",
    "\n",
    "# Create oracle\n",
    "oracle = HMMOracle(emissions, words_by_length)\n",
    "\n",
    "# Prepare training words\n",
    "all_words = []\n",
    "for words in words_by_length.values():\n",
    "    all_words.extend(words)\n",
    "\n",
    "print(f\"Training vocabulary: {len(all_words)} words\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Initialize Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = QLearningAgent(\n",
    "    learning_rate=0.1,\n",
    "    discount_factor=0.95,\n",
    "    epsilon=0.3,\n",
    "    epsilon_decay=0.9995,\n",
    "    epsilon_min=0.01\n",
    ")\n",
    "\n",
    "print(f\"Agent initialized with epsilon={agent.epsilon}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Training Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_episodes = 5000\n",
    "episode_rewards = []\n",
    "episode_wins = []\n",
    "win_rate_history = []\n",
    "\n",
    "for episode in tqdm(range(num_episodes)):\n",
    "    # Sample random word\n",
    "    word = random.choice(all_words)\n",
    "    \n",
    "    # Create environment\n",
    "    env = HangmanEnv(word, max_lives=6)\n",
    "    state = env.reset()\n",
    "    \n",
    "    total_reward = 0\n",
    "    done = False\n",
    "    \n",
    "    while not done:\n",
    "        # Get HMM probabilities\n",
    "        hmm_probs = oracle.get_letter_probs(state['mask'], state['guessed_letters'])\n",
    "        \n",
    "        # Select action\n",
    "        valid_actions = env.get_valid_actions()\n",
    "        action = agent.select_action(state, hmm_probs, valid_actions, training=True)\n",
    "        \n",
    "        # Take step\n",
    "        next_state, reward, done, info = env.step(action)\n",
    "        \n",
    "        # Get next HMM probabilities\n",
    "        next_hmm_probs = oracle.get_letter_probs(\n",
    "            next_state['mask'], \n",
    "            next_state['guessed_letters']\n",
    "        )\n",
    "        \n",
    "        # Update agent\n",
    "        agent.update(state, action, reward, next_state, done, hmm_probs, next_hmm_probs)\n",
    "        \n",
    "        state = next_state\n",
    "        total_reward += reward\n",
    "    \n",
    "    # Decay epsilon\n",
    "    agent.decay_epsilon()\n",
    "    \n",
    "    # Track metrics\n",
    "    episode_rewards.append(total_reward)\n",
    "    episode_wins.append(1 if info.get('win', False) else 0)\n",
    "    \n",
    "    # Compute rolling win rate\n",
    "    if episode >= 99:\n",
    "        win_rate = sum(episode_wins[-100:]) / 100\n",
    "        win_rate_history.append(win_rate)\n",
    "    \n",
    "    # Log progress\n",
    "    if (episode + 1) % 500 == 0:\n",
    "        recent_wins = sum(episode_wins[-500:])\n",
    "        recent_reward = sum(episode_rewards[-500:]) / 500\n",
    "        print(f\"\\nEpisode {episode+1}: Win Rate={recent_wins/500:.3f}, \"\n",
    "              f\"Avg Reward={recent_reward:.2f}, Epsilon={agent.epsilon:.4f}\")\n",
    "\n",
    "print(\"\\nTraining complete!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Plot Training Curves"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 1, figsize=(12, 10))\n",
    "\n",
    "# Reward curve\n",
    "window = 100\n",
    "smoothed_rewards = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')\n",
    "\n",
    "axes[0].plot(smoothed_rewards, linewidth=2)\n",
    "axes[0].set_xlabel('Episode')\n",
    "axes[0].set_ylabel('Average Reward')\n",
    "axes[0].set_title('Training Reward Curve (100-episode moving average)')\n",
    "axes[0].grid(alpha=0.3)\n",
    "\n",
    "# Win rate curve\n",
    "axes[1].plot(win_rate_history, linewidth=2, color='green')\n",
    "axes[1].axhline(y=0.9, color='r', linestyle='--', label='90% target')\n",
    "axes[1].set_xlabel('Episode')\n",
    "axes[1].set_ylabel('Win Rate')\n",
    "axes[1].set_title('Rolling Win Rate (100-episode window)')\n",
    "axes[1].legend()\n",
    "axes[1].grid(alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(f\"\\nFinal 1000 episodes win rate: {sum(episode_wins[-1000:])/1000:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Save Trained Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save agent\n",
    "Path('../models/rl').mkdir(parents=True, exist_ok=True)\n",
    "agent.save('../models/rl/q_table_final.pkl')\n",
    "\n",
    "# Save training history\n",
    "import json\n",
    "history = {\n",
    "    'rewards': episode_rewards,\n",
    "    'wins': episode_wins,\n",
    "    'win_rate_history': win_rate_history\n",
    "}\n",
    "\n",
    "with open('../models/rl/training_history.json', 'w') as f:\n",
    "    json.dump(history, f)\n",
    "\n",
    "print(\"Agent and training history saved!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
