In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle

from adaptive_time.utils import set_directory_in_project

from importlib import reload
from joblib import Parallel, delayed
from tqdm.notebook import tqdm

from pprint import pprint


In [None]:
from adaptive_time import plot_utils
from adaptive_time import utils
from adaptive_time import run_lib
from adaptive_time import value_est
from adaptive_time.value_est import approx_integrators

approx_integrators = reload(approx_integrators)
run_lib = reload(run_lib)
value_est = reload(value_est)
plot_utils = reload(plot_utils)
utils = reload(utils)

In [None]:
set_directory_in_project()
data_dir = "./data"
env_names = [env_name for env_name in os.listdir(data_dir) if not env_name.startswith(".DS_Store")]
print(env_names)

In [None]:
samplers_tried = dict(
    q100=approx_integrators.AdaptiveQuadratureIntegrator(tolerance=100),
    q10=approx_integrators.AdaptiveQuadratureIntegrator(tolerance=10),
    q1=approx_integrators.AdaptiveQuadratureIntegrator(tolerance=1),
    q0=approx_integrators.AdaptiveQuadratureIntegrator(tolerance=0),
    u1=approx_integrators.UniformlySpacedIntegrator(1),
    u10=approx_integrators.UniformlySpacedIntegrator(50),
    u100=approx_integrators.UniformlySpacedIntegrator(500),
)

In [None]:
def compute_approx_integrals(
    reward_file: str,
    samplers_tried: dict,
):
    print(reward_file)
    reward_sequences = np.load(reward_file)
    idxes = np.where(reward_sequences[0, :][:, None] - reward_sequences[0, :][None, :] == 0)
    
    if len(idxes[0]) == len(idxes[1]):
        assert np.sum(idxes[0] - idxes[1]) == 0
    else:
        assert 0

    approx_integrals = {}
    num_pivots = {}
    for sampler_name, sampler in samplers_tried.items():
        approx_integrals[sampler_name] = []
        num_pivots[sampler_name] = []
        for idx, reward_seq in enumerate(reward_sequences):
            integral, all_pivots = sampler.integrate(reward_seq)
            approx_integrals[sampler_name].append(integral)
            num_pivots[sampler_name].append(len(all_pivots))
        approx_integrals[sampler_name] = np.array(approx_integrals[sampler_name])
        num_pivots[sampler_name] = np.array(num_pivots[sampler_name])

    return {
        "reward_file": reward_file,
        "approx_integrals": approx_integrals,
        "num_pivots": num_pivots,
    }

In [None]:
all_results = {}

In [None]:
for env_name in tqdm(env_names):
    if env_name in all_results:
        continue
    print("env: {}".format(env_name))

    env_dir = os.path.join(data_dir, env_name)
    all_results.setdefault(env_name, {})
    run_files = [run_file for run_file in os.listdir(env_dir) if not run_file.startswith(".DS_Store")]
    all_results[env_name] = Parallel(
        n_jobs=len(run_files)
    )(
        delayed(compute_approx_integrals)(
            os.path.join(env_dir, run_file),
            samplers_tried,
        )
        for run_file in run_files
    )

In [None]:
pickle.dump(all_results, open("./mujoco_val_est.pkl", "wb"))

In [None]:
assert 0

In [None]:
update_budget = 100_000_000

estimated_values_by_episode = {}
number_of_pivots_by_episode = {}
all_values_by_episode = {}
weights = np.ones(len(reward_sequences)) / len(reward_sequences)

true_value = np.sum(weights * np.sum(reward_sequences, axis=-1))

for sampler_name, sampler in tqdm(samplers_tried.items()):
    print("sampler_name:", sampler_name)
    # Update the value estimate with new samples until we run out of budget.
    used_updates = 0
    value_estimate = 0
    num_samples = 0
    all_values_by_episode[sampler_name] = []
    # empirical_state_distr = np.zeros((num_trajs))

    estimated_values_by_episode[sampler_name] = []
    number_of_pivots_by_episode[sampler_name] = []

    while used_updates < update_budget:
        num_samples += 1
        start_state = np.random.choice(len(reward_sequences), p=weights)
        # empirical_state_distr[start_state] += 1
        val_sample = approx_integrals[sampler_name][start_state]
        all_values_by_episode[sampler_name].append(val_sample)
        
        value_estimate += (1.0/num_samples) * (val_sample - value_estimate)
        used_updates += num_pivots[sampler_name][start_state]

        estimated_values_by_episode[sampler_name].append(value_estimate)
        number_of_pivots_by_episode[sampler_name].append(used_updates)
    
    # empirical_state_distr /= np.sum(empirical_state_distr)
    # empirical_value = approx_integrals[sampler_name] @ empirical_state_distr


# CODE TO SAMPLE MANY TRAJECOTRIES TO FIND AN EMPIRICAL DISTRIBUTION 
# episode_samples = 100_000
# sampled_start_states = np.random.choice(num_trajs, size=(episode_samples,), p=weights)
# # We now have samples, we determine the empirical state distribution.
# empirical_state_distr = np.zeros((num_trajs))
# values, counts = np.unique(sampled_start_states, return_counts=True)
# empirical_state_distr[values] = counts
# empirical_state_distr /= np.sum(empirical_state_distr)


In [None]:
# Verify final means.
for key, value in estimated_values_by_episode.items():
    mean_total = np.mean(all_values_by_episode[key])
    mean_updated = value[-1]
    print("sampler:", key, "mean_total:", mean_total, "mean_updated:", mean_updated)
    if abs(mean_total - mean_updated) > 0.01:
        assert False, f"Means don't match for {key}: {mean_total} vs {mean_updated}"
    


In [None]:
for s in samplers_tried.keys():
    if "q" in s:
        continue
    plt.plot(
        number_of_pivots_by_episode[s],
        np.abs(estimated_values_by_episode[s]-true_value),
        label=s)

plt.legend()
plt.ylabel("Error in value estimate")
plt.ylim(-20, 800)
plt.xlabel("Number of Samples")

In [None]:
for s in samplers_tried.keys():
    if "q" not in s:
        continue
    plt.plot(
        number_of_pivots_by_episode[s],
        np.abs(estimated_values_by_episode[s]-true_value),
        label=s)

plt.legend()
plt.ylabel("Error in value estimate")
plt.ylim(-20, 800)
plt.xlabel("Number of Samples")

In [None]:
s_to_plot = ["q1", "q0", "u10", "u100"]
# for s in samplers_tried.keys():
for s in s_to_plot:
    linestyle = "-" if "q" in s else "--"
    plt.plot(
        number_of_pivots_by_episode[s],
        np.abs(estimated_values_by_episode[s]-true_value),
        label=s, linestyle=linestyle)

plt.legend()
plt.ylabel("Error in value estimate")
plt.ylim(-2, 50)
plt.xlabel("Number of Samples")