In [None]:
import os
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
import matplotlib.pyplot as plt

from explore.env.mujoco_sim import MjSim
from explore.utils.vis import AdjMap, play_path
from explore.datasets.utils import cost_computation, load_trees

dataset = "../data/twoFingers_bitree"

config_path = os.path.join(dataset, ".hydra/config.yaml")
cfg = OmegaConf.load(config_path)

ERROR_THRESH = cfg.RRT.min_cost

look_at_specific_start_idx = cfg.RRT.start_idx
look_at_specific_end_idx = cfg.RRT.end_idx
look_at_specific_end_idx = -1

q_mask = np.array(cfg.RRT.q_mask)
sim_cfg = cfg.RRT.sim

mujoco_xml = os.path.join("..", sim_cfg.mujoco_xml)

print(f"Looking at start_idx {look_at_specific_start_idx} and end_idx {look_at_specific_end_idx} with error threshold {ERROR_THRESH}.")

## Load data

In [None]:
tree_dataset = os.path.join(dataset, "trees")
trees, tree_count, total_nodes_count = load_trees(tree_dataset, 1)

if not q_mask.shape[0]:
    q_mask = np.ones_like(trees[0][0]["state"][1])

print("Loaded ", total_nodes_count, " RRT nodes.")

## Cost evolution over nodes

In [None]:
if look_at_specific_start_idx != -1:
    si = look_at_specific_start_idx

    costs_over_time = []
    for i, node in tqdm(enumerate(trees[si])):

        costs = []
        
        for target_idx in range(tree_count):
            cost = cost_computation(trees[target_idx][0], node, q_mask)
            if i == 0 or costs_over_time[i-1][target_idx] > cost:
                costs.append(cost)
            else:
                costs.append(costs_over_time[i-1][target_idx])

        costs_over_time.append(costs)
    
    # Subtract 1 because it also contains a cost against itself
    mean_cost_over_time = [sum(costs)/(len(costs)-1) for costs in costs_over_time]
    found_paths_over_time = [len([1 for c in costs if c < ERROR_THRESH])-1 for costs in costs_over_time]

    fig, axes = plt.subplots(1, 2, figsize=(20, 8))

    axes[0].set_title("Avg. Cost Over Time")
    axes[0].plot(mean_cost_over_time, label="Mean Cost")
    
    if cfg.RRT.end_idx != -1:
        cost_for_target_over_time = [costs[cfg.RRT.end_idx] for costs in costs_over_time]
        axes[0].plot(cost_for_target_over_time, label=f"Target {cfg.RRT.end_idx}")
    
    axes[0].axhline(y=ERROR_THRESH, color="red", linestyle="--", label="Success Threshhold")
    axes[0].legend()
    
    axes[1].set_title(f"Found Paths Over Time for Start Node {look_at_specific_start_idx}")
    axes[1].plot(found_paths_over_time)

    plt.tight_layout()
    plt.show()

## Costs for all config pairs (without clustering)

In [None]:
top_nodes = []
min_costs = []
for i in tqdm(range(tree_count)):

    tree_min_costs = [float("inf") for _ in range(tree_count)]
    tree_top_nodes = [-1 for _ in range(tree_count)]
    
    for n, node in enumerate(trees[i]):
        for j in range(tree_count):
            node_cost = cost_computation(trees[j][0], node, q_mask)
            if node_cost < tree_min_costs[j]:
                tree_min_costs[j] = node_cost
                tree_top_nodes[j] = n
    
    top_nodes.append(tree_top_nodes)
    min_costs.append(tree_min_costs)

min_costs = np.array(min_costs)

AdjMap(min_costs, ERROR_THRESH, min_costs.max())

## Clustering

In [None]:
colors = [-1 for _ in range(tree_count)]

max_color_idx = 0
for i in range(tree_count):
        
    if colors[i] == -1:
        colors[i] = max_color_idx
        max_color_idx += 1
        
        for j in range(tree_count):
            if i != j and min_costs[i][j] <= ERROR_THRESH:
                if colors[j] != -1:
                    c = colors[j]
                    for k in range(tree_count):
                        if colors[k] == c:
                            colors[k] = colors[i]
                else:
                    colors[j] = colors[i]

groups = []
for c in colors:
    if not c in groups:
        groups.append(c)

group_sizes = [0 for _ in groups]
for c in colors:
    group_sizes[groups.index(c)] += 1

# print(colors)
# print(groups)
print("Group Count: ", len(groups))
print("Group Sizes: ", group_sizes)

## Looking at specific start and end configurations

In [None]:
start_idx = look_at_specific_start_idx
end_idx = look_at_specific_end_idx

if look_at_specific_start_idx != -1:
    costs = [min_costs[start_idx][i] for i in range(tree_count) if start_idx != i]
    print(f"Mean costs for start config {start_idx}: {sum(costs)/tree_count}, (Min: {min(costs)}, Max {max(costs)})")
    labels = [str(i) for i in range(tree_count) if i != start_idx]
    
    highlight_labels = {str(end_idx)}
    colors = ["red" if label in highlight_labels else "blue" for label in labels]
    plt.figure(figsize=(25, 5))
    plt.bar(labels, costs, color=colors)
    plt.xticks(rotation=90, ha="right")
    plt.axhline(y=ERROR_THRESH, color="red", linestyle="--")
    plt.title(f"Costs to reach each end config from start config {start_idx}")
    plt.xlabel("Config idx")
    plt.ylabel("Cost")
    plt.show()

    if end_idx != -1:
        print(f"Cost for target {end_idx} with start {start_idx}: {costs[end_idx]}")

## Collecting top paths

