In [None]:
import os
import datetime
import numpy as np
from omegaconf import OmegaConf
import matplotlib.pyplot as plt

from explore.utils.vis import play_path
from explore.env.mujoco_sim import MjSim
from explore.datasets.utils import load_trees, get_diverse_paths

dataset = "../data/pandasTable_exp"

config_path = os.path.join(dataset, ".hydra/config.yaml")
cfg = OmegaConf.load(config_path)

ERROR_THRESH = cfg.RRT.min_cost
path_diff_thresh = cfg.RRT.path_diff_thresh

look_at_specific_start_idx = cfg.RRT.start_idx
look_at_specific_end_idx = cfg.RRT.end_idx
look_at_specific_end_idx = -1
cutoff = -1

q_mask = np.array(cfg.RRT.q_mask)
sim_cfg = cfg.RRT.sim

mujoco_xml = os.path.join("..", sim_cfg.mujoco_xml)

print(f"Looking at start_idx {look_at_specific_start_idx} and end_idx {look_at_specific_end_idx} with error threshold {ERROR_THRESH}.")
print(f"Tau action: {cfg.RRT.sim.tau_action}; Tau sim: {cfg.RRT.sim.tau_sim}")

In [None]:
tree_dataset = os.path.join(dataset, "trees")
trees, tree_count, total_nodes_count = load_trees(tree_dataset, cutoff, verbose=1)

if not q_mask.shape[0]:
    q_mask = np.ones_like(trees[0][0]["state"][1])

print("Loaded ", total_nodes_count, " RRT nodes.")

time_taken = float(np.loadtxt(os.path.join(dataset, "time_taken.txt")))
print(f"Time taken to generate tree: {datetime.timedelta(seconds=time_taken)}")

In [None]:
paths, traj_pairs = get_diverse_paths(
    trees, cfg.RRT.min_cost, q_mask, cfg.RRT.path_diff_thresh, min_len=10)

In [None]:
path_idx = np.random.randint(0, len(paths))
print(traj_pairs[path_idx])
start_idx, end_idx = traj_pairs[path_idx]
ctrl_dim = len(paths[path_idx][0][3])
state_dim = len(paths[path_idx][0][1])
print("ctrl_dim: ", ctrl_dim)
print("state_dim: ", state_dim)

plt.title("Ctrls")
for i in range(ctrl_dim):
    ctrls = [node[3][i] for node in paths[path_idx]]
    if i != 7 and i != 15:
        plt.plot(ctrls, label=f"ctrl {i}")

plt.legend()
plt.grid(True)
plt.show()

plt.title("States")
for i in range(state_dim):
    ctrls = [node[1][i] for node in paths[path_idx]]
    if i < 18:
        plt.plot(ctrls, label=f"state {i}")

plt.legend()
plt.grid(True)
plt.show()

In [None]:
start_state = trees[start_idx][0]["state"][1]
target_state = trees[end_idx][0]["state"][1]
sim = MjSim(
    mujoco_xml, tau_sim=sim_cfg.tau_sim, interpolate=sim_cfg.interpolate_actions,
    joints_are_same_as_ctrl=sim_cfg.joints_are_same_as_ctrl, view=False
)
frames, states = play_path(
    paths[path_idx], sim, start_state, target_state,
    tau_action=sim_cfg.tau_action, camera=cfg.RRT.sim.camera, reset_state=True)

In [None]:
states = np.array(states)
plt.title("States")
for i in range(state_dim):
    if i < 18:
        plt.plot(states[:, i], label=f"state {i}")

plt.legend()
plt.grid(True)
plt.show()