**Note**: When running this notebook on SageMaker Studio, you should make sure the 'SageMaker JumpStart Tensorflow 1.0' image/kernel is used. You can run run all cells at once or step through the notebook.
# Policy Training

This notebook outlines the steps involved in building and deploying a Battlesnake model using Ray RLlib and TensorFlow on Amazon SageMaker.

Library versions currently in use:  TensorFlow 2.1, Ray RLlib 0.8.2

The model is first trained using multi-agent PPO, and then deployed to a managed _TensorFlow Serving_ SageMaker endpoint that can be used for inference.

In [1]:
import sagemaker
from sagemaker.rl import RLEstimator, RLToolkit
import boto3
import botocore
import json

In [10]:
with open("../stack_outputs.json") as f:
    info = json.load(f)
print(info)

{'AwsAccountId': '018864217387', 'AwsRegion': 'us-west-2', 'S3Bucket': 'sagemaker-soln-bs-rbcsnake-bucket', 'SolutionPrefix': 'sagemaker-soln-bs', 'SageMakerIamRoleArn': 'arn:aws:iam::018864217387:role/sagemaker-soln-bs-us-west-2-nb-role', 'SnakeAPI': 'https://56fbd5ro6k.execute-api.us-west-2.amazonaws.com/snake/', 'EndPointS3Location': 's3://sagemaker-solutions-prod-us-west-2/sagemaker-battlesnake-ai/1.1.0/build/model-complete.tar.gz', 'SagemakerEndPointName': 'sagemaker-soln-bs-ep', 'SagemakerTrainingInstanceType': 'ml.m5.xlarge', 'SagemakerInferenceInstanceType': 'ml.t2.medium'}


## Initialise sagemaker
We need to define several parameters prior to running the training job. 

In [3]:
sm_session = sagemaker.session.Session()
s3_bucket = info["S3Bucket"]

s3_output_path = 's3://{}/'.format(s3_bucket)
print("S3 bucket path: {}".format(s3_output_path))

S3 bucket path: s3://sagemaker-soln-bs-rbcsnake-bucket/


In [4]:
job_name_prefix = info["SolutionPrefix"]+'-job-rllib'

role = info["SageMakerIamRoleArn"]
print(role)

arn:aws:iam::018864217387:role/sagemaker-soln-bs-us-west-2-nb-role


Change local_mode to True if you want to do local training within this Notebook instance

In [5]:
local_mode = False

if local_mode:
    instance_type = 'local'
else:
    instance_type = info["SagemakerTrainingInstanceType"]
    
# If training locally, do some Docker housekeeping..
if local_mode:
    !/bin/bash ./common/setup.sh

# Train your model here

In [6]:
region = sm_session.boto_region_name
device = "cpu"
image_name = '462105765813.dkr.ecr.{region}.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.2-tf-{device}-py36'.format(region=region, device=device)

In [7]:
%%time

# Define and execute our training job
# Adjust hyperparameters and train_instance_count accordingly

