In [None]:
import os
from glob import glob

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm

sns.set_theme()

# os.chdir("/app")

repertories = [
    # ("Random", "outputs/wvcp_hard_ahead_hh_4h/random"),
    # ("Roulette", "outputs/wvcp_hard_ahead_hh_4h/roulette_wheel"),
    # ("Deleter", "outputs/wvcp_hard_ahead_hh_4h/deleter"),
    # ("UCB", "outputs/wvcp_hard_ahead_hh_4h/ucb"),
    # ("Pursuit", "outputs/wvcp_hard_ahead_hh_4h/pursuit"),
    # ("NN", "outputs/wvcp_hard_ahead_hh_4h/neural_net_cross"),
    
    ("Random", "outputs/wvcp_hard_ahead_1h/random"),
    ("Roulette", "outputs/wvcp_hard_ahead_1h/roulette_wheel"),
    ("Deleter", "outputs/wvcp_hard_ahead_1h/deleter"),
    ("UCB", "outputs/wvcp_hard_ahead_1h/ucb"),
    ("Pursuit", "outputs/wvcp_hard_ahead_1h/pursuit"),
    ("NN", "outputs/wvcp_hard_ahead_1h/neural_net_cross"),
]

run_time = 3600 * 4
time_ls = 0.04


def get_nb_vertices(instance: str):
    file = "instances/instance_info_gcp.csv"
    with open(file, "r", encoding="utf8") as file:
        for line in file.readlines():
            instance_, reduced, nb_vertices, _, _ = line[:-1].split(",")
            if instance_ != instance and reduced != "true":
                continue
            return int(nb_vertices)
    print(f"instance {instance} not found in instances/instance_info.txt")


if not os.path.exists("plots"):
    os.mkdir("plots")


with open("DIMACS_hard.txt", "r", encoding="utf8") as f:
    instances = [line.strip() for line in f.readlines()]

nb_iter = { instance: 99999 for instance in instances }

operators = [
    "GPX-50% + ILS-TS",
    "GPX-50% + RedLS",
    "GPX-75% + ILS-TS",
    "GPX-75% + RedLS",
    "GPX-90% + ILS-TS",
    "GPX-90% + RedLS",
]
colors = [
    sns.color_palette("Blues", 3)[0],
    sns.color_palette("Reds", 3)[0],
    sns.color_palette("Blues", 3)[1],
    sns.color_palette("Reds", 3)[1],
    sns.color_palette("Blues", 3)[2],
    sns.color_palette("Reds", 3)[2],
]

for instance, nb_turns in nb_iter.items():
    print(instance)
    fig = plt.figure(tight_layout=True, figsize=(20, 5))
    gs = gridspec.GridSpec(1, 6)
    for n_criteria, (criteria, path_c) in enumerate(repertories):
        mean_curve = {o: [] for o in operators}
        for rd in range(20):
            file_tbt = glob(f"{path_c}/tbt/{instance}_{rd}.csv")
            if not file_tbt:
                print(f"no file for {path_c}/tbt/{instance}_{rd}.csv")
                continue
            file_tbt = file_tbt[0]
            if not os.path.exists(file_tbt):
                print(f"no found {file_tbt}")
                continue
            # get the data about turn by turn results
            data_tbt = pd.read_csv(file_tbt, comment="#")
            count = {m: [0] for m in operators}
            current_turn = -1
            selected_s = data_tbt.selected.to_list()
            # if nb_turns < len(selected_s):
            #     nb_turns = len(selected_s)
            if nb_turns == -1:
                nb_turns = len(selected_s)
            elif nb_turns > len(selected_s):
                print(
                    f"error {instance} {criteria} nb selected got : {len(selected_s)} before : {nb_turns}"
                )
                nb_turns = len(selected_s)
            # print(file_tbt)
            # continue
            # if nb_turns != len(selected_s):
            #     print(instance, criteria, rd, len(selected_s), nb)
            for selected in selected_s:
                for nb_o, o in enumerate(operators):
                    to_add = 0
                    for select in selected.split(":"):
                        if int(select) == nb_o:
                            to_add += 1
                    count[o].append(count[o][-1] + to_add)
            # remove first 0
            for i, o in enumerate(operators):
                count[o].pop(0)
                mean_curve[o].append(count[o])
        x = np.linspace(0, nb_turns, nb_turns)

        ax = fig.add_subplot(gs[0, n_criteria % 6])
        # for i, (o, line_style) in enumerate(zip(operators, line_styles)):
        for i, o in enumerate(operators):
            line_style = "-"
            mean_o = []
            for t in range(nb_turns):
                mean_turn = []
                for val in mean_curve[o]:
                    if len(val) > t:
                        mean_turn.append(val[t])
                mean_o.append(np.mean(mean_turn))
            mean_o = np.array(mean_o)
            # mean_o = np.array(
            #     [
            #         np.mean(
            #             [val[i] for val in mean_curve[o]]
            #         )
            #         for i in range(nb_turns)
            #     ]
            # )
            std_o = []
            for t in range(nb_turns):
                std_turn = []
                for val in mean_curve[o]:
                    if len(val) > t:
                        std_turn.append(val[t])
                std_o.append(np.std(std_turn) / 2)
            std_o = np.array(std_o)
            # std_o = np.array(
            #     [
            #         np.std(
            #             [val[i] if len(val) < i else np.nan for val in mean_curve[o]]
            #         )
            #         / 2
            #         for i in range(nb_turns)
            #     ]
            # )
            ax.fill_between(
                x,
                mean_o - std_o,
                mean_o + std_o,
                color=colors[i],
                alpha=0.5,
            )
            ax.plot(x, mean_o, line_style, label=o, color=colors[i])

        # handles, labels = plt.gca().get_legend_handles_labels()
        # by_label = dict(zip(labels, handles))
        # plt.legend(
        #     by_label.values(),
        #     by_label.keys(),
        #     bbox_to_anchor=(1.05, 1),
        #     loc="upper left",
        #     borderaxespad=0.0,
        # )
        # plt.grid(True, which="both", linestyle="--")
        # print(nb_turns, instance, criteria)
        ax.set_xlim(0, nb_turns)
        ax.set_ylim(0)
        if n_criteria == 0:
            plt.legend(loc="upper left")
            ax.set_ylabel("cumulative selections")
        ax.set_xlabel("Iterations")

        plt.title(f"{instance} - {criteria}")
    # plt.savefig(f"plots/{instance}_head_gcp.png")
    plt.show()
    plt.close()


In [None]:
# 'Accent', 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'crest', 'crest_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'flare', 'flare_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'icefire', 'icefire_r', 'inferno', 'inferno_r', 'jet', 'jet_r', 'magma', 'magma_r', 'mako', 'mako_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'rocket', 'rocket_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'turbo', 'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', 'twilight_shifted_r', 'viridis', 'viridis_r', 'vlag', 'vlag_r', 'winter', 'winter_r'