In [21]:
import sys
# sys.path.insert(0, "..")

In [55]:
import numpy as np
import ray
from ray.rllib.agents import ppo
from ray.tune.registry import register_env
from ray.tune.logger import UnifiedLogger
from gym.spaces import Discrete, Box

from cpr_reputation.environments import HarvestEnv

defaults_ini = {
    "num_agents": 4,
    "size": (20, 20),
    "sight_width": 5,
    "sight_dist": 10,
    "num_crosses": 4,
}

register_env("harvest", lambda config: HarvestEnv(config, **defaults_ini))

walker1 = (
    None,
    Box(
        0.0,
        1.0,
        (defaults_ini["sight_dist"], 2 * defaults_ini["sight_width"] + 1, 3),
        np.float32,
    ),  # obs
    Discrete(8),  # action
    dict(),
)

walkers = {f"Agent{k}": walker1 for k in range(defaults_ini["num_agents"])}

config = {
    "multiagent": {
        "policies": walkers,
        "policy_mapping_fn": lambda agent_id: agent_id,
        "policies_to_train": list(walkers.keys())
    },
    "framework": "torch",
    "model": {
        "dim": 3,
        "conv_filters": [
            [16, [4, 4], 1],
            [
                32,
                [defaults_ini["sight_dist"], 2 * defaults_ini["sight_width"] + 1],
                1,
            ],
        ],
    },
}



In [56]:
# ray.init()
trainer = ppo.PPOTrainer(
    env="harvest",
    config=config,
    logger_creator=lambda cfg: UnifiedLogger(cfg, "log"),
)


. In total there are 0 pending tasks and 2 pending actors on this node. This is likely due to all cluster resources being claimed by actors. To resolve the issue, consider creating fewer actors or increase the resources available to this Ray cluster. You can ignore this message if this Ray cluster is expected to auto-scale.
2021-03-28 20:39:35,770	INFO trainable.py:100 -- Trainable.setup took 17.982 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [57]:
results = trainer.train()



In [4]:
dir(trainer)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_allow_unknown_configs',
 '_allow_unknown_subkeys',
 '_before_evaluate',
 '_close_logfiles',
 '_create_logger',
 '_default_config',
 '_env_id',
 '_episodes_total',
 '_evaluate',
 '_experiment_id',
 '_export_model',
 '_init',
 '_is_overridden',
 '_iteration',
 '_iterations_since_restore',
 '_local_ip',
 '_log_result',
 '_logdir',
 '_make_workers',
 '_monitor',
 '_name',
 '_open_logfiles',
 '_override_all_subkeys_if_type_changes',
 '_policy_class',
 '_register_if_needed',
 '_restore',
 '_restored',
 '_result_logger',
 '_save',
 '_setup',
 '_stderr_context',
 '_stderr_fp',
 '_stderr_logging

In [65]:
trainer.export_policy_model("log", "Agent1")

NotImplementedError: 

In [None]:
results = trainer.train()


In [58]:
results["config"]["multiagent"]

{'policies': {'Agent0': (ray.rllib.policy.policy_template.PPOTorchPolicy,
   Box(0.0, 1.0, (10, 11, 3), float32),
   Discrete(8),
   {}),
  'Agent1': (ray.rllib.policy.policy_template.PPOTorchPolicy,
   Box(0.0, 1.0, (10, 11, 3), float32),
   Discrete(8),
   {}),
  'Agent2': (ray.rllib.policy.policy_template.PPOTorchPolicy,
   Box(0.0, 1.0, (10, 11, 3), float32),
   Discrete(8),
   {}),
  'Agent3': (ray.rllib.policy.policy_template.PPOTorchPolicy,
   Box(0.0, 1.0, (10, 11, 3), float32),
   Discrete(8),
   {})},
 'policy_mapping_fn': <function __main__.<lambda>(agent_id)>,
 'policies_to_train': ['Agent0', 'Agent1', 'Agent2', 'Agent3'],
 'observation_fn': None,
 'replay_mode': 'independent',
 'count_steps_by': 'env_steps'}

