In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import math
import random
import time
import numpy as np
import matplotlib.pyplot as plt

import gymnasium as gym
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.vec_env import DummyVecEnv

import torch
import torch.nn as nn

# For rendering (optional):
try:
    import pygame
except ImportError:
    pygame = None

###############################################################################
# CUSTOM POLICY: GammaMlpPolicy
###############################################################################
from stable_baselines3.common.policies import ActorCriticPolicy

class GammaMlpPolicy(ActorCriticPolicy):
    def __init__(self, *args, **kwargs):
        super(GammaMlpPolicy, self).__init__(*args, **kwargs)
    
    def forward(self, obs, deterministic=False):
        # Use the built-in feature extractor
        latent = self.extract_features(obs)
        latent_pi, latent_vf = self.mlp_extractor(latent)
        raw_mean = self.action_net(latent_pi)
        # Squash the mean so that it lies in [-1, 1]
        mean_actions = torch.tanh(raw_mean)
        log_std = torch.clamp(self.log_std, -20, 2)
        std = torch.exp(log_std)
        # Create a Normal distribution with the squashed mean and computed std
        distribution = torch.distributions.Normal(mean_actions, std)
        if deterministic:
            actions = mean_actions
        else:
            actions = distribution.rsample()  # reparameterized sample
            # Apply tanh to ensure the final actions are in [-1, 1]
            actions = torch.tanh(actions)
        # Compute log probability (note: no tanh correction term here)
        log_prob = distribution.log_prob(actions).sum(dim=1, keepdim=True)
        value = self.value_net(latent_vf)
        return actions, value, log_prob

###############################################################################
# CONSTANTS & UTILS
###############################################################################
FULL_VIEW_SIZE = (1200, 800)
SCALING_FACTOR_X = FULL_VIEW_SIZE[0] / 600.0
SCALING_FACTOR_Y = FULL_VIEW_SIZE[1] / 600.0
SCALING_FACTOR   = (SCALING_FACTOR_X + SCALING_FACTOR_Y) / 2

DOT_RADIUS       = int(15 * SCALING_FACTOR)
TARGET_RADIUS    = int(10 * SCALING_FACTOR)
OBSTACLE_RADIUS  = int(10 * SCALING_FACTOR)
COLLISION_BUFFER = int(5  * SCALING_FACTOR)
MAX_SPEED        = 3 * SCALING_FACTOR
NOISE_MAGNITUDE  = 0.5
RENDER_FPS       = 30

