In [None]:
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.ops import MLP

from nemo.siren import Siren
from nemo.global_planner import GlobalPlanner

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# autoreload
%load_ext autoreload
%autoreload 2

# Mt Bruno SIREN

In [None]:
# Load the siren elevation model
siren = Siren(in_features=2, out_features=1, hidden_features=256,
                hidden_layers=3, outermost_linear=True).to(device)

siren.load_state_dict(torch.load('../models/mt_bruno_siren.pt'))
siren.eval()
pass

In [None]:
# Visualize it
xs = torch.linspace(-1, 1, steps=100, device=device)
ys = torch.linspace(-1, 1, steps=100, device=device)
x, y = torch.meshgrid(xs, ys, indexing='xy')
xy = torch.hstack((x.reshape(-1, 1), y.reshape(-1, 1)))

with torch.no_grad():
    #pred, coords = siren(xy)
    pred = siren(xy)
pred_np = pred.detach().cpu().numpy().reshape(100, 100)
x_np = x.cpu().numpy()
y_np = y.cpu().numpy()

# Plot the predictions
fig = go.Figure(data=[go.Surface(z=pred_np, x=x_np, y=y_np)])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

In [None]:
# 2D plot
fig = px.imshow(pred.detach().cpu().numpy().reshape(100, 100), x=xs.cpu().numpy(), y=ys.cpu().numpy())
fig.update_layout(width=500, height=500)
fig.show()

### Continuous A*

Given a continuous 2D scalar cost function $C$, learn the 2D cost-to-go function $G$ and backwards flow field $F$ such that $G$ is the solution to the Eikonal equation and $F$ is the gradient of $G$. The Eikonal equation is a first-order nonlinear PDE that can be solved using the Fast Marching Method. The gradient of the cost-to-go function is the optimal control for a continuous A* planner.

In the context of dynamic programming and HJB, the value function $V = \min_{u} \{ \int C(x(t),u(t))dt \}$ where $x(t)$ obeys some dynamics $\dot{x} = f(x,u)$ and $x(0) = x_0$. The optimal control $u^*$ is the minimizer of the integral. The value function is the solution to the HJB equation $\nabla V + \min_u \{ C(x,u) + \nabla V \cdot f(x,u) \} = 0$. The backwards flow field $F$ is the gradient of $G$.  

### Continuous path planning

#### Optimal control formulation

minimize $\sum_{t=0}^{T-1} C(t,x_t,u_t)$

subject to: \
 $x_{t+1} = f(t,x_t,u_t)$ \
 $x_0 = x_{\text{start}}$ \
 $x_T = x_{\text{goal}}$ \
 $u_t \in U$   (control constraints)

where $x_t$ is the state at time $t$, $u_t$ is the control at time $t$, $C(t,x_t,u_t)$ is the cost at time $t$, and $f(t,x_t,u_t)$ is the dynamics.

We'll consider 

#### Neural ODE

* Is there a way of converting some of these losses into constraints?
    * Specifically the goal reaching loss
* Try enforcing max norm on dynamics

In [None]:
x_0 = torch.tensor([-1.0, -1.0], device=device)
x_f = torch.tensor([1.0, 1.0], device=device)
T = torch.linspace(0, 1, 100, device=device)[:,None]

mse = nn.MSELoss()

In [None]:
dyn = MLP(in_channels=1, hidden_channels=[256, 256, 256, 2]).to(device)

In [None]:
optimizer = torch.optim.Adam(dyn.parameters(), lr=1e-3)

In [None]:
for i in range(1000):
    dx = dyn(T)
    path = torch.cumsum(dx, dim=0) + x_0
    goal_loss = 1e2 * mse(path[-1], x_f)
    dist_loss = 1e-3 * torch.norm(dx, dim=1).nanmean()
    
    # heights = siren(path)
    # cost_loss = heights.mean()
    heights, grad = siren.forward_with_grad(path)
    # cost_loss = torch.exp(10 * torch.abs(torch.sum(dx * grad, axis=1)) - 1).mean()
    cost_loss = 100 * torch.abs(torch.sum(dx * grad, axis=1)).mean()

    loss = goal_loss + cost_loss 
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if i % 100 == 0:
        print(f'Loss: {loss.item()}')

