## Tic Tac Toc + Reinforcement Learning with SageMaker

**SageMaker Studio Kernel**: Data Science

In this example you'll create a simple game with an agent that plays Tic Tac Toe against a human player. The NPC (Non Player Character) or Bot is a trained **Tensorflow** model. The model will be trained using Reinforcement Learning (OpenAI Gym + RLLib) + **Amazon SageMaker**.

After the training process, you'll optimize the model and convert it to a light format using **Amazon SageMaker Neo**. Then with DLR (Runtime) the game will load the model and use its predictions as actions for the Bot.

These are the activities for this example:
- Prepare an OpenAI Gym custom environment that represents the board and the rules of the game
- Prepare an Heuristics (rule based engine) that will play against the agent to make it learn
- Train the model using SageMaker
- Load the model and play against the Bot using an IPython (widgets) based application

### 1/4) First we need a new OpenAI Gym that represents the board and the game rules

This is a multi-agent experiment, so we need an Env that supports two players simultaneously. 
To make this work, besides the 9 possible positions in the board we need to create an additional action to represent the player waiting its turn. 

Run the next cell to visualize the env definition in Python

In [None]:
!pygmentize tictactoe/tictactoe.py

### 2/4) Heuristics/policy that plays against the agent to make it improve

In order to learn how to play and improve its skills, the agent needs to play against a good rival.
For that purpose, we can make use of a Policy/Heuristics where we encode some rules and conditions that will simulate a human player.  

The policy has a set of rules that are adjusted stochastically over time to behave like a harder or an easier player.

Run the next cell to visualize the policy.

In [None]:
!pygmentize tictactoe/heuristics.py

### 3/4) Training the model
Now it is time to train our agent using SageMaker Reinforcemenet Learning Estimator.  

SageMaker expects that you share a python script with the estimator to execute the training. The following script defines the whole training process using Ray+RLLib + Tensorflow 2 + OpenAI Gym.

In [None]:
%%writefile tictactoe/train.py
import sys
import subprocess
# we need a special package for cleaning our data, lets pip install it first
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "sagemaker-training==3.9.2", "ray[rllib]==1.2.0"])

import copy
import os
import argparse
import traceback
import random
import time
import numpy as np
import glob
import re

import ray
import ray.tune
import ray.rllib as rllib
from ray.rllib.agents.registry import get_agent_class

import gym
from gym import error, spaces, utils
from gym.utils import seeding
from gym.envs.registration import register

from shutil import copyfile

from sagemaker_training import environment, intermediate_output, logging_config, params, files
from tensorflow.python.tools import freeze_graph
from tensorflow.python.saved_model import tag_constants

from heuristics import SemiSmartTicTacToeHeuristicsPolicy

def freeze_model(saved_model_dir, output_node_names, output_filename):
    output_graph_filename = os.path.join(saved_model_dir, output_filename)
    initializer_nodes = ''
    freeze_graph.freeze_graph(
        input_saved_model_dir=saved_model_dir,
        output_graph=output_graph_filename,
        saved_model_tags = tag_constants.SERVING,
        output_node_names=output_node_names,
        initializer_nodes=initializer_nodes,
        input_graph=None,
        input_saver=False,
        input_binary=False,
        input_checkpoint=None,
        restore_op_name=None,
        filename_tensor_name=None,
        clear_devices=True,
        input_meta_graph=False,
    )

def start_file_sync(env):
    """Uploads the checkpoints to S3 in background"""
    global logger, intermediate_sync
    ## this service will copy all the files, stored in the intermediate dir, to S3
    region = os.environ.get("AWS_REGION", os.environ.get(params.REGION_NAME_ENV))
    s3_endpoint_url = os.environ.get(params.S3_ENDPOINT_URL, None)

    logger.info("Starting intermediate sync. %s: %s - %s" % (region, env.sagemaker_s3_output(), s3_endpoint_url))
    intermediate_sync = intermediate_output.start_sync(
        env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url
    )
    
def get_latest_checkpoint(env, algo):
    """Scan the intermediate dir and get the latest checkpoint"""
    global logger
    logger.info("Latest checkpoint")
    # get the latest experiment
    experiments = glob.glob(os.path.join(env.output_intermediate_dir,'training', f'{algo}*'))
    experiments.sort(key=lambda x: [int(c) if c.isdigit() else c for c in ''.join(x.replace('-','').split('_')[-2:])])

    if len(experiments) > 0:
        exp_name = experiments[-1]

        chkpts = [c for c in glob.glob(f'{exp_name}/checkpoint*')]
        chkpts.sort(key=lambda x: [int(c) if c.isdigit() else c for c in re.split('(\d+)', x)])

        if len(chkpts) == 0: raise Exception("No checkpoint found!")
        ckpt_path=chkpts[-1]
        ckpt_meta_filename=ckpt_path.split('/')[-1].split('_')
        ckpt_meta_filename=f'{ckpt_meta_filename[0]}-{int(ckpt_meta_filename[1])}'
        logger.info(f'{ckpt_path}/{ckpt_meta_filename}')
        return ckpt_path, ckpt_meta_filename

