In [None]:
import os
import math
import time
import random
import numpy as np
import matplotlib.pyplot as plt

import gymnasium as gym
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback

# Only needed if visualize=True
try:
    import pygame
except ImportError:
    pygame = None

###############################################################################
# ENVIRONMENT CONSTANTS
###############################################################################
FULL_VIEW_SIZE = (1200, 800)
SCALING_FACTOR_X = FULL_VIEW_SIZE[0] / 600.0
SCALING_FACTOR_Y = FULL_VIEW_SIZE[1] / 600.0
SCALING_FACTOR   = (SCALING_FACTOR_X + SCALING_FACTOR_Y) / 2

DOT_RADIUS       = int(15 * SCALING_FACTOR)
TARGET_RADIUS    = int(10 * SCALING_FACTOR)
OBSTACLE_RADIUS  = int(10 * SCALING_FACTOR)
COLLISION_BUFFER = int(5  * SCALING_FACTOR)
MAX_SPEED        = 3 * SCALING_FACTOR
NOISE_MAGNITUDE  = 0.5
RENDER_FPS       = 30

OBSTACLES = [
    (200, 100),
    (300, 700),
    (1000, 150),
    (1100, 600),
    (200, 650),
]

GOALS = [
    (600, 100),
    (1100, 200),
    (1100, 700),
    (600, 700),
    (100, 700),
    (100, 200),
    (900, 400),
    (300, 400),
]

START_POS = np.array([FULL_VIEW_SIZE[0]//2, FULL_VIEW_SIZE[1]//2], dtype=np.float32)

WHITE = (255,255,255)
GRAY  = (128,128,128)
YELLOW= (255,255,0)
BLACK = (0,0,0)

###############################################################################
# UTILITY
###############################################################################
def distance(a, b):
    return math.hypot(a[0]-b[0], a[1]-b[1])

def check_line_collision(start, end, center, radius):
    dx = end[0] - start[0]
    dy = end[1] - start[1]
    fx = center[0] - start[0]
    fy = center[1] - start[1]
    l2 = dx*dx + dy*dy
    if l2 < 1e-9:
        return distance(start, center) <= radius
    t = max(0, min(1, (fx*dx + fy*dy) / l2))
    px = start[0] + t*dx
    py = start[1] + t*dy
    return distance((px,py), center) <= radius

def line_collision(pos, new_pos):
    for obs in OBSTACLES:
        if check_line_collision(pos, new_pos, obs, OBSTACLE_RADIUS + COLLISION_BUFFER):
            return True
    return False

def inside_obstacle(pos):
    for obs in OBSTACLES:
        if distance(pos, obs) <= (OBSTACLE_RADIUS + DOT_RADIUS):
            return True
    return False

def potential_field_dir(pos, goal):
    gx = goal[0] - pos[0]
    gy = goal[1] - pos[1]
    dg = math.hypot(gx, gy)
    if dg < 1e-6:
        return np.zeros(2, dtype=np.float32)
    att = np.array([gx/dg, gy/dg], dtype=np.float32)

    repulse_x = 0.0
    repulse_y = 0.0
    repulsion_radius = 150.0 * SCALING_FACTOR
    repulsion_gain   = 15000.0

    for obs in OBSTACLES:
        dx = pos[0] - obs[0]
        dy = pos[1] - obs[1]
        dobs = math.hypot(dx, dy)
        if dobs<1e-6:
            continue
        if dobs < repulsion_radius:
            pushx   = dx/dobs
            pushy   = dy/dobs
            strength= repulsion_gain/(dobs**2)
            repulse_x += pushx*strength
            repulse_y += pushy*strength

    px = att[0] + repulse_x
    py = att[1] + repulse_y
    mg = math.hypot(px, py)
    if mg < 1e-6:
        return np.zeros(2, dtype=np.float32)
    return np.array([px/mg, py/mg], dtype=np.float32)

###############################################################################
# CALLBACK
###############################################################################
class MetricsCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_rewards = []
        self.episode_gammas  = []
        self.current_episode_gammas = []
        self.total_reward = 0.0

        self.n_collisions = 0
        self.n_episodes   = 0
        self.action_low  = None
        self.action_high = None

    def _on_training_start(self):
        self.n_collisions = 0
        self.n_episodes   = 0

    def _on_step(self) -> bool:
        if self.action_low is None:
            self.action_low  = float(self.model.action_space.low)
            self.action_high = float(self.model.action_space.high)

        actions = self.locals['actions']
        rewards = self.locals['rewards']
        done    = self.locals['dones'][0]
        infos   = self.locals['infos']

        raw_action = float(actions[0])
        raw_action = max(self.action_low, min(self.action_high, raw_action))
        gamma = 0.5*(raw_action + 1.0)

        r = float(rewards[0])
        self.total_reward += r

        if done:
            self.episode_rewards.append(self.total_reward)
            avg_g = np.mean(self.current_episode_gammas) if len(self.current_episode_gammas)>0 else 0.0
            self.episode_gammas.append(avg_g)
            self.total_reward = 0.0
            self.current_episode_gammas.clear()
            self.n_episodes += 1

            if infos[0].get("terminal_reason")== "collision":
                self.n_collisions += 1
        else:
            self.current_episode_gammas.append(gamma)

        return True

    def save_metrics(self, save_dir="training_metrics"):
        os.makedirs(save_dir, exist_ok=True)

        plt.figure(figsize=(8,5))
        plt.plot(self.episode_rewards, label="Episode Reward")
        plt.xlabel("Episode")
        plt.ylabel("Reward")
        plt.title("Episode Reward Over Time")
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, "episode_reward.png"))
        plt.close()

        plt.figure(figsize=(8,5))
        plt.plot(self.episode_gammas, label="Average Gamma")
        plt.xlabel("Episode")
        plt.ylabel("Gamma")
        plt.title("Average Gamma per Episode")
        plt.grid(True)
        plt.legend()
        plt.savefig(os.path.join(save_dir, "average_gamma.png"))
        plt.close()

        with open(os.path.join(save_dir, "training_summary.txt"), 'w') as f:
            f.write(f"Total Episodes: {len(self.episode_rewards)}\n")
            if len(self.episode_rewards)>0:
                f.write(f"Mean Reward: {np.mean(self.episode_rewards):.3f}\n")
                f.write(f"Mean Gamma: {np.mean(self.episode_gammas):.3f}\n")
            f.write(f"Collisions Count: {self.n_collisions}\n")

