In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy.ndimage import gaussian_filter

import scipy
import seaborn as sns
import datetime


def plot_heat(heatmap, sigma, name):
    fig, ax = plt.subplots()
    magma = sns.color_palette("magma", as_cmap=True)
    newcolors = magma(np.linspace(0, 1, 256))
    # reds = sns.color_palette("Reds", as_cmap=True)
    # white = np.array([1, 1, 1, 1])
    # newcolors[:5, :] = white
    newcmp = ListedColormap(newcolors)

    heatmap = gaussian_filter(heatmap, sigma=sigma)
    extent = [-0.5, 0.5, -0.5, 0.5]
    plt.rcParams.update({'font.size': 20})
    image = plt.imshow(heatmap.T, extent=extent, origin='lower', vmin=0, vmax=150, cmap=newcmp, rasterized=True)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)

    plt.xlabel("X", fontsize=20)
    plt.ylabel("Y", fontsize=20)

    plt.xticks(np.arange(-0.5, 0.6, 0.5))
    plt.yticks(np.arange(-0.5, 0.6, 0.5))

    plt.colorbar()
    plt.gca().set_aspect('equal', adjustable='box')
    timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

    plot_dir = os.path.join(os.getcwd(), 'plots')
    if not os.path.exists(plot_dir):
        os.makedirs(plot_dir)
    fig.savefig(os.path.join(plot_dir, 'plot_{} - {}.svg'.format(name, timestamp)), dpi=2000, bbox_inches="tight", format="svg")
    print(f"{name} plot:")
    plt.show()
    fig.clf()

def select_first(x, y, steps, max):
    x = np.array([x[x_i] for x_i in range(len(x)) if steps[x_i] < max])
    y = np.array([y[y_i] for y_i in range(len(y)) if steps[y_i] < max])
    steps = np.array([steps[steps_i] for steps_i in range(len(steps)) if steps[steps_i] < max])
    return x, y, steps

def process_plot(env, dir_pro, dir_ant, max_step=160, sigma=3,  bins_size=None, individual_diff=False, plot_h_pro=False, plot_h_ant=True, plot_h_diff=True):
    save_path_pro = os.path.join(os.getcwd(), "results", "models", "custom", env, dir_pro)
    save_path_ant = os.path.join(os.getcwd(), "results", "models", "custom", env, dir_ant)

    with open(os.path.join(save_path_pro, "pro_pos.npy"), "rb") as pro_pos_f:
        fsz = os.fstat(pro_pos_f.fileno()).st_size
        pro_pos = np.load(pro_pos_f)
        while pro_pos_f.tell() < fsz:
            pro_pos = np.vstack((pro_pos, np.load(pro_pos_f)))
        pro_x = pro_pos[:, 0]
        pro_y = pro_pos[:, 1]
        pro_step = pro_pos[:, 2]

    with open(os.path.join(save_path_ant, "ant_pos.npy"), "rb") as ant_pos_f:
        fsz = os.fstat(ant_pos_f.fileno()).st_size
        ant_pos = np.load(ant_pos_f)
        while ant_pos_f.tell() < fsz:
            ant_pos = np.vstack((ant_pos, np.load(ant_pos_f)))
        ant_x = ant_pos[:, 0]
        ant_y = ant_pos[:, 1]
        ant_step = ant_pos[:, 2]

    ant_x, ant_y, ant_step = select_first(ant_x, ant_y, ant_step, max_step)
    pro_x, pro_y, pro_step = select_first(pro_x, pro_y, pro_step, max_step)

    data_per_step_ant = [[[], [], []] for _ in range(max_step)]
    data_per_step_pro = [[[], [], []] for _ in range(max_step)]
    data_per_step_diff = []
    data_per_step_heat_ant = []
    data_per_step_heat_pro = []

    for x, y, step in zip(ant_x, ant_y, ant_step):
        data_per_step_ant[int(step)][0].append(x)
        data_per_step_ant[int(step)][1].append(y)
        data_per_step_ant[int(step)][2].append(step)

    for x, y, step in zip(pro_x, pro_y, pro_step):
        data_per_step_pro[int(step)][0].append(x)
        data_per_step_pro[int(step)][1].append(y)
        data_per_step_pro[int(step)][2].append(step)

    if not bins_size:
        bins_size = [256, 256]

    overall = 0
    for i in range(max_step):
        heatmap_pro, _, _ = np.histogram2d(data_per_step_pro[i][0], data_per_step_pro[i][1], bins=bins_size, range=[[-0.5, 0.5], [-0.5, 0.5]], density=False)
        heatmap_ant, _, _ = np.histogram2d(data_per_step_ant[i][0], data_per_step_ant[i][1], bins=bins_size, range=[[-0.5, 0.5], [-0.5, 0.5]], density=False)
        overall += np.sum(heatmap_pro) + np.sum(heatmap_ant)
        heatmap_diff = np.abs(heatmap_pro - heatmap_ant)
        data_per_step_heat_ant.append(heatmap_ant)
        data_per_step_heat_pro.append(heatmap_pro)
        data_per_step_diff.append(heatmap_diff)

    heatmap_plot_pro = np.zeros((bins_size[0], bins_size[1]))
    for heat_pro in data_per_step_heat_pro:
        heatmap_plot_pro += heat_pro

    heatmap_plot_ant = np.zeros((bins_size[0], bins_size[1]))
    for heat_ant in data_per_step_heat_ant:
        heatmap_plot_ant += heat_ant

    heatmap_plot_diff = np.zeros((bins_size[0], bins_size[1]))
    for heat in data_per_step_diff:
        heatmap_plot_diff += heat

    if plot_h_pro:
        plot_heat(heatmap_plot_pro, sigma, "pro")
    if plot_h_ant:
        plot_heat(heatmap_plot_ant, sigma, "ant")
    if plot_h_diff:
        if not individual_diff:
            heatmap_plot_diff = np.abs(heatmap_plot_pro - heatmap_plot_ant)
        print("\t ========================")
        print("\t Same: ", (overall - np.sum(heatmap_plot_diff)))
        print("\t Difference: ", np.sum(heatmap_plot_diff))
        print("\t Overall: ", overall)
        print("\t Percentage Same: ", (overall - np.sum(heatmap_plot_diff)) / overall)
        print("\t ========================")
        plot_heat(heatmap_plot_diff, sigma, "diff")
    return (overall - np.sum(heatmap_plot_diff)) / overall

