In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-Agent Reinforcement Learning Demonstration\n",
    "This notebook demonstrates a simple multi-agent environment where agents learn to collaborate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required libraries\n",
    "import gymnasium as gym\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from agents.multi_agent import MultiAgentSystem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize environment and agents\n",
    "env = gym.make(\"Pendulum-v1\")\n",
    "state_size = env.observation_space.shape[0]\n",
    "action_size = env.action_space.shape[0]\n",
    "num_agents = 2\n",
    "agent_system = MultiAgentSystem(state_size, action_size, num_agents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training parameters\n",
    "num_episodes = 1000\n",
    "batch_size = 64\n",
    "gamma = 0.99\n",
    "tau = 0.01\n",
    "noise_scale = 1.0\n",
    "noise_decay = 0.995\n",
    "\n",
    "# Training loop\n",
    "rewards = []\n",
    "\n",
    "for episode in range(num_episodes):\n",
    "    state, _ = env.reset()\n",
    "    episode_rewards = np.zeros(num_agents)\n",
    "    done = False\n",
    "    \n",
    "    while not done:\n",
    "        actions = agent_system.act(state, noise_scale=noise_scale)\n",
    "        next_state, reward, terminated, truncated, _ = env.step(actions)\n",
    "        done = terminated or truncated\n",
    "        \n",
    "        agent_system.remember(state, actions, reward, next_state, done)\n",
    "        agent_system.train(batch_size, gamma, tau)\n",
    "        \n",
    "        state = next_state\n",
    "        episode_rewards += reward\n",
    "    \n",
    "    rewards.append(np.mean(episode_rewards))\n",
    "    noise_scale = max(0.01, noise_scale * noise_decay)\n",
    "    \n",
    "    if episode % 10 == 0:\n",
    "        print(f\"Episode: {episode + 1}, Average Reward: {np.mean(episode_rewards):.2f}, Noise Scale: {noise_scale:.2f}\")\n",
    "\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot training results\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(rewards)\n",
    "plt.title(\"Multi-Agent Training Progress\")\n",
    "plt.xlabel(\"Episode\")\n",
    "plt.ylabel(\"Average Reward\")\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save trained models\n",
    "for i, agent in enumerate(agent_system.agents):\n",
    "    torch.save(agent.actor.state_dict(), f\"multi_agent_actor_{i}.pth\")\n",
    "    torch.save(agent.critic.state_dict(), f\"multi_agent_critic_{i}.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the trained agents\n",
    "for i, agent in enumerate(agent_system.agents):\n",
    "    agent.actor.load_state_dict(torch.load(f\"multi_agent_actor_{i}.pth\"))\n",
    "    agent.critic.load_state_dict(torch.load(f\"multi_agent_critic_{i}.pth\"))\n",
    "\n",
    "evaluate(env_name=\"Pendulum-v1\", model_path=\"multi_agent_actor_0.pth\", num_episodes=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the agents' performance\n",
    "visualize(env_name=\"Pendulum-v1\", model_path=\"multi_agent_actor_0.pth\", num_episodes=5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}