## Run options
Set `recalc` to True to force recalculating all results, otherwise, load the saved results from file.

Choose `easy` or `hard` to specify which dataset to test over.

In [None]:
recalc = True
difficulty = "easy"  # or "hard"

SIM_EVALS = 10     # PAPER VALUE: 100
tuning_iters = 10  # PAPER VALUE: 100
TUNENET_EPOCHS = 10 # PAPER VALUE: 200

In [None]:
import numpy as np
import os
from tune.definitions import ROOT_DIR, OUTPUT_DIR
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import importlib
from tune.utils import save_files, load_files, get_torch_device
from tune.train_tunenet_gt import INPUT_DIM, OUT_DIM, BATCH_SIZE
import tune.train_tunenet_gt
from tune.model_tunenet import TuneNet
import torch

%matplotlib inline
%load_ext jupyternotify

output_path = os.path.join(ROOT_DIR, OUTPUT_DIR)
print(ROOT_DIR)

In [None]:
# What kind of sim are we going to run?
from tune.ball_sim import BallSim
from tune.dataset_tunenet import DatasetTuneNet
importlib.reload(tune.ball_sim)
importlib.reload(tune.dataset_tunenet)

SimType = BallSim

if difficulty == "easy":
    prefix, test_loader, graph_title =  \
        "ball_gt", DatasetTuneNet.get_data_loader("tune", "ground_truth", "val", BATCH_SIZE), "GT Easy"
elif difficulty == "hard":
    prefix, test_loader, graph_title = \
       "ball_gt_hard", DatasetTuneNet.get_data_loader("tune_hard", "ground_truth", "val", BATCH_SIZE), "GT Hard"
else:
    raise Exception("I don't know what kind of difficulty you're going for (you told me '{}')".format(difficulty))

## Load the dataset

In [None]:
import torch
targets_loaded = []
with torch.no_grad():
    with SimType() as sim:
        for batch_idx, batch_data in enumerate(test_loader):
            zeta_batch = batch_data[0]
            s_batch = batch_data[1]
            print("max:")
            print(torch.max(zeta_batch[:, 1, 0]))
            print("min:")
            print(torch.min(zeta_batch[:, 1, 0]))
            print("mean:")
            print(torch.mean(torch.abs(zeta_batch[:, 1, 0] - zeta_batch[:, 0, 0])))
            targets_loaded.append(zeta_batch[:, 1, 0].detach().cpu().float())
            print("========")
print(targets_loaded)

# Use Test Mean as Estimate

In [None]:
targets = torch.cat(targets_loaded, dim=0)
print("overall target mean:")
mean = np.mean(targets.numpy())
print(mean)
constant_diffs = torch.abs(torch.tensor(mean).expand((len(targets))) - targets)
print(np.mean(constant_diffs.cpu().detach().numpy()))

# TuneNet Estimate

In [None]:
model_name = "tunenet_gt"
outnames = [model_name + "_history.pkl", "each_" + model_name + "_error.pkl"]

In [None]:
%%notify -m "finished testing TuneNet"
if recalc:
    importlib.reload(tune.train_tunenet_gt)
    
    model = TuneNet(INPUT_DIM, OUT_DIM).to(get_torch_device())
    model.load_state_dict(torch.load(os.path.join(output_path, model_name + "_{}.pth".format(TUNENET_EPOCHS))))
    with SimType() as sim:
        _, _, tunenet_history, each_tunenet_error = \
            tune.train_tunenet_gt.test(1, model, sim, test_loader,
                                       tuning_iterations=tuning_iters,
                                       display_graphs=False)

    tunenet_history = tunenet_history.cpu().detach().numpy()

    save_files(prefix, [tunenet_history, each_tunenet_error], outnames)

In [None]:
tunenet_history, each_tunenet_error = load_files(prefix, outnames)
fig, ax = plt.subplots()
print(each_tunenet_error.shape)
# ax.plot(np.mean(each_tunenet_error, axis=0, keepdims=False))
ax.plot(np.swapaxes(each_tunenet_error[1:30, :5], 0, 1))
print(np.amax(each_tunenet_error))

# Direct Neural Estimate

In [None]:
model_name = "tunenet_gt_direct"
outnames = [model_name + "_history.pkl", "each_" + model_name + "_error.pkl"]