START_POS = np.array([FULL_VIEW_SIZE[0]//2, FULL_VIEW_SIZE[1]//2], dtype=np.float32)

WHITE = (255, 255, 255)
GRAY  = (128, 128, 128)
YELLOW= (255, 255, 0)
BLACK = (0, 0, 0)

def distance(a, b):
    return math.hypot(a[0] - b[0], a[1] - b[1])

def check_line_collision(start, end, center, radius):
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    fx = center[0] - start[0]
    fy = center[1] - start[1]
    l2 = dx*dx + dy*dy
    if l2 < 1e-9:
        return distance(start, center) <= radius
    t = max(0, min(1, (fx*dx + fy*dy) / l2))
    px = start[0] + t*dx
    py = start[1] + t*dy
    return distance((px, py), center) <= radius

def line_collision(pos, new_pos, obstacles):
    for obs in obstacles:
        if check_line_collision(pos, new_pos, obs, OBSTACLE_RADIUS + COLLISION_BUFFER):
            return True
    return False

def inside_obstacle(pos, obstacles):
    for obs in obstacles:
        if distance(pos, obs) <= (OBSTACLE_RADIUS + DOT_RADIUS):
            return True
    return False

def potential_field_dir(pos, goal, obstacles):
    """
    Returns a normalized direction from pos to goal,
    plus repulsion from obstacles.
    """
    gx = goal[0] - pos[0]
    gy = goal[1] - pos[1]
    dg = math.hypot(gx, gy)
    if dg < 1e-6:
        return np.zeros(2, dtype=np.float32)
    att = np.array([gx / dg, gy / dg], dtype=np.float32)

    repulse_x = 0.0
    repulse_y = 0.0
    repulsion_radius = 23.0 * SCALING_FACTOR
    repulsion_gain   = 30000.0

    for obs in obstacles:
        dx = pos[0] - obs[0]
        dy = pos[1] - obs[1]
        dobs = math.hypot(dx, dy)
        if dobs < 1e-9:
            continue
        if dobs < repulsion_radius:
            pushx    = dx / dobs
            pushy    = dy / dobs
            strength = repulsion_gain / (dobs**2)
            repulse_x += pushx * strength
            repulse_y += pushy * strength

    px = att[0] + repulse_x
    py = att[1] + repulse_y
    mg = math.hypot(px, py)
    if mg < 1e-9:
        return np.zeros(2, dtype=np.float32)
    return np.array([px / mg, py / mg], dtype=np.float32)

###############################################################################
# METRICS CALLBACK
###############################################################################
class MetricsCallback(BaseCallback):
    """
    Logs training metrics and saves various plots after training.
    Plots include:
      - Episode Reward
      - Average Gamma per Episode
      - Gamma Std per Episode
      - Total Model Loss
      - Critic (Value) Loss
      - Actor (Policy) Loss
      - Entropy Loss
      - Episode Length
    """
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards   = []
        self.episode_lengths   = []
        self.episode_mean_gammas = []
        self.episode_std_gammas  = []
        self.total_reward = 0.0
        self.ep_length    = 0
        self.current_episode_gammas = []  # store gamma each step
        self.n_collisions = 0
        self.n_episodes   = 0
        # Loss tracking:
        self.losses         = []
        self.value_losses   = []
        self.policy_losses  = []
        self.entropy_losses = []
        self.training_steps = []
        self.n_updates      = 0

    def _on_training_start(self):
        self.episode_rewards.clear()
        self.episode_lengths.clear()
        self.episode_mean_gammas.clear()
        self.episode_std_gammas.clear()
        self.total_reward = 0.0
        self.ep_length    = 0
        self.n_collisions = 0
        self.n_episodes   = 0
        self.current_episode_gammas.clear()

    def _on_step(self) -> bool:
        actions = self.locals['actions']
        rewards = self.locals['rewards']
        dones   = self.locals['dones']
        infos   = self.locals['infos']
        
        # Compute gamma from the action (mapping [-1,1] -> [0,1])
        gamma_val = 0.5 * (actions[0] + 1.0)
        self.current_episode_gammas.append(gamma_val)
        r = float(rewards[0])
        self.total_reward += r
        self.ep_length    += 1

        if dones[0]:
            self.episode_rewards.append(self.total_reward)
            self.episode_lengths.append(self.ep_length)
            mean_gamma = np.mean(self.current_episode_gammas)
            std_gamma  = np.std(self.current_episode_gammas)
            self.episode_mean_gammas.append(mean_gamma)
            self.episode_std_gammas.append(std_gamma)
            self.total_reward = 0.0
            self.ep_length    = 0
            self.current_episode_gammas.clear()
            self.n_episodes  += 1
            if 'terminal_reason' in infos[0] and infos[0]['terminal_reason'] == 'collision':
                self.n_collisions += 1

        return True

    def _on_rollout_end(self):
        self.n_updates += 1
        logs = self.model.logger.name_to_value or {}
        if "train/loss" in logs:
            self.losses.append(logs["train/loss"])
            self.training_steps.append(self.n_updates)
        if "train/value_loss" in logs:
            self.value_losses.append(logs["train/value_loss"])
        if "train/policy_gradient_loss" in logs:
            self.policy_losses.append(logs["train/policy_gradient_loss"])
        if "train/entropy_loss" in logs:
            self.entropy_losses.append(logs["train/entropy_loss"])

    def _moving_average(self, data, window=10):
        if len(data) < window:
            return np.array(data)
        return np.convolve(data, np.ones(window)/window, mode='valid')

    def save_metrics(self, save_dir="training_metrics"):
        os.makedirs(save_dir, exist_ok=True)
        # 1) Episode Rewards
        if self.episode_rewards:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_rewards, label="Episode Reward", alpha=0.6)
            ma_rewards = self._moving_average(self.episode_rewards, 10)
            if len(ma_rewards):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_rewards)), 
                         ma_rewards, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Reward")
            plt.title("Episode Rewards")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "episode_rewards.png"))
            plt.close()
        # 2) Average Gamma per Episode
        if self.episode_mean_gammas:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_mean_gammas, label="Average Gamma", alpha=0.6)
            ma_gamma = self._moving_average(self.episode_mean_gammas, 10)
            if len(ma_gamma):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_gamma)), 
                         ma_gamma, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Gamma (avg)")
            plt.title("Average Gamma per Episode")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "average_gamma.png"))
            plt.close()
        # 3) Gamma Std per Episode
        if self.episode_std_gammas:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_std_gammas, label="Gamma Std", alpha=0.6)
            ma_gstd = self._moving_average(self.episode_std_gammas, 10)
            if len(ma_gstd):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_gstd)), 
                         ma_gstd, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Gamma Std")
            plt.title("Gamma Std per Episode")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "gamma_std.png"))
            plt.close()
        # 4) Total Model Loss
        if self.losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.training_steps, self.losses, label="Total Model Loss", alpha=0.7)
            if len(self.losses) >= 10:
                ma_loss = self._moving_average(self.losses, 10)
                plt.plot(range(self.training_steps[0] + (10 - 1),
                               self.training_steps[0] + (10 - 1) + len(ma_loss)),
                         ma_loss, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Training Updates")
            plt.ylabel("Loss")
            plt.title("Total Model Loss Over Rollouts")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "total_loss.png"))
            plt.close()
        # 5) Critic (Value) Loss
        if self.value_losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.value_losses, label="Value Loss", alpha=0.7)
            if len(self.value_losses) >= 10:
                ma_val_loss = self._moving_average(self.value_losses, 10)
                plt.plot(range(10 - 1, 10 - 1 + len(ma_val_loss)), 
                         ma_val_loss, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Rollout End #")
            plt.ylabel("Value Loss")
            plt.title("Value (Critic) Loss")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "value_loss.png"))
            plt.close()
        # 6) Actor (Policy) Loss
        if self.policy_losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.policy_losses, label="Policy Loss", alpha=0.7)
            if len(self.policy_losses) >= 10:
                ma_pol_loss = self._moving_average(self.policy_losses, 10)
                plt.plot(range(10 - 1, 10 - 1 + len(ma_pol_loss)), 
                         ma_pol_loss, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Rollout End #")
            plt.ylabel("Policy Loss")
            plt.title("Policy (Actor) Loss")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "policy_loss.png"))
            plt.close()
        # 7) Entropy Loss
        if self.entropy_losses:
            plt.figure(figsize=(10, 6))
            plt.plot(self.entropy_losses, label="Entropy Loss", alpha=0.7)
            if len(self.entropy_losses) >= 10:
                ma_ent_loss = self._moving_average(self.entropy_losses, 10)
                plt.plot(range(10 - 1, 10 - 1 + len(ma_ent_loss)), 
                         ma_ent_loss, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Rollout End #")
            plt.ylabel("Entropy Loss")
            plt.title("Entropy Loss (Exploration)")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "entropy_loss.png"))
            plt.close()
        # 8) Episode Length
        if self.episode_lengths:
            plt.figure(figsize=(10, 6))
            plt.plot(self.episode_lengths, label="Episode Length", alpha=0.6)
            ma_length = self._moving_average(self.episode_lengths, 10)
            if len(ma_length):
                plt.plot(range(10 - 1, 10 - 1 + len(ma_length)), 
                         ma_length, label="MA(10)", color='red', linewidth=2)
            plt.xlabel("Episode")
            plt.ylabel("Length (# steps)")
            plt.title("Episode Length")
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(save_dir, "episode_length.png"))
            plt.close()
        # Also save a summary text
        with open(os.path.join(save_dir, "summary.txt"), "w") as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if self.episode_rewards:
                avg_reward = np.mean(self.episode_rewards)
                f.write(f"Mean Episode Reward: {avg_reward:.3f}\n")
            f.write(f"Collisions Count: {self.n_collisions}\n")
            if self.episode_mean_gammas:
                mean_gamma_all = np.mean(self.episode_mean_gammas)
                f.write(f"Mean of Average-Gamma: {mean_gamma_all:.3f}\n")

