In [1]:
"""
Train a PPO agent in a single environment that matches your "demo":
 - Potential-field perfect movement
 - Same obstacles/goals
 - Action = scalar in [-1..1], mapped to gamma in [0..1]
 - Saves final training metrics + trained model

Now includes a 'visualize' parameter so you can disable rendering for faster training.
"""

import os
import math
import time
import random
import numpy as np
import matplotlib.pyplot as plt

import gymnasium as gym
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback

# Try to import pygame safely. If you plan to use visualize=False only,
# you could skip installing pygame entirely.
try:
    import pygame
except ImportError:
    pygame = None


###############################################################################
# Global config (mirroring your final "demo" environment)
###############################################################################
FULL_VIEW_SIZE = (1200, 800)

SCALING_FACTOR_X  = FULL_VIEW_SIZE[0] / 600.0
SCALING_FACTOR_Y  = FULL_VIEW_SIZE[1] / 600.0
SCALING_FACTOR    = (SCALING_FACTOR_X + SCALING_FACTOR_Y) / 2

WHITE  = (255, 255, 255)
BLACK  = (0, 0, 0)
GRAY   = (128, 128, 128)
YELLOW = (255, 255, 0)
GREEN  = (0, 200, 0)
BLUE   = (0, 0, 255)
RED    = (255, 0, 0)

FONT_SIZE = int(24 * SCALING_FACTOR)
FONT = None  # we'll init in __init__ if visualize=True

DOT_RADIUS           = int(15 * SCALING_FACTOR)
TARGET_RADIUS        = int(10 * SCALING_FACTOR)
OBSTACLE_RADIUS      = int(10 * SCALING_FACTOR)
COLLISION_BUFFER     = int(5 * SCALING_FACTOR)
MAX_SPEED            = 3 * SCALING_FACTOR

GOAL_DETECTION_RADIUS= DOT_RADIUS + TARGET_RADIUS
START_POS = np.array([FULL_VIEW_SIZE[0] // 2, FULL_VIEW_SIZE[1] // 2], dtype=np.float32)

NOISE_MAGNITUDE = 0.5
RENDER_FPS = 30

# Obstacles
OBSTACLES = [
    (200, 100),
    (300, 700),
    (1000, 150),
    (1100, 600),
    (200, 650),
]

# Goals
GOALS = [
    (600, 100),   # top center
    (1100, 200),  # top-right
    (1100, 700),  # bottom-right
    (600, 700),   # bottom center
    (100, 700),   # bottom-left
    (100, 200),   # top-left
    (900, 400),   # mid-right
    (300, 400),   # mid-left
]


###############################################################################
# Metrics callback: same as before
###############################################################################
class MetricsCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.episode_gammas = []
        self.current_episode_gammas = []
        self.total_reward = 0.0

        self.action_low = None
        self.action_high = None

    def _on_step(self) -> bool:
        if self.action_low is None:
            # action_space is shape=(), so read as float
            self.action_low = float(self.model.action_space.low)
            self.action_high = float(self.model.action_space.high)

        actions = self.locals['actions']
        raw_action = float(actions[0])
        # clamp
        raw_action = max(self.action_low, min(self.action_high, raw_action))

        # map raw in [-1,1] => gamma in [0,1]
        gamma = 0.5*(raw_action+1.0)

        # reward
        rewards = self.locals['rewards']
        r = float(rewards[0])
        self.total_reward += r

        done = self.locals['dones'][0]
        if done:
            self.episode_rewards.append(self.total_reward)
            avg_g = np.mean(self.current_episode_gammas) if len(self.current_episode_gammas)>0 else 0.0
            self.episode_gammas.append(avg_g)

            self.total_reward = 0.0
            self.current_episode_gammas = []
        else:
            self.current_episode_gammas.append(gamma)

        return True

    def save_metrics(self, save_dir="training_metrics"):
        os.makedirs(save_dir, exist_ok=True)

        # Plot episode rewards
        plt.figure(figsize=(8,5))
        plt.plot(self.episode_rewards, label="Episode Reward")
        plt.title("Episode Reward Over Time")
        plt.xlabel("Episode")
        plt.ylabel("Reward")
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, "episode_reward.png"))
        plt.close()

        # Plot average gamma
        plt.figure(figsize=(8,5))
        plt.plot(self.episode_gammas, label="Average Gamma")
        plt.title("Average Gamma per Episode")
        plt.xlabel("Episode")
        plt.ylabel("Gamma")
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, "average_gamma.png"))
        plt.close()

        # Summary text
        with open(os.path.join(save_dir, "training_summary.txt"), 'w') as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if len(self.episode_rewards)>0:
                f.write(f"Mean Reward: {np.mean(self.episode_rewards):.3f}\n")
                f.write(f"Mean Gamma: {np.mean(self.episode_gammas):.3f}\n")
                f.write(f"Best Episode: {max(self.episode_rewards):.3f}\n")
                f.write(f"Worst Episode: {min(self.episode_rewards):.3f}\n")


