# Visualization


In [None]:
# !export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

import torch
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from bs_gym.gymbattlesnake import BattlesnakeEnv
from a2c_ppo_acktr.storage import RolloutStorage

from policy import SnakePolicyBase, create_policy
from utils import n_opponents, device
from utils import PathHelper, plot_graphs


In [None]:

u = PathHelper()
rollouts = None
policy = None
policies = None

# TODO:
n_envs = 1
n_steps = 600
CPU_THREADS = 6


def setup_env(model_path):
    global rollouts, policy, policies
    tmp_env = BattlesnakeEnv(n_threads=2, n_envs=n_envs)

    rollouts = RolloutStorage(n_steps,
                            n_envs,
                            tmp_env.observation_space.shape,
                            tmp_env.action_space,
                            n_steps)
    tmp_env.close()

    policy = create_policy(tmp_env.observation_space.shape, tmp_env.action_space, SnakePolicyBase)
    policy.load_state_dict(torch.load(model_path))

    policy.to(device)
    policy.eval()

    policies = [policy for _ in range(n_opponents)]
    # TODO: load multiple models, load state dict into policies


In [None]:


def obs_to_frame(obs):
    ''' Converts an environment observation into a renderable RGB image '''
    # First, let's find the game board dimensions from layer 5
    x_offset, y_offset = 0, 0
    done = False
    for x in range(23):
        if done:
            break
        for y in range(23):
            if obs[0][5][x][y] == 1:
                x_offset = x
                y_offset = y
                done = True
                break
    output = np.zeros((11, 11, 3), dtype=np.uint8)

    # See https://github.com/cbinners/gym-battlesnake/blob/master/gym_battlesnake/src/gamewrapper.cpp#L55 for
    # layer reference
    # TODO: to improve?
    for x in range(23):
        for y in range(23):
            # Render snake bodies
            if obs[0][1][x][y] == 1:
                output[x-x_offset][y-y_offset] = 255 - 10*(255 - obs[0][2][x][y])
            # Render food
            if obs[0][4][x][y] == 1:
                output[x-x_offset][y-y_offset][0] = 255
                output[x-x_offset][y-y_offset][1] = 255
                output[x-x_offset][y-y_offset][2] = 0
            # Render snake heads as a red pixel
            if obs[0][0][x][y] > 0:
                output[x-x_offset][y-y_offset][0] = 255
                output[x-x_offset][y-y_offset][1] = 0
                output[x-x_offset][y-y_offset][2] = 0
            # Render snake heads
            if obs[0][6][x][y] == 1:
                output[x-x_offset][y-y_offset][0] = 0
                output[x-x_offset][y-y_offset][1] = 255
                output[x-x_offset][y-y_offset][2] = 0

    return output

def visualize_game(policy):
    playground = BattlesnakeEnv(n_threads=CPU_THREADS, n_envs=1, fixed_orientation=True,
                                opponents=policies, device=device,
                                )

    # Reset the environment 
    obs = playground.reset()

    # Keep track of game frames to render
    video = []

    # Grab a set of frames to render
    with torch.no_grad():
        for _ in tqdm(range(300)):
            # Add the rendered observation to our frame stack
            video.append(obs_to_frame(obs))

            # Get the action our policy should take
            _, action, _, _ = policy.act(torch.tensor(obs, dtype=torch.float32).to(device), None, None)

            # Perform our action and update our observation
            obs,_,_,_ = playground.step(action.cpu().squeeze())

    # Render, adapted from here: https://stackoverflow.com/questions/57060422/fast-way-to-display-video-from-arrays-in-jupyter-lab

    video = np.array(video, dtype=np.uint8)
    fig = plt.figure()

    im = plt.imshow(video[0,:,:,:])
    def init():
        im.set_data(video[0,:,:,:])
    def animate(i):
        im.set_data(video[i,:,:,:])
        return im

    plt.close()

    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=200)
    return anim


In [None]:
u.set_modelgroup('test6', read_tmp=True)
model_path = None

latest_model_path, _ = u.get_latest_model()
print(latest_model_path)
if latest_model_path is not None:
    model_path = latest_model_path

print('Loading model from:', model_path)
setup_env(model_path)

data = u.load_data()
rewards = data['rewards']
value_losses = data['value_losses']
action_losses = data['action_losses']
dist_entropies = data['dist_entropies']
lengths = data['lengths']

print(data)

plot_graphs(rewards, value_losses, action_losses, dist_entropies, lengths)

In [None]:
anim = visualize_game(policy)


In [None]:
HTML(anim.to_html5_video())