In [None]:
import sys as _sys
import os

current_path = os.path.abspath(os.getcwd())

split = current_path.split("inverse_geometric_locomotion")
if len(split)<2:
    print("Please rename the repository 'inverse_geometric_locomotion'")
    raise ValueError
path_to_python_scripts = os.path.join(split[0], "inverse_geometric_locomotion/python/")
path_to_notifications = os.path.join(split[0], "inverse_geometric_locomotion/notebooks/notifications/")
path_to_settings = os.path.join(split[0], "inverse_geometric_locomotion/python/figures/")
path_to_cubic_splines = os.path.join(split[0], "inverse_geometric_locomotion/ext/torchcubicspline/")
path_to_output = os.path.join(split[0], "inverse_geometric_locomotion/output/")
path_to_data = path_to_output
path_to_save = os.path.join(path_to_output, "snake_teaser")

if not os.path.exists(path_to_save):
    os.makedirs(path_to_save)

_sys.path.insert(0, path_to_python_scripts)
_sys.path.insert(0, path_to_settings)
_sys.path.insert(0, path_to_cubic_splines)

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import shutil
import torch

from torchcubicspline import (
    natural_cubic_spline_coeffs, NaturalCubicSpline
)

from obstacle_implicits import generate_siggraph_implicit, TranslateImplicit
from vis_utils import produce_video_from_path, print_json_data, plot_sdf_2d
from vis_utils_snake import plot_animated_snake

In [None]:
exp_file_names = [
    "snake_teaser/snake_teaser_opt_00.json",
]

trial_numbers = [int(fn.split('_')[-1].split('.')[0]) for fn in exp_file_names]
exp_names = ["snake_teaser_{:02d}".format(tn) for tn in trial_numbers]

list_js_loads = []
for exp_file_name in exp_file_names:
    with open(os.path.join(path_to_data, exp_file_name)) as jsonFile:
        js_load = json.load(jsonFile)

    print(exp_file_name)
    print_json_data(js_load)
    list_js_loads.append(js_load)

In [None]:
xyz_objectives = []
xyz_wavelengths = []
list_wavelengths = []
for js in list_js_loads:
    n_ts = np.array(js['pos']).shape[0]
    objectives = torch.tensor(js["optimization_evolution"]["obj_values"])
    xyz_tmp = torch.zeros(size=(objectives.shape[0], 3))
    xyz_tmp[:, 0] = torch.arange(objectives.shape[0])
    xyz_tmp[:, 1] = objectives
    xyz_objectives.append(xyz_tmp.tolist())

    cps_opt = torch.tensor(js['optimization_settings']['params_opt']).reshape(-1, 3)
    n_cp = cps_opt.shape[0]
    ts = torch.linspace(0.0, 1.0, n_ts)
    ts_cp = torch.linspace(0.0, 1.0, n_cp)
    spline_coeffs = natural_cubic_spline_coeffs(ts_cp, cps_opt)
    spline = NaturalCubicSpline(spline_coeffs)
    wl_opt = spline.evaluate(ts)[:, 2]
    list_wavelengths.append(wl_opt.numpy())

fig = plt.figure(figsize=(6, 3))
gs = fig.add_gridspec(1, 1)
ax_tmp = fig.add_subplot(gs[0, 0])
for xyz_tmp in xyz_objectives:
    ax_tmp.plot(np.array(xyz_tmp)[:, 1], lw=3.0, zorder=0)
ax_tmp.set_title("Objectives")
ax_tmp.set_yscale("log")
plt.show()

fig = plt.figure(figsize=(6, 3))
gs = fig.add_gridspec(1, 1)
ax_tmp = fig.add_subplot(gs[0, 0])
for wl in list_wavelengths:
    ax_tmp.plot(wl, lw=3.0, zorder=0)
ax_tmp.set_title("Optimized wavelengths")
plt.show()

In [None]:
pos = np.array(js_load['pos'])
n_ts = pos.shape[0]
g = np.array(js_load['g'])
gt = np.array(js_load['optimization_settings']['gt'])
params_opt = np.array(js_load['optimization_settings']['params_opt'])
angle_rot, x_obstacle_offset, x_last_target_offset = js_load['optimization_settings']['angle_rot'], js_load['optimization_settings']['x_obstacle_offset'], js_load['optimization_settings']['x_last_target_offset']
translate_implicit, scale_implicit = torch.tensor(js_load['optimization_settings']['translate_implicit']), torch.tensor(js_load['optimization_settings']['scale_implicit'])

