In [1]:
from functools import partial

import gymnasium as gym
import numpy as np
import torch.nn as nn
from gymnasium.wrappers import (
    FlattenObservation,
    FrameStack,
    RecordVideo,
    RescaleAction,
    TimeLimit,
)
from rl_zoo3 import linear_schedule
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize
from wandb.integration.sb3 import WandbCallback

import wandb
from src.environments import ea
from src.train.ea_ppo import make_env, train
from src.utils import save_config
from src.wrappers import (
    LogTaskStatistics,
    PlotEpisode,
    PolishedDonkeyReward,
    RescaleObservation,
)

In [None]:
config = {
    # Environment
    "action_mode": "delta",
    "max_quad_setting": 30.0,
    "max_quad_delta": 30.0,
    "max_steerer_delta": 6.1782e-3,
    "magnet_init_mode": np.array([10.0, -10.0, 0.0, 10.0, 0.0]),
    "incoming_mode": "random",
    "misalignment_mode": "random",
    "max_misalignment": 5e-4,
    "target_beam_mode": np.zeros(4),
    "threshold_hold": 1,
    "clip_magnets": True,
    # Reward (also environment)
    "beam_param_transform": "ClippedLinear",
    "beam_param_combiner": "Mean",
    "beam_param_combiner_args": {},
    "beam_param_combiner_weights": [1, 1, 1, 1],
    "magnet_change_transform": "Sigmoid",
    "magnet_change_combiner": "Mean",
    "magnet_change_combiner_args": {},
    "magnet_change_combiner_weights": [1, 1, 1, 1, 1],
    "final_combiner": "Mean",
    "final_combiner_args": {},
    "final_combiner_weights": [3, 0.5, 0.5],
    # Wrappers
    "frame_stack": 1,  # 1 means no frame stacking
    "normalize_observation": True,
    "running_obs_norm": False,
    "normalize_reward": False,  # Not really needed because normalised by design
    "rescale_action": True,
    "target_threshold": None,  # 2e-5 m is estimated screen resolution
    "max_episode_steps": 50,
    "polished_donkey_reward": False,
    # RL algorithm
    "batch_size": 64,
    "learning_rate": 0.0003,
    "lr_schedule": "constant",  # Can be "constant" or "linear"
    "gamma": 0.99,
    "n_envs": 40,
    "n_steps": 64,
    "ent_coef": 0.0,
    "n_epochs": 10,
    "gae_lambda": 0.95,
    "clip_range": 0.2,
    "clip_range_vf": None,  # None,
    "vf_coef": 0.5,
    "max_grad_norm": 0.5,
    "use_sde": False,
    "sde_sample_freq": -1,
    "target_kl": None,
    "total_timesteps": 500_000,
    # Policy
    "net_arch": "small",  # Can be "small" or "medium"
    "activation_fn": "Tanh",  # Tanh, ReLU, GELU
    "ortho_init": True,  # True, False
    "log_std_init": 0.0,
    # SB3 config
    "sb3_device": "auto",
    "vec_env": "subproc",
}

In [3]:
train(config)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjank324[0m ([33mmsk-ipc[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using cpu device
Logging to log/deft-pond-68/PPO_1




---------------------------------
| rollout/           |          |
|    ep_len_mean     | 50       |
|    ep_rew_mean     | -0.388   |
| time/              |          |
|    fps             | 1138     |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2560     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 50          |
|    ep_rew_mean          | -0.397      |
| time/                   |             |
|    fps                  | 1381        |
|    iterations           | 2           |
|    time_elapsed         | 3           |
|    total_timesteps      | 5120        |
| train/                  |             |
|    approx_kl            | 0.019554382 |
|    clip_fraction        | 0.292       |
|    clip_range           | 0.2         |
|    entropy_loss         | 4.41        |
|    explained_variance   | 0.152       |
|    learning_rate        | 0.

  logger.warn(
  logger.warn(