def save_model(env_vars, experiment_params):
    """Load a checkpoint and export it as a TF1.15 model SavedModel"""
    global logger
    config = copy.deepcopy(experiment_params)['training']['config']

    config["monitor"] = False
    config["num_workers"] = 1
    config["num_gpus"] = 0
    logger.info(experiment_params)
    algo = experiment_params['training']['run']
    env_name = experiment_params['training']['env']
    logger.info(f'{algo} - {env_name}')
    cls = get_agent_class(algo)        
    agent = cls(env=env_name, config=config)
    
    ckpt_path, ckpt_meta_filename = get_latest_checkpoint(env_vars, algo)
    
    logger.info('Restoring agent...')
    agent.restore(os.path.join(ckpt_path, ckpt_meta_filename))
    logger.info(f'Exporting model to {env_vars.model_dir}...')
    
    agent.export_policy_model(os.path.join(env_vars.output_intermediate_dir,'saved_model'), 'agent_x')
    freeze_model(os.path.join(env_vars.output_intermediate_dir,'saved_model'), 'agent_x/fc_out/BiasAdd', 'frozen.pb')
    copyfile(os.path.join(env_vars.output_intermediate_dir,'saved_model', 'frozen.pb'), os.path.join(env_vars.model_dir, 'frozen.pb'))
    
if __name__ == "__main__":
    
    env_vars = environment.Environment()
    parser = argparse.ArgumentParser()
    logging_config.configure_logger(env_vars.log_level)
    
    parser.add_argument("--log-level", type=int, default=0)
    parser.add_argument("--record-videos", type=bool, default=False)
    parser.add_argument("--num-workers", type=int, default=max(env_vars.num_cpus-1, 3))
    parser.add_argument("--num-gpus", type=int, default=env_vars.num_gpus)
    parser.add_argument("--batch-mode", type=str, default="complete_episodes")
    parser.add_argument("--episode-reward-mean", type=float, default=3.5)
    parser.add_argument("--learning-rate", type=float, default=0.001)
    parser.add_argument("--init-seed", type=int, default=-1)
    parser.add_argument("--refining-iter", type=int, default=4)
    args,unknown = parser.parse_known_args()

    seed=args.init_seed if args.init_seed != -1 else None
    
    random.seed(seed)
    np.random.seed(seed)
    
    logger = logging_config.get_logger()
    intermediate_sync = None

    env_name='TicTacToeEnv-v0'
    register(
        id=env_name,
        entry_point='tictactoe:TicTacToeEnv'
    )
    env = gym.make(env_name)
    env.seed(seed)
    
    experiment_params = {
        "training": {
            "env": env_name,
            "run": "A3C",
            "stop": {
                "episode_reward_mean": args.episode_reward_mean,
            },
            "local_dir": env_vars.output_intermediate_dir,
            "checkpoint_at_end": True,
            "checkpoint_freq": 60,
            #"export_formats": ["h5"],
            "config": {            
                "log_level": args.log_level,
                "monitor": args.record_videos,
                #"framework": "tfe",
                "lr": args.learning_rate,
                "model": {
                    # https://docs.ray.io/en/master/rllib-models.html#default-model-config-settings
                },
                "multiagent": {
                    "policies": {
                        "agent_x": (None, env.observation_space, env.action_space, {}),
                        "agent_o": (SemiSmartTicTacToeHeuristicsPolicy, env.observation_space, env.action_space, {})
                    },
                    "policy_mapping_fn": lambda x: x,
                    "policies_to_train": ["agent_x"],                
                },            
                "num_workers": args.num_workers,
                "num_gpus": args.num_gpus,
                "batch_mode": args.batch_mode,
                "seed": seed
            }
        }
    }

    try:
        start_file_sync(env_vars)
        # main program
        ray.init()
        ray.tune.register_env(env_name, lambda x: env)
        ray.tune.run_experiments(copy.deepcopy(experiment_params))
        for i in range(args.refining_iter):
            seed = int(time.time())
            random.seed(seed)
            np.random.seed(seed)    
            env.seed(seed)
            algo = experiment_params['training']['run']
            ckpt_path, ckpt_meta_filename = get_latest_checkpoint(env_vars, algo)
            experiment_params['training']['config']['seed'] = seed
            experiment_params['training']['restore'] = os.path.join(ckpt_path, ckpt_meta_filename)
            ray.tune.run_experiments(copy.deepcopy(experiment_params))
        save_model(env_vars, experiment_params)
        ray.shutdown()
        
        files.write_success_file()
        logger.info("Reporting training SUCCESS")
    except Exception as e:
        failure_msg = "framework error: \n%s\n%s" % (traceback.format_exc(), str(e))
        logger.error("Reporting training FAILURE: %s" % failure_msg)
        files.write_failure_file(failure_msg)
    finally:
        if intermediate_sync:
            intermediate_sync.join()