In [59]:
results


{'episode_reward_max': 9.0,
 'episode_reward_min': 3.0,
 'episode_reward_mean': 6.0,
 'episode_len_mean': 1002.0,
 'episodes_this_iter': 2,
 'policy_reward_min': {'Agent0': 0.0,
  'Agent1': 0.0,
  'Agent2': 3.0,
  'Agent3': 0.0},
 'policy_reward_max': {'Agent0': 2.0,
  'Agent1': 0.0,
  'Agent2': 7.0,
  'Agent3': 0.0},
 'policy_reward_mean': {'Agent0': 1.0,
  'Agent1': 0.0,
  'Agent2': 5.0,
  'Agent3': 0.0},
 'custom_metrics': {},
 'hist_stats': {'episode_reward': [9.0, 3.0],
  'episode_lengths': [1002, 1002],
  'policy_Agent0_reward': [2.0, 0.0],
  'policy_Agent1_reward': [0.0, 0.0],
  'policy_Agent2_reward': [7.0, 3.0],
  'policy_Agent3_reward': [0.0, 0.0]},
 'sampler_perf': {'mean_env_wait_ms': 1.111176417864066,
  'mean_raw_obs_processing_ms': 0.3072834801280695,
  'mean_inference_ms': 11.543739860740558,
  'mean_action_processing_ms': 0.19977010529616784},
 'off_policy_estimator': {},
 'num_healthy_workers': 2,
 'timesteps_total': 4000,
 'timers': {'sample_time_ms': 26727.227,
  's

In [69]:
import pandas as pd
!ls log

checkpoint-1
events.out.tfevents.1616797063.quinn-Latitude-3340
events.out.tfevents.1616798483.quinn-Latitude-3340
events.out.tfevents.1616941169.quinn-Latitude-3340
events.out.tfevents.1616941192.quinn-Latitude-3340
events.out.tfevents.1616943650.quinn-Latitude-3340
events.out.tfevents.1616953298.quinn-Latitude-3340
events.out.tfevents.1616957917.quinn-Latitude-3340
events.out.tfevents.1616958578.quinn-Latitude-3340
events.out.tfevents.1616959652.quinn-Latitude-3340
events.out.tfevents.1616959673.quinn-Latitude-3340
events.out.tfevents.1616960168.quinn-Latitude-3340
events.out.tfevents.1616960182.quinn-Latitude-3340
events.out.tfevents.1616960208.quinn-Latitude-3340
events.out.tfevents.1616960247.quinn-Latitude-3340
events.out.tfevents.1616960357.quinn-Latitude-3340
params.json
params.pkl
progress.csv
result.json


In [76]:
pd.read_csv("log/progress.csv", header=None)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,82,83,84,85,86,87,88,89,90,91
0,16.0,11.0,13.5,1002.0,2,2,4000,False,2,1,...,0.0,0.2,5e-05,0.334471,-0.051019,0.383098,0.440263,0.011957,2.065266,0.0
1,16.0,8.0,12.333333,1002.0,4,2,8000,False,6,2,...,0.0,0.2,5e-05,0.13642,-0.052705,0.185611,0.558235,0.017568,2.040155,0.0
2,16.0,8.0,11.6,1002.0,4,2,12000,False,10,3,...,0.0,0.2,5e-05,0.291126,-0.059521,0.347306,0.527838,0.016709,2.021533,0.0
3,11.0,7.0,9.0,1002.0,2,2,4000,False,2,1,...,0.0,0.2,5e-05,0.027325,-0.038971,0.064254,0.49711,0.010209,2.06639,0.0
4,11.0,7.0,9.0,1002.0,2,2,4000,False,2,1,...,0.0,0.2,5e-05,0.027325,-0.038971,0.064254,0.49711,0.010209,2.06639,0.0
5,9.0,3.0,6.0,1002.0,2,2,4000,False,2,1,...,0.0,0.2,5e-05,-0.024132,-0.029238,0.002606,0.750254,0.0125,2.063753,0.0


In [78]:
!ls -l log

total 5576
-rw-rw-r-- 1 quinn quinn 5508984 Mar 28 20:06 checkpoint-1
-rw-rw-r-- 1 quinn quinn   11530 Mar 26 22:24 events.out.tfevents.1616797063.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn      40 Mar 26 22:41 events.out.tfevents.1616798483.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn      40 Mar 28 15:19 events.out.tfevents.1616941169.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn      40 Mar 28 15:19 events.out.tfevents.1616941192.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn      40 Mar 28 16:02 events.out.tfevents.1616943650.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn   61791 Mar 28 18:50 events.out.tfevents.1616953298.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn   11812 Mar 28 20:04 events.out.tfevents.1616957917.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn   11812 Mar 28 20:10 events.out.tfevents.1616958578.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn       0 Mar 28 20:27 events.out.tfevents.1616959652.quinn-Latitude-3340
-rw-rw-r-- 1 quinn quinn       0 Mar 28 20:27 events.out.

In [84]:
import json
import io

# with io.open("log/result.json") as file: 
#    all_results = json.load(file)
    

In [115]:
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.policy import Policy
from typing import Dict

class SaveCheckpointsCallback(DefaultCallbacks): 
    
    def on_episode_end(
        self, *, 
        worker: RolloutWorker, 
        base_env: BaseEnv,
        policies: Dict[str, Policy], 
        episode: MultiAgentEpisode,
        env_index: int, 
        **kwargs
    ):
        for agent_id, policy in policies.items(): 
            policy.export_model(f"log/{agent_id}")
            
        if self.legacy_callbacks.get("on_episode_end"):
            self.legacy_callbacks["on_episode_end"]({
                "env": base_env,
                "policy": policies,
                "episode": episode,
            })

In [122]:

config = {
    "multiagent": {
        "policies": walkers,
        "policy_mapping_fn": lambda agent_id: agent_id,
        "policies_to_train": list(walkers.keys()), 
        "checkpoint_at_end": True
    },
    "framework": "torch",
    "model": {
        "dim": 3,
        "conv_filters": [
            [16, [4, 4], 1],
            [
                32,
                [defaults_ini["sight_dist"], 2 * defaults_ini["sight_width"] + 1],
                1,
            ],
        ],
    },
    #"callbacks": SaveCheckpointsCallback
}


In [123]:

trainer = ppo.PPOTrainer(
    env="harvest",
    config=config,
    logger_creator=lambda cfg: UnifiedLogger(cfg, "log"),
)



. In total there are 0 pending tasks and 2 pending actors on this node. This is likely due to all cluster resources being claimed by actors. To resolve the issue, consider creating fewer actors or increase the resources available to this Ray cluster. You can ignore this message if this Ray cluster is expected to auto-scale.


KeyboardInterrupt: 

In [None]:

while True:
    results = trainer.train()
    

In [10]:
ls -l ~/ray_results | grep train_fn

drwxrwxr-x 3 quinn quinn 4096 Mar 29 13:15 [01;34mtrain_fn_2021-03-29_13-12-41[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 13:28 [01;34mtrain_fn_2021-03-29_13-28-25[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 13:32 [01;34mtrain_fn_2021-03-29_13-29-19[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 13:36 [01;34mtrain_fn_2021-03-29_13-36-22[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 14:03 [01;34mtrain_fn_2021-03-29_13-37-56[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 14:07 [01;34mtrain_fn_2021-03-29_14-04-40[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 14:10 [01;34mtrain_fn_2021-03-29_14-07-31[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 14:10 [01;34mtrain_fn_2021-03-29_14-10-21[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 14:22 [01;34mtrain_fn_2021-03-29_14-22-46[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 14:23 [01;34mtrain_fn_2021-03-29_14-23-38[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 14:24 [01;34mtrain_fn_2021-03-29_14-24-38[0m/
drwxrwxr-x 3 quinn quinn 4096 Mar 29 14:25 [01;34mtrain_fn_2021-

In [25]:
import os

def all_dirs_under(path):
    """Iterates through all files that are under the given path."""
    for cur_path, dirnames, filenames in os.walk(path):
        for dir_ in dirnames: 
            yield os.path.join(cur_path, dir_)

def latest_dir(path): 
    return max(all_dirs_under(path), key=os.path.getmtime)

to_restore = latest_dir("//home/quinn/ray_results")
to_restore

'//home/quinn/ray_results/train_fn_2021-03-29_17-18-24'

In [42]:
def walk_to_checkpoint(to_restore_path: str): 
    for cur_path, dirnames, _ in os.walk(to_restore_path): 
        for dir_ in dirnames: 
            if dir_.startswith("train_fn_"):
                # print(os.path.join(cur_path, dir_))
                for cur_path_, dirnames_, _ in os.walk(os.path.join(cur_path, dir_)): 
                    for dir__ in dirnames_:                     
                        if dir__.startswith("checkpoint_"): 
                            return os.path.join(cur_path_, dir__)

checkpoint = walk_to_checkpoint(
    latest_dir(
        "//home/quinn/ray_results"
    )
)
checkpoint

In [70]:
from typing import List

def retrieve_checkpoint(path: str = "//home/quinn/ray_results", prefix: str = "train_fn") -> str: 

    def all_dirs_under(path):
        """Iterates through all files that are under the given path."""
        for cur_path, dirnames, filenames in os.walk(path):
            for dir_ in dirnames: 
                yield os.path.join(cur_path, dir_)

    def retrieve_checkpoints(paths: List[str]) -> List[str]:
        checkpoints = list()
        for path in paths: 
            for cur_path, dirnames, _ in os.walk(path): 
                for dirname in dirnames: 
                    if dirname.startswith("checkpoint_"): 
                        checkpoints.append(os.path.join(cur_path, dirname))
        return checkpoints
    
    sorted_checkpoints = retrieve_checkpoints(
        sorted(
            filter(
                lambda x: x.startswith(f"{path}/{prefix}"), all_dirs_under(path)
            ), 
            key=os.path.getmtime
        )
    )[::-1]
    
    for checkpoint in sorted_checkpoints:
        if checkpoint is not None: 
            return checkpoint
    return None

retrieve_checkpoint()

'//home/quinn/ray_results/train_fn_2021-03-29_20-08-33/train_fn_28297_00000_0_2021-03-29_20-08-33/checkpoint_tmp095c77'

In [60]:
!ls -la //home/quinn/ray_results/train_fn_2021-03-29_20-08-33/train_fn_28297_00000_0_2021-03-29_20-08-33/checkpoint_tmp095c77/

total 12
drwxrwxr-x 2 quinn quinn 4096 Mar 29 20:08 .
drwxrwxr-x 4 quinn quinn 4096 Mar 29 20:08 ..
-rw-rw-r-- 1 quinn quinn    0 Mar 29 20:08 .is_checkpoint
-rw-rw-r-- 1 quinn quinn    0 Mar 29 20:08 .null_marker
-rw-rw-r-- 1 quinn quinn    0 Mar 29 20:08 .temp_marker
-rw-rw-r-- 1 quinn quinn  181 Mar 29 20:08 .tune_metadata


In [76]:
from numpy.random import permutation
permutation(list(map(lambda x: x == 5, [4,5,3])))

array([False, False,  True])

# toward video proof

In [2]:
from celluloid import Camera
from IPython.display import Video
import matplotlib.pyplot as plt
import numpy as np

from cpr_reputation import board
from cpr_reputation.utils import retrieve_checkpoint

In [5]:
checkpoint = retrieve_checkpoint("//home/quinn/ray_results")

In [7]:
# load checkpoint 
# run episode 
# with Camera calls in it
