In [1]:
import ray
from ray.rllib.algorithms.ppo import PPO
import os
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env
import numpy as np
import matplotlib.pyplot as plt
import pickle
from rbc_maenv import DedalusRBC_Env

In [2]:
os.environ['PYTHONPATH'] = os.path.abspath('.')

In [3]:
def env_creator():
    env = DedalusRBC_Env(nagents=10)
    return env

In [4]:
ray.init()
env_name = "ma_rbc"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator()))  

2024-04-08 21:26:15,634	INFO worker.py:1752 -- Started a local Ray instance.


In [7]:
checkpoint_path = os.path.expanduser('~/ray_results/PPO/PPO_ma_rbc_1d6be_00000_0_2024-04-06_13-20-54/checkpoint_000000')
base_dir = os.path.dirname(checkpoint_path)
config_path = os.path.join(base_dir, 'params.pkl')
with open(config_path, 'rb') as f:
    config = pickle.load(f)
config['num_workers'] = 1

In [13]:
PPOagent = PPO(config=config)
PPOagent.restore(checkpoint_path)

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2024-04-08 21:36:36,840	INFO trainable.py:164 -- Trainable.setup took 39.449 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
2024-04-08 21:36:36,852	INFO trainable.py:575 -- Restored on 10.109.52.54 from checkpoint: Checkpoint(fil

In [14]:
env=env_creator()
reward_sum = 0
observation, _ = env.reset(seed=52)

for i in range(256):
    action = PPOagent.compute_actions(observation, explore=False)
    observation, reward, termination, truncation, info = env.step(action)
    reward_sum += np.average(list(reward.values()))
    T = env.problem.variables[1]['g']
    fig, ax = plt.subplots()
    c = ax.imshow(np.transpose(T), aspect=1/np.pi, origin="lower", vmin=0., vmax=1.4)
    fig.colorbar(c)
    plt.title('$Nu=$'+str(np.round(env.fp.properties['Nu']['g'].flatten()[-1], 2)))
    plt.savefig('figs/'+str(i)+'.png')
    plt.close()

print(reward_sum)

127.58396717719734
