In [1]:
import os
import numpy as np
import gymnasium as gym
import pandas as pd
import matplotlib.pyplot as plt
import gc
import torch
from sbx import DDPG, SAC, TD3, CrossQ
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.logger import configure

environments = ["Pendulum-v1", "HalfCheetah-v5", "Hopper-v5", "Humanoid-v5"]
total_timesteps = 20_000

hyperparameter_sets = {
    "SAC": {
        "learning_rate": 0.0003,
        "batch_size": 256,
        "tau": 0.005,
        "gamma": 0.99,
        "ent_coef": "auto"
    },
    "TD3": {
        "learning_rate": 0.001,
        "batch_size": 100,
        "tau": 0.005,
        "gamma": 0.99,
        "policy_delay": 2
    },
    "DDPG": {
        "learning_rate": 0.001,
        "batch_size": 64,
        "tau": 0.005,
        "gamma": 0.99
    },
    "CrossQ": {
        "learning_rate": 0.0003,
        "batch_size": 64,
        "gamma": 0.99
    },
}

for env_name in environments:
    for algo_name, params in hyperparameter_sets.items():
        log_dir = f"./logs/{env_name}/"
        os.makedirs(log_dir, exist_ok=True)

        monitor_file = os.path.join(log_dir, f"{algo_name}.monitor.csv")

        env = gym.make(env_name)
        env = Monitor(env, monitor_file)

        new_logger = configure(folder=log_dir)

        adjusted_timesteps = total_timesteps if algo_name != "CrossQ" else 10_000

        device = "cpu" if algo_name == "CrossQ" else "auto"

        model_cls = {"SAC": SAC, "TD3": TD3, "DDPG": DDPG, "CrossQ": CrossQ}.get(algo_name)
        if model_cls:
            model = model_cls("MlpPolicy", env, **params, verbose=0, seed=50, device=device)
            model.set_logger(new_logger)
        else:
            continue

        model.learn(
            total_timesteps=adjusted_timesteps,
            log_interval=1,
            progress_bar=False
        )

        env.close()
        del model
        del env
        gc.collect()

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        if not os.path.exists(monitor_file):
            print(f"ERROR: no log file found for {algo_name} on {env_name}, skipping plot")
            continue

        df = pd.read_csv(monitor_file, skiprows=1)
        if "r" in df.columns and len(df) > 1:
            rewards = df["r"].rolling(10, min_periods=1).mean()
            episode_numbers = range(1, len(rewards) + 1)

            plot_dir = f"./plots/{env_name}/"
            os.makedirs(plot_dir, exist_ok=True)

            plt.figure(figsize=(8, 5))
            plt.plot(episode_numbers, rewards, label="Mean Reward (Rolling Avg)", color='b')

            plt.xlabel("Episodes")
            plt.ylabel("Reward")
            plt.title(f"{algo_name} Training - {env_name}\n"
                      f"LR: {params.get('learning_rate', '?')}, "
                      f"Gamma: {params.get('gamma', '?')}, "
                      f"Batch: {params.get('batch_size', '?')}, "
                      f"Tau: {params.get('tau', 'N/A')}")

            plt.legend()
            plt.grid()

            plot_path = os.path.join(plot_dir, f"{algo_name}_learning_curve.png")
            plt.savefig(plot_path)
            plt.close()

            print(f"plot saved: {plot_path}")
        else:
            print(f"ERROR: no valid rewards found for {algo_name} on {env_name}, skipping plot")

print("training done")


Training SAC on Pendulum-v1 with {'learning_rate': 0.0003, 'batch_size': 256, 'tau': 0.005, 'gamma': 0.99, 'ent_coef': 'auto'}...

Logging to ./logs/Pendulum-v1/
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 200       |
|    ep_rew_mean     | -1.28e+03 |
| time/              |           |
|    episodes        | 1         |
|    fps             | 81        |
|    time_elapsed    | 2         |
|    total_timesteps | 200       |
| train/             |           |
|    actor_loss      | 8.29      |
|    critic_loss     | 1.19      |
|    ent_coef        | 0.972     |
|    ent_coef_loss   | -0.0473   |
|    n_updates       | 99        |
----------------------------------
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 200       |
|    ep_rew_mean     | -1.25e+03 |
| time/              |           |
|    episodes        | 2         |
|    fps             | 90        |
|    time_elapsed    | 4        



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 18       |
|    ep_rew_mean     | 83.2     |
| time/              |          |
|    episodes        | 1        |
|    fps             | 844      |
|    time_elapsed    | 0        |
|    total_timesteps | 18       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 23.5     |
|    ep_rew_mean     | 113      |
| time/              |          |
|    episodes        | 2        |
|    fps             | 958      |
|    time_elapsed    | 0        |
|    total_timesteps | 47       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 27.7     |
|    ep_rew_mean     | 132      |
| time/              |          |
|    episodes        | 3        |
|    fps             | 1079     |
|    time_elapsed    | 0        |
|    total_timesteps | 83       |
--------------



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 25.5     |
|    ep_rew_mean     | 121      |
| time/              |          |
|    episodes        | 4        |
|    fps             | 101      |
|    time_elapsed    | 1        |
|    total_timesteps | 102      |
| train/             |          |
|    actor_loss      | 0        |
|    critic_loss     | 207      |
|    n_updates       | 1        |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 23.8     |
|    ep_rew_mean     | 112      |
| time/              |          |
|    episodes        | 5        |
|    fps             | 56       |
|    time_elapsed    | 2        |
|    total_timesteps | 119      |
| train/             |          |
|    actor_loss      | -4.24    |
|    critic_loss     | 28.1     |
|    n_updates       | 18       |
---------------------------------
---------------------------------
| rollout/    



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 18       |
|    ep_rew_mean     | 83.2     |
| time/              |          |
|    episodes        | 1        |
|    fps             | 1494     |
|    time_elapsed    | 0        |
|    total_timesteps | 18       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 23.5     |
|    ep_rew_mean     | 113      |
| time/              |          |
|    episodes        | 2        |
|    fps             | 1242     |
|    time_elapsed    | 0        |
|    total_timesteps | 47       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 27.7     |
|    ep_rew_mean     | 132      |
| time/              |          |
|    episodes        | 3        |
|    fps             | 1193     |
|    time_elapsed    | 0        |
|    total_timesteps | 83       |
--------------



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 18       |
|    ep_rew_mean     | 83.2     |
| time/              |          |
|    episodes        | 1        |
|    fps             | 1674     |
|    time_elapsed    | 0        |
|    total_timesteps | 18       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 23.5     |
|    ep_rew_mean     | 113      |
| time/              |          |
|    episodes        | 2        |
|    fps             | 1427     |
|    time_elapsed    | 0        |
|    total_timesteps | 47       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 27.7     |
|    ep_rew_mean     | 132      |
| time/              |          |
|    episodes        | 3        |
|    fps             | 1329     |
|    time_elapsed    | 0        |
|    total_timesteps | 83       |
--------------

---