In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Q-Learning Demonstration\n",
    "This notebook provides a detailed walkthrough of our Q-Learning implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required libraries\n",
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from agents.q_agent import QAgent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize environment and agent\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "state_size = env.observation_space.shape[0]\n",
    "action_size = env.action_space.n\n",
    "agent = QAgent(state_size, action_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training parameters\n",
    "num_episodes = 1000\n",
    "batch_size = 64\n",
    "gamma = 0.99\n",
    "epsilon_start = 1.0\n",
    "epsilon_end = 0.01\n",
    "epsilon_decay = 0.995\n",
    "\n",
    "# Training loop\n",
    "rewards = []\n",
    "epsilon = epsilon_start\n",
    "\n",
    "for episode in range(num_episodes):\n",
    "    state, _ = env.reset()\n",
    "    total_reward = 0\n",
    "    done = False\n",
    "    \n",
    "    while not done:\n",
    "        action = agent.act(state, epsilon)\n",
    "        next_state, reward, terminated, truncated, _ = env.step(action)\n",
    "        done = terminated or truncated\n",
    "        \n",
    "        agent.remember(state, action, reward, next_state, done)\n",
    "        agent.train(batch_size, gamma)\n",
    "        \n",
    "        state = next_state\n",
    "        total_reward += reward\n",
    "    \n",
    "    rewards.append(total_reward)\n",
    "    epsilon = max(epsilon_end, epsilon * epsilon_decay)\n",
    "    \n",
    "    if episode % 10 == 0:\n",
    "        print(f\"Episode: {episode + 1}, Total Reward: {total_reward}, Epsilon: {epsilon:.2f}\")\n",
    "\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot training results\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(rewards)\n",
    "plt.title(\"Q-Learning Training Progress\")\n",
    "plt.xlabel(\"Episode\")\n",
    "plt.ylabel(\"Total Reward\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save trained model\n",
    "torch.save(agent.q_network.state_dict(), \"q_learning_model.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the trained agent\n",
    "agent.q_network.load_state_dict(torch.load(\"q_learning_model.pth\"))\n",
    "evaluate(env_name=\"CartPole-v1\", model_path=\"q_learning_model.pth\", num_episodes=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the agent's performance\n",
    "visualize(env_name=\"CartPole-v1\", model_path=\"q_learning_model.pth\", num_episodes=5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}