In [None]:
import ray

ray.init(num_cpus=12, num_gpus=1, memory=1024 * 1024 * 1024 * 10, object_store_memory=1024 * 1024 * 1024 * 30)

In [None]:
from ray.tune.registry import register_env
import gym

def choose_env_for(env_config):
    print(env_config)
    print("worker index is {}".format(env_config.worker_index))
    print("testing vector_index {}".format(env_config.vector_index))
    mod = env_config.worker_index
    if env_config.worker_index > 0:
        mod -= 1
    sat_id = mod * env_config["num_envs_per_worker"] + env_config.vector_index
    env = gym.make("satellite_gym:SatelliteEnv-v2", sat_id=sat_id)
    return env

register_env("SatelliteMultiEnv-v2", lambda x: choose_env_for(x))

In [None]:
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print

def on_train_result(info):
    result = info["result"]
    if result["episode_reward_mean"] > 42:
        phase = 2
    elif result["episode_reward_mean"] > 21:
        phase = 1
    else:
        phase = 0
    trainer = info["trainer"]
    trainer.workers.foreach_worker(
        lambda ev: ev.foreach_env(
            lambda env: env.set_phase(phase)))
    
    
config = ppo.DEFAULT_CONFIG.copy()
config['model']['use_lstm'] = True
config["model"]["vf_share_layers"] = True
# config["optimizer"]["batch_replay"] = True
config["num_workers"] = 10
config["num_gpus_per_worker"] = .1
config["seed"] = 0
config["eager"] = False

# config["clip_rewards"] = False
# config["tau"] = 1.0 # 1-tau * value_network + 1-tau/tau * target_network
# config["evaluation_interval"] = 5
# config["evaluation_num_episodes"] = 10
# config["exploration_ou_noise_scale"] = 1.0
# config["buffer_size"] = 4000000
# config["observation_filter"] = "NoFilter"
# config["train_batch_size"] = 1024
# config["sample_batch_size"] = 100
# config["num_envs_per_worker"] = 30
config["callbacks"] = { "on_train_result": on_train_result }
config["num_envs_per_worker"] = 12
config["env_config"]["num_envs_per_worker"] = config["num_envs_per_worker"]

trainer = ppo.PPOTrainer(config=config, env="SatelliteMultiEnv-v2")

In [None]:
for i in range(201):
    # Perform one iteration of training the policy with PPO
    result = trainer.train()
    print(pretty_print(result))
    
    if i % 100 == 0:
        checkpoint = trainer.save()
        print("checkpoint saved at", checkpoint)