#### Training the agent using SageMaker RL container

In [None]:
import sagemaker
import boto3
# S3 bucket
sagemaker_session = sagemaker.session.Session()
s3_bucket = sagemaker_session.default_bucket()  
s3_output_path = 's3://{}/'.format(s3_bucket)

# create a descriptive job name 
aws_region = boto3.Session().region_name
role = sagemaker.get_execution_role()
print("S3 bucket path: {}".format(s3_output_path))

In [None]:
from sagemaker.rl import RLEstimator, RLToolkit, RLFramework
import time

image_name=f"462105765813.dkr.ecr.{aws_region}.amazonaws.com/sagemaker-rl-ray-container:ray-1.1.0-tf-gpu-py36"
estimator = RLEstimator(
    image_uri=image_name,
    entry_point="train.py",
    source_dir='tictactoe',
    role=role,
    instance_type='ml.p3.2xlarge',
    max_run=60*(60 * 2),
    instance_count=1,
    output_path=s3_output_path,
    metric_definitions=RLEstimator.default_metric_definitions(RLToolkit.RAY),
    hyperparameters={
        "log-level": 20,
        "record-videos": False,
        "batch-mode": "complete_episodes",
        "episode-reward-mean": 5.0,
        "learning-rate": 0.0001,
        "init-seed": 1, # seed == 1 makes the agent learn faster but it gets biased
        "refining-iter": 5 # refining iterations are to make the agent generalize to random matches
    }
)

#### Kick-off the training job

In [None]:
estimator.fit(wait=True)
job_name = estimator.latest_training_job.job_name
print("Training job: %s" % job_name)

### 4/4) Playing against the agent

There is a simple IPython application [Game](tictactoe/game.py) that we'll use to load the model, render the board an let us play against the agent.

This application expects a model optimized with SageMaker NEO.
So, let's compile the Tensorflow model with SageMaker Neo and use it in our local Application.

In [None]:
import time
import boto3
import sagemaker

role = sagemaker.get_execution_role()
sagemaker_session=sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()

sm_client = boto3.client('sagemaker')

# Change the target OS + Arch depending on the device you'll run this application
os='LINUX'
arch='X86_64' # ARM64
model_name='tic-tac-toe'
filename=f'model-{os}_{arch}.tar.gz'

s3_uri=f'{estimator.output_path}{estimator.latest_training_job.name}/output/model.tar.gz'
s3_uri_out=f's3://{sagemaker_session.default_bucket()}/{model_name}-tensorflow/optimized/{filename}'

compilation_job_name = f'{model_name}-tensorflow-{int(time.time()*1000)}'
sm_client.create_compilation_job(
    CompilationJobName=compilation_job_name,
    RoleArn=role,
    InputConfig={
        'S3Uri': s3_uri,
        'DataInputConfig': '{"agent_x/observations": [1,10]}',
        'Framework': 'TENSORFLOW'
    },
    OutputConfig={
        'S3OutputLocation': s3_uri_out,
        'TargetPlatform': { 
            'Os': os, 
            'Arch': arch,
            #'Accelerator': 'NVIDIA'  # comment this if you don't have an Nvidia GPU
        },
        # Comment or change the following line depending on your edge device
        # Jetson Xavier: sm_72; Jetson Nano: sm_53
        #'CompilerOptions': '{"trt-ver": "7.1.3", "cuda-ver": "10.2", "gpu-code": "sm_53"}' # Jetpack 4.4.1
    },
    StoppingCondition={ 'MaxRuntimeInSeconds': 900 }
)
while True:
    resp = sm_client.describe_compilation_job(CompilationJobName=compilation_job_name)    
    if resp['CompilationJobStatus'] in ['STARTING', 'INPROGRESS']:
        print('Running...')
    else:
        print(resp['CompilationJobStatus'], compilation_job_name)
        break
    time.sleep(5)

#### Download and unpack the compiled model

In [None]:
!aws s3 cp $s3_uri_out /tmp/
!rm -rf model_neo && mkdir model_neo
!tar -xzvf /tmp/$filename -C model_neo

#### Install DLR (runtime that loads the model)

In [None]:
# install DLR, the runtime required to load the model
!pip install -U dlr==1.8.0

#### Have fun!! :)

In [None]:
from tictactoe.game import TicTacToeGame
game = TicTacToeGame()
game.run()