In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from matplotlib.lines import Line2D

import __init__
from src.evaluation.data import (compute_iqm_time_series, compute_iqm_values, smooth_time_series, wandb_load_overview,
                                 wandb_load_runs, wandb_load_time_series)
from src.evaluation.plots import iqm_ci_plot, iqm_line_plot
from src.evaluation.utils import mm2in


---

#### Plotting setup

In [None]:
plt.rcParams.update({
    'axes.prop_cycle': cycler('color', ["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC",
                                        "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"]),
    'axes.titlepad': 3.0,
    'axes.xmargin': 0.025,
    'axes.ymargin': 0.025,
    'axes.titlesize': 'medium',
    'axes.labelpad': 1.0,
    'axes.spines.right': False,
    'axes.spines.top': False,
    'font.size': 7,
    'font.family': 'serif',
    'font.serif': 'Times',
    'mathtext.fontset': 'custom',
    'mathtext.it': 'Times',
    'mathtext.rm': 'Times',
    'text.usetex': False,
    'grid.alpha': 0.1,
    'grid.color': '#000000',
    'legend.borderaxespad': 0.25,
    'legend.borderpad': 0.0,
    'legend.frameon': False,
    'legend.columnspacing': 1.0,
    'legend.handletextpad': 0.5,
    'legend.handlelength': 0.75,
    'lines.solid_capstyle': 'round',
    'lines.solid_joinstyle': 'round',
    'xtick.major.pad': 2.0,
    'xtick.major.size': 2.0,
    'xtick.minor.size': 1.0,
    'ytick.major.pad': 2.0,
    'ytick.major.size': 2.0,
    'ytick.minor.size': 1.0,
    'figure.constrained_layout.h_pad': 0.01,
    'figure.constrained_layout.hspace': 0.0,
    'figure.constrained_layout.use': True,
    'figure.constrained_layout.w_pad': 0.01,
    'figure.constrained_layout.wspace': 0.0
})


In [None]:
labels = {
    ('rl-full', np.nan): 'full state',
    ('rl-key-5', np.nan): 'ground truth points',
    ('rl-key-only-5', np.nan): 'ground truth points',
    ('rl-feat', 'sae-keynet-vel-var-bg'): 'KeyNet-vel-std-bg',
    ('rl-feat', 'sae-basic-vel-var-bg'): 'Basic-vel-std-bg',
    ('rl-feat', 'sae-basic-fp32'): 'Basic-kp32',
    ('rl-feat', 'sae-dsae'): 'DSAE',
    ('rl-feat-only', 'sae-keynet-vel-var-bg'): 'KeyNet-vel-std-bg',
    ('rl-feat-only', 'sae-basic-vel-var-bg'): 'Basic-vel-std-bg',
    ('rl-feat-only', 'sae-basic-fp32'): 'Basic-kp32',
    ('rl-feat-only', 'sae-dsae'): 'DSAE'
}


---

#### Load overview table with all runs

In [None]:
entity = 'jonasreiher'
project = 'sae-rl-RL'


In [None]:
runs = wandb_load_runs(entity, project)
runs_all = wandb_load_overview(runs)


---

#### Filter, group, and count runs

In [None]:
# filter for completely finished runs
# group by RL experiment and SAE model

runs_grouped = runs_all[
    (runs_all['state'] == 'finished') &
    (runs_all['global_step'] == 3000000)
].set_index(['experiment', 'training.sae_experiment', 'name']).sort_index()

# runs_grouped = runs_all.set_index(['experiment', 'training.sae_experiment', 'name']).sort_index()


In [None]:
# count number of runs for every group

runs_grouped.groupby(
    ['experiment', 'training.sae_experiment'], dropna=False
)[['run']].count()


---
---

**Feature points only**

In [None]:
exp = 'rl-feat-only'
df = runs_grouped.filter_by({
    'experiment': ['rl-feat-only', 'rl-full', 'rl-key-only-5'],
    'training.sae_experiment': ['sae-keynet-vel-var-bg', 'sae-basic-vel-var-bg', 'sae-basic-fp32', 'sae-dsae', np.nan]
}).sort_index()