In [None]:
top_paths_data = []
top_costs = []
for i in range(tree_count):
    for j in range(tree_count):
        if look_at_specific_start_idx != -1 and i != look_at_specific_start_idx:
            continue
        if look_at_specific_end_idx != -1 and j != look_at_specific_end_idx:
            continue

        if (
            (look_at_specific_end_idx != -1 and look_at_specific_end_idx != -1) or
            (min_costs[i][j] < ERROR_THRESH and i != j)
            ):
            top_paths_data.append(
                ((i, j), top_nodes[i][j])
            )
            top_costs.append(min_costs[i][j])
top_paths = []
top_paths_start = []
top_paths_goal = []
for path_data in top_paths_data:

    start_idx = path_data[0][0]
    end_idx = path_data[0][1]
    if start_idx == end_idx: continue
    
    tree = trees[start_idx]
    
    node = tree[path_data[1]]
    path = []
    
    while True:
        path.append(node)
        if node["parent"] == -1: break
        node = tree[node["parent"]]
    
    path.reverse()
    assert path[0] == tree[0]

    top_paths.append(path)
    top_paths_start.append(start_idx)
    top_paths_goal.append(end_idx)

target_counts = []
for i, path in enumerate(top_paths):
    goal_idx = top_paths_goal[i]
    target_counts.append(0)
    for node in path:
        if goal_idx == node["target_config_idx"]:
            target_counts[-1] += 1

percs = [float(np.round(c/len(top_paths[i])*100)) for i, c in enumerate(target_counts)]
percs.sort()
percs.reverse()

possible_paths = tree_count**2 - tree_count
print("Top costs: ", [float(c) for c in top_costs])
print("Found Trajectories Count: ", len(top_paths), " of ", possible_paths)
if look_at_specific_start_idx == -1:
    print("When considering full graph: ", sum([v**2 for v in group_sizes]) - tree_count, " of ", possible_paths)

if not len(top_paths):
    print("No trajectories found!")

else:
    print("Percentage of reached config used as target: ", percs)
    print("Avg. use of reached config as target: ", sum(percs)/len(percs))

    path_lens = [len(p) for p in top_paths]

    # print("Path lens: [", end="")
    # for i in range(len(top_paths)):
    #     end = "]\n" if i == len(top_paths)-1 else ", "
    #     print(f"{path_lens[i]} ({top_paths_start[i]}, {top_paths_goal[i]})", end=end)

    path_lens.sort()
    path_lens.reverse()
    print("Path lengths: ", path_lens)
    print("Avg. Path length: ", sum(path_lens)/len(path_lens))

## Sample a single path

In [None]:
path = []
min_path_len = 2

valid_path_idxs = [i for i, l in enumerate(path_lens) if l >= min_path_len]

if not valid_path_idxs:
    print("No good paths no analyse!")

else:
    path_idx = np.random.choice(valid_path_idxs)
    path = top_paths[path_idx]
    start_idx = top_paths_start[path_idx]
    end_idx = top_paths_goal[path_idx]
    e_vec = trees[end_idx][0]["state"][1] - path[-1]["state"][1]

    print("---- Sampled Path Data ----")
    print("Target config ids: ", [n["target_config_idx"] for n in path])
    print("Start idx: ", start_idx)
    print("End idx: ", end_idx)
    print("Cost: ", top_costs[path_idx])
    print("Error vec: ", e_vec)
    print("Error vec (with mask): ", e_vec * q_mask)
    print("Sampled Path Length: ", len(path))

In [None]:
if path:
    start_state = trees[start_idx][0]["state"][1]
    target_state = trees[end_idx][0]["state"][1]
    sim = MjSim(mujoco_xml, tau_sim=sim_cfg.tau_sim,
                interpolate=sim_cfg.interpolate_actions, joints_are_same_as_ctrl=sim_cfg.joints_are_same_as_ctrl, view=True)
    play_path(start_state, target_state, path.copy(), sim, tau_action=sim_cfg.tau_action)

## Bi-directional analysis

In [None]:
if cfg.RRT.bidirectional and cfg.RRT.start_idx != -1 and cfg.RRT.start_idx != -1:
    si = cfg.RRT.start_idx
    ei = cfg.RRT.start_idx

    bi_tree_dataset = os.path.join(dataset, "trees")
    bi_trees, bi_tree_count, bi_total_nodes_count = load_trees(tree_dataset)
    assert bi_tree_count == tree_count
    
    print(f"Starting bi-directional analysis for trees with sizes {len(trees[si])} and {len(bi_trees[ei])}...")

    # Find connections to target without the bi-tree
    found_direct_nodes_ids = []
    for id, node in enumerate(trees[si]):
        cost_to_target = cost_computation(node, trees[ei][0], q_mask)
        if cost_to_target <= ERROR_THRESH:
            found_direct_nodes_ids.append(id)

    # Find connections using the bi-tree
    found_bi_nodes_ids = []
    for id, node in tqdm(enumerate(trees[si])):
        for bi_id, bi_node in enumerate(bi_trees[ei]):
            cost_to_target = cost_computation(node, bi_node, q_mask)
            if cost_to_target <= ERROR_THRESH:
                found_bi_nodes_ids.append((id, bi_id))

    print(f"------ Bi-Tree status for path from config {si} to {ei} ------")
    print(f"Node count in bi-tree: {len(bi_trees[ei])}")
    print(f"Found paths without bi-tree: {len(found_direct_nodes_ids)}")
    print(f"Found paths with bi-tree: {len(found_bi_nodes_ids)}")

    # TODO: Play path
    # TODO: Glue path with BBO