In [None]:
!pip install -r battlesnake_gym/requirements.txt
!pip install -e battlesnake_gym/

In [None]:
from collections import namedtuple

import gym
from gym import wrappers
import numpy as np
import mxnet as mx
import matplotlib.pyplot as plt
%matplotlib inline
from IPython import display

from battlesnake_gym.battlesnake_gym.snake_gym import BattleSnakeGym
from battlesnake_src.networks.utils import sort_states_for_snake_id

# Define the openAI gym
The parameters here must match the ones provided during training

In [None]:
map_size = (15, 15)
number_of_snakes = 4
env = BattleSnakeGym(map_size=map_size, number_of_snakes=number_of_snakes, observation_type="bordered-51s")

# Load the trained model

In [None]:
params_name = "pretrained_models/Model/local-0000.params".format(map_size[0], map_size[1])
symbol_name = "pretrained_models/Model/local-symbol.json".format(map_size[0], map_size[1])

ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu()
net = mx.gluon.SymbolBlock.imports(symbol_name, ['data0', 'data1', 'data2', 'data3'],
                                   params_name, ctx=ctx)
net.hybridize(static_alloc=True, static_shape=True)

Net takes the following arguments:
`net(state, snake_id, turn_count, snake_health)`

`state`: *nd.array* of size (batch_size, sequence_length=2, c=3, map_size[0]+2, mapsize[1]+2)
- Provides the observation space of the gym
- `batch_size` should be set to 1
- `sequence_length` provides the number of timesteps back the model considers. Give `t` is the current time step, `c=0`   refers to `t-1` and `c=1` is t.
- Each `c` slide refers to the *food*, *current snake*, and *other snakes* respectively.
- `map_size` is based on the size of the BattleSnakeGym +2 for the -1 border.

`snake_id`: *nd.array* of size (batch_size, sequence_length=2)
- Provides the id of the snake, which is a i \in [0...number_of_snakes-1]

`turn_count`: *nd.array* of size (batch_size, sequence_length=2)
- Provides the number of turns that has elapsed (obtained from `info["current_turn"]` in the gym)

`snake_health`: *nd.array* of size (batch_size, sequence_length=2)
- Provides the health of the snake (obtained from `info["snake_health"]` in the gym)

# Visualisation loop

In [None]:
def get_action(state, snake_id, turn_count, health, memory):
    '''
    Processes the input data to be fed into the defined neural network.

    Parameters:
    -----------
    `state`: np.array of size (3, map_size[0]+2, map_size[1]+2)
    Provides the current observation of the gym
    
    `snake_id`: int
    Indicates the id where id \in [0...number_of_snakes]
    
    `turn_count`: int
    Indicates the number of elapsed turns
    
    `health`: dict
    Indicates the health of all snakes in the form of {snake_id: health}
    
    `memory`: (state, turn_count, health)
    Indicates the state, turn_count, and health of the previous turn.
    
    Returns:
    -----------
    `action`: np.array 
    The expected q value of the action.
    i.e., the larger the value the better the action is.
    To get the next best action, perform np.argmax(action)
    '''
    sequence_length = 2
    state_i = sort_states_for_snake_id(state, i+1, one_versus_all=True)
    previous_state_i = sort_states_for_snake_id(memory.state, i+1, one_versus_all=True)
    state_sequence = mx.nd.array(np.stack([previous_state_i, state_i]), ctx=ctx).transpose((0, 3, 1, 2)).expand_dims(0)
    
    snake_id_sequence = mx.nd.array([snake_id]*sequence_length, ctx=ctx).expand_dims(0)
    turn_count_sequence = mx.nd.array([memory.turn_count, turn_count], ctx=ctx).expand_dims(0)
    snake_health_sequence = mx.nd.array([memory.health[snake_id], health[snake_id]], ctx=ctx).expand_dims(0)
        
    action = net(state_sequence, snake_id_sequence, turn_count_sequence, snake_health_sequence)
    return action.asnumpy()[0]

In [None]:
from battlesnake_inference.battlesnake_heuristics import MyBattlesnakeHeuristics

Memory = namedtuple("Memory", "state turn_count health")
heuristics = MyBattlesnakeHeuristics()

state, _, _, info = env.reset()
memory = Memory(state=np.zeros(state.shape), turn_count=info["current_turn"], health=info["snake_health"])
while True:
    info["current_turn"] += 1
    actions = []
    for i in range(number_of_snakes):
        action = get_action(state, snake_id=i,
                            turn_count=info["current_turn"]+1,
                            health=info["snake_health"],
                            memory=memory)        
        # Add heuristics
        action = heuristics.run(state, snake_id=i,
                                turn_count=info["current_turn"]+1,
                                health=info["snake_health"], 
                                action=action)
        actions.append(action)
    memory = Memory(state=state, turn_count=info["current_turn"], 
                    health=info["snake_health"])
    next_state, reward, dones, infos = env.step(np.array(actions))
    
    # Check if only 1 snake remains
    number_of_snakes_alive = sum(list(dones.values()))
    if number_of_snakes - number_of_snakes_alive <= 1:
        done = True
    else:
        done = False
    
    state = next_state
    if done:
        print("Completed")
        break  
        
    
    # Display the results
    rgb_array = env.render(mode="rgb_array")
    plt.clf()
    
    health_str = "".join(["s{}={} ".format(k, infos["snake_health"][k]) for k in infos["snake_health"]])
    plt.title("{}, health {}".format(info["current_turn"],
                                     health_str))
    plt.imshow(rgb_array)
    plt.axis('off')
    
    display.clear_output(wait=True)
    display.display(plt.gcf())

# Deploy your new heuristics

In [None]:
import sagemaker

sage_session = sagemaker.session.Session()
s3_bucket = sage_session.default_bucket()

target_key = "battlesnake-pretrainedmodels/Model-{}x{}/Model.tar.gz".format(map_size[0], map_size[1])

model_data = "s3://{}/{}".format(s3_bucket, target_key)

from sagemaker.mxnet import MXNetModel
mxnet_model = MXNetModel(model_data=model_data,
                             entry_point='predict.py',
                             role=role,
                             framework_version='1.6.0',
                             source_dir='battlesnake_inference',
                             name="battlesnake-mxnet",
                             py_version='py3')
predictor = mxnet_model.deploy(initial_instance_count=1,
                               instance_type=endpoint_instance_type,
                               update_endpoint=True,
                               endpoint_name='battlesnake-endpoint')