###############################################################################
# ENV: TINY THRESHOLD => +alpha*g if within 10 units, else -beta*g
###############################################################################
class DemoArbitrationEnv(gym.Env):
    """
    The dot must be extremely close to the goal (dist<10) to get +alpha*g.
    Otherwise it sees -beta*g => big punish for high gamma away from the goal.
    Collisions => -2 end the episode.

    This ensures the agent basically never picks gamma=1 unless physically
    hugging the goal center.
    """
    metadata = {"render_modes":["human"], "render_fps":RENDER_FPS}

    def __init__(self, visualize=False):
        super().__init__()
        self.visualize= visualize

        low  = np.array([0,0, -1,-1, 0,0, -1,-1, 0], dtype=np.float32)
        high = np.array([
            FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
            1,1,
            FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1],
            1,1,
            1
        ], dtype=np.float32)
        self.observation_space = spaces.Box(low=low, high=high, shape=(9,), dtype=np.float32)

        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(), dtype=np.float32)

        self.dot_pos = None
        self.goal_pos= None
        self.step_count=0
        self.max_steps=300
        self.episode_reward=0.0

        self.max_dist = math.hypot(FULL_VIEW_SIZE[0], FULL_VIEW_SIZE[1])

        # extremely small threshold => basically on top of goal
        self.close_threshold = 10.0
        self.alpha = 3.0
        self.beta  = 3.0

        if self.visualize and pygame is not None:
            pygame.init()
            self.window = pygame.display.set_mode(FULL_VIEW_SIZE)
            pygame.display.set_caption("Extremely Close => High Gamma")
            self.clock = pygame.time.Clock()
        else:
            self.window = None
            self.clock  = None

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.step_count=0
        self.episode_reward=0.0

        self.dot_pos = START_POS.copy()
        idx = np.random.randint(len(GOALS))
        self.goal_pos= np.array(GOALS[idx], dtype=np.float32)

        return self._get_obs(), {}

    def step(self, action):
        raw_a = float(action)
        raw_a = np.clip(raw_a, -1.0, 1.0)
        gamma = 0.5*(raw_a+1.0)

        self.step_count +=1

        # perfect dir
        w_dir = potential_field_dir(self.dot_pos, self.goal_pos)
        # "human/noise" dir
        noise = np.random.normal(0,NOISE_MAGNITUDE,size=2)
        h_dir = w_dir + noise
        hm= np.hypot(h_dir[0], h_dir[1])
        if hm>1e-6:
            h_dir/= hm

        c_dir = gamma*w_dir + (1-gamma)*h_dir
        cm= np.hypot(c_dir[0], c_dir[1])
        if cm>1e-6:
            c_dir/=cm

        move_vec = c_dir*MAX_SPEED
        new_pos  = self.dot_pos + move_vec
        if not line_collision(self.dot_pos,new_pos):
            new_pos[0] = np.clip(new_pos[0],0,FULL_VIEW_SIZE[0])
            new_pos[1] = np.clip(new_pos[1],0,FULL_VIEW_SIZE[1])
            self.dot_pos= new_pos

        collided= inside_obstacle(self.dot_pos)
        info={}
        if collided:
            original_reward= -2.0
            done= True
            info["terminal_reason"]="collision"
        else:
            original_reward=0.0
            done=False
            info["terminal_reason"]=None

        truncated= (self.step_count>=self.max_steps)
        if truncated and not done:
            info["terminal_reason"] = "timeout"

        # If dist<10 => +alpha*g, else => -beta*g
        dist_g= distance(self.dot_pos,self.goal_pos)
        if dist_g< self.close_threshold:
            shaping_reward= self.alpha*gamma
        else:
            shaping_reward= -self.beta*gamma

        reward= original_reward + shaping_reward
        self.episode_reward+= reward

        if self.visualize:
            self._render()

        obs= self._get_obs()
        return obs, float(reward), done, truncated, info

    def _get_obs(self):
        to_g= self.goal_pos - self.dot_pos
        d= math.hypot(to_g[0],to_g[1])
        dist_ratio= d/self.max_dist if self.max_dist>1e-6 else 0.0

        w_dir= potential_field_dir(self.dot_pos, self.goal_pos)
        noise= np.random.normal(0,NOISE_MAGNITUDE,size=2)
        h_dir= w_dir+noise
        hm= np.hypot(h_dir[0],h_dir[1])
        if hm>1e-6:
            h_dir/=hm

        obs= np.concatenate([
            self.dot_pos,
            h_dir,
            self.goal_pos,
            w_dir,
            [dist_ratio]
        ]).astype(np.float32)
        return obs

    def _render(self):
        if not self.window or not pygame:
            return

        self.window.fill(WHITE)
        # obstacles
        for obs in OBSTACLES:
            pygame.draw.circle(self.window, GRAY, (int(obs[0]),int(obs[1])),
                               OBSTACLE_RADIUS)
        # all goals
        for gpos in GOALS:
            pygame.draw.circle(self.window, YELLOW, (int(gpos[0]),int(gpos[1])),
                               TARGET_RADIUS)
        # highlight current
        pygame.draw.circle(self.window, BLACK,(int(self.goal_pos[0]),
                                               int(self.goal_pos[1])),
                          TARGET_RADIUS+2,width=2)

        # dot
        pygame.draw.circle(self.window, BLACK,(int(self.dot_pos[0]),
                                               int(self.dot_pos[1])),
                           DOT_RADIUS,width=2)

        pygame.display.flip()
        self.clock.tick(RENDER_FPS)

    def close(self):
        if self.visualize and pygame is not None:
            pygame.quit()
        super().close()

