In [None]:
import sys

from hytea.utils import DotDict
from hytea.bitstringdecoder import BitStringDecoder

from pathlib import Path
from yaml import safe_load
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.signal import savgol_filter
from scipy.interpolate import interp1d

### Set global variables

In [None]:
with open(Path() / 'hytea' / 'config.yaml', 'r') as f:
    CFG = DotDict.from_dict(safe_load(f))

In [None]:
DATA = Path() / 'results' / 'data'
PLOTS = Path() / 'results' / 'plots'

PLOTS.mkdir(parents=True, exist_ok=True)

### Define function to load data

In [None]:
def load_data(name: str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """ Get both dataframes for a given experiment name.

    First merges all training csv files into one dataframe.
    Then merges all summary csv files into one dataframe.
    Finally, returns both dataframes.
    """

    # check how many csv files there are for this experiment
    n = len(list(DATA.glob(f'{name}*tr.csv')))
    trdfs = []
    for i in range(1, n+1):
        trdfs.append(pd.read_csv(DATA / f'{name}{i}tr.csv', index_col=0))

    trdf = pd.concat(trdfs, axis=1)
    # filter out columns with __MIN in the name
    trdf = trdf.loc[:, ~trdf.columns.str.contains('__MIN')]
    # filter out columns with __MAX in the name
    trdf = trdf.loc[:, ~trdf.columns.str.contains('__MAX')]

    # remove " - train_reward" from the column names
    trdf.columns = trdf.columns.str.replace(' - train_reward', '')

    # transpose the dataframe
    trdf = trdf.T
    trdf.index.name = 'runID'

    sdfs = []
    for i in range(1, n+1):
        sdfs.append(pd.read_csv(DATA / f'{name}{i}s.csv', index_col=0))
    
    sdf = pd.concat(sdfs, axis=0)

    # create df with only the columns we are interested in
    sdf = sdf.loc[:, ['agent.bl_sub', 'agent.ent_reg_weight', 'agent.gamma', 'group_name', 'network.hidden_activation', 'network.hidden_size', 'network.num_layers', 'optimizer.lr', 'optimizer.lr_decay', 'optimizer.lr_step', 'test_reward']]

    sdf.index.name = 'runID'

    return trdf, sdf

### Define function to plot global rewards

In [None]:
def plot_rewards(
    grouper: pd.Grouper, df_tr: pd.DataFrame, df_s: pd.DataFrame,
    title: str, max: bool = False
) -> plt.Figure:
    """
    Plot the rewards on a global episode scale.
    If max is True, the maximum reward is plotted instead of the mean.
    """

    dfs: list[pd.DataFrame] = []
    for name in grouper.groups.keys():
        dfs.append(df_tr.loc[df_s.loc[df_s['group_name'] == name].index])

    fig, ax = plt.subplots(figsize=(10, 5))

    l = dfs[0].shape[1]

    for i, df in enumerate(dfs):
        data = df.mean(axis=0) if not max else df.max(axis=0)
        smooth_data = savgol_filter(data, 51, 3)
        ax.plot(np.arange(i*l+1, (i+1)*l+1), data, alpha=0.5, color='tab:blue')
        ax.plot(np.arange(i*l+1, (i+1)*l+1), smooth_data, color='tab:blue')

    ax.set_xlabel('global episode')
    ax.xaxis.set_major_locator(plt.MaxNLocator(10))
    ax.xaxis.set_major_formatter(lambda x, pos: f'{x/1000:.1f}k')

    ax.set_ylabel('avg. reward' if not max else 'max. reward')
    ax.set_title(title + (' (max)' if max else ' (avg)'))
    fig.tight_layout()

    return fig

### Load data

In [None]:
db = {}
for name in ['ab', 'cp', 'll']:
    trdf, sdf = load_data(name)
    db[name] = {'tr': trdf, 's': sdf}

DB: DotDict[str, DotDict[str, pd.DataFrame]] = DotDict.from_dict(db)

In [None]:
fig = plot_rewards(DB.ab.s.groupby('group_name'), DB.ab.tr, DB.ab.s, 'AcroBot-v1')
fig.savefig(PLOTS / 'ab_avg_rewards.png', dpi=500)

fig = plot_rewards(DB.cp.s.groupby('group_name'), DB.cp.tr, DB.cp.s, 'CartPole-v1')
fig.savefig(PLOTS / 'cp_avg_rewards.png', dpi=500)

fig = plot_rewards(DB.ll.s.groupby('group_name'), DB.ll.tr, DB.ll.s, 'LunarLander-v2')
fig.savefig(PLOTS / 'll_avg_rewards.png', dpi=500)

In [None]:
fig = plot_rewards(DB.ab.s.groupby('group_name'), DB.ab.tr, DB.ab.s, 'AcroBot-v1', max=True)
fig.savefig(PLOTS / 'ab_max_rewards.png', dpi=500)

fig = plot_rewards(DB.cp.s.groupby('group_name'), DB.cp.tr, DB.cp.s, 'CartPole-v1', max=True)
fig.savefig(PLOTS / 'cp_max_rewards.png', dpi=500)

fig = plot_rewards(DB.ll.s.groupby('group_name'), DB.ll.tr, DB.ll.s, 'LunarLander-v2', max=True)
fig.savefig(PLOTS / 'll_max_rewards.png', dpi=500)