---

IQM Success Rate over Time (FPs only)

In [None]:
if exp == 'rl-feat-only':
    variable = 'eval/success_rate'

    time_series = wandb_load_time_series(df, variable)
    iqm_time_series = compute_iqm_time_series(time_series)
    iqm_time_series_smooth = smooth_time_series(iqm_time_series)


In [None]:
if exp == 'rl-feat-only':
    fig, ax = iqm_line_plot(iqm_time_series_smooth, labels=labels)
    ax.set_ylabel('success rate')
    ax.set_xlabel(r'training steps ($\times 10^6$)')
    ax.xaxis.get_offset_text().set_visible(False)

    fig.set_size_inches(mm2in(122 * 0.49, 25))
    fig.savefig('../local/paper/rl_successrate_fpsonly.pdf')


In [None]:
lines = fig.axes[0].lines

fig, ax = plt.subplots()
ax.axis('off')

run_sets = iqm_time_series_smooth.columns.droplevel(-1).unique()
ax.legend(lines, [labels[run_set] for run_set in run_sets], ncols=6)

fig.set_size_inches(mm2in(122, 3.5))
fig.savefig('../local/paper/rl_legend.pdf')


---

IQM Average Episode Length over Time

In [None]:
if exp == 'rl-feat-only':
    variable = 'eval/mean_ep_length'

    time_series = wandb_load_time_series(df, variable)
    iqm_time_series = compute_iqm_time_series(time_series)
    iqm_time_series_smooth = smooth_time_series(iqm_time_series)


In [None]:
if exp == 'rl-feat-only':
    fig, ax = iqm_line_plot(iqm_time_series_smooth, labels=labels)
    ax.set_ylabel('episode length')
    ax.set_xlabel(r'training steps ($\times 10^6$)')
    ax.xaxis.get_offset_text().set_visible(False)
    ax.set_ylim(0, None)

    fig.set_size_inches(mm2in(122 * 0.49, 25))
    fig.savefig('../local/paper/rl_episodelength_fpsonly.pdf')


---
---

**Feature Points + End Effector**

In [None]:
exp = 'rl-feat'
df = runs_grouped.filter_by({
    'experiment': ['rl-feat', 'rl-full', 'rl-key-5'],
    'training.sae_experiment': ['sae-keynet-vel-var-bg', 'sae-basic-vel-var-bg', 'sae-basic-fp32', 'sae-dsae', np.nan]
}).sort_index()


---

In [None]:
if exp == 'rl-feat':
    variable = 'eval/success_rate'

    time_series = wandb_load_time_series(df, variable)
    iqm_time_series = compute_iqm_time_series(time_series)
    iqm_time_series_smooth = smooth_time_series(iqm_time_series)


In [None]:
if exp == 'rl-feat':
    fig, ax = iqm_line_plot(iqm_time_series_smooth, labels=labels)
    ax.set_ylabel('success rate')
    ax.set_xlabel(r'training steps ($\times 10^6$)')
    ax.xaxis.get_offset_text().set_visible(False)

    fig.set_size_inches(mm2in(122 * 0.49, 25))
    fig.savefig('../local/paper/rl_successrate_fps+ee.pdf')


---

In [None]:
if exp == 'rl-feat':
    variable = 'eval/mean_ep_length'

    time_series = wandb_load_time_series(df, variable)
    iqm_time_series = compute_iqm_time_series(time_series)
    iqm_time_series_smooth = smooth_time_series(iqm_time_series)


In [None]:
if exp == 'rl-feat':
    fig, ax = iqm_line_plot(iqm_time_series_smooth, labels=labels)
    ax.set_ylabel('episode length')
    ax.set_xlabel(r'training steps ($\times 10^6$)')
    ax.xaxis.get_offset_text().set_visible(False)
    ax.set_ylim(0, None)

    fig.set_size_inches(mm2in(122 * 0.49, 25))
    fig.savefig('../local/paper/rl_episodelength_fps+ee.pdf')