metric_definitions =  [
    {'Name': 'training_iteration', 'Regex': 'training_iteration: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episodes_total', 'Regex': 'episodes_total: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'num_steps_trained', 'Regex': 'num_steps_trained: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'timesteps_total', 'Regex': 'timesteps_total: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'training_iteration', 'Regex': 'training_iteration: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},

    {'Name': 'episode_reward_max', 'Regex': 'episode_reward_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episode_reward_mean', 'Regex': 'episode_reward_mean: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episode_reward_min', 'Regex': 'episode_reward_min: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    
    {'Name': 'episode_len_max', 'Regex': 'episode_len_mean: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episode_len_mean', 'Regex': 'episode_len_mean: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 
    {'Name': 'episode_len_min', 'Regex': 'episode_len_mean: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}, 

    {'Name': 'best_snake_episode_len_max', 'Regex': 'best_snake_episode_len_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'worst_snake_episode_len_max', 'Regex': 'worst_snake_episode_len_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},

    {'Name': 'Snake_hit_wall_max', 'Regex': 'Snake_hit_wall_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Snake_was_eaten_max', 'Regex': 'Snake_was_eaten_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Killed_another_snake_max', 'Regex': 'Killed_another_snake_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Snake_hit_body_max', 'Regex': 'Snake_hit_body_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Starved_max', 'Regex': 'Starved_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
    {'Name': 'Forbidden_move_max', 'Regex': 'Forbidden_move_max: ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'}
] 

algorithm = "PPO"
map_size = 11
num_agents = 5
additional_config = {
    'lambda': 0.90,
    'gamma': 0.999,
    'kl_coeff': 0.2,
    'clip_rewards': True,
    'vf_clip_param': 175.0,
    'train_batch_size': 9216,
    'sample_batch_size': 96,
    'sgd_minibatch_size': 256,
    'num_sgd_iter': 3,
    'lr': 5.0e-4,
}

estimator = RLEstimator(entry_point="train-mabs.py",
                        source_dir='training/training_src',
                        dependencies=["training/common/sagemaker_rl", "inference/inference_src/", "../BattlesnakeGym/"],
                        image_uri=image_name,
                        role=role,
                        train_instance_type=instance_type,
                        train_instance_count=1,
                        output_path=s3_output_path,
                        base_job_name=job_name_prefix,
                        metric_definitions=metric_definitions,
                        hyperparameters={
                            # See train-mabs.py to add additional hyperparameters
                            # Also see ray_launcher.py for the rl.training.* hyperparameters
                            
                            "num_iters": 10,
                            # number of snakes in the gym
                            "num_agents": num_agents,

                            "iterate_map_size": False,
                            "map_size": map_size,
                            "algorithm": algorithm,
                            "additional_configs": additional_config,
                            "use_heuristics_action_masks": False
                        }
                    )

estimator.fit()

job_name = estimator.latest_training_job.job_name
print("Training job: %s" % job_name)

train_instance_count has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
train_instance_type has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


2022-03-02 20:06:52 Starting - Starting the training job...
2022-03-02 20:07:17 Starting - Launching requested ML instancesProfilerReport-1646251612: InProgress
.........
2022-03-02 20:08:37 Starting - Preparing the instances for training...
2022-03-02 20:09:19 Downloading - Downloading input data
2022-03-02 20:09:19 Training - Downloading the training image......
2022-03-02 20:10:23 Training - Training image download completed. Training in progress..[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2022-03-02 20:10:27,965 sagemaker-containers INFO     Imported framework sagemaker_tensorflow_container.training[0m
[34m2022-03-02 20:10:27,972 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2022-03-02 20:10:28,160 sagemaker-containers INFO     Installing module with the following command:[0m
[34m/usr/bin/python3 -m pip install . -r requirements.txt[0m
[34mProces

[34m#033[2m#033[36m(pid=111)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=111)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=111)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=111)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=112)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=113)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=110)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=113)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=112)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=110)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=113)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=110)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=112)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=113)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=110)#033[0m   obj = yaml.load(type_)[0m
[34m#033[

[34mResult for PPO_MultiAgentBattlesnake-v1_d1764e86:
  best_snake_episode_len_max: 14
  custom_metrics:
    Forbidden_move_max: 10
    Forbidden_move_mean: 3.783977110157368
    Forbidden_move_min: 0
    Killed_another_snake_max: 3
    Killed_another_snake_mean: 0.12875536480686695
    Killed_another_snake_min: 0
    Snake_hit_body_max: 6
    Snake_hit_body_mean: 0.27801621363853124
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 8
    Snake_hit_wall_mean: 1.3218884120171674
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.3562231759656652
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 12
    policy0_max_len_mean: 2.2460658082975677
    policy0_max_len_min: 0
    policy1_max_len_max: 11
    policy1_max_len_mean: 2.290891750119218
    policy1_max_len_min: 0
    policy2_max_len_max: 13
    policy2_max_len_mean: 2.2889842632331905
    policy2_max_len_min: 0
    policy3_max_len_max: 12
    

[34mResult for PPO_MultiAgentBattlesnake-v1_d1764e86:
  best_snake_episode_len_max: 18
  custom_metrics:
    Forbidden_move_max: 8
    Forbidden_move_mean: 4.143020594965675
    Forbidden_move_min: 0
    Killed_another_snake_max: 3
    Killed_another_snake_mean: 0.17791762013729978
    Killed_another_snake_min: 0
    Snake_hit_body_max: 5
    Snake_hit_body_mean: 0.35068649885583525
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 5
    Snake_hit_wall_mean: 0.755720823798627
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.40274599542334094
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 15
    policy0_max_len_mean: 2.747711670480549
    policy0_max_len_min: 0
    policy1_max_len_max: 18
    policy1_max_len_mean: 2.977688787185355
    policy1_max_len_min: 0
    policy2_max_len_max: 16
    policy2_max_len_mean: 2.9302059496567505
    policy2_max_len_min: 0
    policy3_max_len_max: 18
    po

[34mResult for PPO_MultiAgentBattlesnake-v1_d1764e86:
  best_snake_episode_len_max: 37
  custom_metrics:
    Forbidden_move_max: 8
    Forbidden_move_mean: 4.031660231660232
    Forbidden_move_min: 0
    Killed_another_snake_max: 3
    Killed_another_snake_mean: 0.25096525096525096
    Killed_another_snake_min: 0
    Snake_hit_body_max: 4
    Snake_hit_body_mean: 0.47876447876447875
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 4
    Snake_hit_wall_mean: 0.37065637065637064
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.5691119691119692
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 20
    policy0_max_len_mean: 3.7166023166023168
    policy0_max_len_min: 0
    policy1_max_len_max: 37
    policy1_max_len_mean: 4.197683397683398
    policy1_max_len_min: 0
    policy2_max_len_max: 36
    policy2_max_len_mean: 4.193050193050193
    policy2_max_len_min: 0
    policy3_max_len_max: 17
    p

[34mResult for PPO_MultiAgentBattlesnake-v1_d1764e86:
  best_snake_episode_len_max: 26
  custom_metrics:
    Forbidden_move_max: 9
    Forbidden_move_mean: 3.8449905482041586
    Forbidden_move_min: 0
    Killed_another_snake_max: 3
    Killed_another_snake_mean: 0.3393194706994329
    Killed_another_snake_min: 0
    Snake_hit_body_max: 5
    Snake_hit_body_mean: 0.6635160680529301
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 3
    Snake_hit_wall_mean: 0.18714555765595464
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.6427221172022685
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 25
    policy0_max_len_mean: 4.6351606805293
    policy0_max_len_min: 0
    policy1_max_len_max: 26
    policy1_max_len_mean: 5.131379962192817
    policy1_max_len_min: 0
    policy2_max_len_max: 22
    policy2_max_len_mean: 4.934782608695652
    policy2_max_len_min: 0
    policy3_max_len_max: 25
    polic

[34mResult for PPO_MultiAgentBattlesnake-v1_d1764e86:
  best_snake_episode_len_max: 32
  custom_metrics:
    Forbidden_move_max: 8
    Forbidden_move_mean: 3.4829931972789114
    Forbidden_move_min: 0
    Killed_another_snake_max: 4
    Killed_another_snake_mean: 0.4557823129251701
    Killed_another_snake_min: 0
    Snake_hit_body_max: 4
    Snake_hit_body_mean: 0.8469387755102041
    Snake_hit_body_min: 0
    Snake_hit_wall_max: 2
    Snake_hit_wall_mean: 0.14058956916099774
    Snake_hit_wall_min: 0
    Snake_was_eaten_max: 8
    Snake_was_eaten_mean: 0.8458049886621315
    Snake_was_eaten_min: 0
    Starved_max: 0
    Starved_mean: 0.0
    Starved_min: 0
    policy0_max_len_max: 28
    policy0_max_len_mean: 5.4229024943310655
    policy0_max_len_min: 0
    policy1_max_len_max: 31
    policy1_max_len_mean: 6.456916099773243
    policy1_max_len_min: 0
    policy2_max_len_max: 30
    policy2_max_len_mean: 6.187074829931973
    policy2_max_len_min: 0
    policy3_max_len_max: 28
    po

[34m#033[2m#033[36m(pid=1049)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=1049)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=1049)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=1049)#033[0m   obj = yaml.load(type_)[0m
[34m#033[2m#033[36m(pid=1049)#033[0m   obj = yaml.load(type_)[0m
[34m2022-03-02 20:20:50,048#011INFO trainable.py:178 -- _setup took 32.376 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.[0m
[34m2022-03-02 20:20:51,965#011INFO trainable.py:416 -- Restored on 10.0.251.116 from checkpoint: /opt/ml/model/checkpoint[0m
[34m2022-03-02 20:20:51,965#011INFO trainable.py:423 -- Current state after restoring: {'_iteration': 10, '_timesteps_total': 92160, '_time_total': 530.1671938896179, '_episodes_total': 14741}[0m
[34mSaved TensorFlow serving model![0m
[34m2022-03-02 20:20:57,518 sagemaker-containers INFO     Reporting training SUCCESS[0m

2022-0

In [8]:
# Where is the model stored in S3?
estimator.model_data

's3://sagemaker-soln-bs-rbcsnake-bucket/sagemaker-soln-bs-job-rllib-2022-03-02-20-06-52-405/output/model.tar.gz'

# Create an endpoint to host the policy
Firstly, we will delete the previous endpoint and model

In [11]:
sm_client = boto3.client(service_name='sagemaker')
waiter = sm_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=info['SagemakerEndPointName'])
try:
    sm_client.delete_endpoint(EndpointName=info['SagemakerEndPointName'])
    sm_client.delete_endpoint_config(EndpointConfigName=info['SagemakerEndPointName'])
    sm_client.delete_model(ModelName=info['SagemakerEndPointName'])
    ep_waiter = sm_client.get_waiter('endpoint_deleted')
    ep_waiter.wait(EndpointName=info['SagemakerEndPointName'])
except botocore.exceptions.ClientError:
    pass
    
# Copy the endpoint to a central location
model_data = "s3://{}/pretrainedmodels/model.tar.gz".format(s3_bucket)
!aws s3 cp {estimator.model_data} {model_data}

from sagemaker.tensorflow.serving import Model

model = Model(model_data=model_data,
              role=role,
              entry_point="inference.py",
              source_dir='inference/inference_src',
              framework_version='2.1.0',
              name=info['SagemakerEndPointName'],
              code_location='s3://{}//code'.format(s3_bucket)
             )

if local_mode:
    inf_instance_type = 'local'
else:
    inf_instance_type = info["SagemakerInferenceInstanceType"]

# Deploy an inference endpoint
predictor = model.deploy(initial_instance_count=1, instance_type=inf_instance_type,
                         endpoint_name=info['SagemakerEndPointName'])

Completed 6.2 MiB/6.2 MiB (23.1 MiB/s) with 1 file(s) remainingcopy: s3://sagemaker-soln-bs-rbcsnake-bucket/sagemaker-soln-bs-job-rllib-2022-03-02-20-06-52-405/output/model.tar.gz to s3://sagemaker-soln-bs-rbcsnake-bucket/pretrainedmodels/model.tar.gz


The class sagemaker.tensorflow.serving.Model has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
update_endpoint is a no-op in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


--------!

# Test the endpoint

This example is using single observation for a 5-agent environment 
The last axis is 12 because the current MultiAgentEnv is concatenating 2 frames
5 agent maps + 1 food map = 6 maps total    6 maps * 2 frames = 12

In [12]:
import numpy as np
from time import time

state = np.zeros(shape=(1, 21, 21, 6), dtype=np.float32).tolist()

health_dict = {0: 50, 1: 50}
json = {"turn": 4,
        "board": {
                "height": 11,
                "width": 11,
                "food": [],
                "snakes": []
                },
            "you": {
                "id": "snake-id-string",
                "name": "Sneky Snek",
                "health": 90,
                "body": [{"x": 1, "y": 3}]
                }
            }

before = time()
action_mask = np.array([1, 1, 1, 1]).tolist()

action = predictor.predict({"state": state, "action_mask": action_mask,
                            "prev_action": -1, 
                           "prev_reward": -1, "seq_lens": -1,  
                           "all_health": health_dict, "json": json})
elapsed = time() - before

action_to_take = action["outputs"]["heuristisc_action"]
print("Action to take {}".format(action_to_take))
print("Inference took %.2f ms" % (elapsed*1000))

Action to take 2
Inference took 2757.06 ms


# Navigation
- To go back to the introduction click [here](./1_Introduction.ipynb)
- To build some heuristics click [here](./3_HeuristicsDeveloper.ipynb)