In [None]:
with torch.no_grad():
    dx = dyn(T)
path = torch.cumsum(dx, dim=0) + x_0
heights, grad = siren.forward_with_grad(path)
path_z = heights.detach().cpu().numpy().flatten()
costs = torch.abs(torch.sum(dx * grad, axis=1)).detach().cpu().numpy()

In [None]:
fig = go.Figure()
fig.add_trace(go.Surface(z=pred_np, x=x_np, y=y_np, colorscale='Viridis'))
fig.add_trace(go.Scatter3d(x=path[:,0].detach().cpu().numpy(), 
                           y=path[:,1].detach().cpu().numpy(), 
                           z=path_z, mode='lines', line=dict(color=costs, width=10), 
                           hovertext=costs))
fig.update_layout(width=1600, height=900, scene_aspectmode='data')
fig.show()

In [None]:
# save html
fig.write_html("path.html")

### Discrete path planning

In [None]:
pred, z_xy_grad = siren.forward_with_grad(xy)

x_grad = z_xy_grad[:, 0].detach().cpu().numpy().reshape(100, 100)
y_grad = z_xy_grad[:, 1].detach().cpu().numpy().reshape(100, 100)

In [None]:
GRID_LEN = 100

start_idx = (0, 0)                   # (-1, -1)
end_idx = (GRID_LEN-1, GRID_LEN-1)   # (1, 1)
# start_idx = (0, GRID_LEN-1)           # (-1, -1)
# end_idx = (GRID_LEN-1, 0)             # (1, 1)

grad_costmat = (np.abs(x_grad) + np.abs(y_grad))

# Visualize the cost matrix
fig = px.imshow(grad_costmat, x=xs.cpu().numpy(), y=ys.cpu().numpy(), origin='lower')
fig.update_layout(width=500, height=500)
fig.show()

In [None]:
# Run A* on the cost matrix
gp = GlobalPlanner(grad_costmat)
path = gp.plan(start_idx, end_idx)

path_xs = xs[path[:,0]]
path_ys = ys[path[:,1]]
path_xy = torch.hstack((path_xs[:,None], path_ys[:,None]))
path_zs, _ = siren(path_xy)

In [None]:
# Plot path on surface plot
fig = go.Figure()
fig.add_trace(go.Surface(z=pred.detach().cpu().numpy().reshape(100, 100), x=x.cpu().numpy(), y=y.cpu().numpy()))
fig.add_trace(go.Scatter3d(x=path_xs.detach().cpu().numpy(), 
                           y=path_ys.detach().cpu().numpy(), 
                           z=path_zs.detach().cpu().numpy().flatten(), 
                           mode='markers', marker=dict(size=3, color='red')))
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

# Lunar DEM

In [None]:
dem_full = np.load('../data/lunar_dem.npy')
Z = dem_full[:1000, :1000]

# Surface plot approximately to scale
xs = torch.linspace(0, 118500, steps=Z.shape[0], device=device)
ys = torch.linspace(0, 118400, steps=Z.shape[1], device=device)
x, y = torch.meshgrid(xs, ys, indexing='xy')

fig = go.Figure(data=[go.Surface(z=Z, x=x.cpu().numpy(), y=y.cpu().numpy())])
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()

In [None]:
# Compute gradients
x_grad = np.gradient(Z, axis=0)
y_grad = np.gradient(Z, axis=1)

grad_costmat = (np.abs(x_grad) + np.abs(y_grad))

In [None]:
start_idx = (0, 0)                  
end_idx = (999, 999)  

gp = GlobalPlanner(grad_costmat)
path = gp.plan(start_idx, end_idx)

In [None]:
path_xs = xs[path[:,0]]
path_ys = ys[path[:,1]]
path_zs = Z[path[:,1], path[:,0]]

In [None]:
# Plot path on surface plot
fig = go.Figure()
fig.add_trace(go.Surface(z=Z, x=x.cpu().numpy(), y=y.cpu().numpy()))
fig.add_trace(go.Scatter3d(x=path_xs.detach().cpu().numpy(), 
                           y=path_ys.detach().cpu().numpy(), 
                           z=path_zs, 
                           mode='markers', marker=dict(size=3, color='red')))
fig.update_layout(width=1200, height=700, scene_aspectmode='data')
fig.show()