In [None]:
import os
import numpy as np
import pickle
import matplotlib.pyplot as plt
from copy import deepcopy
from ipywidgets import interact
from tqdm.notebook import tqdm
from numba import njit
import torch
import torch.nn.functional as F
import requests
import json

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from kaggle_environments import make
from kaggle_environments import structify

In [None]:
import lux_ai.rl_agent.rl_agent as rl_agent

In [None]:
# replay_id, player_id = 34014720, 0
replay_id, player_id = 33947998, 1
# replay_id, player_id = 34222068, 0

replay = requests.get(f"https://www.kaggleusercontent.com/episodes/{replay_id}.json")
replay = json.loads(replay.text)
steps = replay["steps"]
config = replay["configuration"]
len(steps)

In [None]:
env = make("lux_ai_2021", configuration=config, info=replay["info"], steps=steps)
env.render(mode="ipython", width=1000, height=800)

In [None]:
from process_cerberus_replays import get_delta_with_cache, extract_obs

In [None]:
def get_worker_build_prb(obs, out):
    size = obs["width"]
    prb = out["policy_logits"]["city_tile"]
    prb = prb.squeeze(0).squeeze(0)[player_id]
    prb = F.softmax(prb, dim=-1)
    prb = np.array(prb)
    return prb[:size, :size, 1]

In [None]:
states = list()
values = list()
deltas = list()
outputs = list()
worker_build_prbs = list()

obs = extract_obs(steps[0], player_id)
agent = rl_agent.RLAgent(obs, config)
for step in tqdm(steps):
    states.append(deepcopy(agent.game_state))
    obs = extract_obs(step, player_id)
    deltas.append(get_delta_with_cache(replay_id, agent, obs, config, skip_uncached=True))
    out = agent(obs, config, True)
    outputs.append(out)
    values.append(float(out['baseline'][0][player_id]))
    worker_build_prbs.append(get_worker_build_prb(obs, out))

In [None]:
from cerberus_viz import (
    make_figure,
    add_traces,
    plot_array,
)

In [None]:
def heatmap_function(state):
    if state.id:
        return -deltas[state.turn]
    return deltas[state.turn]

def timeseries_function(state):
    return {"expected_value": values[:state.turn + 1]}

def cityhighlight_function(state):
    return worker_build_prbs[state.turn] / 2 + 0.5

Widget that lets you step through the match:

In [None]:
@interact(view_step=(0, len(steps) - 1))
def interactive_display(view_step=0):
    state = states[view_step]
    fig = make_figure(state, replay_id, player_id)
    add_traces(fig, state, heatmap_function, timeseries_function, cityhighlight_function)
    fig.show()

Save all frames as PNGs:

In [None]:
for view_step in tqdm(range(len(steps) - 1)):
    state = states[view_step]
    fig = make_figure(state, replay_id, player_id)
    add_traces(fig, state, heatmap_function, timeseries_function, cityhighlight_function)
    fig.write_image(f"cerberus_replays/pngs/{replay_id}-{view_step:03}.png")