# Learning to balance

In [13]:
import gymnasium as gym
from gymnasium.wrappers import RecordVideo, TimeLimit
from stable_baselines3 import PPO
import numpy as np
import os
from pathlib import Path
from tqdm import trange


In [2]:
current = Path.cwd()

# If launched from a subfolder (VS Code), go one level up
if (current / "notebooks").exists():
    PROJECT_ROOT = current
else:
    PROJECT_ROOT = current.parent

os.chdir(PROJECT_ROOT)

DATA_DIR = Path("data")
DATA_DIR.mkdir(exist_ok=True)

## Set up the environment

In [12]:
N_STEPS = 1500 # 30 seconds

In [None]:
env = gym.make("CartPole-v1", max_episode_steps=N_STEPS)

In [None]:
env_gui = gym.make("CartPole-v1", render_mode="rgb_array", max_episode_steps=N_STEPS)

env_gui = RecordVideo(
    env_gui,
    video_folder="videos",
    episode_trigger=lambda ep: True,
    name_prefix="balance_demo"
)

  logger.warn(


## Helpers

In [9]:
def rollout_episode(env, model, max_steps=1500, deterministic=True):

    obs, _ = env.reset()

    states = []
    actions = []
    rewards = []

    for _ in range(max_steps):
        states.append(obs.copy())

        action, _ = model.predict(obs, deterministic=deterministic)
        actions.append(action)

        obs, reward, terminated, truncated, _ = env.step(action)
        rewards.append(reward)

        if terminated or truncated:
            break

    return np.array(states), np.array(actions), np.array(rewards)


In [10]:
def rollout_episode_video(env, model, max_steps=1500, deterministic=True):
    obs, _ = env.reset()
    done = False
    steps = 0

    while not done and steps < max_steps:
        action, _ = model.predict(obs, deterministic=deterministic)
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        steps += 1


## Getting a good policy for data generation

In [None]:
model = PPO("MlpPolicy", env, verbose=0)

In [14]:
model.learn(total_timesteps=50000, progress_bar=True)

Output()

<stable_baselines3.ppo.ppo.PPO at 0x7dec09efcc20>

In [11]:
rollout_episode_video(env_gui, model)
env_gui.close()
# Look at the videos folder

In [None]:
MODEL_PATH = DATA_DIR / "test_ppo_cartpole_balance"

model.save(MODEL_PATH)
print("Model saved to:", MODEL_PATH)

Model saved to: data/ppo_cartpole_balance


In [7]:
MODEL_PATH = DATA_DIR / "ppo_cartpole_balance"
model = PPO.load(MODEL_PATH, env)

The policy learned to stabilize. But the first 2 seconds, the cartpole drifts a little bit away. We should filterout this from the training data.

## Record the data

In [8]:
DT = 0.02           # CartPole timestep (50 Hz)
RECORD_TIME = 30.0  # seconds (paper)
N_RECORD = int(RECORD_TIME / DT)

WARMUP_TIME = 4.0   # seconds to let PPO stabilize before recording
N_WARMUP = int(WARMUP_TIME / DT)

print("Warmup steps:", N_WARMUP)
print("Recording steps:", N_RECORD)

Warmup steps: 200
Recording steps: 1500


In [24]:
def record_30s_episodes_with_data(
    model,
    n_episodes=5,                 # ✅ number of 30s episodes to record
    video_dir="videos",
    data_dir="data",
    filename="ppo_balance_dataset",
    env_id="CartPole-v1",
    dt=0.02,
):
    """
    Records n_episodes of 30-second episodes using a trained PPO policy
    AND saves all observations and actions to a single .npz dataset.
    
    The .npz file contains:
        - observations : (T_total, 4)
        - actions      : (T_total,)
        - episode_ids  : (T_total,)
        - episode_lens : (n_episodes,)
        - dt
        - env_id
    """

    os.makedirs(video_dir, exist_ok=True)
    os.makedirs(data_dir, exist_ok=True)

    base_env = gym.make(env_id, render_mode="rgb_array", max_episode_steps=N_STEPS)

    env = RecordVideo(
        base_env,
        video_folder=video_dir,
        episode_trigger=lambda ep: True,
        name_prefix=filename
    )

    all_obs = []
    all_actions = []
    all_episode_ids = []
    episode_lens = []

    ep = 0  # number of successfull episodes recorded

    while ep < n_episodes:

        obs, _ = env.reset()
        local_obs = []
        local_actions = []

        success = True

        for step in range(N_STEPS):
            local_obs.append(obs.copy())

            action, _ = model.predict(obs, deterministic=True)
            local_actions.append(int(action))

            obs, _, terminated, truncated, _ = env.step(action)

            if terminated:
                success = False
                break

            if truncated and step < N_STEPS - 1:
                success = False
                break

        if success and len(local_obs) == N_STEPS:

            all_obs.extend(local_obs)
            all_actions.extend(local_actions)
            all_episode_ids.extend([ep] * N_STEPS)
            episode_lens.append(N_STEPS)

            print(f"✅ Episode {ep} SUCCESS: {N_STEPS} steps recorded")
            ep += 1

        else:
            print(f"⚠️ Episode FAILED at {len(local_obs)} steps → retrying...")

    env.close()

    all_obs = np.array(all_obs)
    all_actions = np.array(all_actions)
    all_episode_ids = np.array(all_episode_ids)
    episode_lens = np.array(episode_lens)

    save_path = os.path.join(data_dir, f"{filename}.npz")
    np.savez(
        save_path,
        observations=all_obs,
        actions=all_actions,
        episode_ids=all_episode_ids,
        episode_lens=episode_lens,
        dt=dt,
        env_id=env_id,
        n_episodes=n_episodes,
        steps_per_episode=N_STEPS
    )

    print("FINAL DATASET SAVED")
    print("Observations shape :", all_obs.shape)
    print("Actions shape      :", all_actions.shape)
    print("Episode ids shape  :", all_episode_ids.shape)
    print("Episode lengths    :", episode_lens)
    print("Saved to           :", save_path)

    return all_obs, all_actions, all_episode_ids, episode_lens


In [25]:
X, U, ep_ids, ep_lens = record_30s_episodes_with_data(
    model,
    n_episodes=10,
    video_dir="videos",
    data_dir="data",
    filename="ppo_balance_clean_30s"
)

✅ Episode 0 SUCCESS: 1500 steps recorded
✅ Episode 1 SUCCESS: 1500 steps recorded
✅ Episode 2 SUCCESS: 1500 steps recorded
✅ Episode 3 SUCCESS: 1500 steps recorded
✅ Episode 4 SUCCESS: 1500 steps recorded
✅ Episode 5 SUCCESS: 1500 steps recorded
✅ Episode 6 SUCCESS: 1500 steps recorded
✅ Episode 7 SUCCESS: 1500 steps recorded
✅ Episode 8 SUCCESS: 1500 steps recorded
✅ Episode 9 SUCCESS: 1500 steps recorded
FINAL DATASET SAVED
Observations shape : (15000, 4)
Actions shape      : (15000,)
Episode ids shape  : (15000,)
Episode lengths    : [1500 1500 1500 1500 1500 1500 1500 1500 1500 1500]
Saved to           : data/ppo_balance_clean_30s.npz
