Initialisations to change the base directory of the entry scripts and to update mxnet to the newest mxnet-mkl

In [None]:
import boto3
import sagemaker
import numpy as np

from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner
from sagemaker.mxnet.estimator import MXNet

Set the parameters of the notebook.
`run_hpo = False` will only train the network once within this notebook. 
If `run_hpo = True` will run the hyperparameter optimization (run the training multiple times with different parameters to find the best parameters) (see Hyperparameter tuning jobs in SageMaker)

In [None]:
run_hpo = False
map_size = (15, 15)

## Initialise sagemaker
We need to define several parameters prior to running the training job. 

In [None]:
sage_session = sagemaker.session.Session()
s3_bucket = sage_session.default_bucket()
s3_output_path = 's3://{}/'.format(s3_bucket)
print("S3 bucket path: {}".format(s3_output_path))

# Run on local mode if no HPO is required
if run_hpo == False:
    local_mode = True
else:
    local_mode = False
    
if local_mode:
    train_instance_type = 'local'
else:
    train_instance_type = "SAGEMAKER_TRAINING_INSTANCE_TYPE"
endpoint_instance_type = "SAGEMAKER_INFERENCE_INSTANCE_TYPE"
    
role = sagemaker.get_execution_role()
print("Using IAM role arn: {}".format(role))

## Define the attributes of the training job
Use `job_name_prefix` to identify the sagemaker training job for this.

In [None]:
job_name_prefix = 'Battlesnake-job-mxnet'

## Define the metrics to evaluate your training job
The regex for this metric was defined based on what is printed in the training script `examples/train.py`

In [None]:
metric_definitions = [
    {'Name': 'timesteps', 'Regex': '.*Mean timesteps ([-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?)'},
]

## Define the hyperparameters of your job

In [None]:
map_size_string = "[{}, {}]".format(map_size[0], map_size[1])
static_hyperparameters = {
    'qnetwork_type': "attention",
    'seed': 111,
    'number_of_snakes': 4,
    'episodes': 10000,
    'print_score_steps': 10,
    'activation_type': "softrelu",
    'state_type': 'one_versus_all',
    'sequence_length': 2,
    'repeat_size': 3,
    'kernel_size': 3,
    'starting_channels': 6,
    'map_size': map_size_string,
    'snake_representation': 'bordered-51s',
    'save_model_every': 700,
    'eps_start': 0.99,
    'models_to_save': 'local'
}

# Train your model here
Defines the estimator.
If `run_hpo == False`, this training job will run. Please note that this will take a couple of hours.

In [None]:
estimator = MXNet(entry_point="train.py",
                  source_dir='training/training_src',
                  dependencies=["../BattlesnakeGym/"],
                  role=role,
                  instance_type=train_instance_type,
                  instance_count=1,
                  output_path=s3_output_path,
                  framework_version="1.6.0",
                  py_version='py3',
                  base_job_name=job_name_prefix,
                  metric_definitions=metric_definitions,
                  hyperparameters=static_hyperparameters
                 )
if local_mode:
    estimator.fit()

# Running hyperparameter optimisation
Start the Hyperparameter optimisation sagemaker jobs!
HPO will run if `run_hpo == True`. You can view the training progress in SageMaker > Training > Hyperparameter tuning jobs. Please note that this runs a couple of instances and could be costly.

In [None]:
hyperparameter_ranges = {
    'buffer_size': IntegerParameter(1000, 6000),
    'update_every': IntegerParameter(10, 20),
    'batch_size': IntegerParameter(16, 256),

    'lr_start': ContinuousParameter(1e-5, 1e-3),
    'lr_factor': ContinuousParameter(0.5, 1.0),
    'lr_step': IntegerParameter(5000, 30000),
    
    'tau': ContinuousParameter(1e-4, 1e-3),
    'gamma': ContinuousParameter(0.85, 0.99),
    
    'depth': IntegerParameter(10, 256),
    'depthS': IntegerParameter(10, 256),
}

In [None]:
max_jobs = 3
max_parallel_jobs = 3

