**Note**: When running this notebook on SageMaker Studio, you should make sure the 'SageMaker JumpStart Tensorflow 1.0' image/kernel is used. You can run run all cells at once or step through the notebook.

In [1]:
! pip3 install gym

from io import BytesIO
import time
import sys
sys.path.append("../BattlesnakeGym")
import json
import boto3
import botocore
import PIL.Image
import sagemaker

import numpy
import gym
from gym import wrappers
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
%matplotlib inline
from importlib import reload
from IPython import display
import ipywidgets as widgets
from IPython.display import display as i_display

from heuristics_utils import simulate
from battlesnake_gym.snake_gym import BattlesnakeGym




# Introduction
In this notebook, you can take the model that you've developed (or the pre-trained model provided) and develop heuristics to edit your snake's behaviour.

### Define the openAI gym
Optionally, you can define the initial game state (the situation simulator) of the snakes and food.
To use the initial state, set `USE_INITIAL_STATE = True` and enter the desired coordinates of the snake and food using the initial_state dictionary. The dictionary follows the same format as the battlesnake API.

In [2]:
USE_INITIAL_STATE = False

# Sample initial state for the situation simulator
initial_state = {
    "turn": 4,
    "board": {
        "height": 11,
        "width": 11,
        "food": [
            {
                "x": 1,
                "y": 3
            }
        ],
        "snakes": [{
                        "health": 90,
                        "body": [{"x": 3, "y": 0}],
                    }, 
                    {
                        "health": 90,
                        "body": [{"x": 6, "y": 0}],
                    },
                    {
                        "health": 90,
                        "body": [{"x": 2, "y": 5}],
                    },
                    {
                        "health": 90,
                        "body": [{"x": 6, "y": 4}],
                    },
                    {
                        "health": 90,
                        "body": [{"x": 7, "y": 3}],
                    },
                  ]

    }
}

if USE_INITIAL_STATE == False:
    initial_state = None

The parameters here must match the ones provided during training (except initial_state)

In [3]:
map_size = (11, 11)
number_of_snakes = 5
random_snake = True  # If this is set to true, the simulator will not use the network but just output random actions
env = BattlesnakeGym(map_size=map_size, number_of_snakes=number_of_snakes, observation_type="max-bordered-51s", 
                     initial_game_state=initial_state)

# Load the trained model
Load the RLlib models.

In [4]:
%cd inference
!mkdir output
!tar -C output -xvf model.tar.gz
%cd ..

/home/ec2-user/SageMaker/RLlibEnv/inference
params.json
checkpoint.tune_metadata
1/
1/saved_model.pb
1/variables/
1/variables/variables.index
1/variables/variables.data-00000-of-00002
1/variables/variables.data-00001-of-00002
checkpoint
/home/ec2-user/SageMaker/RLlibEnv


In [5]:
model_filepath = "inference/output/1/"
if random_snake:
    net = None
else:
    imported = tf.saved_model.load(model_filepath)
    net = imported.signatures["serving_default"]

In [6]:
# Clean up the model
!rm -r inference/output

# Simulation loop

Run a simulation with the environment with the heuritics that you wrote. 
To edit the heuristics, edit the file `RLlibEnv/inference/src/battlesnake_heuristics`.
Note that you can track the progress of your work with git.

In [7]:
import inference.inference_src.battlesnake_heuristics
reload(inference.inference_src.battlesnake_heuristics)
from inference.inference_src.battlesnake_heuristics import MyBattlesnakeHeuristics

heuristics = MyBattlesnakeHeuristics()
infos, rgb_arrays, actions, heuristics_remarks, json_array = simulate(env, net, heuristics, number_of_snakes, random_snake)

