In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RL Training Analysis\n",
    "\n",
    "Analyze PPO and SAC training performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../python')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from pathlib import Path\n",
    "\n",
    "from stable_baselines3 import PPO, SAC\n",
    "from env.market_env import MarketMakerEnv\n",
    "from backtesting.backtest import Backtester\n",
    "from backtesting.metrics import PerformanceMetrics\n",
    "\n",
    "sns.set_style('whitegrid')\n",
    "plt.rcParams['figure.figsize'] = (14, 6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Trained Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find model paths\n",
    "logs_dir = Path('../logs/tensorboard')\n",
    "\n",
    "ppo_dirs = list(logs_dir.glob('ppo_*/best_model/best_model.zip'))\n",
    "sac_dirs = list(logs_dir.glob('sac_*/best_model/best_model.zip'))\n",
    "\n",
    "if ppo_dirs:\n",
    "    ppo_path = sorted(ppo_dirs)[-1]  # Most recent\n",
    "    print(f\"PPO Model: {ppo_path}\")\n",
    "    ppo_model = PPO.load(ppo_path)\n",
    "else:\n",
    "    print(\"No PPO model found\")\n",
    "    ppo_model = None\n",
    "\n",
    "if sac_dirs:\n",
    "    sac_path = sorted(sac_dirs)[-1]  # Most recent\n",
    "    print(f\"SAC Model: {sac_path}\")\n",
    "    sac_model = SAC.load(sac_path)\n",
    "else:\n",
    "    print(\"No SAC model found\")\n",
    "    sac_model = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Backtest Both Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = MarketMakerEnv()\n",
    "backtester = Backtester(env, n_episodes=50, verbose=True)\n",
    "\n",
    "results = {}\n",
    "\n",
    "if ppo_model:\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"TESTING PPO\")\n",
    "    print(\"=\"*60)\n",
    "    results['PPO'] = backtester.run_agent(ppo_model, \"PPO Agent\")\n",
    "\n",
    "if sac_model:\n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(\"TESTING SAC\")\n",
    "    print(\"=\"*60)\n",
    "    results['SAC'] = backtester.run_agent(sac_model, \"SAC Agent\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Performance Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create comparison DataFrame\n",
    "comparison = pd.DataFrame([\n",
    "    {\n",
    "        'Model': name,\n",
    "        'Mean PnL': res['mean_pnl'],\n",
    "        'Std PnL': res['std_pnl'],\n",
    "        'Sharpe': res['sharpe_ratio'],\n",
    "        'Win Rate': res['win_rate'],\n",
    "        'Max PnL': res['max_pnl'],\n",
    "        'Min PnL': res['min_pnl']\n",
    "    }\n",
    "    for name, res in results.items()\n",
    "])\n",
    "\n",
    "print(\"\\n\" + \"=\"*60)\n",
    "print(\"PERFORMANCE COMPARISON\")\n",
    "print(\"=\"*60)\n",
    "print(comparison.to_string(index=False))\n",
    "print()\n",
    "\n",
    "best_model = comparison.loc[comparison['Mean PnL'].idxmax(), 'Model']\n",
    "print(f\"üèÜ Best Model: {best_model}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize comparison\n",
    "fig = PerformanceMetrics.plot_performance(\n",
    "    {name: res['pnls'] for name, res in results.items()},\n",
    "    title=\"RL Models Comparison\"\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Detailed Metrics for Best Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_results = results[best_model]\n",
    "metrics = PerformanceMetrics.get_all_metrics(best_results['pnls'])\n",
    "\n",
    "PerformanceMetrics.print_metrics(metrics, best_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Learning Behavior Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test on single episode to see behavior\n",
    "if sac_model:\n",
    "    env_test = MarketMakerEnv()\n",
    "    obs, _ = env_test.reset()\n",
    "    \n",
    "    trajectory = {\n",
    "        'prices': [],\n",
    "        'inventories': [],\n",
    "        'pnls': [],\n",
    "        'actions': []\n",
    "    }\n",
    "    \n",
    "    done = False\n",
    "    while not done:\n",
    "        action, _ = sac_model.predict(obs, deterministic=True)\n",
    "        obs, reward, terminated, truncated, info = env_test.step(action)\n",
    "        done = terminated or truncated\n",
    "        \n",
    "        trajectory['prices'].append(info['mid_price'])\n",
    "        trajectory['inventories'].append(info['inventory'])\n",
    "        trajectory['pnls'].append(info['total_pnl'])\n",
    "        trajectory['actions'].append(action)\n",
    "    \n",
    "    # Plot trajectory\n",
    "    fig, axes = plt.subplots(3, 1, figsize=(14, 10))\n",
    "    \n",
    "    axes[0].plot(trajectory['prices'], linewidth=2)\n",
    "    axes[0].set_title('Price Evolution', fontweight='bold')\n",
    "    axes[0].set_ylabel('Price ($)')\n",
    "    axes[0].grid(True, alpha=0.3)\n",
    "    \n",
    "    axes[1].plot(trajectory['inventories'], linewidth=2, color='orange')\n",
    "    axes[1].axhline(y=0, color='r', linestyle='--', alpha=0.5)\n",
    "    axes[1].set_title('Inventory Management', fontweight='bold')\n",
    "    axes[1].set_ylabel('Inventory')\n",
    "    axes[1].grid(True, alpha=0.3)\n",
    "    \n",
    "    axes[2].plot(trajectory['pnls'], linewidth=2, color='green')\n",
    "    axes[2].axhline(y=0, color='r', linestyle='--', alpha=0.5)\n",
    "    axes[2].set_title('PnL Evolution', fontweight='bold')\n",
    "    axes[2].set_xlabel('Step')\n",
    "    axes[2].set_ylabel('PnL ($)')\n",
    "    axes[2].grid(True, alpha=0.3)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    print(f\"Final PnL: ${trajectory['pnls'][-1]:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- RL agents successfully trained\n",
    "- SAC typically more stable than PPO\n",
    "- Both learn inventory management\n",
    "- Models adapt to market dynamics"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}