In [1]:
import os
import pickle

import numpy as np
import matplotlib.pyplot as plt

In [2]:
def remove_outliers(arr: np.ndarray) -> np.ndarray:
    Q1 = np.percentile(arr, 25)
    Q3 = np.percentile(arr, 75)
    IQR_scaled = 1.5 * (Q3 - Q1)
    l_fence = Q1 - IQR_scaled
    u_fence = Q3 + IQR_scaled
    return np.clip(arr, l_fence, u_fence)

In [3]:
with open("heatmaps.pkl", "rb") as f:
    all_results = pickle.load(f)

with open("mapdefs.pkl", "rb") as f:
    map_defs = pickle.load(f)

hmap_res = 10

In [4]:
things_marker_def = {
    "ammo"          : "s",
    "weapon"        : "X",
    "medkit"        : "P",
    "spawn point"   : "o",
    "player"        : "v"
}

things_colour_def = {
    "ammo"          : "cyan",
    "weapon"        : "yellow",
    "medkit"        : "orange",
    "spawn point"   : "white",
    "player"        : "k"
}

In [5]:
plt.rcParams["font.family"] = "monospace"

def display_hmap(hmap_offset: tuple[np.ndarray, np.ndarray], name: str, use_alpha: bool=False, 
                 cmap: str="hot", interpolation: str="antialiased", base_value: float=0.02, eps=0.0,
                 amplify: bool=False):
    hmap_, offset = hmap_offset
    if amplify:
        hmap_ = remove_outliers(hmap_)
        hmap_ /= hmap_.max()
        hmap_ *= (1 - base_value)
        hmap_[hmap_ > eps] += base_value
        name += " (amplified)"
    else:
        hmap_ /= hmap_.max()
        hmap_ *= (1 - base_value)
        hmap_[hmap_ > eps] += base_value
    xmin, xmax, ymin, ymax = -offset[0], hmap_.shape[0] - offset[0], -offset[1], hmap_.shape[1] - offset[1]
    ax = plt.imshow(np.flip(hmap_.T, axis=0), cmap=cmap, interpolation=interpolation,
                    alpha=(hmap_ > 0).astype(np.float64) if use_alpha else None,
                    extent=(xmin, xmax, ymin, ymax))
    ax.set_clim(0, 1)
    plt.xlim(xmin - hmap_res, xmax + hmap_res)
    plt.ylim(ymin - hmap_res, ymax + hmap_res)
    plt.colorbar(ax)
    plt.title(f"   {name:42s}")
    return ax

def draw_items(map_id: str="map1"):
    map_type = int(map_id.replace("rtss_", '')[3])
    things_dict, edges = map_defs[map_type]
    for ex, ey in edges:
        plt.plot(ex, ey, color='white')
    for k, v in things_dict.items():
        if k == "player" or len(v) == 0:
            continue
        plt.scatter(v[0, :], v[1, :], s=60, marker=things_marker_def[k], c=things_colour_def[k], label=k, edgecolors="k")
    if map_type == 3:
        plt.ylim(-500, 1000)
        plt.xlim(-870, 870)
        plt.legend(bbox_to_anchor=(0, 1.015))
    else:
        plt.legend(bbox_to_anchor=(0, 1.015))
    if map_type == 1:
        plt.figtext(-0.055, 0.59, "Note that the position\nof items only roughly\nindicates their actual\nplacement in game.", fontsize=8, bbox={"facecolor":"white", "alpha":0.95, "edgecolor":"gray"})
    elif map_type == 2:
        plt.figtext(-0.055, 0.63, "Note that the position\nof items only roughly\nindicates their actual\nplacement in game.", fontsize=8, bbox={"facecolor":"white", "alpha":0.95, "edgecolor":"gray"})
    elif map_type == 3:
        plt.figtext(-0.097, 0.56, "Note that the position\nof items only roughly\nindicates their actual\nplacement in game.", fontsize=8, bbox={"facecolor":"white", "alpha":0.95, "edgecolor":"gray"})

In [6]:
Names = {
    "bot"               : " Best ZCajun Bot (built-in)",
    "r1_r_sl_ss_5e-4"   : "RPPO(1,   SS)",
    "r1_r_Dsl_ss_5e-4"  : "RPPO(1,   SS, double damage reward)",
    # "s4_ss_5e-4"        : " PPO(4,   SS)",
    "ss_1e-3"           : " PPO(1,   SS)",
    "s4_ss_1e-3"        : " PPO(4,   SS)",
    # "s4_ss_5e-4"        : " PPO(4,   SS, halved lr)",
    "rgb_9e-4"          : " PPO(1, RGB)",
    "rgb_9e-4_a"        : " PPO(1, RGB, wall & floor altered)",
    "rgb_9e-4_w"        : " PPO(1, RGB, wall texture altered)",
    "ss_rgb_1e-3_best"  : " PPO(1, RGB+  SS)",
    "ss_rgb_1e-3_besta" : " PPO(1, RGB+  SS, wall & floor altered)",
    "ss_rgb_1e-3_bestw" : " PPO(1, RGB+  SS, wall texture altered)",
}

In [7]:
if not os.path.exists("heatmaps"):
    os.makedirs("heatmaps")

In [None]:
for map_id, map_dict in all_results.items():
    for model_name, model_dict in map_dict.items():
        model_name = Names[model_name]
        if "rtss" in map_id:
            model_name = model_name.replace("  SS", "RTSS")
        title = f"{map_id.replace('rtss_', '').rstrip('s').upper()}({model_name})"
        plt.clf()
        display_hmap(model_dict['hmap'], title, cmap="hot", base_value=0.01)
        draw_items(map_id)
        plt.savefig(f"heatmaps/{title}.png", bbox_inches='tight')