In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Policy Gradient Demonstration\n",
    "This notebook provides a detailed walkthrough of our Policy Gradient implementation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required libraries\n",
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "from agents.policy_agent import PolicyAgent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize environment and agent\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "state_size = env.observation_space.shape[0]\n",
    "action_size = env.action_space.n\n",
    "agent = PolicyAgent(state_size, action_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training parameters\n",
    "num_episodes = 1000\n",
    "gamma = 0.99\n",
    "learning_rate = 0.01\n",
    "\n",
    "# Training loop\n",
    "rewards = []\n",
    "optimizer = optim.Adam(agent.policy_network.parameters(), lr=learning_rate)\n",
    "\n",
    "for episode in range(num_episodes):\n",
    "    state, _ = env.reset()\n",
    "    log_probs = []\n",
    "    rewards_episode = []\n",
    "    done = False\n",
    "    \n",
    "    while not done:\n",
    "        state_tensor = torch.FloatTensor(state).unsqueeze(0)\n",
    "        action, log_prob = agent.act(state_tensor)\n",
    "        next_state, reward, terminated, truncated, _ = env.step(action)\n",
    "        done = terminated or truncated\n",
    "        \n",
    "        log_probs.append(log_prob)\n",
    "        rewards_episode.append(reward)\n",
    "        state = next_state\n",
    "    \n",
    "    # Calculate discounted rewards\n",
    "    discounted_rewards = []\n",
    "    cumulative_reward = 0\n",
    "    for r in reversed(rewards_episode):\n",
    "        cumulative_reward = r + gamma * cumulative_reward\n",
    "        discounted_rewards.insert(0, cumulative_reward)\n",
    "    \n",
    "    # Normalize discounted rewards\n",
    "    discounted_rewards = torch.FloatTensor(discounted_rewards)\n",
    "    discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-9)\n",
    "    \n",
    "    # Calculate policy loss\n",
    "    policy_loss = []\n",
    "    for log_prob, reward in zip(log_probs, discounted_rewards):\n",
    "        policy_loss.append(-log_prob * reward)\n",
    "    policy_loss = torch.cat(policy_loss).sum()\n",
    "    \n",
    "    # Update policy\n",
    "    optimizer.zero_grad()\n",
    "    policy_loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    total_reward = sum(rewards_episode)\n",
    "    rewards.append(total_reward)\n",
    "    \n",
    "    if episode % 10 == 0:\n",
    "        print(f\"Episode: {episode + 1}, Total Reward: {total_reward}\")\n",
    "\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot training results\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(rewards)\n",
    "plt.title(\"Policy Gradient Training Progress\")\n",
    "plt.xlabel(\"Episode\")\n",
    "plt.ylabel(\"Total Reward\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save trained model\n",
    "torch.save(agent.policy_network.state_dict(), \"policy_gradient_model.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the trained agent\n",
    "agent.policy_network.load_state_dict(torch.load(\"policy_gradient_model.pth\"))\n",
    "evaluate(env_name=\"CartPole-v1\", model_path=\"policy_gradient_model.pth\", num_episodes=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the agent's performance\n",
    "visualize(env_name=\"CartPole-v1\", model_path=\"policy_gradient_model.pth\", num_episodes=5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}