In [None]:
%%notify -m "finished testing TuneNet Direct"
if recalc:
    model = TuneNet(INPUT_DIM, OUT_DIM, degenerate=True).to(get_torch_device())
    model.load_state_dict(torch.load(os.path.join(output_path, model_name + "_{}.pth".format(TUNENET_EPOCHS))))
    with SimType() as sim:
        _, _, tunenet_direct_history, each_tunenet_direct_error = \
            tune.train_tunenet_gt.test(1, model, sim, test_loader,
                                       tuning_iterations=1,
                                       display_graphs=False,
                                       incremental=False)

    tunenet_direct_history = tunenet_direct_history.cpu().detach().numpy()
    save_files(prefix, [tunenet_direct_history, each_tunenet_direct_error],
               outnames)
print(tunenet_direct_history)

In [None]:
tunenet_direct_history, each_tunenet_direct_error = load_files(prefix, outnames)
# Tile the direct (single-iteration) tunenet result so its first dimension matches the other results
tunenet_direct_history = np.pad(tunenet_direct_history,
                                ((0, 0), (0, tuning_iters-tunenet_direct_history.shape[1]+1)),
                                mode="edge")
each_tunenet_direct_error = np.pad(each_tunenet_direct_error,
                                   ((0, 0), (0, tuning_iters-each_tunenet_direct_error.shape[1]+1)),
                                   mode="edge")
fig, ax = plt.subplots()
print(each_tunenet_direct_error.shape)
# ax.plot(np.mean(each_tunenet_error, axis=0, keepdims=False))
ax.plot(np.swapaxes(each_tunenet_direct_error[1:30, :5], 0, 1))
print(np.amax(each_tunenet_direct_error))

# CMA-ES Estimate

In [None]:
%%notify -m "finished CMA-ES"
if recalc:
    from tune.cma import do_cma, do_cma_over_dataset
    from tune.utils import exec_sim
    import torch
    importlib.reload(tune.cma)

    with SimType() as sim:
        cma_evals, cma_estimates, cma_targets = do_cma_over_dataset(test_loader, sim, maxfevals=SIM_EVALS, popsize=10)

    save_files(prefix, [cma_evals, cma_estimates, cma_targets],
               ["cma_evals.pkl", "cma_estimates.pkl", "cma_targets.pkl"])

In [None]:
cma_evals, cma_estimates, cma_targets = load_files(prefix,
           ["cma_evals.pkl", "cma_estimates.pkl", "cma_targets.pkl"])

# Greedy Entropy Search Estimate

In [None]:
%%notify -m "finished Greedy Entropy Search"
if recalc:
    from tune.entsearch import entsearch_over_dataset
    from tune.utils import get_timestamp
    importlib.reload(tune.entsearch)

    print(get_timestamp())
    # range of possible parameter values
    big_theta = np.linspace(np.min(targets.numpy()), np.max(targets.numpy()), 50)
    epsilon = 0.001
    print(big_theta)
    # max number of sims
    k_max = SIM_EVALS
    # population size
    n = 200

    print("Performing entropy search...")
    entsearch_P_history, entsearch_estimates, entsearch_targets = None, None, None
    with SimType() as sim:
        entsearch_P_history, entsearch_estimates, entsearch_targets = \
            entsearch_over_dataset(test_loader, sim, big_theta, epsilon, k_max, n)
    print("Entropy search complete.")
    print(get_timestamp())

    save_files(prefix, [entsearch_P_history, entsearch_estimates, entsearch_targets],
               ["entsearch_P_history.pkl", "entsearch_estimates.pkl", "entsearch_targets.pkl"])

In [None]:
entsearch_P_history, entsearch_estimates, entsearch_targets = load_files(prefix,
    ["entsearch_P_history.pkl", "entsearch_estimates.pkl", "entsearch_targets.pkl"])

# Munge results

In [None]:
# print(cma_estimates)
# print(type(cma_estimates))
# print(type(cma_estimates[0,0]))

# calculate the error for each tuning instance wrt that instance's target parameter value
each_cma_error = np.abs(cma_targets[:, 0:1] - cma_estimates[:, :, 0])
each_entsearch_error = np.abs(entsearch_targets[:, np.newaxis] - entsearch_estimates[:, :])
# calculate mean error
mean_cma_error = np.mean(each_cma_error, axis=0)
mean_tunenet_error = np.mean(each_tunenet_error, axis=0, keepdims=False)
mean_tunenet_direct_error = np.mean(each_tunenet_direct_error, axis=0, keepdims=False)
mean_entsearch_error = np.mean(each_entsearch_error, axis=0, keepdims=False)

