In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

from nemo.global_planner import GlobalPlanner2
from nemo.nemo import Nemo
from nemo.util import wrap_angle_torch, path_metrics

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

## Load the heightnet

In [None]:
nemo = Nemo()
nemo.load_weights('../models/red_rocks_encs.pth', '../models/red_rocks_height_net.pth')
#nemo.load_weights('../models/kt22_encs.pth', '../models/kt22_height_net.pth')

In [None]:
N = 64
bounds = (-0.4, 0.8, -0.6, 0.6) # red rocks
#bounds = (-0.75, 0.75, -0.75, 0.75) # kt22
xs = torch.linspace(bounds[0], bounds[1], N, device=device)
ys = torch.linspace(bounds[2], bounds[3], N, device=device)
XY_grid = torch.meshgrid(xs, ys, indexing='xy')
XY_grid = torch.stack(XY_grid, dim=-1)
positions = XY_grid.reshape(-1, 2)

In [None]:
heights = nemo.get_heights(positions)

In [None]:
z_grid = heights.reshape(N, N).detach().cpu().numpy()
x_grid = XY_grid[:,:,0].detach().cpu().numpy()
y_grid = XY_grid[:,:,1].detach().cpu().numpy()

fig = go.Figure(data=[go.Surface(x=x_grid, y=y_grid, z=z_grid, colorscale='Viridis')])
fig.update_layout(width=1500, height=800, scene_aspectmode='data')
fig.show()

## A*

In [None]:
start = (0.32, -0.21)
end = (-0.02, -0.07)
grid_size = 1.2 / N
start_idx = (int((start[1] - bounds[2]) / grid_size), int((start[0] - bounds[0]) / grid_size))
end_idx = (int((end[1] - bounds[2]) / grid_size), int((end[0] - bounds[0]) / grid_size))
print(start_idx, end_idx)

In [None]:
def path_trace(scale, size, color):
    heights = scale * (z_grid + 1.0).reshape(N, N)
    gp = GlobalPlanner2(heights, bounds)
    # gp = GlobalPlanner2(grad_costmat)
    # path_idxs = gp.plan(start_idx, end_idx)
    path_idxs = gp.plan((0, 0), (N-1, N-1))

    path_xs = xs[path_idxs[:,0]]
    path_ys = ys[path_idxs[:,1]]
    path_xy = torch.hstack((path_xs[:,None], path_ys[:,None]))
    path_zs = nemo.get_heights(path_xy)
    return go.Scatter3d(x=path_xs.detach().cpu().numpy(), 
                           y=path_ys.detach().cpu().numpy(), 
                           z=path_zs.detach().cpu().numpy().flatten(), 
                           mode='markers+lines', marker=dict(size=size, color=color),
                           line=dict(color=color, width=size))

In [None]:
# Run A* on the cost matrix
heights = 1e6 * (z_grid + 1.0).reshape(N, N)
gp = GlobalPlanner2(heights, bounds)
# gp = GlobalPlanner2(grad_costmat)
path_idxs = gp.plan(start_idx, end_idx)
#path_idxs = gp.plan((0, 0), (N-1, N-1))

In [None]:
path_xs = xs[path_idxs[:,0]]
path_ys = ys[path_idxs[:,1]]
path_xy = torch.hstack((path_xs[:,None], path_ys[:,None]))
path_zs = nemo.get_heights(path_xy)

In [None]:
astar_path = torch.cat((path_xy, path_zs), dim=1)

In [None]:
# Plot path on surface plot
fig = go.Figure()
fig.add_trace(go.Surface(x=x_grid, y=y_grid, z=z_grid, colorscale='Viridis'))
# fig.add_trace(go.Scatter3d(x=path_xs.detach().cpu().numpy(), 
#                            y=path_ys.detach().cpu().numpy(), 
#                            z=path_zs.detach().cpu().numpy().flatten(), 
#                            mode='markers+lines', marker=dict(size=3, color='red'),
#                            line=dict(color='red', width=3)))
fig.add_trace(path_trace(1e1, 5, 'yellow'))
fig.add_trace(path_trace(2e2, 5, 'orange'))
fig.add_trace(path_trace(1e6, 5, 'red'))
# fig.add_trace(path_trace(1e5, 'green'))
fig.update_layout(width=1600, height=900, scene_aspectmode='data')
fig.update_layout(
    scene = dict(
        xaxis = dict(visible=False),
        yaxis = dict(visible=False),
        zaxis =dict(visible=False)
        )
    )
fig.show()

## Path optimization

In [None]:
path_start = path_xy[0]
path_end = path_xy[-1]
path_opt = path_xy[1:-1].clone().detach().requires_grad_(True)
path = torch.cat((path_start[None], path_opt, path_end[None]), dim=0)

In [None]:
def cost(path, dt=0.1):
    # Compute path thetas
    thetas = torch.atan2(path[1:,1] - path[:-1,1], path[1:,0] - path[:-1,0])

    # Omegas as wrapped difference
    omegas = wrap_angle_torch(thetas[1:] - thetas[:-1]) / dt

    # Path Vs
    path_dxy = torch.diff(path, dim=0)
    Vs = torch.norm(path_dxy, dim=1) / dt

    controls_cost = 10 * torch.mean(Vs**2) + torch.mean(omegas**2)

    # Slope cost
    path_zs = 10 * nemo.get_heights(path)
    slope_cost = torch.mean(torch.abs(path_zs[1:] - path_zs[:-1]))
    return controls_cost + slope_cost

In [None]:
opt = torch.optim.Adam([path_opt], lr=1e-3)

for it in range(200):
    opt.zero_grad()
    path = torch.cat((path_start[None], path_opt, path_end[None]), dim=0)
    c = cost(path)
    c.backward()
    opt.step()
    if it % 50 == 0:
        print(f'it: {it},  Cost: {c.item()}')

In [None]:
path_zs = nemo.get_heights(path) + 1e-3

# Plot path on surface plot
fig = go.Figure()
fig.add_trace(go.Surface(x=x_grid, y=y_grid, z=z_grid, colorscale='Viridis'))
fig.add_trace(go.Scatter3d(x=path[:,0].detach().cpu().numpy(), 
                           y=path[:,1].detach().cpu().numpy(), 
                           z=path_zs.detach().cpu().numpy().flatten(), 
                           mode='markers+lines', marker=dict(size=3, color='red')))
fig.add_trace(path_trace(1e8, 5, 'orange'))
fig.update_layout(width=1600, height=900, scene_aspectmode='data')
fig.update_layout(
    scene = dict(
        xaxis = dict(visible=False),
        yaxis = dict(visible=False),
        zaxis =dict(visible=False)
        )
    )
fig.show()

In [None]:
opt_path = torch.cat((path, path_zs), dim=1)

In [None]:
path_metrics(astar_path)

In [None]:
path_metrics(opt_path)