###############################################################################
# Utility
###############################################################################
def distance(a, b):
    return math.hypot(a[0]-b[0], a[1]-b[1])

def check_line_collision(start, end, center, radius):
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    fx = center[0] - start[0]
    fy = center[1] - start[1]
    l2 = dx*dx + dy*dy
    if l2<1e-9:
        return distance(start, center)<=radius
    t = max(0, min(1, (fx*dx + fy*dy)/l2))
    px= start[0]+ t*dx
    py= start[1]+ t*dy
    return distance((px,py), center)<=radius

def line_collision(pos, new_pos):
    for obs in OBSTACLES:
        if check_line_collision(pos, new_pos, obs, OBSTACLE_RADIUS+COLLISION_BUFFER):
            return True
    return False

def inside_obstacle(pos):
    for obs in OBSTACLES:
        if distance(pos, obs)<= (OBSTACLE_RADIUS+DOT_RADIUS):
            return True
    return False

def potential_field_dir(pos, goal):
    gx = goal[0]-pos[0]
    gy = goal[1]-pos[1]
    dg = math.hypot(gx,gy)
    if dg<1e-6:
        return np.zeros(2,dtype=np.float32)
    att = np.array([gx/dg, gy/dg], dtype=np.float32)

    repulse_x=0.0
    repulse_y=0.0
    repulsion_radius=150.0*SCALING_FACTOR
    repulsion_gain=30000.0
    for obs in OBSTACLES:
        dx = pos[0]-obs[0]
        dy = pos[1]-obs[1]
        dobs = math.hypot(dx,dy)
        if dobs<1e-6:
            continue
        if dobs<repulsion_radius:
            pushx = dx/dobs
            pushy = dy/dobs
            strength= repulsion_gain/(dobs**2)
            repulse_x+= pushx*strength
            repulse_y+= pushy*strength

    px = att[0]+repulse_x
    py = att[1]+repulse_y
    mg= math.hypot(px,py)
    if mg<1e-6:
        return np.zeros(2,dtype=np.float32)
    return np.array([px/mg, py/mg],dtype=np.float32)