# need two points to define a line for plotting, so duplicate the result
each_constant_error = constant_diffs.unsqueeze(1).repeat([1, 2]).numpy()
# also need x-values to plot
constant_x = [0, SIM_EVALS]

# find average across all runs
mean_constant_error = np.mean(each_constant_error, axis=0, keepdims=False)

# Plot

In [None]:
num_plots = 20
# cma_color = [1.0, 0.5, 0.0, 1.0]
# tunenet_color = [0.0, 0.0, 1.0, 1.0]
# tunenet_direct_color = [0.0, 0.7, 0.7, 1.0]
# entsearch_color = [0.0, 0.5, 0.0, 1.0]


tunenet_color = (np.array([230,97,1,255]) / 255).tolist()
tunenet_direct_color = (np.array([253,184,99,255]) / 255).tolist()
cma_color = (np.array([178,171,210,255]) / 255).tolist()
entsearch_color = (np.array([94,60,153,255]) / 255).tolist()
constant_color = [0.5, 0.5, 0.5, 0.5]

label_list = ["TuneNet",
              "Direct Prediction",
              "CMA-ES",
              "EntSearch",
              "Mean",
             ]
error_x_list = [np.asarray(list(range(len(mean_tunenet_error)))),
                np.asarray(list(range(len(mean_tunenet_direct_error)))),
                cma_evals[0],
                np.asarray(list(range(len(mean_entsearch_error)))),
                constant_x,
               ]
error_y_list = [mean_tunenet_error,
                mean_tunenet_direct_error,
                mean_cma_error,
                mean_entsearch_error,
                mean_constant_error,
               ]
error_each_list = [each_tunenet_error,
                   each_tunenet_direct_error,
                   each_cma_error,
                   each_entsearch_error,
                   each_constant_error,
                  ]
color_list = [tunenet_color, 
              tunenet_direct_color,
              cma_color,
              entsearch_color,
              constant_color,
             ]
        
# some quality of life functions
def lighten_value(val, amt):
    return val + (1-val)*amt

def faint(color, lighten=0.65, alpha=0.15):
    return [
        lighten_value(color[0], lighten),
        lighten_value(color[1], lighten),
        lighten_value(color[2], lighten),
        alpha
    ]

# Traces figure
for label, error_each, error_x, error_y, color in \
        zip(label_list, error_each_list, error_x_list, error_y_list, color_list):    
    fig, ax = plt.subplots()
    n_averaged = len(test_loader.dataset)
    for e in error_each:
        ax.plot(error_x, e, color=faint(color))
    ax.plot(error_x, np.mean(error_each, axis=0), color=color, label=label)
    
    ax.tick_params(
      axis='y',          # changes apply to the x-axis
      which='both',      # both major and minor ticks are affected
      right=False,      # ticks along the bottom edge are off
      left=False,         # ticks along the top edge are off
      labelbottom=False) # labels along the bottom edge are off
    ax.get_xaxis().set_visible(True)
    ax.get_xaxis().set_visible(True)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_xlabel("Simulation rollouts")
    ax.set_ylabel("COR value error")
    ax.set_xlim([0, 60*(error_x[1]-error_x[0])])
    ax.set_ylim([0, 0.3])
    ax.legend()
    ax.set_title(label)
    fig.tight_layout()
    filename = 'iterations_vs_performance_{}_{}.pdf'.format(prefix,
                                                            label.lower().replace('-', '').replace(' ', '_'))
    fig.savefig(os.path.join(ROOT_DIR, OUTPUT_DIR, prefix, filename))

    
# report the best value over this many iterations
min_over_iterations_list = [1, 5, 10, 100]
print("N = " + str(min_over_iterations_list))
print("minimum value over the first N iterations")
for label, error_each, x, error_y, color in \
        reversed(list(zip(label_list, error_each_list, error_x_list, error_y_list, color_list))):

    format_string = "{}"
    mins = []
    for n in min_over_iterations_list:
        up_thru_idx = 1
        for idx, xval in enumerate(x):
            if xval <= n and idx > up_thru_idx:
                up_thru_idx = idx
        print(up_thru_idx)
#         print(error_y)
#         mins.append(np.min(error_y[:up_thru_idx]))
        mins.append(error_y[up_thru_idx])
        format_string += " & {:.4f}"
    print(format_string.format((label + " "*20)[:20], *mins))

    
