In [None]:
import sys
import os
import logging

import torch
import numpy as np
from bs_gym.gymbattlesnake import BattlesnakeEnv
from a2c_ppo_acktr.storage import RolloutStorage

from policy import SnakePolicyBase, create_policy
from utils import PathHelper
# from utils import device

In [None]:

# TODO: CONFIG FILE
n_envs = 1
n_steps = 600

# torch.backends.cuda.matmul.allow_tf32 = False # Do matmul at TF32 mode.
CPU_THREADS = os.cpu_count()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda')
# device = torch.device('cpu')
MODEL_GROUP = 'test10'

NUM_LAYERS = 17
LAYER_WIDTH = 23
LAYER_HEIGHT = 23


In [None]:
# configure logger
logger = logging.getLogger('inference')

logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.DEBUG)
stdout_handler.setFormatter(formatter)
logger.addHandler(stdout_handler)


# get latest model
MODEL_PATH = None
u = PathHelper()
u.set_modelgroup(MODEL_GROUP, read_tmp=True)
MODEL_PATH, _ = u.get_latest_model()

print('Loading model from:', MODEL_PATH)

if MODEL_PATH is None:
    sys.exit(1)

In [None]:

tmp_env = BattlesnakeEnv(n_threads=CPU_THREADS, n_envs=n_envs)
tmp_env.close()

# Load policy
policy = create_policy(tmp_env.observation_space.shape, tmp_env.action_space, SnakePolicyBase)
policy.load_state_dict(torch.load(MODEL_PATH))

policy.to(device)
policy.eval()

In [None]:
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from bs_gym.gymbattlesnake import BattlesnakeEnv

import sys
import copy
def get_offset(obs):
    for x in range(23):
        for y in range(23):
            if obs[0][5][x][y] == 1:
                return x, y


def obs_to_frame(obs, width=11, height=11):
    output = np.zeros((width, height, 3), dtype=np.uint8)

    x_offset, y_offset = get_offset(obs)
    
    # layer reference: https://github.com/cbinners/gym-battlesnake/blob/master/gym_battlesnake/src/gamewrapper.cpp#L132
    for x in range(23):
        for y in range(23):
            # Render snake bodies
            if obs[0][1][x][y] == 1:
                output[x-x_offset][y-y_offset] = 255 - 10*(255 - obs[0][2][x][y])

            # layer 4: shared food location
            if obs[0][4][x][y] == 1:
                output[x-x_offset][y-y_offset][0] = 255
                output[x-x_offset][y-y_offset][1] = 255 # yellow
                output[x-x_offset][y-y_offset][2] = 0
            
            # layer 0: snake health
            if obs[0][0][x][y] > 0:
                output[x-x_offset][y-y_offset][0] = 255 # red
                output[x-x_offset][y-y_offset][1] = 0
                output[x-x_offset][y-y_offset][2] = 0
                
                # output[x-x_offset][y-y_offset][3] = obs[0][0][x][y]
                
            # layer 6: agent's head
            if obs[0][6][x][y] == 1:
                output[x-x_offset][y-y_offset][0] = 0
                output[x-x_offset][y-y_offset][1] = 255 # green
                output[x-x_offset][y-y_offset][2] = 0

    return output

