In [None]:
import sys as _sys
import os

current_path = os.path.abspath(os.getcwd())

split = current_path.split("inverse_geometric_locomotion")
if len(split)<2:
    print("Please rename the repository 'inverse_geometric_locomotion'")
    raise ValueError
path_to_python_scripts = os.path.join(split[0], "inverse_geometric_locomotion/python/")
path_to_notifications = os.path.join(split[0], "inverse_geometric_locomotion/notebooks/notifications/")
path_to_settings = os.path.join(split[0], "inverse_geometric_locomotion/python/figures/")
path_to_cubic_splines = os.path.join(split[0], "inverse_geometric_locomotion/ext/torchcubicspline/")
path_to_output = os.path.join(split[0], "inverse_geometric_locomotion/output/")
path_to_data = path_to_output
path_to_save = os.path.join(path_to_output, "snake_ff_obstacle")

_sys.path.insert(0, path_to_python_scripts)
_sys.path.insert(0, path_to_settings)
_sys.path.insert(0, path_to_cubic_splines)

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import shutil
import torch

from obstacle_implicits import SphereSquareImplicit
from vis_utils import produce_video_from_path, print_json_data
from vis_utils_snake import plot_animated_snake

In [None]:
exp_file_names = [
    "snake_ff_obstacle/snake_ff_obstacle_opt_00.json",
    "snake_ff_obstacle/snake_ff_obstacle_opt_01.json",
]

trial_numbers = [int(fn.split('_')[-1].split('.')[0]) for fn in exp_file_names]
exp_names = ["snake_ff_obstacle_{:02d}".format(tn) for tn in trial_numbers]

list_js_loads = []
for exp_file_name in exp_file_names:
    with open(os.path.join(path_to_data, exp_file_name)) as jsonFile:
        js_load = json.load(jsonFile)

    print(exp_file_name)
    print_json_data(js_load)
    list_js_loads.append(js_load)

In [None]:
objectives = []
for js in list_js_loads:
    n_ts = np.array(js['pos']).shape[0]
    objectives.append(js["optimization_evolution"]["obj_values"])

fig = plt.figure(figsize=(6, 3))
gs = fig.add_gridspec(1, 1)
ax_tmp = fig.add_subplot(gs[0, 0])
for obj in objectives:
    ax_tmp.plot(np.array(obj), lw=3.0, zorder=0)
ax_tmp.set_title("Objectives")
ax_tmp.set_yscale("log")
plt.show()


In [None]:
js_load = list_js_loads[1]

pos = np.array(js_load['pos'])
n_ts = pos.shape[0]
g = np.array(js_load['g'])
gt = np.array(js_load['optimization_settings']['gt'])
params_opt = np.array(js_load['optimization_settings']['params_opt'])
snake_length = js_load['optimization_settings']['snake_length']
obstacle_params = torch.tensor(js_load['optimization_settings']['obstacle_params'])
obstacle = SphereSquareImplicit(obstacle_params)

# plot the contour of the circle
ts_circle = np.linspace(0, 2 * np.pi, 100)
circle_x = obstacle_params[-1] * np.cos(ts_circle) + obstacle_params[0]
circle_y = obstacle_params[-1] * np.sin(ts_circle) + obstacle_params[1]

In [None]:
alphas = np.linspace(0.1, 1.0, n_ts)
fig = plt.figure(figsize=(10, 6))
gs = fig.add_gridspec(1, 1)
ax_tmp = fig.add_subplot(gs[0, 0])
for id_step in range(n_ts):
    ax_tmp.plot(pos[id_step, :, 0], pos[id_step, :, 1], lw=3.0, c='tab:blue', alpha=alphas[id_step], zorder=0)
    ax_tmp.plot(pos[0, :, 0], pos[0, :, 1], lw=3.0, c='tab:blue', alpha=alphas[id_step], zorder=0)
    ax_tmp.scatter(g[id_step, 4], g[id_step, 5], marker='x', s=30.0, c='tab:blue', alpha=alphas[id_step], zorder=0)
ax_tmp.scatter(gt[4], gt[5], marker='o', s=30.0, c='tab:orange', alpha=1.0, zorder=1)

ax_tmp.set_xlim(-snake_length, gt[4] + snake_length)
ax_tmp.set_ylim(-1.5 * obstacle_params[-1], 1.5 * obstacle_params[-1])
xy_lim = np.zeros(shape=(2, 2))
xy_lim[0] = ax_tmp.get_xlim()
xy_lim[1] = ax_tmp.get_ylim()

n_plot = 1000
n_levels = 15
x_plot = torch.linspace(ax_tmp.get_xlim()[0], ax_tmp.get_xlim()[1], n_plot)
y_plot = torch.linspace(ax_tmp.get_ylim()[0], ax_tmp.get_ylim()[1], n_plot)
print(ax_tmp.get_xlim())
print(ax_tmp.get_ylim())
xyz_plot = torch.stack([
    torch.tile(x_plot, dims=(n_plot,)),
    torch.repeat_interleave(y_plot, repeats=n_plot, dim=0),
    torch.zeros(size=(n_plot*n_plot,))
], dim=1)

sdfs = obstacle.evaluate_implicit_function(xyz_plot).reshape(n_plot, n_plot)

min_sdf, max_sdf = torch.min(sdfs), torch.max(sdfs)
max_abs_sdf = max(torch.abs(min_sdf), torch.abs(max_sdf))
levels = np.linspace(-max_abs_sdf, max_abs_sdf, n_levels)
ax_tmp.contourf(x_plot, y_plot, sdfs, levels=levels, cmap='coolwarm', zorder=-2)

ax_tmp.plot(circle_x, circle_y, lw=3.0, c='k', zorder=0)

ax_tmp.set_aspect('equal')
plt.axis('off')
plt.savefig(os.path.join(path_to_save, "sdf_with_snake.png"), dpi=300, bbox_inches='tight')
plt.show()

In [None]:
exp_id = 0
js_load = list_js_loads[exp_id]
exp_name = exp_names[exp_id]
pos_plot = np.array(js_load['pos'])
g_plot = np.array(js_load['g'])
broken_joint_ids = []

path_to_images_anim = os.path.join(path_to_save, "images_registered")
arrow_params = {
    "length": 0.05,
    "width": 0.02,
}

# clear existing images
if os.path.exists(path_to_images_anim):
    shutil.rmtree(path_to_images_anim)
os.makedirs(path_to_images_anim)

plot_animated_snake(
    pos_plot, path_to_images_anim,
    g=g_plot, gt=gt, gcp=None,
    broken_joint_ids=broken_joint_ids, obstacle=obstacle,
    exponent=1.0, xy_lim=xy_lim, 
    show_orientation=False, show_snake_trail=False, 
    show_g_trail=True, show_g_start=True, show_joints=False,
    arrow_params=arrow_params,
)

In [None]:
fn_pattern = os.path.join(path_to_images_anim, "step_%05d.png")
produce_video_from_path(
    fn_pattern, path_to_save, 
    "{}_pos.mp4".format(exp_name), overwrite_anim=True, transparent=False
)