###############################################################################
# The Real-Time Env (with visualize param)
###############################################################################
class DemoArbitrationEnv(gym.Env):
    """
    Single environment, shape=() action => scalar in [-1,1], mapped to gamma [0..1].
    Observations: 9D [dot_x, dot_y, h_dir_x, h_dir_y, goal_x, goal_y, w_dir_x, w_dir_y, dist_norm].
    If visualize=True, we open a Pygame window and do real-time rendering (slow).
    If visualize=False, we skip all rendering for faster training.
    """
    metadata = {"render_modes":["human"], "render_fps":RENDER_FPS}

    def __init__(self, visualize=True):
        super().__init__()
        self.visualize= visualize
        # Obs space
        low  = np.array([0,0, -1,-1, 0,0, -1,-1, 0], dtype=np.float32)
        high = np.array([
            FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
            1,1,
            FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
            1,1,
            1
        ], dtype=np.float32)
        self.observation_space = spaces.Box(
            low=low, high=high, shape=(9,), dtype=np.float32
        )

        # Action space: shape=(), in [-1..1]
        self.action_space = spaces.Box(
            low=-1.0, high=1.0, shape=(), dtype=np.float32
        )

        self.dot_pos = None
        self.goal_pos= None
        self.step_count=0
        self.max_steps=300
        self.episode_reward=0.0
        self.max_dist= math.hypot(FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1])

        self.recent_positions=[]
        self.last_reset_time= time.time()

        # only do pygame init if visualize==True
        if self.visualize and pygame is not None:
            pygame.init()
            global FONT
            FONT = pygame.font.Font(None, FONT_SIZE)
            self.window = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("PPO Real-time Env [action shape=()]")
            self.clock = pygame.time.Clock()
        else:
            self.window = None
            self.clock = None

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.step_count=0
        self.episode_reward=0.0
        self.dot_pos= START_POS.copy()
        idx= np.random.randint(len(GOALS))
        self.goal_pos= np.array(GOALS[idx], dtype=np.float32)
        self.recent_positions.clear()
        self.last_reset_time= time.time()

        return self._get_obs(), {}

    def step(self, action):
        raw_a = float(action)
        raw_a= max(-1.0, min(1.0, raw_a))
        gamma= 0.5*(raw_a+1.0) # => [0..1]

        self.step_count+=1

        w_dir= potential_field_dir(self.dot_pos, self.goal_pos)
        noise= np.random.normal(0, NOISE_MAGNITUDE, size=2)
        h_dir= w_dir+noise
        hm= np.hypot(h_dir[0], h_dir[1])
        if hm>1e-6:
            h_dir/=hm

        c_dir= gamma*w_dir+(1-gamma)*h_dir
        cm= np.hypot(c_dir[0], c_dir[1])
        if cm>1e-6:
            c_dir/=cm

        old_pos= self.dot_pos.copy()
        move= c_dir*MAX_SPEED
        new_pos= self.dot_pos+ move
        if not line_collision(self.dot_pos, new_pos):
            new_pos[0]= np.clip(new_pos[0],0,FULL_VIEW_SIZE[0])
            new_pos[1]= np.clip(new_pos[1],0,FULL_VIEW_SIZE[1])
            self.dot_pos= new_pos

        collided= inside_obstacle(self.dot_pos)
        if collided:
            reward= -20.0
            done= True
        else:
            dist_goal= distance(self.dot_pos, self.goal_pos)
            if dist_goal< GOAL_DETECTION_RADIUS:
                reward=10.0
                idx= np.random.randint(len(GOALS))
                self.goal_pos= np.array(GOALS[idx], dtype=np.float32)
            else:
                reward=0.0
            done=False

        truncated= (self.step_count>=self.max_steps)
        self.episode_reward+= reward

        # Only render if visualize=True
        if self.visualize:
            self._render(old_pos, w_dir, h_dir, c_dir, collided)

        obs= self._get_obs()
        info={}
        return obs, float(reward), done, truncated, info

    def _get_obs(self):
        to_g= self.goal_pos- self.dot_pos
        d= math.hypot(to_g[0], to_g[1])
        dist_ratio= d/self.max_dist if self.max_dist>1e-6 else 0.0

        # for obs only => w_dir + noise
        w_dir= potential_field_dir(self.dot_pos, self.goal_pos)
        noise= np.random.normal(0, NOISE_MAGNITUDE, size=2)
        h_dir= w_dir+ noise
        hm= np.hypot(h_dir[0], h_dir[1])
        if hm>1e-6:
            h_dir/=hm

        obs= np.concatenate([
            self.dot_pos,
            h_dir,
            self.goal_pos,
            w_dir,
            [dist_ratio]
        ]).astype(np.float32)
        return obs

    def _render(self, old_pos, w_dir, h_dir, c_dir, collided):
        now= time.time()
        self.recent_positions.append((self.dot_pos[0], self.dot_pos[1], now))
        while len(self.recent_positions)>1 and (now - self.recent_positions[0][2])>3.0:
            self.recent_positions.pop(0)

        self.window.fill(WHITE)

        # obstacles
        for obs in OBSTACLES:
            pygame.draw.circle(self.window, GRAY, obs, OBSTACLE_RADIUS)

        # goals
        for i,gpos in enumerate(GOALS):
            pygame.draw.circle(self.window, YELLOW, (int(gpos[0]), int(gpos[1])), TARGET_RADIUS)
            textimg= FONT.render(str(i+1), True, BLACK)
            self.window.blit(textimg, (gpos[0]-5, gpos[1]-12))

        # highlight
        pygame.draw.circle(self.window, BLACK, (int(self.goal_pos[0]),int(self.goal_pos[1])),
                           TARGET_RADIUS+2, width=2)

        # ghost path
        if len(self.recent_positions)>1:
            for k in range(len(self.recent_positions)-1):
                x1,y1,t1= self.recent_positions[k]
                x2,y2,t2= self.recent_positions[k+1]
                pygame.draw.line(self.window, (200,200,200), (x1,y1), (x2,y2), 2)

        # dot
        pygame.draw.circle(self.window, BLACK, (int(self.dot_pos[0]),int(self.dot_pos[1])),
                           DOT_RADIUS, width=2)

        def draw_arrow(surf, color, start, vec):
            dx,dy=vec
            mg= math.hypot(dx,dy)
            if mg<1e-6:
                return
            dx/=mg
            dy/=mg
            length= int(60*SCALING_FACTOR)
            endx= start[0]+ dx*length
            endy= start[1]+ dy*length
            pygame.draw.line(surf, color, start, (endx,endy), width=2)
            arrow_size= 7*SCALING_FACTOR
            angle= math.atan2(dy,dx)
            ax1= endx- arrow_size*math.cos(angle+math.pi/6)
            ay1= endy- arrow_size*math.sin(angle+math.pi/6)
            ax2= endx- arrow_size*math.cos(angle-math.pi/6)
            ay2= endy- arrow_size*math.sin(angle-math.pi/6)
            pygame.draw.line(surf, color, (endx,endy), (ax1,ay1), width=2)
            pygame.draw.line(surf, color, (endx,endy), (ax2,ay2), width=2)

        cx, cy= int(self.dot_pos[0]), int(self.dot_pos[1])
        draw_arrow(self.window, GREEN, (cx,cy), w_dir)
        draw_arrow(self.window, BLUE,  (cx,cy), h_dir)
        draw_arrow(self.window, RED,   (cx,cy), c_dir)

        steps_txt= FONT.render(f"Steps: {self.step_count}/{self.max_steps}", True, BLACK)
        self.window.blit(steps_txt, (10,10))
        rew_txt= FONT.render(f"Episode Reward: {self.episode_reward:.1f}", True, BLACK)
        self.window.blit(rew_txt, (10,40))
        if collided:
            col_msg= FONT.render("Collision => -20.0 reward, done!", True, RED)
            self.window.blit(col_msg, (250,20))

        pygame.display.flip()
        self.clock.tick(RENDER_FPS)

    def render(self):
        # no-op, we do everything in _render if visualize==True
        pass

    def close(self):
        if self.visualize and pygame is not None:
            pygame.quit()
        super().close()