###############################################################################
# A SIMPLE RENDER CALLBACK (OPTIONAL) FOR LIVE VIEW
###############################################################################
class RenderCallback(BaseCallback):
    def __init__(self, render_freq=1, verbose=0):
        super().__init__(verbose)
        self.render_freq = render_freq
    def _on_step(self) -> bool:
        if self.n_calls % self.render_freq == 0:
            self.model.env.envs[0].render()
        return True

###############################################################################
# DEMO ARBITRATION ENV
###############################################################################
class DemoArbitrationEnv(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": RENDER_FPS}
    def __init__(self, visualize=False):
        super().__init__()
        self.visualize = visualize
        # Observation: [dot_x, dot_y, h_dir_x, h_dir_y, goal_x, goal_y, w_dir_x, w_dir_y, dist_ratio, obs_dist_ratio]
        low  = np.array([0, 0, -1, -1, 0, 0, -1, -1, 0, 0], dtype=np.float32)
        high = np.array([FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1,
                         FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1], 1, 1, 1, 1], dtype=np.float32)
        self.observation_space = spaces.Box(low=low, high=high, shape=(10,), dtype=np.float32)
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(), dtype=np.float32)
        self.dot_pos = None
        self.goal_pos = None
        self.obstacles = []
        self.goals = []
        self.step_count = 0
        self.max_steps = 300
        self.episode_reward = 0.0
        self.max_dist = math.hypot(FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1])
        self.alpha = 3.0
        self.beta  = 3.0
        self.goal_threshold = 100.0
        self.obs_threshold  = 100.0
        self.SCENARIO_SEEDS = [0, 1, 2, 58, 487]
        self.scenario_index = 0
        self.episode_counter = 0
        self.random_seed_probability = 0.3
        if self.visualize and pygame is not None:
            pygame.init()
            self.window = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("Demo Arbitration Environment")
            self.clock = pygame.time.Clock()
        else:
            self.window = None
            self.clock = None

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.episode_counter += 1
        use_random = (random.random() < self.random_seed_probability)
        if use_random:
            scenario_seed = random.randint(0, 9999999)
        else:
            scenario_seed = self.SCENARIO_SEEDS[self.scenario_index]
            self.scenario_index = (self.scenario_index + 1) % len(self.SCENARIO_SEEDS)
        self.randomize_env(scenario_seed)
        self.step_count = 0
        self.episode_reward = 0.0
        self.dot_pos = START_POS.copy()
        if self.goals:
            idx = random.randint(0, len(self.goals) - 1)
            self.goal_pos = self.goals[idx].copy()
        else:
            self.goal_pos = np.array([random.uniform(0.2*FULL_VIEW_SIZE[0], 0.8*FULL_VIEW_SIZE[0]),
                                       random.uniform(0.2*FULL_VIEW_SIZE[1], 0.8*FULL_VIEW_SIZE[1])],
                                      dtype=np.float32)
        return self._get_obs(), {}

    def randomize_env(self, scenario_seed):
        random.seed(scenario_seed)
        np.random.seed(scenario_seed)
        margin = 50 * SCALING_FACTOR
        N_GOALS = 8
        N_OBSTACLES = 5
        min_goal_distance = 300 * SCALING_FACTOR
        new_goals = []
        attempts = 0
        while len(new_goals) < N_GOALS and attempts < 2000:
            x = random.uniform(margin, FULL_VIEW_SIZE[0] - margin)
            y = random.uniform(margin, FULL_VIEW_SIZE[1] - margin)
            candidate = np.array([x, y], dtype=np.float32)
            if distance(candidate, START_POS) >= min_goal_distance:
                new_goals.append(candidate)
            attempts += 1
        self.goals = new_goals[:N_GOALS]
        new_obstacles = []
        if len(self.goals) > 1:
            obstacle_goals = random.sample(self.goals, k=min(min(N_GOALS-1, N_OBSTACLES), len(self.goals)-1))
        else:
            obstacle_goals = self.goals
        for goal in obstacle_goals:
            t = random.uniform(0.6, 0.8)
            base_point = START_POS + t*(goal - START_POS)
            vec = goal - START_POS
            vec_norm = np.linalg.norm(vec)
            if vec_norm < 1e-6:
                perp = np.array([0, 0], dtype=np.float32)
            else:
                perp = np.array([-vec[1], vec[0]], dtype=np.float32)
                perp /= np.linalg.norm(perp)
            offset_mag = random.uniform(20*SCALING_FACTOR, 40*SCALING_FACTOR)
            offset = perp * offset_mag * random.choice([-1, 1])
            candidate = base_point + offset
            candidate[0] = np.clip(candidate[0], margin, FULL_VIEW_SIZE[0] - margin)
            candidate[1] = np.clip(candidate[1], margin, FULL_VIEW_SIZE[1] - margin)
            valid = True
            if distance(candidate, START_POS) < (DOT_RADIUS + OBSTACLE_RADIUS + 10):
                valid = False
            if distance(candidate, goal) < (TARGET_RADIUS + OBSTACLE_RADIUS + 20):
                valid = False
            for obs in new_obstacles:
                if distance(candidate, obs) < (2*OBSTACLE_RADIUS + 10):
                    valid = False
            if valid:
                new_obstacles.append(candidate)
        self.obstacles = new_obstacles

    def step(self, action):
        raw_a = float(action)
        gamma_val = 0.5 * (raw_a + 1.0)  # map [-1,1] -> [0,1]
        self.step_count += 1
        w_dir = potential_field_dir(self.dot_pos, self.goal_pos, self.obstacles)
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        h_dir = w_dir + noise
        hm = np.hypot(h_dir[0], h_dir[1])
        if hm > 1e-6:
            h_dir /= hm
        c_dir = gamma_val * w_dir + (1 - gamma_val) * h_dir
        cm = np.hypot(c_dir[0], c_dir[1])
        if cm > 1e-6:
            c_dir /= cm
        move_vec = c_dir * MAX_SPEED
        new_pos = self.dot_pos + move_vec
        if not line_collision(self.dot_pos, new_pos, self.obstacles):
            new_pos[0] = np.clip(new_pos[0], 0, FULL_VIEW_SIZE[0])
            new_pos[1] = np.clip(new_pos[1], 0, FULL_VIEW_SIZE[1])
            self.dot_pos = new_pos
        collided = inside_obstacle(self.dot_pos, self.obstacles)
        info = {}
        if collided:
            original_reward = -2.0
            done = True
            info["terminal_reason"] = "collision"
        else:
            original_reward = 0.0
            done = False
            info["terminal_reason"] = None
        truncated = (self.step_count >= self.max_steps)
        if truncated and not done:
            info["terminal_reason"] = "timeout"
        d_goal = distance(self.dot_pos, self.goal_pos)
        if self.obstacles:
            d_obs = min(distance(self.dot_pos, obs) for obs in self.obstacles)
        else:
            d_obs = 999999.0
        if (d_goal < self.goal_threshold) or (d_obs < self.obs_threshold):
            shaping_reward = self.alpha * gamma_val
        else:
            shaping_reward = -self.beta * gamma_val
        reward = original_reward + shaping_reward
        self.episode_reward += reward
        obs = self._get_obs()
        return obs, float(reward), done, truncated, info

    def _get_obs(self):
        to_g = self.goal_pos - self.dot_pos
        d = math.hypot(to_g[0], to_g[1])
        dist_ratio = d / self.max_dist if self.max_dist > 1e-6 else 0.0
        w_dir = potential_field_dir(self.dot_pos, self.goal_pos, self.obstacles)
        noise = np.random.normal(0, NOISE_MAGNITUDE, size=2)
        h_dir = w_dir + noise
        hm = np.hypot(h_dir[0], h_dir[1])
        if hm > 1e-6:
            h_dir /= hm
        if self.obstacles:
            min_obs_distance = min(distance(self.dot_pos, obs) for obs in self.obstacles)
        else:
            min_obs_distance = self.max_dist
        obs_dist_ratio = min_obs_distance / self.max_dist
        obs = np.concatenate([self.dot_pos, h_dir, self.goal_pos, w_dir, [dist_ratio], [obs_dist_ratio]]).astype(np.float32)
        return obs

    def render(self):
        if not self.visualize or (pygame is None):
            return
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                return
        self.window.fill(WHITE)
        for obs in self.obstacles:
            pygame.draw.circle(self.window, GRAY, (int(obs[0]), int(obs[1])), OBSTACLE_RADIUS)
        for gpos in self.goals:
            pygame.draw.circle(self.window, YELLOW, (int(gpos[0]), int(gpos[1])), TARGET_RADIUS)
        pygame.draw.circle(self.window, BLACK, (int(self.goal_pos[0]), int(self.goal_pos[1])), TARGET_RADIUS+2, width=2)
        pygame.draw.circle(self.window, BLACK, (int(self.dot_pos[0]), int(self.dot_pos[1])), DOT_RADIUS, width=2)
        pygame.display.flip()
        self.clock.tick(RENDER_FPS)

    def close(self):
        if self.visualize and pygame is not None:
            pygame.quit()
        super().close()

