In [None]:
# assuming a conda/mambda environment has been configured already

# !curl -Ls https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xvj bin/micromamba


# !MAMBA_ROOT_PREFIX="$HOME/mambaforge-pypy3" eval "$(bin/micromamba shell hook --shell zsh)" && MAMBA_ROOT_PREFIX="$HOME/mambaforge-pypy3" micromamba activate bs
# !source $HOME/.zshrc && mamba activate bs
# !mkdir -p $HOME/micromamba
# !eval "$(bin/micromamba shell hook --shell xonsh)"
# !bin/micromamba shell reinit --shell 
# !bin/micromamba activate $HOME/mambaforge-pypy3/envs/bs

In [9]:
!cd ../bs-gym/ && pip install -e . -q
!cd .
!cd ../pytorch-a2c-ppo-acktr-gail/ && pip install -e . -q

In [None]:
!which python
!which pip
!pip list

In [1]:
# !export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python

import torch
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML
from bs_gym.gymbattlesnake import BattlesnakeEnv
from a2c_ppo_acktr.storage import RolloutStorage

from policy import SnakePolicyBase, create_policy


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  "stable-baselines is in maintenance mode, please use [Stable-Baselines3 (SB3)](https://github.com/DLR-RM/stable-baselines3) for an up-to-date version. You can find a [migration guide](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html) in SB3 documentation."


In [7]:

# NOTE: CONFIG

CPU_THREADS = 6
device = torch.device('cuda')
n_envs = 208
n_steps = 600

tmp_env = BattlesnakeEnv(n_threads=CPU_THREADS, n_envs=n_envs)

# Storage for rollouts (game turns played and the rewards)
rollouts = RolloutStorage(n_steps,
                          n_envs,
                          tmp_env.observation_space.shape,
                          tmp_env.action_space,
                          n_steps)
tmp_env.close()

policy = create_policy(tmp_env.observation_space.shape, tmp_env.action_space, SnakePolicyBase)
# policy.load_state_dict(torch.load('models/final_08_02_19_07_2023.pt'))
policy.load_state_dict(torch.load('models/final_08_02_19_07_2023.pt'))



policy.to(device)
policy.eval()

OPPONENTS = 3
policies = [policy for _ in range(OPPONENTS)]

# policies = [
#     torch.load('models/final_06_59_19_07_2023.pt'),
#     torch.load('models/weights_06_59_19_07_2023_iter50.pt'),
#     torch.load('models/weights_06_53_19_07_2023_iter15.pt'),
# ]


In [9]:


def obs_to_frame(obs):
    ''' Converts an environment observation into a renderable RGB image '''
    # First, let's find the game board dimensions from layer 5
    x_offset, y_offset = 0, 0
    done = False
    for x in range(23):
        if done:
            break
        for y in range(23):
            if obs[0][5][x][y] == 1:
                x_offset = x
                y_offset = y
                done = True
                break
    output = np.zeros((11, 11, 3), dtype=np.uint8)

    # See https://github.com/cbinners/gym-battlesnake/blob/master/gym_battlesnake/src/gamewrapper.cpp#L55 for
    # layer reference
    # TODO: to improve?
    for x in range(23):
        for y in range(23):
            # Render snake bodies
            if obs[0][1][x][y] == 1:
                output[x-x_offset][y-y_offset] = 255 - 10*(255 - obs[0][2][x][y])
            # Render food
            if obs[0][4][x][y] == 1:
                output[x-x_offset][y-y_offset][0] = 255
                output[x-x_offset][y-y_offset][1] = 255
                output[x-x_offset][y-y_offset][2] = 0
            # Render snake heads as a red pixel
            if obs[0][0][x][y] > 0:
                output[x-x_offset][y-y_offset][0] = 255
                output[x-x_offset][y-y_offset][1] = 0
                output[x-x_offset][y-y_offset][2] = 0
            # Render snake heads
            if obs[0][6][x][y] == 1:
                output[x-x_offset][y-y_offset][0] = 0
                output[x-x_offset][y-y_offset][1] = 255
                output[x-x_offset][y-y_offset][2] = 0

    return output

def visualize_game(policy):
    playground = BattlesnakeEnv(n_threads=CPU_THREADS, n_envs=1, fixed_orientation=True,
                                opponents=policies,
                                )

    # Reset the environment 
    obs = playground.reset()

    # Keep track of game frames to render
    video = []

    # Grab a set of frames to render
    with torch.no_grad():
        for _ in tqdm(range(300)):
            # Add the rendered observation to our frame stack
            video.append(obs_to_frame(obs))

            # Get the action our policy should take
            _, action, _, _ = policy.act(torch.tensor(obs, dtype=torch.float32).to(device), None, None)

            # Perform our action and update our observation
            obs,_,_,_ = playground.step(action.cpu().squeeze())

    # Render, adapted from here: https://stackoverflow.com/questions/57060422/fast-way-to-display-video-from-arrays-in-jupyter-lab

    video = np.array(video, dtype=np.uint8)
    fig = plt.figure()

    im = plt.imshow(video[0,:,:,:])
    def init():
        im.set_data(video[0,:,:,:])
    def animate(i):
        im.set_data(video[i,:,:,:])
        return im

    plt.close()

    anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=200)
    return anim


In [10]:
anim = visualize_game(policy)


  0%|          | 0/300 [00:00<?, ?it/s]

  inputs = torch.tensor(inputs, dtype=torch.float).to(device)


In [11]:
HTML(anim.to_html5_video())