translation_siggraph = torch.tensor([- translate_implicit[0] + scale_implicit[0] + x_obstacle_offset, 0.0, 0.0])
siggraph_obstacle = generate_siggraph_implicit(angle_rotation=angle_rot, translation=translate_implicit, scale=scale_implicit)

translation_siggraph = torch.tensor([- translate_implicit[0] + scale_implicit[0] + x_obstacle_offset, 0.0, 0.0])
obstacle = TranslateImplicit(siggraph_obstacle, translation_siggraph)

alphas = np.linspace(0.1, 1.0, n_ts)
fig = plt.figure(figsize=(10, 6))
gs = fig.add_gridspec(1, 1)
ax_tmp = fig.add_subplot(gs[0, 0])
for id_step in range(n_ts):
    ax_tmp.plot(pos[id_step, :, 0], pos[id_step, :, 1], lw=3.0, c='tab:blue', alpha=alphas[id_step], zorder=0)
    ax_tmp.plot(pos[0, :, 0], pos[0, :, 1], lw=3.0, c='tab:blue', alpha=alphas[id_step], zorder=0)
    ax_tmp.scatter(g[id_step, 4], g[id_step, 5], marker='x', s=30.0, c='tab:blue', alpha=alphas[id_step], zorder=0)
ax_tmp.scatter(gt[:, 4], gt[:, 5], marker='o', s=30.0, c='tab:orange', alpha=1.0, zorder=1)

n_plot = 1000
n_levels = 15
x_plot = torch.linspace(ax_tmp.get_xlim()[0], ax_tmp.get_xlim()[1], n_plot)
y_plot = torch.linspace(ax_tmp.get_ylim()[0], ax_tmp.get_ylim()[1], n_plot)
xyz_plot = torch.stack([
    torch.tile(x_plot, dims=(n_plot,)),
    torch.repeat_interleave(y_plot, repeats=n_plot, dim=0),
    torch.zeros(size=(n_plot*n_plot,))
], dim=1)

sdfs = obstacle.evaluate_implicit_function(xyz_plot).reshape(n_plot, n_plot)

min_sdf, max_sdf = torch.min(sdfs), torch.max(sdfs)
max_abs_sdf = max(torch.abs(min_sdf), torch.abs(max_sdf))
levels = np.linspace(-max_abs_sdf, max_abs_sdf, n_levels)
ax_tmp.contourf(x_plot, y_plot, sdfs, levels=levels, cmap='coolwarm', zorder=-2)
ax_tmp.contour(x_plot, y_plot, sdfs, levels=[0.0], colors='k', linewidths=3.0, zorder=-1)

ax_tmp.set_aspect('equal')
xy_lim = np.zeros(shape=(2, 2))
xy_lim[0] = ax_tmp.get_xlim()
xy_lim[1] = ax_tmp.get_ylim()
plt.axis('off')
plt.savefig(os.path.join(path_to_save, "teaser_sdf_with_snake.png"), dpi=300, bbox_inches='tight')
plt.show()

plot_sdf_2d(x_plot, y_plot, sdfs, n_levels=15, show_text=False, filename=os.path.join(path_to_data, "snake_teaser/teaser_sdf.png"))

fig = plt.figure(figsize=(6, 3))
gs = fig.add_gridspec(1, 1)
ax_tmp = fig.add_subplot(gs[0, 0])
ax_tmp.plot(params_opt[2::3], lw=3.0, c='tab:blue', zorder=0)
ax_tmp.set_title("Optimized wavelengths")
plt.show()


In [None]:
exp_id = 0
js_load = list_js_loads[exp_id]
exp_name = exp_names[exp_id]
pos_plot = np.array(js_load['pos'])
g_plot = np.array(js_load['g'])
broken_joint_ids = []

path_to_images_anim = os.path.join(path_to_save, "images_registered")
arrow_params = {
    "length": 0.05,
    "width": 0.02,
}

# clear existing images
if os.path.exists(path_to_images_anim):
    shutil.rmtree(path_to_images_anim)
os.makedirs(path_to_images_anim)

plot_animated_snake(
    pos_plot, path_to_images_anim,
    g=g_plot, gt=gt[-1], gcp=gt[:-1],
    broken_joint_ids=broken_joint_ids, obstacle=obstacle,
    exponent=1.0, xy_lim=xy_lim, 
    show_orientation=False, show_snake_trail=False, 
    show_g_trail=True, show_g_start=True, show_joints=False,
    arrow_params=arrow_params,
)

In [None]:
fn_pattern = os.path.join(path_to_images_anim, "step_%05d.png")
produce_video_from_path(
    fn_pattern, path_to_save, 
    "{}_pos.mp4".format(exp_name), overwrite_anim=True, transparent=False
)