In [None]:
import os
import cma
import datetime
import numpy as np
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline

from explore.utils.vis import play_path
from explore.env.mujoco_sim import MjSim
from explore.datasets.utils import load_trees, get_diverse_paths

dataset = "../data/fingerRamp_exp"

config_path = os.path.join(dataset, ".hydra/config.yaml")
cfg = OmegaConf.load(config_path)

ERROR_THRESH = cfg.RRT.min_cost
path_diff_thresh = cfg.RRT.path_diff_thresh

look_at_specific_start_idx = cfg.RRT.start_idx
look_at_specific_end_idx = cfg.RRT.end_idx
look_at_specific_end_idx = -1
cutoff = -1

q_mask = np.array(cfg.RRT.q_mask)
sim_cfg = cfg.RRT.sim

mujoco_xml = os.path.join("..", sim_cfg.mujoco_xml)

print(f"Looking at start_idx {look_at_specific_start_idx} and end_idx {look_at_specific_end_idx} with error threshold {ERROR_THRESH}.")
print(f"Tau action: {cfg.RRT.sim.tau_action}; Tau sim: {cfg.RRT.sim.tau_sim}")

In [None]:
tree_dataset = os.path.join(dataset, "trees")
trees, tree_count, total_nodes_count = load_trees(tree_dataset, cutoff, verbose=1)

if not q_mask.shape[0]:
    q_mask = np.ones_like(trees[0][0]["state"][1])

print("Loaded ", total_nodes_count, " RRT nodes.")

time_taken = float(np.loadtxt(os.path.join(dataset, "time_taken.txt")))
print(f"Time taken to generate tree: {datetime.timedelta(seconds=time_taken)}")

In [None]:
paths, traj_pairs = get_diverse_paths(
    trees,
    cfg.RRT.min_cost,
    q_mask,
    cfg.RRT.path_diff_thresh,
    min_len=10,
    cached_folder=dataset
)

In [None]:
path_idx = np.random.randint(0, len(paths))
reference_path = paths[path_idx]
print("Start and end idxs: ", traj_pairs[path_idx])
start_idx, end_idx = traj_pairs[path_idx]
ctrl_dim = len(reference_path[0][3])
state_dim = len(reference_path[0][1])
print("ctrl_dim: ", ctrl_dim)
print("state_dim: ", state_dim)
print("Path length: ", len(reference_path))

plt.title("Ctrls")
for i in range(ctrl_dim):
    ctrls = [node[3][i] for node in reference_path]
    plt.plot(ctrls, label=f"ctrl {i}")

plt.legend()
plt.grid(True)
plt.show()

plt.title("States")
for i in range(state_dim):
    ctrls = [node[1][i] for node in reference_path]
    plt.plot(ctrls, label=f"state {i}")

plt.legend()
plt.grid(True)
plt.show()

In [None]:
start_state = trees[start_idx][0]["state"][1]
target_state = trees[end_idx][0]["state"][1]
sim = MjSim(
    mujoco_xml, tau_sim=sim_cfg.tau_sim, interpolate=sim_cfg.interpolate_actions,
    joints_are_same_as_ctrl=sim_cfg.joints_are_same_as_ctrl, view=False
)
frames, states = play_path(
    reference_path, sim, start_state, target_state,
    tau_action=sim_cfg.tau_action, camera=cfg.RRT.sim.camera, reset_state=True)

In [None]:
states = np.array(states)
plt.title("States")
for i in range(state_dim):
    if i < ctrl_dim:
        plt.plot(states[:, i], label=f"state {i}")

plt.legend()
plt.grid(True)
plt.show()

## Initial Spline Guess

In [None]:
class BSpline:
    def __init__(self, start_pos: np.ndarray, end_pos: np.ndarray, end_time: float, degree: int=6):
        self.start_time = 0.
        self.end_time = end_time
        self.start_pos = start_pos
        self.end_pos = end_pos
        self.degree = degree
        
    def compute(self, points: np.ndarray):
        points = np.concatenate((self.start_pos.reshape(1, -1), points))
        times = np.linspace(self.start_time, self.end_time, points.shape[0])
        
        self.splines = []
        for dim in range(points.shape[1]):
            s = make_interp_spline(times, points[:, dim], k=self.degree)
            self.splines.append(s)
        
    def eval(self, t: float):
        point = np.array([s(t) for s in self.splines])
        return point

def eval_spline(points: np.ndarray, sim: MjSim, vis: bool=False) -> float:

    sim.pushConfig(stable_configs[start_idx], stable_configs_ctrl[start_idx])
    tau_action = (spline.end_time - spline.start_time) / 100
    view = tau_action if vis else 0

    spline.compute(points)
    ts = np.linspace(spline.start_time, spline.end_time, 100)
    for t in ts[1:]:
        ctrl = spline.eval(t)
        sim.step(tau_action, ctrl, view)
    
    final_state = sim.getState()[1]
    e = final_state - stable_configs[end_idx]
    result = e.T @ e
    return result

def eval_splines(candidates: np.ndarray, sim: MjSim) -> list[float]:

    results = []
    for c in candidates:
        v = eval_spline(c.reshape(-1, 6), sim, vis=False)
        results.append(v)
    
    return results

In [None]:
path_duration = len(reference_path)*0.5
spline = BSpline(reference_path[0][3], reference_path[-1][3], path_duration, degree=2)

initial_guess_points = np.array([node[3] for node in reference_path[1:]])
initial_guess = initial_guess_points.flatten()

print("Decision variables: ", initial_guess.shape)

spline.compute(initial_guess_points)
spline_path = np.array([spline.eval(t) for t in np.linspace(0, path_duration, int(path_duration/sim_cfg.tau_sim))])
print(spline_path.shape)
for i in range(ctrl_dim):
    ctrls = [node[3][i] for node in reference_path]
    plt.plot(ctrls, label=f"ctrl {i}")
    plt.plot(spline_path[:, i])
plt.grid(True)
plt.legend()
plt.show()

In [None]:
es = cma.CMAEvolutionStrategy(initial_guess, 0.5, {
    "popsize": 32,
    "maxfevals": 100,
    "verbose": -1
})

while not es.stop():
    candidates = es.ask()

    results = eval_splines(candidates, sim)

    es.tell(candidates, results)
    es.disp()

print(f"Done! with cost {es.result.fbest}")

In [None]:
r = es.result.xbest
points = np.concatenate((spline.start_pos.reshape(1, -1), r.reshape(-1, 6)))
times = np.linspace(spline.start_time, spline.end_time, points.shape[0])

spline.compute(points)
ts = np.linspace(spline.start_time, spline.end_time, 100)
ctrls = np.array([spline.eval(t) for t in ts])