In [159]:
env_name = "3s_vs_3z"

algo_qtran_eval = [2524, 2525, 2526, 2527, 2528]
algo_qmix_eval = [2529, 2530, 2531, 2532, 2533]
algo_qplex_eval = [2534, 2535, 2536, 2537, 2538]
algo_vdn_eval = [2539, 2540, 2541, 2542, 2543]
algo_iql_eval = [2544, 2545, 2546, 2547, 2548]
algo_ow_qmix_eval = [2549, 2550, 2551, 2552, 2553]

h_qtran_eval = [2797, 2798, 2799, 2800, 2801]
h_qmix_eval = [2802, 2803, 2804, 2805, 2806]
h_qplex_eval = [2807, 2808, 2809, 2810, 2811]
h_vdn_eval = [2812, 2813, 2814, 2815, 2816]
h_iql_eval = [2817, 2818, 2819, 2820, 2821]
h_ow_qmix_eval = [2822, 2823, 2824, 2825, 2826]

obv_qtran_eval = [2827, 2828, 2829, 2830, 2831]
obv_qmix_eval = [2832, 2833, 2834, 2835, 2836]
obv_qplex_eval = [2837, 2838, 2839, 2840, 2841]
obv_vdn_eval = [2842, 2843, 2844, 2845, 2846]
obv_iql_eval = [2847, 2848, 2849, 2850, 2851]
obv_ow_qmix_eval = [2852, 2853, 2854, 2855, 2856]

random_qtran = [3147, 3148, 3149, 3150, 3151]
random_qmix = [3152, 3153, 3154, 3155, 3156]
random_qplex = [3157, 3158, 3159, 3160, 3161]
random_vdn = [3162, 3163, 3164, 3165, 3166]
random_iql = [3167, 3168, 3169, 3170, 3171]
random_ow_qmix = [3172, 3173, 3174, 3175, 3176]

In [4]:
env_name = "5m_vs_6m"

algo_qtran_eval = [2494, 2495, 2496, 2497, 2498]
algo_qmix_eval = [2499, 2500, 2501, 2502, 2503]
algo_qplex_eval = [2504, 2505, 2506, 2507, 2508]
algo_vdn_eval = [2509, 2510, 2511, 2512, 2513]
algo_iql_eval = [2514, 2515, 2516, 2517, 2518]
algo_ow_qmix_eval = [2519, 2520, 2521, 2522, 2523]

h_qtran_eval = [2857, 2858, 2859, 2860, 2861]
h_qmix_eval = [2862, 2863, 2864, 2865, 2866]
h_qplex_eval = [2867, 2868, 2869, 2870, 2871]
h_vdn_eval = [2872, 2873, 2874, 2875, 2876]
h_iql_eval = [2877, 2878, 2879, 2880, 2881]
h_ow_qmix_eval = [2882, 2883, 2884, 2885, 2886]

obv_qtran_eval = [2887, 2888, 2889, 2890, 2891]
obv_qmix_eval = [2892, 2893, 2894, 2895, 2896]
obv_qplex_eval = [2897, 2898, 2899, 2900, 2901]
obv_vdn_eval = [2902, 2903, 2904, 2905, 2906]
obv_iql_eval = [2907, 2908, 2909, 2910, 2911]
obv_ow_qmix_eval = [2912, 2913, 2914, 2915, 2916]

