**Note**: When running this notebook on SageMaker Studio, you should make sure the 'SageMaker JumpStart Tensorflow 1.0' image/kernel is used. You can run run all cells at once or step through the notebook.
# Policy Training

This notebook outlines the steps involved in building and deploying a Battlesnake model using Ray RLlib and TensorFlow on Amazon SageMaker.

Library versions currently in use:  TensorFlow 2.1, Ray RLlib 0.8.2

The model is first trained using multi-agent PPO, and then deployed to a managed _TensorFlow Serving_ SageMaker endpoint that can be used for inference.

In [1]:
import sagemaker
from sagemaker.rl import RLEstimator, RLToolkit
import boto3
import botocore
import json

In [2]:
with open("../stack_outputs.json") as f:
    info = json.load(f)

## Initialise sagemaker
We need to define several parameters prior to running the training job. 

In [3]:
sm_session = sagemaker.session.Session()
s3_bucket = info["S3Bucket"]

s3_output_path = 's3://{}/'.format(s3_bucket)
print("S3 bucket path: {}".format(s3_output_path))

S3 bucket path: s3://sagemaker-soln-bs-rbcsnake-bucket/


In [4]:
job_name_prefix = info["SolutionPrefix"]+'-job-rllib'

role = info["SageMakerIamRoleArn"]
print(role)

arn:aws:iam::018864217387:role/sagemaker-soln-bs-us-west-2-nb-role


Change local_mode to True if you want to do local training within this Notebook instance

In [5]:
local_mode = False

if local_mode:
    instance_type = 'local'
else:
    instance_type = info["SagemakerTrainingInstanceType"]
    
# If training locally, do some Docker housekeeping..
if local_mode:
    !/bin/bash ./common/setup.sh

# Train your model here

In [6]:
region = sm_session.boto_region_name
device = "cpu"
image_name = '462105765813.dkr.ecr.{region}.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.2-tf-{device}-py36'.format(region=region, device=device)

In [8]:
%%time

# Define and execute our training job
# Adjust hyperparameters and train_instance_count accordingly

