In [1]:
from typing import Callable

import re
import os
import glob
import shutil
import subprocess
from functools import partial

import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from scipy.signal import lfilter

from parse_data.helper_func import load_merged_records

In [2]:
def agent_heatmap(pos: np.ndarray, width: int, height: int, radius: int=10, intensity: str="linear"):
    f_int = (lambda z: pow(max(radius - z, 0), 2)) if intensity.lower() == "exponential" else (lambda z: max(radius - z, 0))
    hf_w, hf_h = height // 2 + 1 + 2 * radius, width // 2 + 1 + 2 * radius
    hmap = np.zeros((hf_w * 2, hf_h * 2), dtype=np.float64)
    offset = np.asfarray([hf_w, hf_h])
    ax, ay, az = [], [], []
    for i in range(-radius, radius+1):
        for j in range(-radius, radius+1):
            if (z := np.linalg.norm([i, j])) <= radius:
                ax.append(i)
                ay.append(j)
                az.append(f_int(z))
    area = (np.array(ax, dtype=np.int64), np.array(ay, dtype=np.int64), np.array(az, dtype=np.float64))
    area_cache = {}

    for i in range(pos.shape[0]):
        tmp = np.round(pos[i, :2] + offset).astype(int)
        p = tuple(tmp.tolist())
        if p in area_cache:
            hmap[area_cache[p]] += 1
        else:
            new_entry = area[0] + p[0], area[1] + p[1]
            area_cache[p] = new_entry
            hmap[new_entry] += area[2]
    
    return (hmap, offset)

In [3]:
things_pattern = re.compile(r"thing // ([0-9]+)")
vertex_pattern = re.compile(r"vertex // ([0-9]+)")
linedef_pattern = re.compile(r"linedef // ([0-9]+)")

things_id_map = {
    "ammo"          : [2007, 2008, 2049],
    "weapon"        : [2001],
    "medkit"        : [2012],
    "spawn point"   : [11],
    "player"        : [1]
}

things_marker_def = {
    "ammo"          : "s",
    "weapon"        : "X",
    "medkit"        : "P",
    "spawn point"   : "o",
    "player"        : "v"
}

things_colour_def = {
    "ammo"          : "cyan",
    "weapon"        : "yellow",
    "medkit"        : "orange",
    "spawn point"   : "white",
    "player"        : "k"
}

In [4]:
map_defs = dict()

for map_id in range(1, 3+1):
    with open(f"./scenarios/TEXTMAP{map_id}.txt", 'r', encoding="utf-8") as f:
        things = []
        verts = {}
        edges = []
        while line := f.readline():
            line = line.lstrip()
            if things_pattern.match(line):
                x_ = y_ = type_ = None
                while line and line[0] != '{':
                    line = f.readline().lstrip()
                while line and line[0] != '}':
                    if line[:2] == 'x ':
                        exec(line.replace("x = ", "x_ = "))
                    elif line[:2] == 'y ':
                        exec(line.replace("y = ", "y_ = "))
                    elif line[:5] == 'type ':
                        exec(line.replace("type = ", "type_ = "))
                    line = f.readline().lstrip()
                things.append((x_, y_, type_))
            elif match := vertex_pattern.match(line):
                x_ = y_ = None
                while line and line[0] != '{':
                    line = f.readline().lstrip()
                while line and line[0] != '}':
                    if line[:2] == 'x ':
                        exec(line.replace("x = ", "x_ = "))
                    elif line[:2] == 'y ':
                        exec(line.replace("y = ", "y_ = "))
                    line = f.readline().lstrip()
                verts[int(match.group(1))] = (x_, y_)
            elif linedef_pattern.match(line):
                v1 = v2 = None
                while line and line[0] != '{':
                    line = f.readline().lstrip()
                while line and line[0] != '}':
                    if line[:2] == 'v1':
                        exec(line)
                    elif line[:2] == 'v2':
                        exec(line)
                    line = f.readline().lstrip()
                v1x, v1y = verts[v1]
                v2x, v2y = verts[v2]
                edges.append(((v1x, v2x), (v1y, v2y)))

    id_things_map = {}
    for k, v in things_id_map.items():
        for i in v:
            id_things_map[i] = k
    things_dict = {i : [] for i in things_id_map.keys()}
    for tx, ty, tid in things:
        things_dict[id_things_map[tid]].append([tx, ty])
    things_dict = {k : np.array(v).T for k, v in things_dict.items()}

    map_defs[map_id] = (things_dict, edges)

In [5]:
all_results = {m.split(os.path.sep)[-1] : {} for m in glob.glob("logs/*map*")}
names_to_eval = {"r1_r_sl_ss_5e-4", "r1_r_Dsl_ss_5e-4", "ss_1e-3", "s4_ss_1e-3", "rgb_9e-4", "ss_rgb_1e-3_best"}

In [6]:
hmap_res = 10
for map_name, map_dict in all_results.items():
    all_models = [m.split(os.path.sep)[-1] for m in glob.glob(f"logs/{map_name}/*")]
    for m in all_models:
        m_cleaned = m.strip('_').replace("ppo_", '').replace("_final", '')
        if m_cleaned not in names_to_eval:
            continue
        path_ = f"logs/{map_name}/{m}/record_*.npz"
        if glob.glob(path_):
            _, pos, _, wpn, fps, miou = load_merged_records(f"logs/{map_name}/{m}/record_*.npz", load_obs=False, no_tqdm=True)
            print(f"Generating heatmap: {map_name}/{m_cleaned}", end='\r')
            size_width, size_height = (2000, 2000) if "map3" in map_name else (1440, 1440)
            hmap = agent_heatmap(pos, size_width, size_height, hmap_res, "linear")
            map_dict[m_cleaned] = dict(pos=pos, wpn=wpn, fps=fps, miou=miou, hmap=hmap)

Generating heatmap: rtss_map3/s4_ss_1e-3_5e-4t4

In [7]:
for map_name, map_dict in all_results.items():
    for model, model_dict in map_dict.items():
        miou = map_dict[model]["miou"]
        if isinstance(miou, np.ndarray):
            print(f"Map: {map_name[5:]:5s} | Model: {model:26s} | MIoU : {miou.mean():5.3f} +/- {miou.std():6.4f}")
            # print(map_name, model, miou.min(), miou.max(), miou.mean())

Map: map1  | Model: r1_r_sl_ss_5e-4            | MIoU : 0.993 +/- 0.0008
Map: map1  | Model: s4_ss_1e-3                 | MIoU : 0.994 +/- 0.0004
Map: map1a | Model: ss_rgb_1e-3_best           | MIoU : 0.982 +/- 0.0009
Map: map1w | Model: ss_rgb_1e-3_best           | MIoU : 0.983 +/- 0.0008
Map: map2s | Model: r1_r_Dsl_ss_5e-4           | MIoU : 0.982 +/- 0.0006
Map: map3  | Model: r1_r_sl_ss_5e-4            | MIoU : 0.985 +/- 0.0019
Map: map3  | Model: s4_ss_1e-3                 | MIoU : 0.982 +/- 0.0026


In [8]:
import pickle

with open("heatmaps.pkl", "wb") as f:
    pickle.dump(all_results, f)

with open("mapdefs.pkl", "wb") as f:
    pickle.dump(map_defs, f)