random_qtran = [2350, 2351, 2352, 2353, 2354]
random_qmix = [2355, 2356, 2357, 2358, 2359]
random_qplex = [2360, 2361, 2362, 2363, 2364]
random_vdn = [2365, 2366, 2367, 2368, 2369]
random_iql = [2370, 2371, 2372, 2373, 2374]
random_ow_qmix = [2375, 2376, 2377, 2378, 2379]

In [None]:
env_name = "MMM"

algo_qtran_eval = [3040, 3041, 3042, 3043, 3044]
algo_qmix_eval = [3045, 3046, 3047, 3048, 3049]
algo_qplex_eval = [3050, 3051, 3052, 3053, 3054]
algo_vdn_eval = [3055, 3056, 3057, 3058, 3059]
algo_iql_eval = [3060, 3061, 3062, 3063, 3064]
algo_ow_qmix_eval = [3065, 3066, 3067, 3068, 3069]

h_qtran_eval = [3070, 3071, 3072, 3073, 3074]
h_qmix_eval = [3075, 3076, 3077, 3078, 3085]
h_qplex_eval = [3080, 3081, 3082, 3083, 3084]
h_vdn_eval = [3086, 3087, 3088, 3089, 3090]
h_iql_eval = [3091, 3092, 3093, 3094, 3095]
h_ow_qmix_eval = [3096, 3097, 3098, 3099, 3100]

obv_qtran_eval = [3101, 3102, 3103, 3104, 3105]
obv_qmix_eval = [3106, 3107, 3108, 3109, 3110]
obv_qplex_eval = [3111, 3112, 3113, 3114, 3115]
obv_vdn_eval = [3116, 3117, 3118, 3119, 3120]
obv_iql_eval = [3121, 3122, 3123, 3124, 3125]
obv_ow_qmix_eval = [3126, 3127, 3128, 3129, 3130]

random_qtran = [3010, 3011, 3012, 3013, 3019]
random_qmix = [3014, 3015, 3016, 3017, 3018]
random_qplex = [3020, 3021, 3022, 3023, 3024]
random_vdn = [3025, 3026, 3027, 3028, 3029]
random_iql = [3030, 3031, 3032, 3033, 3034]
random_ow_qmix = [3035, 3036, 3037, 3038, 3039]

In [None]:

# hidden
eval_list = [h_qplex_eval[4]]

# obviousw
# eval_list = [obv_qplex_eval[2]]
# eval_list = obv_qplex_eval

# random
# eval_list = [random_qplex_eval[1]]

max_step = 160
bins_size = [128, 128]
sigma = 1

dir_name_pro = f"AA-{h_qplex_eval[4]}"

for eval_i, eval_e in enumerate(eval_list):

    dir_name_ant = f"AA-{eval_e}"

    print("========================")
    print("")
    print(eval_i)
    print("")

    percentage = process_plot(env_name, dir_name_pro, dir_name_ant, max_step=max_step, sigma=sigma, bins_size=bins_size, individual_diff=False, plot_h_pro=True)

In [None]:


# hidden
# eval_list = h_qmix_eval

algo_runs = [algo_qtran_eval, algo_qmix_eval, algo_qplex_eval, algo_vdn_eval, algo_iql_eval, algo_ow_qmix_eval]
h_runs = [h_qtran_eval, h_qmix_eval, h_qplex_eval, h_vdn_eval, h_iql_eval, h_ow_qmix_eval]
obv_runs = [obv_qtran_eval, obv_qmix_eval, obv_qplex_eval, obv_vdn_eval, obv_iql_eval, obv_ow_qmix_eval]
random_runs = [random_qtran, random_qmix, random_qplex, random_vdn, random_iql, random_ow_qmix]

runs = h_runs

# h_qtran_eval
# h_qmix_eval
# h_qplex_eval
# h_vdn_eval
# h_iql_eval
# h_ow_qmix_eval

# obvious
# eval_list = [obv_qplex_eval[2]]
# eval_list = random_qplex_eval

# random
# eval_list = [random_qplex_eval[1]]

max_step = 1000
bins_size = [128, 128]
sigma = 1

# dir_name_pro = f"AA-{h_qplex_eval[4]}"

for run in runs:

    eval_list = run

    percentages = []

    for eval_i, eval_e in enumerate(eval_list):

        dir_name_ant = f"AA-{eval_e}"
        percentage = process_plot(env_name, dir_name_ant, dir_name_ant, max_step=max_step, sigma=sigma, bins_size=bins_size, individual_diff=False, plot_h_pro=True)
        percentages.append(percentage)

    def m_ci(data, confidence=0.95):
        h = scipy.stats.sem(data) * scipy.stats.t.ppf((1 + confidence) / 2., len(data)-1)
        return h

    percentages = np.array(percentages)
    print(f"${round(percentages.mean(), 2):.2f} \pm {round(m_ci(percentages), 2):.2f}$")