In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

from nemo.global_planner import AStarGradPlanner
from nemo.nemo import Nemo
from nemo.util import wrap_angle_torch, path_metrics
from nemo.plotting import plot_surface, plot_path_3d
from nemo.planning import path_optimization

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

## Load the heightnet

In [None]:
nemo = Nemo()
#nemo.load_weights('../models/redrocks_encs_relu.pth', '../models/redrocks_heightnet_relu.pth')
nemo.load_weights('../models/kt22_encs.pth', '../models/kt22_heightnet.pth')

In [None]:
N = 64
# xmin, xmax, ymin, ymax
#bounds = (-0.3, 0.8, -0.45, 0.5) # red rocks
bounds = (-1, 1, -1, 1) # kt22
xs = torch.linspace(bounds[0], bounds[1], N, device=device)
ys = torch.linspace(bounds[2], bounds[3], N, device=device)
XY_grid = torch.meshgrid(xs, ys, indexing='xy')
XY_grid = torch.stack(XY_grid, dim=-1)
positions = XY_grid.reshape(-1, 2)

heights = nemo.get_heights(positions)

In [None]:
z_grid = heights.reshape(N, N).detach().cpu().numpy()
x_grid = XY_grid[:,:,0].detach().cpu().numpy()
y_grid = XY_grid[:,:,1].detach().cpu().numpy()

fig = plot_surface(x_grid, y_grid, z_grid, no_axes=True, showscale=False)
fig.update_layout(width=1600, height=900)
fig.show()

In [None]:
z_grid = heights.reshape(N, N).detach().cpu().numpy()
x_grid = XY_grid[:,:,0].detach().cpu().numpy()
y_grid = XY_grid[:,:,1].detach().cpu().numpy()

fig = plot_surface(x_grid, y_grid, z_grid, no_axes=True, showscale=False)
fig.update_layout(width=1600, height=1000)
fig.show()

In [None]:
x_eye = 1.5
y_eye = -1
z_eye = 1

fig.update_layout(
         title='Animation Test',
         width=1600,
         height=900,
         scene_camera_eye=dict(x=x_eye, y=y_eye, z=z_eye),
         updatemenus=[dict(type='buttons',
                  showactive=False,
                  y=1,
                  x=0.8,
                  xanchor='left',
                  yanchor='bottom',
                  pad=dict(t=45, r=10),
                  buttons=[dict(label='Play',
                                 method='animate',
                                 args=[None, dict(frame=dict(duration=5, redraw=True), 
                                                             transition=dict(duration=0),
                                                             fromcurrent=True,
                                                             mode='immediate'
                                                            )]
                                            )
                                      ]
                              )
                        ]
)


def rotate_z(x, y, z, theta):
    w = x+1j*y
    return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z

frames=[]
for t in np.arange(0, 6.26, 0.025):
    xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, t)
    frames.append(go.Frame(layout=dict(scene_camera_eye=dict(x=xe, y=ye, z=ze))))
fig.frames=frames

fig.show()

## A*

In [None]:
# Initialize the planner with scaled heightmap
heights = 1e4 * (z_grid + 1.0).reshape(N, N)
gp = AStarGradPlanner(heights, bounds)

# Start and end positions for path
start = (0.7, 0.7)
end = (-0.7, -0.7)
# start = (0.32, -0.21)
# end = (-0.02, -0.07)

# Compute path
path_xy = gp.spatial_plan(start, end)
path_xy_torch = torch.tensor(path_xy, device=device)
# Get heights along path
path_zs = nemo.get_heights(path_xy_torch)  

# Save path as torch tensor
astar_path = torch.cat((path_xy_torch, path_zs), dim=1)

In [None]:
fig = plot_surface(x_grid, y_grid, z_grid, no_axes=True)
fig = plot_path_3d(fig=fig, x=path_xy[:,0], y=path_xy[:,1], z=path_zs.detach().cpu().numpy().flatten())
fig.show()

In [None]:
path_metrics(astar_path)

## Path optimization

In [None]:
path_3d = path_optimization(nemo, path_xy_torch, iterations=500, lr=1e-3)

In [None]:
fig = plot_surface(x_grid, y_grid, z_grid, no_axes=True)
fig = plot_path_3d(fig=fig, x=path_3d[:,0].detach().cpu().numpy(), 
                        y=path_3d[:,1].detach().cpu().numpy(), 
                        z=path_3d[:,2].detach().cpu().numpy(),
                        markers=False)