snake 0
turn 2
[[0.25586194 0.77410989 0.43122534 0.28888822]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, True, True]
==food==
[True, True, True, True]
[[0.25586194 0.77410989 0.43122534 0.28888822]]
1
====
snake 1
turn 2
[[0.64623334 0.34778004 0.3753329  0.08384057]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, True, True]
==food==
[True, True, True, True]
[[0.64623334 0.34778004 0.3753329  0.08384057]]
0
====
snake 2
turn 2
[[0.8702927  0.53357037 0.40588072 0.60629369]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, True, True]
==food==
[True, True, True, True]
[[0.8702927  0.53357037 0.40588072 0.60629369]]
0
====
snake 3
turn 2
[[0.66383184 0.7020551  0.79637212 0.44775015]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, True, True]
==food==
[True, True, True, True]
[[0.66383184 0.7020551  0.79637212 0.44775015]]
2
====
snake 4
turn 2
[[0.43187263 0.77811562 0.1385129  0.6314563 ]]
==wall==
[True, True, True, True]
==f

snake 1
turn 20
[[0.10072098 0.85770546 0.49732505 0.42559756]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, False, True]
==food==
[True, True, True, True]
[[0.10072098 0.85770546 0.         0.42559756]]
1
====
snake 2
turn 20
[[0.76705094 0.90935555 0.87853788 0.85631413]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, False, True]
==food==
[True, True, True, True]
[[0.76705094 0.90935555 0.         0.85631413]]
1
====
snake 3
turn 20
[[0.98386733 0.10229183 0.52420806 0.41190204]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, False, True]
==food==
[True, True, True, True]
[[0.98386733 0.10229183 0.         0.41190204]]
0
====
snake 1
turn 21
[[0.12384631 0.47500107 0.13694116 0.41183824]]
==wall==
[True, True, True, True]
==forbidden==
[False, True, True, True]
==food==
[True, True, True, True]
[[0.         0.47500107 0.13694116 0.41183824]]
1
====
snake 2
turn 21
[[0.21061306 0.42589094 0.16497979 0.07980214]]
==wall==
[True, True, True, 

snake 1
turn 40
[[0.21232708 0.47405068 0.34973535 0.55220001]]
==wall==
[True, True, True, True]
==forbidden==
[True, False, True, True]
==food==
[True, True, True, True]
[[0.21232708 0.         0.34973535 0.55220001]]
3
====
snake 3
turn 40
[[0.57828    0.36827441 0.66233828 0.28841643]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, True, False]
==food==
[True, True, True, True]
[[0.57828    0.36827441 0.66233828 0.        ]]
2
====
snake 1
turn 41
[[0.54127576 0.35169074 0.94366274 0.08667236]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, False, True]
==food==
[True, True, True, True]
[[0.54127576 0.35169074 0.         0.08667236]]
0
====
snake 3
turn 41
[[0.39582198 0.83836207 0.76647794 0.66418659]]
==wall==
[True, True, True, True]
==forbidden==
[True, True, True, False]
==food==
[True, True, True, True]
[[0.39582198 0.83836207 0.76647794 0.        ]]
1
====


# Playback the simulation

Defines the user interface of the simulator.

In [9]:
def get_env_json():
    if slider.value < len(json_array):
        return json_array[slider.value]
    else:
        return ""
    
def play_simulation(_):
    for i in range(slider.value, len(rgb_arrays) - slider.value - 1):
        slider.value = slider.value + 1
        display_image(slider.value)
        time.sleep(0.2)

def on_left_button_pressed(_):
    if slider.value > 0:
        slider.value = slider.value - 1 
    display_image(slider.value)

def on_right_button_pressed(_):
    if slider.value < len(rgb_arrays):
        slider.value = slider.value + 1 
    display_image(slider.value)
        
def display_image(index):  
    if index >= len(rgb_arrays):
        return
    info = infos[index]
    action = actions[index]
    heuristics = heuristics_remarks[index]
    snake_colours = env.snakes.get_snake_colours()
        
    line_0 = [widgets.Label("Turn count".format(info["current_turn"])),
                 widgets.Label("Snake")]
    
    line_1 = [widgets.Label(""), widgets.Label("Health")]
    
    line_2 = [widgets.Label("{}".format(info["current_turn"])), 
              widgets.Label("Action")]
    
    line_3 = [widgets.Label(""), widgets.Label("Gym remarks")]
    
    line_4 = [widgets.Label(""), widgets.Label("Heur. remarks")]

    action_convertion_dict = {0: "Up", 1: "Down", 2: "Left", 3: "Right", 4: "None"}
    for snake_id in range(number_of_snakes):
        snake_health = "{}".format(info["snake_health"][snake_id])
        snake_health_widget = widgets.Label(snake_health)
        snake_action = "{}".format(action_convertion_dict[action[snake_id]])
        snake_action_widget = widgets.Label(snake_action)

        snake_colour = snake_colours[snake_id]
        hex_colour = '#%02x%02x%02x' % (snake_colour[0], snake_colour[1], snake_colour[2])
        snake_colour_widget = widgets.HTML(value = f"<b><font color="+hex_colour+">⬤</b>")

        gym_remarks = ""
        if snake_id in info["snake_info"]:
            if info["snake_info"][snake_id] != "Did not colide": 
                gym_remarks = "{}".format(info["snake_info"][snake_id])
        gym_remarks_widget = widgets.Label(gym_remarks)
        
        heuris_remarks = "{}".format(heuristics[snake_id])
        heuris_remarks_widget = widgets.Label(heuris_remarks)

        line_0.append(snake_colour_widget)
        line_1.append(snake_health_widget)
        line_2.append(snake_action_widget)
        line_3.append(gym_remarks_widget)
        line_4.append(heuris_remarks_widget)

    line_0_widget = widgets.VBox(line_0)
    line_1_widget= widgets.VBox(line_1)
    line_2_widget = widgets.VBox(line_2)
    line_3_widget = widgets.VBox(line_3)
    line_4_widget = widgets.VBox(line_4)
   
    info_widget = widgets.HBox([line_0_widget, line_1_widget, line_2_widget, line_3_widget, line_4_widget])
        
    image = PIL.Image.fromarray(rgb_arrays[index])
    f = BytesIO()
    image.save(f, "png")
    
    states_widget = widgets.Image(value=f.getvalue(), width=500)
    main_widgets_list = [states_widget, info_widget]
    
    main_widget = widgets.HBox(main_widgets_list)
    
    display.clear_output(wait=True)
    i_display(navigator)
    i_display(main_widget)
    
left_button = widgets.Button(description='◄')
left_button.on_click(on_left_button_pressed)
right_button = widgets.Button(description='►')
right_button.on_click(on_right_button_pressed)
slider = widgets.IntSlider(max=len(rgb_arrays) - 1)
play_button = widgets.Button(description='Play')
play_button.on_click(play_simulation)

navigator = widgets.HBox([left_button, right_button, slider, play_button])
display_image(index=0)

HBox(children=(Button(description='◄', style=ButtonStyle()), Button(description='►', style=ButtonStyle()), Int…

HBox(children=(Image(value=b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02N\x00\x00\x02N\x08\x02\x00\x00\x00…

To get a JSON representation of the gym (environment), run the following function. You can also use output of the following function as an initial_state of the gym.

*Please provide this json array if you are reporting bugs in the gym*

In [10]:
get_env_json()

{'turn': 0,
 'board': {'height': 11,
  'width': 11,
  'food': [{'x': 4, 'y': 10}],
  'snakes': [{'health': 100,
    'body': [{'x': 8, 'y': 3}],
    'id': 0,
    'name': 'Snake 0'},
   {'health': 100, 'body': [{'x': 10, 'y': 4}], 'id': 1, 'name': 'Snake 1'},
   {'health': 100, 'body': [{'x': 3, 'y': 8}], 'id': 2, 'name': 'Snake 2'},
   {'health': 100, 'body': [{'x': 6, 'y': 4}], 'id': 3, 'name': 'Snake 3'},
   {'health': 100, 'body': [{'x': 7, 'y': 4}], 'id': 4, 'name': 'Snake 4'}]}}

# Deploy the SageMaker endpoint
This section will deploy your new heuristics into the SageMaker endpoint

In [11]:
with open("../stack_outputs.json") as f:
    info = json.load(f)

In [12]:
sage_session = sagemaker.session.Session()
s3_bucket = info["S3Bucket"]
role = info["SageMakerIamRoleArn"]
endpoint_instance_type = info["SagemakerInferenceInstanceType"]
print("Your sagemaker s3_bucket is s3://{}".format(s3_bucket))

Your sagemaker s3_bucket is s3://sagemaker-soln-bs-rbcsnake-bucket


In [13]:
model_data = "s3://{}/pretrainedmodels/model.tar.gz".format(s3_bucket)

# Check if model_data exists
s3_client = boto3.client('s3')

try:
    s3_client.get_object(Bucket=s3_bucket, Key="pretrainedmodels/model.tar.gz")
except botocore.errorfactory.ClientError:
    model_data = info["EndPointS3Location"]

## Deploy your new heuristics
Using the new heuristics you developed, a new SageMaker endpoint will be created.

Firstly, delete the old endpoint, model and endpoint config.

In [14]:
sm_client = boto3.client(service_name='sagemaker')
waiter = sm_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=info['SagemakerEndPointName'])
try:
    sm_client.delete_endpoint(EndpointName=info['SagemakerEndPointName'])
    sm_client.delete_endpoint_config(EndpointConfigName=info['SagemakerEndPointName'])
    sm_client.delete_model(ModelName=info['SagemakerEndPointName'])
    ep_waiter = sm_client.get_waiter('endpoint_deleted')
    ep_waiter.wait(EndpointName=info['SagemakerEndPointName'])
except botocore.exceptions.ClientError:
    pass

# Run the following cells to create a new model and endpoint with the new heuristics

from sagemaker.tensorflow.serving import Model

model = Model(model_data=model_data,
              role=role,
              entry_point="inference.py",
              source_dir='inference/inference_src',
              framework_version='2.1.0',
              name=info['SagemakerEndPointName'],
              code_location='s3://{}//code'.format(s3_bucket)
             )

# Deploy an inference endpoint
predictor = model.deploy(initial_instance_count=1, instance_type=endpoint_instance_type, endpoint_name=info['SagemakerEndPointName'])

The class sagemaker.tensorflow.serving.Model has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.
update_endpoint is a no-op in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


-----!

## Testing the new endpoint
You should see `Action to take is X`

In [15]:
import numpy as np
from time import time

state = np.zeros(shape=(1, 21, 21, 6), dtype=np.float32).tolist()

health_dict = {0: 50, 1: 50}
json = {"turn": 4,
        "board": {
                "height": 11,
                "width": 11,
                "food": [],
                "snakes": []
                },
            "you": {
                "id": "snake-id-string",
                "name": "Sneky Snek",
                "health": 90,
                "body": [{"x": 1, "y": 3}]
                }
            }

before = time()
action_mask = np.array([1, 1, 1, 1]).tolist()

action = predictor.predict({"state": state, "action_mask": action_mask,
                            "prev_action": -1, 
                           "prev_reward": -1, "seq_lens": -1,  
                           "all_health": health_dict, "json": json})
elapsed = time() - before

action_to_take = action["outputs"]["heuristisc_action"]
print("Action to take {}".format(action_to_take))
print("Inference took %.2f ms" % (elapsed*1000))

Action to take 2
Inference took 3637.13 ms


# Navigation
- To go back to the introduction click [here](./1_Introduction.ipynb)
- To train a new model click [here](./2_PolicyTraining.ipynb)