metric_definitions =  [
    {'Name': 'training_iteration', 'Regex': 'training_iteration: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episodes_total', 'Regex': 'episodes_total: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'num_steps_trained', 'Regex': 'num_steps_trained: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'timesteps_total', 'Regex': 'timesteps_total: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'training_iteration', 'Regex': 'training_iteration: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},

    {'Name': 'episode_reward_max', 'Regex': 'episode_reward_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episode_reward_mean', 'Regex': 'episode_reward_mean: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episode_reward_min', 'Regex': 'episode_reward_min: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    
    {'Name': 'episode_len_max', 'Regex': 'episode_len_mean: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episode_len_mean', 'Regex': 'episode_len_mean: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episode_len_min', 'Regex': 'episode_len_mean: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 

    {'Name': 'best_snake_episode_len_max', 'Regex': 'best_snake_episode_len_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'worst_snake_episode_len_max', 'Regex': 'worst_snake_episode_len_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},

    {'Name': 'Snake_hit_wall_max', 'Regex': 'Snake_hit_wall_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Snake_was_eaten_max', 'Regex': 'Snake_was_eaten_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Killed_another_snake_max', 'Regex': 'Killed_another_snake_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Snake_hit_body_max', 'Regex': 'Snake_hit_body_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Starved_max', 'Regex': 'Starved_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Forbidden_move_max', 'Regex': 'Forbidden_move_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}
] 

algorithm = "PPO"
map_size = 11
num_agents = 5
additional_config = {
    'lambda': 0.90,
    'gamma': 0.999,
    'kl_coeff': 0.2,
    'clip_rewards': True,
    'vf_clip_param': 175.0,
    'train_batch_size': 9216,
    'sample_batch_size': 96,
    'sgd_minibatch_size': 256,
    'num_sgd_iter': 3,
    'lr': 5.0e-4,
}

estimator = RLEstimator(entry_point="train-mabs.py",
                        source_dir='training/training_src',
                        dependencies=["training/common/sagemaker_rl", "inference/inference_src/", "../BattlesnakeGym/"],
                        image_uri=image_name,
                        role=role,
                        train_instance_type=instance_type,
                        train_instance_count=1,
                        output_path=s3_output_path,
                        base_job_name=job_name_prefix,
                        metric_definitions=metric_definitions,
                        hyperparameters={
                            # See train-mabs.py to add additional hyperparameters
                            # Also see ray_launcher.py for the rl.training.* hyperparameters
                            
                            "num_iters": 10,
                            # number of snakes in the gym
                            "num_agents": num_agents,

                            "iterate_map_size": False,
                            "map_size": map_size,
                            "algorithm": algorithm,
                            "additional_configs": additional_config,
                            "use_heuristics_action_masks": False
                        }
                    )

estimator.fit()

job_name = estimator.latest_training_job.job_name
print("Training job: %s" % job_name)

train_instance_count has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
train_instance_type has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


2022-02-25 02:29:54 Starting - Starting the training job...
2022-02-25 02:29:56 Starting - Launching requested ML instancesProfilerReport-1645756193: InProgress
.........
2022-02-25 02:31:52 Starting - Preparing the instances for training......
2022-02-25 02:32:53 Downloading - Downloading input data
2022-02-25 02:32:53 Training - Downloading the training image......
2022-02-25 02:33:53 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2022-02-25 02:33:44,450 sagemaker-containers INFO     Imported framework sagemaker_tensorflow_container.training[0m
[34m2022-02-25 02:33:44,457 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2022-02-25 02:33:44,627 sagemaker-containers INFO     Installing module with the following command:[0m
[34m/usr/bin/python3 -m pip install . -r requirements.txt[0m
[34mProc

[34m2022-02-25 02:33:50,179#011INFO resource_spec.py:212 -- Starting Ray with 6.69 GiB memory available for workers and up to 3.36 GiB for objects. You can adjust these settings with ray.init(memory=<bytes>, object_store_memory=<bytes>).[0m
[34m2022-02-25 02:33:50,597#011INFO services.py:1078 -- View the Ray dashboard at #033[1m#033[32mlocalhost:8265#033[39m#033[22m[0m
[34mNo checkpoint path specified. Training from scratch.[0m
[34mImportant! Ray with version <=7.2 may report "Did not find checkpoint file" even if the experiment is actually restored successfully. If restoration is expected, please check "training_iteration" in the experiment info to confirm.[0m
[34m== Status ==[0m
[34mMemory usage on this node: 2.3/15.4 GiB[0m
[34mUsing FIFO scheduling algorithm.[0m
[34mResources requested: 4/4 CPUs, 0/0 GPUs, 0.0/6.69 GiB heap, 0.0/2.29 GiB objects[0m
[34mResult logdir: /opt/ml/output/intermediate/training[0m
[34mNumber of trials: 1 (1 RUNNING)[0m
[34m+-----------

[34m#033[2m#033[36m(pid=116)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=114)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=115)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=117)#033[0m 2022-02-25 02:34:33,590#011INFO trainable.py:178 -- _setup took 39.392 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.[0m
[34mResult for PPO_MultiAgentBattlesnake-v1_5e0c2aea:
  best_snake_episode_len_max: 15
  custom_metrics:
    Forbidden_move_max: 10
    Forbidden_move_mean: 3.825323475046211
    Forbidden_move_min: 0
    Killed_another_snake_max: 3
    Killed_another_snake_mean: 0.10905730129390019
    Killed_another_snake_min: 0
    Snake_hit_body_max: 4
    Snake_hit_body_mean: 0.24306839186691312
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 7
    Snake_hit_wall_mean: 1.3969500924214417
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.3257

[34mResult for PPO_MultiAgentBattlesnake-v1_5e0c2aea:
  best_snake_episode_len_max: 15
  custom_metrics:
    Forbidden_move_max: 10
    Forbidden_move_mean: 4.035846072746442
    Forbidden_move_min: 0
    Killed_another_snake_max: 3
    Killed_another_snake_mean: 0.13547706905640486
    Killed_another_snake_min: 0
    Snake_hit_body_max: 4
    Snake_hit_body_mean: 0.2836056931997891
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 6
    Snake_hit_wall_mean: 1.0542962572482868
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.29467580390089615
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 13
    policy0_max_len_mean: 2.518186610437533
    policy0_max_len_min: 0
    policy1_max_len_max: 13
    policy1_max_len_mean: 2.6389035318924616
    policy1_max_len_min: 0
    policy2_max_len_max: 15
    policy2_max_len_mean: 2.4802319451765946
    policy2_max_len_min: 0
    policy3_max_len_max: 13
    

[34mResult for PPO_MultiAgentBattlesnake-v1_5e0c2aea:
  best_snake_episode_len_max: 22
  custom_metrics:
    Forbidden_move_max: 9
    Forbidden_move_mean: 4.139118457300276
    Forbidden_move_min: 0
    Killed_another_snake_max: 3
    Killed_another_snake_mean: 0.19765840220385675
    Killed_another_snake_min: 0
    Snake_hit_body_max: 6
    Snake_hit_body_mean: 0.4380165289256198
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 4
    Snake_hit_wall_mean: 0.46763085399449034
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.4724517906336088
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 22
    policy0_max_len_mean: 3.319559228650138
    policy0_max_len_min: 0
    policy1_max_len_max: 22
    policy1_max_len_mean: 3.4056473829201104
    policy1_max_len_min: 0
    policy2_max_len_max: 22
    policy2_max_len_mean: 3.462121212121212
    policy2_max_len_min: 0
    policy3_max_len_max: 21
    po

[34mResult for PPO_MultiAgentBattlesnake-v1_5e0c2aea:
  best_snake_episode_len_max: 24
  custom_metrics:
    Forbidden_move_max: 8
    Forbidden_move_mean: 3.949615713065756
    Forbidden_move_min: 0
    Killed_another_snake_max: 3
    Killed_another_snake_mean: 0.28095644748078563
    Killed_another_snake_min: 0
    Snake_hit_body_max: 5
    Snake_hit_body_mean: 0.5943637916310845
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 4
    Snake_hit_wall_mean: 0.21007685738684884
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.6643894107600341
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 23
    policy0_max_len_mean: 3.987190435525192
    policy0_max_len_min: 0
    policy1_max_len_max: 21
    policy1_max_len_mean: 4.4534585824081985
    policy1_max_len_min: 0
    policy2_max_len_max: 24
    policy2_max_len_mean: 4.5687446626814685
    policy2_max_len_min: 0
    policy3_max_len_max: 23
    p

[34mResult for PPO_MultiAgentBattlesnake-v1_5e0c2aea:
  best_snake_episode_len_max: 37
  custom_metrics:
    Forbidden_move_max: 9
    Forbidden_move_mean: 3.698717948717949
    Forbidden_move_min: 0
    Killed_another_snake_max: 4
    Killed_another_snake_mean: 0.33974358974358976
    Killed_another_snake_min: 0
    Snake_hit_body_max: 5
    Snake_hit_body_mean: 0.7649572649572649
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 3
    Snake_hit_wall_mean: 0.12072649572649573
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 6
    Snake_was_eaten_mean: 0.7435897435897436
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 37
    policy0_max_len_mean: 4.990384615384615
    policy0_max_len_min: 0
    policy1_max_len_max: 27
    policy1_max_len_mean: 5.64957264957265
    policy1_max_len_min: 0
    policy2_max_len_max: 29
    policy2_max_len_mean: 5.762820512820513
    policy2_max_len_min: 0
    policy3_max_len_max: 36
    poli

[34mSaved the checkpoint file /opt/ml/output/intermediate/training/PPO_MultiAgentBattlesnake-v1_5e0c2aea_0_2022-02-25_02-33-50a3hcdivz/checkpoint_10/checkpoint-10 as /opt/ml/model/checkpoint[0m
[34mSaved the checkpoint file /opt/ml/output/intermediate/training/PPO_MultiAgentBattlesnake-v1_5e0c2aea_0_2022-02-25_02-33-50a3hcdivz/checkpoint_10/checkpoint-10.tune_metadata as /opt/ml/model/checkpoint.tune_metadata[0m
[34m2022-02-25 02:43:25,769#011INFO trainer.py:420 -- Tip: set 'eager': true or the --eager flag to enable TensorFlow eager execution[0m
[34m2022-02-25 02:43:25,778#011INFO trainer.py:580 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.[0m
  obj = yaml.load(type_)[0m
  obj = yaml.load(type_)[0m
  obj = yaml.load(type_)[0m
  obj = yaml.load(type_)[0m
  obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=1044)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=1044)#033[0m   obj = yaml.load(ty

In [9]:
# Where is the model stored in S3?
estimator.model_data

's3://sagemaker-soln-bs-rbcsnake-bucket/sagemaker-soln-bs-job-rllib-2022-02-25-02-29-53-333/output/model.tar.gz'

# Create an endpoint to host the policy
Firstly, we will delete the previous endpoint and model

In [10]:
sm_client = boto3.client(service_name='sagemaker')
waiter = sm_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=info['SagemakerEndPointName'])
try:
    sm_client.delete_endpoint(EndpointName=info['SagemakerEndPointName'])
    sm_client.delete_endpoint_config(EndpointConfigName=info['SagemakerEndPointName'])
    sm_client.delete_model(ModelName=info['SagemakerEndPointName'])
    ep_waiter = sm_client.get_waiter('endpoint_deleted')
    ep_waiter.wait(EndpointName=info['SagemakerEndPointName'])
except botocore.exceptions.ClientError:
    pass
    
# Copy the endpoint to a central location
model_data = "s3://{}/pretrainedmodels/model.tar.gz".format(s3_bucket)
!aws s3 cp {estimator.model_data} {model_data}

from sagemaker.tensorflow.serving import Model

model = Model(model_data=model_data,
              role=role,
              entry_point="inference.py",
              source_dir='inference/inference_src',
              framework_version='2.1.0',
              name=info['SagemakerEndPointName'],
              code_location='s3://{}//code'.format(s3_bucket)
             )

if local_mode:
    inf_instance_type = 'local'
else:
    inf_instance_type = info["SagemakerInferenceInstanceType"]

# Deploy an inference endpoint
predictor = model.deploy(initial_instance_count=1, instance_type=inf_instance_type,
                         endpoint_name=info['SagemakerEndPointName'])

Completed 6.2 MiB/6.2 MiB (32.9 MiB/s) with 1 file(s) remainingcopy: s3://sagemaker-soln-bs-rbcsnake-bucket/sagemaker-soln-bs-job-rllib-2022-02-25-02-29-53-333/output/model.tar.gz to s3://sagemaker-soln-bs-rbcsnake-bucket/pretrainedmodels/model.tar.gz


The class sagemaker.tensorflow.serving.Model has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
update_endpoint is a no-op in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


------!

# Test the endpoint

This example is using single observation for a 5-agent environment 
The last axis is 12 because the current MultiAgentEnv is concatenating 2 frames
5 agent maps + 1 food map = 6 maps total    6 maps * 2 frames = 12

In [11]:
import numpy as np
from time import time

state = np.zeros(shape=(1, 21, 21, 6), dtype=np.float32).tolist()

health_dict = {0: 50, 1: 50}
json = {"turn": 4,
        "board": {
                "height": 11,
                "width": 11,
                "food": [],
                "snakes": []
                },
            "you": {
                "id": "snake-id-string",
                "name": "Sneky Snek",
                "health": 90,
                "body": [{"x": 1, "y": 3}]
                }
            }

before = time()
action_mask = np.array([1, 1, 1, 1]).tolist()

action = predictor.predict({"state": state, "action_mask": action_mask,
                            "prev_action": -1, 
                           "prev_reward": -1, "seq_lens": -1,  
                           "all_health": health_dict, "json": json})
elapsed = time() - before

action_to_take = action["outputs"]["heuristisc_action"]
print("Action to take {}".format(action_to_take))
print("Inference took %.2f ms" % (elapsed*1000))

Action to take 0
Inference took 752.89 ms


# Navigation
- To go back to the introduction click [here](./1_Introduction.ipynb)
- To build some heuristics click [here](./3_HeuristicsDeveloper.ipynb)