plt.rcParams.update({'font.size': 14})
# Combined figure
fig, ax = plt.subplots()
for label, error_each, error_x, error_y, color in \
        reversed(list(zip(label_list, error_each_list, error_x_list, error_y_list, color_list))):
    ax.plot(error_x, error_y, color=color, label=label, linewidth=2)

# base line
# ax.axhline(y=0, linestyle="--", color=[0.0, 0.0, 0.0, 0.5], label="target")
ax.tick_params(
  axis='y',          # changes apply to the x-axis
  which='both',      # both major and minor ticks are affected
  right=True,      # ticks along the bottom edge are off
  left=True,         # ticks along the top edge are off
  labelbottom=False) # labels along the bottom edge are off
ax.get_xaxis().set_visible(True)
ax.get_xaxis().set_visible(True)
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_xlabel("Simulation rollouts")
ax.set_ylabel("Parameter MAE".format(n_averaged))
ax.set_xlim([-1, 100])
# ax.set_ylim([0, 0.15])
ax.legend()
ax.set_title(graph_title)

fig.tight_layout()
filename = 'error_vs_iterations_{}_all.pdf'.format(prefix)
fig.savefig(os.path.join(ROOT_DIR, OUTPUT_DIR, prefix, filename))

In [None]:
def midpoint_integration(x, y):
    """
    Calculate area under a curve using midpoint summation.
    :param x: vector of x-values
    :param y: vector of y-values
    """
    total = 0
    for i in range(len(x)-1):
        total += ((y[i+1] - y[i])/2 + y[i])*(x[i+1] - x[i])
    return total

# some quick tests
print(midpoint_integration([1, 2, 3], [1, 1, 1])) # rectangle, should = 2
print(midpoint_integration([1, 2, 3], [4, 0, 4])) # two triangles, should = 4
# complex, should = 23.5
print(midpoint_integration([0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 100], [3, 3, 1, 1, 4, 1, 2, 2, 4, 1, 0, 0]))

In [None]:
from numpy.polynomial.polynomial import polyfit
import scipy.stats

print("Area under the error curve, which is averaged over 100 datapoints, 1 run per datapoint, N simulations/run")

error = {}
padlen = max([len(l) for l in label_list]) + 4
for n in [1, 2, 3, 5, 10, 20, 100]:
    print("N = {} simulations:".format(n))
    _, ax_lin = plt.subplots()
    _, ax_log = plt.subplots()
    for typ, x, y, each_y, color in zip(label_list, error_x_list, error_y_list, error_each_list, color_list):
        if n not in x:
            print("  <skipping {}, cannot find data for exactly N={} simulations>".format(typ, n))
            continue
        up_thru_idx = np.where(x==n)[0][0]+1
        label_part = "  {}:".format(typ) + ' '*padlen
        print(label_part[:padlen] + 
              "{:.3f}".format(midpoint_integration(x[:up_thru_idx], y[:up_thru_idx])))
        if typ not in error:
            error[typ] = {'err_initial': {}, 'err_final': {}}
        error[typ]['err_initial'] = each_y[:, 0]
        error[typ]['err_final'][n] = each_y[:, up_thru_idx-1]
#         print(error[typ]['err_initial'].shape)
        
        ax_lin.set_title("N = {} simulations, y-axis linear".format(n))
        slope, intercept, r_value, p_value, std_err = \
            scipy.stats.linregress(error[typ]['err_initial'], error[typ]['err_final'][n])
        endpoints = np.array([0, 0.7])
        ax_lin.plot(endpoints, slope * endpoints + intercept, '-', color=faint(color, alpha=0.5, lighten=0.0))
        fit_label = ", slope={:.2f}, $r^2$={:.2f}".format(slope, r_value**2)
        
        ax_log.set_title("N = {} simulations, y-axis log".format(n))
        ax_log.set_yscale('log')
        
        for ax, label_suffix in zip([ax_lin, ax_log], [fit_label, ""]):
            ax.scatter(error[typ]['err_initial'], error[typ]['err_final'][n], label=typ + label_suffix, color=color, s=3)
            ax.legend()
            ax.set_xlabel("initial guess absolute error")
            ax.set_ylabel("final prediction absolute error")
    print("")
            
    # calculate mean and standard deviation of error across population

    

In [None]:
%%notify
pass