In [None]:
from dotmap import DotMap
import json
import pandas as pd
import numpy as np
import matplotlib.patches as mpatches
import seaborn as sns
import neptune.new as neptune
import os

"""
High quality plots for evaluation.

author(s): Arnold Unterauer
"""

runs = [[644, 645, 646], [650, 651, 652], [653, 654, 655], [680, 681, 682]]
labels = ["Tanh", "ELU", "ReLU", "Mish"]
titel = "Activation"
x_range = [-100, 3100]
y_range = [-10, 300]
x_ticks = [0 , 3001]
x_tick_steps = 1000
legend_location = "upper left"
flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"]


with open('neptune_auth.json', 'r') as f:
    neptune_auth = json.load(f)
    neptune_auth = DotMap(neptune_auth)

def get_df_from_neptune(run_id):
    neptune_run = neptune.init(project=neptune_auth.project, api_token=neptune_auth.api_token, mode="read-only", run="UNITYML-{}".format(run_id))
    epoch_return = neptune_run['return'].fetch_values(include_timestamp=False)['value']
    df = pd.DataFrame()
    df["x"] = range(1, len(epoch_return) + 1)
    df["return"] = epoch_return.rolling(100).mean()
    neptune_run.stop()
    return df

dataframes = []
for run in runs:
    dataframe = []
    for nep_id in run:
        dataframe.append(get_df_from_neptune(nep_id))
    dataframe = pd.concat(dataframe, ignore_index=True)
    dataframes.append(dataframe)

current_palette = sns.color_palette(flatui)

plt = None
for i, dataframe in enumerate(dataframes):
    plt = sns.lineplot(x="x", y="return", data=dataframe, color=current_palette[i], err_style='band', ci="sd")

plt.set_xlabel('episode')
plt.set_ylabel('reward')
plt.set_xticks(np.arange(x_ticks[0], x_ticks[1], step=x_tick_steps))

colors = []
if len(labels) == 0:
    labels = runs
for i, label in enumerate(labels):
    colors.append(mpatches.Patch(color=current_palette[i], label='{}'.format(label)))

leg = plt.legend(handles=colors, labelspacing=0.1, loc=legend_location)
leg.set_title(titel)

plt.set(xlim=(x_range[0], x_range[1]))
plt.set(ylim=(y_range[0], y_range[1]))
fig = plt.get_figure()

plot_dir = os.path.join(os.getcwd(), 'plots')
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)
fig.savefig(plot_dir + '/plot - {}.png'.format(titel), dpi=300, bbox_inches="tight")