In [1]:
from pathlib import Path
import numpy as np
import pandas as pd

from gc_ope.evaluate.evaluation_result_container import WeightedEvaluationResultContainer
from gc_ope.evaluate.evaluator_kde import KDEEvaluator

PROJECT_ROOT_DIR = Path().absolute().parent.parent
PROJECT_ROOT_DIR

PosixPath('/home/gxd/code')

In [59]:
def get_success_desired_goals(csv_path: Path) -> pd.DataFrame:
    eval_res_df = pd.read_csv(csv_path)
    eval_res_df["success"] = eval_res_df["termination"] == "reach target"
    success_rows = eval_res_df[eval_res_df["success"]]
    return success_rows

In [60]:
kde_evaluator = KDEEvaluator(
    evaluation_result_container_class=WeightedEvaluationResultContainer,
    evaluation_result_container_kwargs=dict(
        discounted_factor=0.9,
    ),
    kde_bandwidth=1.0,
    kde_kernel="gaussian",
)

In [61]:
for res_csv in [
    PROJECT_ROOT_DIR / "checkpoints/flycraft/test_sac2/best_model_eval_res.csv",
    PROJECT_ROOT_DIR / "checkpoints/flycraft/test_sac2/rl_model_70000_steps_eval_res.csv",
    PROJECT_ROOT_DIR / "checkpoints/flycraft/test_sac2/rl_model_90000_steps_eval_res.csv",
]:
    success_rows = get_success_desired_goals(csv_path=res_csv)

    kde_evaluator.eval_res_container.add_batch(
        desired_goal_batch=success_rows.iloc[:][["v", "mu", "chi"]].to_numpy(),
        success_batch=[True] * success_rows.shape[0],
        cumulative_reward_batch=success_rows.iloc[:]["cumulative_rewards"].to_numpy(),
        discounted_cumulative_reward_batch=success_rows.iloc[:]["discounted_cumulative_rewards"].to_numpy(),
        desired_goal_weight_batch=[1.0] * success_rows.shape[0],
    )

In [62]:
scaled_dgs, dg_weights, dg_densities = kde_evaluator.fit_evaluator()

for i_dg, i_weight, i_density in zip(scaled_dgs, dg_weights, dg_densities):
    print(i_dg, i_weight, i_density)

[-1.2995271   1.24143006 -1.43800877] 0.81 0.01210900545550918
[ 0.34558012  0.30174515 -0.3078048 ] 0.81 0.01823621094594546
[1.13637192 0.81203895 1.3581045 ] 0.81 0.008160752661177554
[-1.66467929  1.63013617 -1.05365596] 0.81 0.01187677262549155
[-0.8226689   0.55666964  0.23577478] 0.81 0.021298919768589348
[ 0.28260527 -1.71686491 -0.99809756] 0.81 0.011007939928891094
[ 1.12631222 -1.04956096 -0.61521175] 0.81 0.016474957225121954
[-1.095292    1.55262658 -0.55363823] 0.81 0.015456259601157705
[-1.04467516  0.82929106  0.46522558] 0.81 0.017983366711721842
[ 0.84918315 -0.70421606 -1.12299925] 0.9 0.015082943059509663
[-0.81886927  0.33576874 -0.02810466] 0.9 0.021965826579279084
[-0.84887511  0.27712163 -0.37086665] 0.9 0.02086779554491268
[ 1.12631222 -1.04956096 -0.61521175] 0.9 0.016474957225121954
[ 1.05536594 -0.96774354  0.73112629] 1.0 0.01566267010962694
[-0.38557753 -0.21360454  0.60084811] 1.0 0.018845253591111503
[ 1.18921196 -1.18247986  2.28323341] 1.0 0.0086335925