# **Stage 1**: Train PPO on PointMaze with standard rewards, collect data, train distance models


In [1]:
import gymnasium as gym
import gymnasium_robotics
from gymnasium.wrappers import RecordVideo
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
import os
import time

import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))
from src.env_wrappers import GoalObservationWrapper, TerminateOnSuccessWrapper
from src import ppo_agent, distance_models

# Check for GPU
print("CUDA available:", torch.cuda.is_available())
print("Current device:", torch.cuda.current_device())
print("Device name:", torch.cuda.get_device_name(torch.cuda.current_device()) if torch.cuda.is_available() else "CPU")

CUDA available: True
Current device: 0
Device name: NVIDIA GeForce RTX 4080 Laptop GPU


In [2]:
# Easy U-shaped maze
# Letter 'c' means start and reward place. Each iteration is random
c = 'c'
example_map = [
    [1, 1, 1, 1, 1],
    [1, c, 0, 0, 1],
    [1, 1, 1, 0, 1],
    [1, c, 0, 0, 1],
    [1, 1, 1, 1, 1]
]

In [10]:
env_id = 'PointMaze_UMaze-v3'  # Use dense reward for initial training
total_timesteps = 50000
steps_per_iter = 1000
seed = 0
torch.manual_seed(seed); np.random.seed(seed)

# Initialize environment
gym.register_envs(gymnasium_robotics)
env = gym.make(env_id, maze_map=example_map)
env = GoalObservationWrapper(env)
env = TerminateOnSuccessWrapper(env)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

In [12]:
# Initialize PPO agent
agent = ppo_agent.PPOAgent(state_dim=obs_dim, action_dim=act_dim)
writer = SummaryWriter(log_dir="runs/stage1")
global_step = 0
num_updates = total_timesteps // steps_per_iter
for update in range(1, num_updates+1):
    traj = agent.collect_trajectory(env, steps_per_iter)
    pg_loss, v_loss, ent_loss = agent.update(traj)
    # Log training metrics
    rewards = traj["rewards"]; dones = traj["dones"]
    ep_returns = []
    cum_reward = 0.0
    for r, d in zip(rewards, dones):
        cum_reward += r
        if d:
            ep_returns.append(cum_reward)
            cum_reward = 0.0
    if ep_returns:
        writer.add_scalar("charts/episodic_return", np.mean(ep_returns), global_step)
    writer.add_scalar("losses/policy_loss", pg_loss, global_step)
    writer.add_scalar("losses/value_loss", v_loss, global_step)
    writer.add_scalar("losses/entropy", ent_loss, global_step)
    global_step += len(rewards)

    if update == 1: start_time = time.time()
    if update % 10 == 0:
        avg_ret = np.mean(ep_returns) if ep_returns else 0.0
        elapsed = time.time() - start_time
        updates_done = update
        updates_left = num_updates - updates_done
        time_per_update = elapsed / updates_done
        eta = updates_left * time_per_update
        print(f"Update {update}/{num_updates} [{int(100*update/num_updates):3d}%] | AvgReturn: {avg_ret:.2f} | ETA: {eta/60:.1f} min")

Update 10/50 [ 20%] | AvgReturn: 0.00 | ETA: 0.7 min
Update 20/50 [ 40%] | AvgReturn: 0.00 | ETA: 0.6 min
Update 30/50 [ 60%] | AvgReturn: 0.00 | ETA: 0.4 min
Update 40/50 [ 80%] | AvgReturn: 0.00 | ETA: 0.2 min
Update 50/50 [100%] | AvgReturn: 0.00 | ETA: 0.0 min


In [None]:
# After training PPO, collect trajectories to train distance models
eval_env = gym.make(env_id)
eval_env = GoalObservationWrapper(eval_env)
eval_env = TerminateOnSuccessWrapper(eval_env)
eval_episodes = 100
sup_states = []
sup_distances = []      
td_transitions = []
success_count = 0
for ep in range(eval_episodes):
    state, _ = eval_env.reset()
    ep_states = [state]
    transitions = []
    cum_reward = 0.0
    step_count = 0
    success = False
    while True:
        action, logp, val = agent.ac.act(state)  # use trained policy
        next_state, reward, terminated, truncated, info = eval_env.step(action)
        done = terminated or truncated
        transitions.append((state, next_state, done, bool(info.get('success', False))))
        cum_reward += reward
        step_count += 1
        state = next_state
        ep_states.append(state)
        if done:
            success = info.get('success', False)
            if success:
                success_count += 1
                # For each state in this successful episode, record true distance to goal
                # If episode length = step_count, distance for state[i] = step_count - i
                for i in range(step_count):
                    sup_states.append(ep_states[i])
                    sup_distances.append(step_count - i)
                # Include the final goal state with distance 0
                sup_states.append(ep_states[-1])
                sup_distances.append(0.0)
            # Add all transitions to TD dataset (failures will be handled in training)
            td_transitions.extend(transitions)
            break
print(f"Collected data from {eval_episodes} episodes, {success_count} were successful.")

In [None]:
sup_states = np.array(sup_states, dtype=np.float32)
sup_distances = np.array(sup_distances, dtype=np.float32)
# Train distance estimators on the collected data
sup_model = distance_models.SupervisedDistanceEstimator(input_dim=obs_dim)
sup_loss = sup_model.train_from_data(sup_states, sup_distances, epochs=100)
# td_model = distance_models.TDDistanceEstimator(input_dim=obs_dim)
# td_loss = td_model.train_from_transitions(td_transitions, epochs=100)
# Compare models on the supervised dataset
sup_preds = sup_model.model(torch.tensor(sup_states)).detach().numpy().flatten()
# td_preds = td_model.model(torch.tensor(sup_states)).detach().numpy().flatten()
mse_sup = np.mean((sup_preds - sup_distances)**2)
# mse_td = np.mean((td_preds - sup_distances)**2)
print(f"Supervised model MSE on training data: {mse_sup:.4f}")
# print(f"TD model MSE on training data: {mse_td:.4f}")
# Save models for Stage 2
os.makedirs("models", exist_ok=True)
# torch.save(td_model.state_dict(), "models/distance_model_td.pth")
torch.save(sup_model.state_dict(), "models/distance_model_sup.pth")
torch.save(agent.ac.state_dict(), "models/ppo_agent_stage1.pth")
# Record a video of the trained agent in PointMaze
video_env = gym.make(env_id, render_mode="rgb_array")
video_env = GoalObservationWrapper(video_env)
video_env = TerminateOnSuccessWrapper(video_env)
video_env = RecordVideo(video_env, video_folder="videos/stage1", episode_trigger=lambda eid: True)
vid_obs, _ = video_env.reset()
done = False
while not done:
    action, _, _ = agent.ac.act(vid_obs)
    vid_obs, _, terminated, truncated, info = video_env.step(action)
    done = terminated or truncated
video_env.close()
env.close(); eval_env.close()