def visualize_game(obs_list):

    # Keep track of game frames to render
    video = []

    for obs in obs_list:
        # Add the rendered observation to our frame stack
        video.append(obs_to_frame(obs))

    # Render, adapted from here: https://stackoverflow.com/questions/57060422/fast-way-to-display-video-from-arrays-in-jupyter-lab
    video = np.array(video, dtype=np.uint8)
    # fig = plt.figure()
    fig, ax = plt.subplots()

    im = ax.imshow(video[0,:,:,:])

    def init():
        im.set_data(video[0,:,:,:])

    def animate(i):
        for txt in ax.texts:
            txt.remove()
        im.set_data(video[i,:,:,:])

        obs = obs_list[i]
        x_offset, y_offset = get_offset(obs)
        print(f'x_offset: {x_offset}, y_offset: {y_offset}')
        
        layer = 2
        for x in range(23):
            for y in range(23):
                text = ''
                if obs[0][2][x][y] > 0:
                    text = obs[0][layer][x][y]
                ax.text(y-y_offset, x-x_offset, text,
                    ha="center", va="center", color="w")
        return im

    plt.close()

    anim = animation.FuncAnimation(fig,
                                   animate, init_func=init,
                                   frames=video.shape[0],
                                   interval=200 # milliseconds per frame
                                   )
    return anim

In [None]:
# Start server to accept inference requests
from fastapi import FastAPI
from typing import List
from pydantic import BaseModel

import time
app = FastAPI()

class InferenceRequest(BaseModel):
    id: str
    width: int
    height: int
    input: 'Frames'

class Frames(BaseModel):
    l0_health: List[List[int]]
    l1_bodies: List[List[int]]
    l2_segments: List[List[int]]
    l3_snake_length: List[List[int]]
    l4_food: List[List[int]]
    l5_board: List[List[int]]
    l6_head_mask: List[List[int]]
    l7_tail_mask: List[List[int]]
    l8_bodies_gte: List[List[int]]
    l9_bodies_lt: List[List[int]]
    alive_count: List[List[List[int]]] # 7 layers

InferenceRequest.model_rebuild()

obss = []

@app.post("/api/predict")
def predict(req: InferenceRequest):
    # TODO: logging with UUID
    # TODO: time inference

    id = req.id
    # width = req.width
    # height = req.height
    frames = req.input

    # create observation
    # obs = np.zeros(shape=(n_envs, NUM_LAYERS, LAYER_WIDTH, LAYER_HEIGHT), dtype=np.uint8)
    array = [
        [
        frames.l0_health,
        frames.l1_bodies,
        frames.l2_segments,
        frames.l3_snake_length,
        frames.l4_food,
        frames.l5_board,
        frames.l6_head_mask,
        frames.l7_tail_mask,
        frames.l8_bodies_gte,
        frames.l9_bodies_lt,
        *frames.alive_count,
        ]
    ]
    obs = np.asarray(array, dtype=np.uint8)
    obss.append(obs)

    startTime = time.time()

    # execute interence on environment
    with torch.no_grad():

        inp = torch.tensor(obs, dtype=torch.float32).to(device)
        action, value = policy.predict(inp, deterministic=True)

    # bench time
    endTime = time.time()
    logger.info(f"Inference took {endTime - startTime} seconds for {id}")

    # convert pytorch tensor to numpy integer
    flattened_action:np.array = action.cpu().numpy().flatten()
    flattened_value = value.cpu().numpy().flatten()
    if flattened_action.shape != (1,) or flattened_value.shape != (1,):
        logger.error(f"Invalid action or value shape: {flattened_action.shape}, {flattened_value.shape}")
        return {
            "action": -1,
            "value": 0.0,
            "error": "Invalid action or value shape",
        }
    if flattened_action[0] not in [0, 1, 2, 3]:
        return {
            "action": -1,
            "value": 0.0,
            "error": "Unexpected network output",
        }

    return {
        "action": int(flattened_action[0]),
        "value": int(flattened_value[0]),
        "error": "",
    }


@app.get("/")
def read_root():
    return {"status": "ok"}

@app.get("/health")
def health():
    return {"status": "ok"}

import nest_asyncio
import uvicorn

In [None]:


obss = []
nest_asyncio.apply()

uvicorn.run(app, host='0.0.0.0', port=7801)

In [None]:
obs = obss[3]
# print all frames
for f in range(17):
    print('i=', f)
    print(obs[0][f])

In [None]:

anim = visualize_game(obss)


In [None]:
HTML(anim.to_html5_video())