In [None]:
from src.deeprl_lib.rllib.jupong2d_ppo import *
from src.deeprl_lib.rllib.jupong2d_plot_ppo_data import *

class RunRLlib:
    def __init__(self, output_folder, num_cpus, env_name, paddle_length, session=1, checkpoint_frequency=10, train_algorithm="PPO"
                , num_workers=3, env_per_worker=5, stop_reward=None, stop_iters=None, stop_timesteps=None, restore=True, 
                 as_test=False, play=False, play_steps=3):
        self.num_cpus = num_cpus
        self.paddle_length = paddle_length
        self.session = session
        self.checkpoint_freq = checkpoint_frequency
        self.run_alg = train_algorithm
        self.num_workers = num_workers
        self.env_per_worker = env_per_worker
        self.stop_reward = stop_reward
        self.stop_iters = stop_iters
        self.stop_timesteps = stop_timesteps
        self.output_folder = output_folder
        self.restore = restore
        self.play = play
        self.play_steps = play_steps
        self.as_test = as_test
        self.env_name = env_name

        ray.shutdown()
        ray.init(num_cpus=self.num_cpus or None)
        register_gym_env(self.env_name, self.paddle_length)

        self.config, self.stop = rllib_configurations(self.run_alg, self.env_name, self.num_workers, self.env_per_worker, 
                                     stop_reward=self.stop_reward, stop_iters=self.stop_iters, stop_timesteps=self.stop_timesteps)
        self.save_folder, self.results_path = create_result_paths(self.output_folder, self.session, self.paddle_length)

        self.latest_checkpoint_path = None
        if os.path.exists(self.results_path) and len(os.listdir(self.results_path)) and (self.restore or self.play):
            self.latest_checkpoint_path, self.checkpoint_number = get_latest_checkpoint(self.results_path)
        print(f"Path to latest checkpoint: {self.latest_checkpoint_path}")

    def start(self):
        if not self.play:
            print("Training the model")
            train_model(self.run_alg, self.config, self.stop, self.output_folder, self.checkpoint_freq, self.save_folder, self.latest_checkpoint_path, self.as_test, self.stop_reward)
        else:
            print(f"Testing the model {self.latest_checkpoint_path}")
            test_agent(self.env_name, self.config, self.latest_checkpoint_path, self.results_path, self.play_steps, self.paddle_length)


In [None]:
def main():
    results_folder = "results/rllib_results"
    env_name = "jupong2d"
    train_runner = RunRLlib(results_folder, 0, env_name, 2.0, session=2, checkpoint_frequency=5, stop_iters=5)
    train_runner.start()
    play_runner = RunRLlib(results_folder, 0, env_name, 2.0, session=2, play=True, play_steps=1)
    play_runner.start()
    ploter = JuPong2D_PPO_Plot(results_folder)
    ploter.plot_paddle_length()


main()