###############################################################################
# TRAINING FUNCTION
###############################################################################
def train_model(total_timesteps=500_000, visualize=False):
    env = DemoArbitrationEnv(visualize=visualize)
    env = DummyVecEnv([lambda: env])
    metrics_callback = MetricsCallback()
    callbacks = [metrics_callback]
    if visualize:
        render_callback = RenderCallback(render_freq=1)
        callbacks.append(render_callback)
    # Use our custom GammaMlpPolicy instead of the default MlpPolicy
    model = PPO(
        GammaMlpPolicy,
        env,
        learning_rate=3e-4,
        n_steps=1024,
        batch_size=1024,
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1,
        tensorboard_log="./ppo_tensorboard/",
        policy_kwargs={"net_arch": [{"pi": [256, 256], "vf": [256, 256]}],
                       "activation_fn": nn.ReLU}
    )
    model.learn(total_timesteps=total_timesteps, callback=callbacks, log_interval=1)
    os.makedirs("trained_models", exist_ok=True)
    model_path = os.path.join("trained_models", "gamma_ppo_model")
    model.save(model_path)
    print(f"Model saved to {model_path}.zip")
    metrics_callback.save_metrics(save_dir="training_metrics")
    print("Metrics saved to 'training_metrics/'")
    env.close()

###############################################################################
# OPTIONAL: WATCH THE TRAINED MODEL
###############################################################################
def watch_trained_model(model_path="trained_models/gamma_ppo_model"):
    model = PPO.load(model_path)
    env = DemoArbitrationEnv(visualize=True)
    obs, _ = env.reset()
    done, truncated = False, False
    while not (done or truncated):
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(action)
        env.render()
    env.close()

###############################################################################
# MAIN (EXAMPLE)
###############################################################################
if __name__ == "__main__":
    train_model(total_timesteps=1_000_000, visualize=False)
    # Optionally:
    # watch_trained_model("trained_models/gamma_ppo_model")


Using cuda device
Logging to ./ppo_tensorboard/PPO_41


  raw_a = float(action)


-----------------------------
| time/              |      |
|    fps             | 759  |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 1024 |
-----------------------------
---------------------------------------
| time/                   |           |
|    fps                  | 760       |
|    iterations           | 2         |
|    time_elapsed         | 2         |
|    total_timesteps      | 2048      |
| train/                  |           |
|    approx_kl            | 3.2384377 |
|    clip_fraction        | 0.756     |
|    clip_range           | 0.2       |
|    entropy_loss         | -1.42     |
|    explained_variance   | 0.861     |
|    learning_rate        | 0.0003    |
|    loss                 | 241       |
|    n_updates            | 4         |
|    policy_gradient_loss | 0.22      |
|    std                  | 1         |
|    value_loss           | 702       |
---------------------------------------
------------------------------