fig.show()

In [None]:
path_metrics(path_3d)

### Dubin's with $\theta$ optimization

In [None]:
# Compute initial headings
thetas = torch.atan2(path_xy_torch[1:,1] - path_xy_torch[:-1,1], path_xy_torch[1:,0] - path_xy_torch[:-1,0])  
# Duplicate last heading
thetas = torch.cat((thetas, thetas[-1].unsqueeze(0)), dim=0)

path = torch.cat((path_xy_torch, thetas.unsqueeze(1)), dim=1)  # (x, y, theta)
# Fixed variables are initial and final states, free variables are intermediate states
path_start = path[0].clone().detach()
path_end = path[-1].clone().detach()
path_opt = path[1:-1].clone().detach().requires_grad_(True)

In [None]:
# Dubin's based cost
def cost(path, dt=1.0):
    thetas = path[:,2]  
    omegas = wrap_angle_torch(thetas.diff()) / dt  
    # Path Vs
    path_dxy = torch.diff(path[:,:2], dim=0)
    Vs = torch.norm(path_dxy, dim=1) / dt
    controls_cost = 0.1 * (torch.abs(Vs)).nanmean() + (torch.abs(omegas)).nanmean()
    # Slope cost
    path_zs = 10 * nemo.get_heights(path)
    path_zs -= path_zs.min()
    path_zs = path_zs**2
    slope_cost = 1 * (torch.abs(path_zs.diff(dim=0))).nanmean()
    print(f"controls_cost: {controls_cost}, slope_cost: {slope_cost}")
    return controls_cost + slope_cost

In [None]:
path_zs = 10 * nemo.get_heights(path)
path_zs -= path_zs.min()
path_zs = path_zs**2
print(path_zs.min(), path_zs.max())
costs = torch.abs(path_zs.diff(dim=0))
print(costs.min(), costs.max())

In [None]:
# Optimize path
opt = torch.optim.Adam([path_opt], lr=1e-3)

for it in range(500):
    opt.zero_grad()
    path = torch.cat((path_start[None], path_opt, path_end[None]), dim=0)
    c = cost(path)
    c.backward()
    opt.step()
    if it % 50 == 0:
        print(f'it: {it},  Cost: {c.item()}')

print(f'Finished optimization - final cost: {c.item()}')

In [None]:
path_zs = nemo.get_heights(path[:,:2])
path_3d = torch.cat((path[:,:2], path_zs), dim=1)

In [None]:
fig = go.Figure()
fig = plot_surface(fig, x_grid, y_grid, z_grid, no_axes=True)
fig = plot_path_3d(fig, x=path_3d[:,0].detach().cpu().numpy(), 
                        y=path_3d[:,1].detach().cpu().numpy(), 
                        z=path_3d[:,2].detach().cpu().numpy())
fig.show()

### Double integrator dynamics

In [None]:
dt = 0.1
path_vs = torch.diff(path, dim=0) / dt
path_as = torch.diff(path_vs, dim=0) / dt
controls_cost = 2 * (torch.norm(path_as, dim=1)**2).mean()

In [None]:
def resample_path(path, rate=10):
    """Resample path at higher resolution using double integrator dynamics"""
    path_vs = torch.diff(path, dim=0) / dt
    path_as = torch.diff(path_vs, dim=0) / dt
    path_resampled = [path[0]]
    for i in range(len(path)-1):
        for j in range(rate):
            t = j / rate
            path_resampled.append(path[i] + path_vs[i]*t + 0.5*path_as[i]*t**2)
    print(path[-1])
    path_resampled.append(path[-1])
    return torch.stack(path_resampled)

In [None]:
resampled_path = resample_path(path, rate=10)
resampled_path

In [None]:
# Double integrator dynamics
def di_cost(path, dt=0.1):
    path_vs = torch.diff(path, dim=0) / dt
    path_as = torch.diff(path_vs, dim=0) / dt
    path_dxy = torch.diff(path, dim=0)
    Vs = torch.norm(path_dxy, dim=1) / dt
    return torch.mean(Vs**2)

In [None]:
opt = torch.optim.Adam([path_opt], lr=1e-3)

for it in range(500):
    opt.zero_grad()
    path = torch.cat((path_start[None], path_opt, path_end[None]), dim=0)
    c = dubins_cost(path)
    c.backward()
    opt.step()
    if it % 50 == 0:
        print(f'it: {it},  Cost: {c.item()}')