In [1]:
import ray

ray.init(ignore_reinit_error=True)

2023-12-11 11:44:32,689	INFO worker.py:1673 -- Started a local Ray instance.


0,1
Python version:,3.11.4
Ray version:,2.8.1


In [2]:
ray.__version__

'2.8.1'

In [3]:
from ray import train, tune
from fastcore.xtras import Path

In [4]:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.algorithm import Algorithm

from ray.tune.logger import pretty_print



In [5]:
import torch
import os
# Set device currently does nothing
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if "mps" == device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

In [6]:
#!pip install -e aerospaceRL

In [7]:
import sys
sys.path.append(str(Path('.').resolve()/'aerospaceRL'))
import aero_gym

In [8]:
#no GPU support. Still an open issue:
#https://github.com/ray-project/ray/issues/28321
#tune.Tuner(tune.with_resources(0, resources={"cpu": 8, "gpu": 1}),...
#https://docs.ray.io/en/latest/tune/api/doc/ray.tune.with_resources.html
ray.get_gpu_ids()

[]

In [9]:
#grid search parameters
# tune.grid_search([0.01, 0.001])


In [10]:
#custom model example: https://docs.ray.io/en/latest/rllib/rllib-models.html

In [11]:
from aero_gym.envs import dubins_aircraft#.DubinsAircraft

In [12]:
#aerospaceRL.aero_gym.envs.dubins_aircraft.DubinsAircraft

In [13]:
#customEnv=dubins_aircraft.DubinsAircraft()

In [14]:
tune.registry.register_env("myEnv", dubins_aircraft.DubinsAircraft)

In [15]:
# agent, model, and environment setup
param_space = (PPOConfig()
            .environment(env="myEnv")
            .framework("torch")
            .training(lr=0.01, model=dict(fcnet_activation='relu')))

# sets save directory and when to end training                        
run_config=train.RunConfig(name="my_tune_customEnv",
                           storage_path=str(Path('.').resolve()/'saved_agents'),
                           stop={"episode_reward_mean": 100},
                           checkpoint_config=train.CheckpointConfig(
                                            checkpoint_score_attribute="episode_reward_mean",
                                            checkpoint_score_order="max"
                                                                   ))
#load configs into Tuner
tuner = tune.Tuner("PPO", run_config=run_config,
                   param_space=param_space,
                    )
#training model
results = tuner.fit()


0,1
Current time:,2023-12-11 11:44:42
Running for:,00:00:04.07
Memory:,32.1/64.0 GiB

Trial name,# failures,error file
PPO_myEnv_92c0d_00000,1,/Users/sbrewer/ray_results/my_tune_customEnv/PPO_myEnv_92c0d_00000_0_2023-12-11_11-44-38/error.txt

Trial name,status,loc
PPO_myEnv_92c0d_00000,ERROR,


2023-12-11 11:44:42,303	ERROR tune_controller.py:1383 -- Trial task failed for trial PPO_myEnv_92c0d_00000
Traceback (most recent call last):
  File "/Users/sbrewer/anaconda3/envs/RLlib/lib/python3.11/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
             ^^^^^^^^^^^^^^^
  File "/Users/sbrewer/anaconda3/envs/RLlib/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/sbrewer/anaconda3/envs/RLlib/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sbrewer/anaconda3/envs/RLlib/lib/python3.11/site-packages/ray/_private/worker.py", line 2565, in get
    raise value
ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, [36mray::PPO.__init__()[39m (pi

2023-12-11 11:44:42,312	ERROR tune.py:1043 -- Trials did not complete: [PPO_myEnv_92c0d_00000]
2023-12-11 11:44:42,312	INFO tune.py:1047 -- Total run time: 4.08 seconds (4.06 seconds for the tuning loop).
- PPO_myEnv_92c0d_00000: FileNotFoundError('Could not fetch metrics for PPO_myEnv_92c0d_00000: both result.json and progress.csv were not found at /Users/sbrewer/Documents/NAVAIR/RL_Scripts/saved_agents/my_tune_customEnv/PPO_myEnv_92c0d_00000_0_2023-12-11_11-44-38')
[36m(PPO pid=11122)[0m 2023-12-11 11:44:42,297	ERROR actor_manager.py:500 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, [36mray::RolloutWorker.__init__()[39m (pid=11152, ip=127.0.0.1, actor_id=16de5a947e7b99d791cee5fe01000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x15d4d0490>)
[36m(PPO pid=11122)[0m   File "/Users/sbrewer/anaconda3/envs/RLlib/lib/python3.11/site-packages/ray/rllib/utils/pre_checks/env.py", line 145, in check_g

In [None]:
# Find best model
best_result = results.get_best_result(metric="episode_reward_mean", mode="max")

# Get the best checkpoint corresponding to the best result.
best_checkpoint = best_result.checkpoint


algo = Algorithm.from_checkpoint(best_checkpoint.path)

In [None]:
best_checkpoint.path

In [None]:
policy = algo.get_policy()
#print(policy.get_weights())
model = policy.model

In [None]:
model

In [None]:
# render agent
import gym
import matplotlib.pyplot as plt
os.environ["SDL_VIDEODRIVER"] = "dummy"
from IPython.display import clear_output

N=500
env_name = "CartPole-v1"
env = gym.make(env_name, render_mode="rgb_array")

reward_lst=[]
for n in range(N):
    episode_reward = 0
    obs, info = env.reset()

    while True:

        action = algo.compute_single_action(obs)
        obs, reward, done, truncated, info = env.step(action)
        episode_reward += reward
        if n ==N-1:
            clear_output(wait=True)
            plt.imshow( env.render())
            plt.title(f'Reward: {int(episode_reward)}')
            plt.show()

        if done or episode_reward>500:
            reward_lst.append(episode_reward)
            break

env.close()

In [None]:
import pandas as pd
pd.Series(reward_lst).describe()

In [None]:
list(gym.envs.registry.keys())

In [None]:
import datetime as dt
end=dt.datetime.now()
print(f'Finished: {end.strftime("%A %B %d, %Y")} at {end.strftime("%H:%M")}')