###############################################################################
# Training
###############################################################################
def train(visualize=False):
    """
    If visualize=True => shows a Pygame window in real time (slow).
    If visualize=False => no rendering, faster training.
    """
    from stable_baselines3.common.callbacks import CallbackList

    env = DemoArbitrationEnv(visualize=visualize)

    # Build PPO
    model = PPO(
        policy="MlpPolicy",
        env=env,
        learning_rate=3e-4,
        n_steps=1024,
        batch_size=1024,
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1
    )

    metrics_callback = MetricsCallback()
    callback = CallbackList([metrics_callback])

    total_timesteps = 30_000

    try:
        print(f"Starting PPO training (visualize={visualize}) ...")
        model.learn(total_timesteps=total_timesteps, callback=callback)
    except KeyboardInterrupt:
        print("Training interrupted; saving partial model...")

    # Save final model
    os.makedirs("trained_models", exist_ok=True)
    model.save("trained_models/demo_arbitration_ppo")
    print("Model saved to trained_models/demo_arbitration_ppo.zip")

    # Save metrics
    metrics_callback.save_metrics("training_metrics")
    print("Metrics saved in training_metrics/")

    env.close()

###############################################################################
# Main
###############################################################################
if __name__=="__main__":
    # Example: pass visualize=False for fast/no-window
    # or visualize=True to see the environment in real time
    train(visualize=False)


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Starting PPO training (visualize=False) ...
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 300      |
|    ep_rew_mean     | 6.67     |
| time/              |          |
|    fps             | 339      |
|    iterations      | 1        |
|    time_elapsed    | 3        |
|    total_timesteps | 1024     |
---------------------------------
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 300           |
|    ep_rew_mean          | 11.7          |
| time/                   |               |
|    fps                  | 301           |
|    iterations           | 2             |
|    time_elapsed         | 6             |
|    total_timesteps      | 2048          |
| train/                  |               |
|    approx_kl            | 0.00024563202 |
|    clip_fraction        | 0        