In [None]:
from ray import train, tune
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.sac import SACConfig
from ray.tune.registry import register_env

from env_2D.uav2d import UAV2D

In [None]:
def env_creator(config):
    return UAV2D(
        {
            "grid_size": 100,
            "num_buildings": 5,
            "num_dynamic_obstacles": 3,
        }
    )


register_env("uav2d", env_creator)


class CustomCallback(DefaultCallbacks):
    def on_train_result(self, *, algorithm: Algorithm, result: dict, **kwargs) -> None:
        if algorithm._storage:
            algorithm._storage.current_checkpoint_index += 1
            result["checkpoint_dir_name"] = algorithm._storage.checkpoint_dir_name
            algorithm._storage.current_checkpoint_index -= 1

In [None]:
config = (
    SACConfig()
    .framework("torch")
    .environment("uav2d")
    .training(
        lr=1e-5,
        train_batch_size=256,
        _enable_learner_api=False,
    )
    .rollouts(num_rollout_workers=5, rollout_fragment_length="auto", batch_mode="complete_episodes")
    .resources(num_gpus=0)
    .callbacks(CustomCallback)
    .debugging(log_level="ERROR")
    .rl_module(_enable_rl_module_api=False)
)

In [None]:
tuner = tune.Tuner(
    "SAC",
    param_space=config.to_dict(),
    run_config=train.RunConfig(
        "UAV2D_SAC",
        checkpoint_config=train.CheckpointConfig(
            num_to_keep=20,
            checkpoint_score_attribute="episode_reward_mean",
            checkpoint_at_end=True,
            checkpoint_frequency=1000,
        ),
        stop={"timesteps_total": 1e6},
    ),
)

In [None]:
results = tuner.fit()