tuner = HyperparameterTuner(estimator,
                            objective_metric_name='timesteps',
                            objective_type='Maximize',
                            hyperparameter_ranges=hyperparameter_ranges,
                            metric_definitions=metric_definitions,
                            max_jobs=max_jobs,
                            max_parallel_jobs=max_parallel_jobs,
                            base_tuning_job_name=job_name_prefix)
if run_hpo:
    tuner.fit()

Now wait for the hyper parameter turner to complete. If you are running HPO, please check SageMaker > Training > Hyperparameter tuning jobs for the progress.

# Updating your SageMaker endpoint

## Collect the target model

Once you have retrained your models, we will copy the model artifacts into your SageMaker notebook then package it for a SageMaker endpoint. 

Firstly, we will obtain an s3 URL of the best model.

In [None]:
if run_hpo:
    best_training_job = tuner.best_training_job()
    best_model_path = "{}/{}/output/model.tar.gz".format(s3_output_path, best_training_job)
else:
    best_model_path = estimator.model_data
model_path_key = best_model_path.replace(s3_output_path, "")
print("Best model location {}".format(best_model_path))

Download the best model and put it into LocalEnv/pretrained_models/

Note that your new models will override the old models and you can keep version control of all the models

In [None]:
s3 = boto3.resource('s3')
s3.Bucket(s3_bucket).download_file(model_path_key, 'inference/pretrained_models/model.tar.gz')

model_dir = "Model-{}x{}".format(map_size[0], map_size[1])
!rm -r mxnet_inference/pretrained_models/{model_dir}

!mkdir mxnet_inference/pretrained_models/{model_dir}
!tar -xf mxnet_inference/pretrained_models/model.tar.gz -C mxnet_inference/pretrained_models/{model_dir}
!rm mxnet_inference/pretrained_models/model.tar.gz

Package pretrained_models to endpoint

In [None]:
!mv inference/pretrained_models Models
!tar -czf Models.tar.gz Models
!mv Models inference/pretrained_models

s3_client = boto3.client('s3')
s3_client.upload_file("Models.tar.gz", s3_bucket, 
                      "battlesnake-aws/pretrainedmodels/Models.tar.gz")
!rm Models.tar.gz

## Update the SageMaker endpoint with your new model

In [None]:
model_data = "s3://{}/battlesnake-aws/pretrainedmodels/Models.tar.gz".format(s3_bucket)
print("Make an endpoint with {}".format(model_data))

Delete the existing endpoint, model, and endpoint configuration files

In [None]:
sm_client = boto3.client(service_name='sagemaker')
sm_client.delete_endpoint(EndpointName='battlesnake-endpoint')
sm_client.delete_endpoint_config(EndpointConfigName='battlesnake-endpoint')
sm_client.delete_model(ModelName="battlesnake-mxnet")

Create a new endpoint with the new model

In [None]:
from sagemaker.mxnet import MXNetModel
mxnet_model = MXNetModel(model_data=model_data,
                             entry_point='predict.py',
                             role=role,
                             framework_version='1.6.0',
                             source_dir='inference/inference_src',
                             name="battlesnake-mxnet",
                             py_version='py3')
predictor = mxnet_model.deploy(initial_instance_count=1,
                               instance_type=endpoint_instance_type,
                               endpoint_name='battlesnake-endpoint')

## Testing that your endpoint works.
You should see `Action to take is X`

In [None]:
data1 = np.zeros(shape=(1, 2, 3, map_size[0]+2, map_size[1]+2))
data2 = np.zeros(shape=(1, 2))
data3 = np.zeros(shape=(1, 2))
data4 = np.zeros(shape=(1, 2))
health_dict = {0: 50, 1: 50}
json = {"board": {
            "height": 15,
            "width": 15,
            "food": [],
            "snakes": []
            },
        "you": {
            "id": "snake-id-string",
            "name": "Sneky Snek",
            "health": 90,
            "body": [{"x": 1, "y": 3}]
            }
        }
action = predictor.predict({"state": data1, "snake_id": data2, 
                           "turn_count": data3, "health": data4,  
                           "all_health": health_dict, "map_width": map_size[0], "json": json})
print("Action to take is {}".format(action))