###############################################################################
# TRAIN
###############################################################################
def train(visualize=False, total_timesteps=300_000):
    from stable_baselines3.common.callbacks import CallbackList
    env = DemoArbitrationEnv(visualize=visualize)
    metrics_callback = MetricsCallback()
    callback = CallbackList([metrics_callback])

    model = PPO(
        policy="MlpPolicy",
        env=env,
        learning_rate=3e-4,
        n_steps=1024,
        batch_size=1024,
        n_epochs=4,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        verbose=1
    )

    try:
        print(f"Starting PPO training (visualize={visualize}) ...")
        model.learn(total_timesteps=total_timesteps, callback=callback)
    except KeyboardInterrupt:
        print("Training interrupted; saving partial model...")

    os.makedirs("trained_models", exist_ok=True)
    model.save("extreme_close_gamma_ppo")
    print("Model saved to trained_models/extreme_close_gamma_ppo.zip")

    metrics_callback.save_metrics("training_metrics")
    print("Metrics saved in training_metrics/")

    env.close()

###############################################################################
# MAIN
###############################################################################
if __name__=="__main__":
    import sys
    vis = (len(sys.argv)>1 and sys.argv[1].lower()=="visualize")
    train(visualize=vis, total_timesteps=300_000)



Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Starting PPO training (visualize=False) ...
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 300      |
|    ep_rew_mean     | 30.6     |
| time/              |          |
|    fps             | 701      |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 1024     |
---------------------------------
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 300           |
|    ep_rew_mean          | -116          |
| time/                   |               |
|    fps                  | 708           |
|    iterations           | 2             |
|    time_elapsed         | 2             |
|    total_timesteps      | 2048          |
| train/                  |               |
|    approx_kl            | 0.00019927562 |
|    clip_fraction        | 0        